# Connect with Google Drive


In [2]:
# from google.colab import drive
# drive.mount('/content/drive')
import os

os.getcwd()

'/users/wrmod/rafbar/Projs/mambo-dl/Hawthron_DL'


# Attention UNet


## Attention UNet Architecture

In [1]:
import torch
import torch.nn as nn

class conv_block(nn.Module):
    """
    Convolution Block
    """
    def __init__(self, in_ch, out_ch):
        super(conv_block, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.conv(x)
        return x

class up_conv(nn.Module):
    """
    Up Convolution Block
    """
    def __init__(self, in_ch, out_ch):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x

class Attention_block(nn.Module):
    """
    Attention Block
    """
    def __init__(self, F_g, F_l, F_int):
        super(Attention_block, self).__init__()

        self.W_g = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        out = x * psi
        return out

class AttUNet(nn.Module):
    """
    Attention Unet implementation
    Paper: https://arxiv.org/abs/1804.03999
    """
    def __init__(self, img_ch=3, output_ch=1):
        super(AttUNet, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(img_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        self.Att5 = Attention_block(F_g=filters[3], F_l=filters[3], F_int=filters[2])
        self.Up_conv5 = conv_block(filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Att4 = Attention_block(F_g=filters[2], F_l=filters[2], F_int=filters[1])
        self.Up_conv4 = conv_block(filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Att3 = Attention_block(F_g=filters[1], F_l=filters[1], F_int=filters[0])
        self.Up_conv3 = conv_block(filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Att2 = Attention_block(F_g=filters[0], F_l=filters[0], F_int=32)
        self.Up_conv2 = conv_block(filters[1], filters[0])

        self.Conv = nn.Conv2d(filters[0], output_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        x4 = self.Att5(g=d5, x=e4)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4, x=e3)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3, x=e2)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2, x=e1)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        out = self.Conv(d2)

        return out

# Test the Attention UNet network
model = AttUNet(img_ch=3, output_ch=1).to('cuda' if torch.cuda.is_available() else 'cpu')
print(model)


AttUNet(
  (Maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Maxpool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Maxpool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv1): conv_block(
    (conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (Conv2): conv_block(
    (conv): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affin

## Custom Dataset and DataLoader

In [5]:
import torch
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import numpy as np

class RSDataset(Dataset):
    def __init__(self, images_dir, labels_dir, transform=None, max_samples=None):
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.transform = transform
        self.image_files = sorted(os.listdir(images_dir))
        self.label_files = sorted(os.listdir(labels_dir))
        if max_samples:
            self.image_files = self.image_files[:max_samples]
            self.label_files = self.label_files[:max_samples]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.images_dir, self.image_files[idx])
        label_path = os.path.join(self.labels_dir, self.label_files[idx])

        image = Image.open(image_path).convert("RGB")
        label = Image.open(label_path).convert("L")

        image = np.array(image)
        label = np.array(label)

        # Normalize image to [0, 1]
        image = image / 255.0

        # Normalize label to [0, 1]
        label = label / 255.0

        # Add channel dimension to label
        label = np.expand_dims(label, axis=0)

        # Adjust the dimensions of the image to [channels, height, width]
        image = np.transpose(image, (2, 0, 1))

        if self.transform:
            image = self.transform(image)
            label = self.transform(label)

        image = torch.tensor(image, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.float32)

        return image, label

# Example usage:
images_dir = 'resized_train_image_no_flip'
labels_dir = 'binarized_resized_train_label_no_flip'

dataset = RSDataset(images_dir, labels_dir, max_samples=2500)

dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Get a batch of images and labels
images, labels = next(iter(dataloader))

# Print the shape of images and labels
print(f'Images shape: {images.shape}')
print(f'Labels shape: {labels.shape}')

# Check if the images are in the expected format (e.g., [batch_size, channels, height, width])
expected_channels = 3  # Adjust this if your images have a different number of channels

if images.shape[1] != expected_channels:
    print(f'Error: Expected {expected_channels} channels but got {images.shape[1]} channels')

# Check if labels are in the expected format
expected_label_channels = 1  # Adjust this if your labels have a different number of channels

if labels.shape[1] != expected_label_channels:
    print(f'Error: Expected {expected_label_channels} channels in labels but got {labels.shape[1]} channels')

# Check the range of pixel values in images and labels
print(f'Images min: {images.min()}, max: {images.max()}')
print(f'Labels min: {labels.min()}, max: {labels.max()}')


Images shape: torch.Size([4, 3, 256, 256])
Labels shape: torch.Size([4, 1, 256, 256])
Images min: 0.0, max: 1.0
Labels min: 0.0, max: 1.0


## Training and Validation


In [6]:
import time
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score

# Assuming RSDataset and AttUNet classes are defined elsewhere and images_dir, labels_dir are provided
# Use 3000 images per epoch
max_samples = 3000
dataset = RSDataset(images_dir, labels_dir, max_samples=max_samples)

# Split dataset into training and validation sets
train_size = int(0.8 * len(dataset))  # 80% for training
val_size = len(dataset) - train_size  # 20% for validation
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Adjust batch size if GPU memory allows
batch_size = 20  # Further reduce batch size

trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
valloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# Initialize the AttUNet model with 3 input channels and 1 output channel
model = AttUNet(img_ch=3, output_ch=1)

# Use binary cross-entropy as the loss function
criterion = nn.BCELoss()

# Set the Adam optimizer with an initial learning rate of 0.0001
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Set the learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# Set the total number of training epochs
num_epochs = 50

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)

def calculate_metrics(outputs, labels, threshold=0.5):
    preds = (outputs > threshold).float().cpu().numpy().flatten()
    labels = labels.float().cpu().numpy().flatten()
    precision = precision_score(labels, preds)
    recall = recall_score(labels, preds)
    f1 = f1_score(labels, preds)

    return precision, recall, f1

def calculate_accuracy(outputs, labels, threshold=0.5):
    preds = (outputs > threshold).float()
    correct = (preds == labels).float().sum()
    accuracy = correct / labels.numel()
    return accuracy.item()

def compute_iou(preds, labels, threshold=0.5):
    preds = (preds > threshold).astype(float)  # Convert to float using NumPy's astype
    intersection = (preds * labels).sum()
    union = preds.sum() + labels.sum() - intersection
    iou = intersection / union if union != 0 else 0.0
    return iou

# Gradient accumulation steps
accumulation_steps = 4

# Start training the model for a total of 50 epochs
for epoch in range(num_epochs):
    start_time = time.time()  # Start time of the epoch
    model.train()  # Set model to training mode
    epoch_loss = 0
    epoch_acc = 0
    epoch_precision = 0
    epoch_recall = 0
    epoch_f1 = 0
    progress_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=f'Epoch {epoch+1}/{num_epochs}')

    optimizer.zero_grad()  # Zero the gradients

    for i, (images, labels) in progress_bar:
        images = images.float().to(device, non_blocking=True)  # Convert images to float and move to device
        labels = labels.float().to(device, non_blocking=True)  # Normalize labels to [0, 1] and move to device

        # Ensure labels are within [0, 1]
        labels = torch.clamp(labels, 0, 1)

        outputs = model(images)  # Get model outputs
        outputs = torch.sigmoid(outputs)  # Apply sigmoid to ensure outputs are in [0, 1]

        # Adjust labels shape to match outputs shape if needed
        if labels.ndim == 3:
            labels = labels.unsqueeze(1)

        # Check if labels are in the correct range
        if labels.max() > 1 or labels.min() < 0:
            raise ValueError(f"Labels out of bounds: min={labels.min()}, max={labels.max()}")

        loss = criterion(outputs, labels)  # Calculate loss
        loss.backward()  # Backpropagate

        if (i + 1) % accumulation_steps == 0:
            optimizer.step()  # Update weights
            optimizer.zero_grad()  # Zero the gradients

        epoch_loss += loss.item()  # Accumulate the loss
        epoch_acc += calculate_accuracy(outputs, labels)  # Accumulate accuracy

        # Calculate precision, recall, and F1 score
        precision, recall, f1 = calculate_metrics(outputs, labels)
        epoch_precision += precision
        epoch_recall += recall
        epoch_f1 += f1

        progress_bar.set_postfix(loss=loss.item(), acc=epoch_acc / (i + 1), precision=epoch_precision / (i + 1), recall=epoch_recall / (i + 1), f1=epoch_f1 / (i + 1))  # Update progress bar with current loss and accuracy

    # Calculate average training loss and accuracy
    avg_train_loss = epoch_loss / len(trainloader)
    avg_train_acc = epoch_acc / len(trainloader)
    avg_train_precision = epoch_precision / len(trainloader)
    avg_train_recall = epoch_recall / len(trainloader)
    avg_train_f1 = epoch_f1 / len(trainloader)

    model.eval()  # Set model to evaluation mode
    val_loss = 0
    val_acc = 0
    val_precision = 0
    val_recall = 0
    val_f1 = 0

    with torch.no_grad():
        for images, labels in valloader:
            images = images.float().to(device, non_blocking=True)  # Convert images to float and move to device
            labels = labels.float().to(device, non_blocking=True)  # Normalize labels to [0, 1] and move to device

            # Ensure labels are within [0, 1]
            labels = torch.clamp(labels, 0, 1)

            outputs = model(images)  # Get model outputs
            outputs = torch.sigmoid(outputs)  # Apply sigmoid to ensure outputs are in [0, 1]

            # Adjust labels shape to match outputs shape if needed
            if labels.ndim == 3:
                labels = labels.unsqueeze(1)

            # Check if labels are in the correct range
            if labels.max() > 1 or labels.min() < 0:
                raise ValueError(f"Labels out of bounds: min={labels.min()}, max={labels.max()}")

            loss = criterion(outputs, labels)  # Calculate loss
            val_loss += loss.item()  # Accumulate the loss
            val_acc += calculate_accuracy(outputs, labels)  # Accumulate accuracy

            # Calculate precision, recall, and F1 score
            precision, recall, f1 = calculate_metrics(outputs, labels)
            val_precision += precision
            val_recall += recall
            val_f1 += f1

    # Calculate average validation loss and accuracy
    avg_val_loss = val_loss / len(valloader)
    avg_val_acc = val_acc / len(valloader)
    avg_val_precision = val_precision / len(valloader)
    avg_val_recall = val_recall / len(valloader)
    avg_val_f1 = val_f1 / len(valloader)

    end_time = time.time()  # End time of the epoch
    epoch_duration = end_time - start_time  # Duration of the epoch

    tqdm.write(f'Epoch {epoch+1}/{num_epochs} completed. '
               f'Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.4f}, '
               f'Train Precision: {avg_train_precision:.4f}, Train Recall: {avg_train_recall:.4f}, Train F1: {avg_train_f1:.4f}, '
               f'Val Loss: {avg_val_loss:.4f}, Val Acc: {avg_val_acc:.4f}, '
               f'Val Precision: {avg_val_precision:.4f}, Val Recall: {avg_val_recall:.4f}, Val F1: {avg_val_f1:.4f}, '
               f'Duration: {epoch_duration:.2f} seconds')  # Print epoch summary

    # Step the learning rate scheduler
    scheduler.step()

    # Save the model after each epoch
    model_path = f'attention_unet_model/models_building_epoch_{epoch+1}.pth'
    torch.save(model.state_dict(), model_path)

# Save the final model after training is complete
# torch.save(model.state_dict(), 'attention_unet_model/models_building_final.pth')

# Evaluation on the validation set
all_labels = []
all_preds = []
all_iou = []

with torch.no_grad():
    for images, labels in valloader:
        images = images.float().to(device, non_blocking=True)
        labels = labels.float().to(device, non_blocking=True)

        # Ensure labels are within [0, 1]
        labels = torch.clamp(labels, 0, 1)

        outputs = model(images)
        outputs = torch.sigmoid(outputs)

        # Adjust labels shape to match outputs shape if needed
        if labels.ndim == 3:
            labels = labels.unsqueeze(1)

        preds = (outputs > 0.5).float()

        all_labels.append(labels.cpu().numpy())
        all_preds.append(preds.cpu().numpy())

        for i in range(len(labels)):
            iou = compute_iou(all_preds[-1][i], all_labels[-1][i])
            all_iou.append(iou)

# Flatten lists to compute metrics
all_labels = np.concatenate(all_labels).flatten()
all_preds = np.concatenate(all_preds).flatten()

# Compute precision, recall, and F1 score
precision = precision_score(all_labels, all_preds)
recall = recall_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds)
# Compute mean IoU
mean_iou = np.mean(all_iou)

print(f'Validation Mean IoU: {mean_iou:.4f}')
print(f'Validation Precision: {precision:.4f}')
print(f'Validation Recall: {recall:.4f}')
print(f'Validation F1 Score: {f1:.4f}')

# Markdown formatted summary
markdown_summary = f"""
# Model Evaluation Metrics

**Precision:** {precision:.4f}

**Recall:** {recall:.4f}

**F1 Score:** {f1:.4f}

**Mean IoU:** {mean_iou:.4f}
"""

print(markdown_summary)


Using device: cpu


Epoch 1/50: 100%|██████████| 120/120 [26:27<00:00, 13.23s/it, acc=0.746, f1=0.621, loss=0.447, precision=0.541, recall=0.774]


Epoch 1/50 completed. Train Loss: 0.5524, Train Acc: 0.7460, Train Precision: 0.5413, Train Recall: 0.7739, Train F1: 0.6206, Val Loss: 0.6744, Val Acc: 0.8078, Val Precision: 0.5946, Val Recall: 0.7853, Val F1: 0.6741, Duration: 1735.56 seconds


Epoch 2/50:  50%|█████     | 60/120 [13:15<13:15, 13.25s/it, acc=0.838, f1=0.684, loss=0.431, precision=0.661, recall=0.727]


KeyboardInterrupt: 

## Prediction

In [7]:
import os
import torch
import torch.nn as nn
from osgeo import gdal, gdal_array, osr
import numpy as np
import cv2

# Define ConvBlock and AttentionUNet models
class conv_block(nn.Module):
    """
    Convolution Block
    """
    def __init__(self, in_ch, out_ch):
        super(conv_block, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.conv(x)
        return x

class up_conv(nn.Module):
    """
    Up Convolution Block
    """
    def __init__(self, in_ch, out_ch):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x

class Attention_block(nn.Module):
    """
    Attention Block
    """
    def __init__(self, F_g, F_l, F_int):
        super(Attention_block, self).__init__()

        self.W_g = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        out = x * psi
        return out

class AttUNet(nn.Module):
    """
    Attention Unet implementation
    Paper: https://arxiv.org/abs/1804.03999
    """
    def __init__(self, img_ch=3, output_ch=1):
        super(AttUNet, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(img_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        self.Att5 = Attention_block(F_g=filters[3], F_l=filters[3], F_int=filters[2])
        self.Up_conv5 = conv_block(filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Att4 = Attention_block(F_g=filters[2], F_l=filters[2], F_int=filters[1])
        self.Up_conv4 = conv_block(filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Att3 = Attention_block(F_g=filters[1], F_l=filters[1], F_int=filters[0])
        self.Up_conv3 = conv_block(filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Att2 = Attention_block(F_g=filters[0], F_l=filters[0], F_int=32)
        self.Up_conv2 = conv_block(filters[1], filters[0])

        self.Conv = nn.Conv2d(filters[0], output_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        x4 = self.Att5(g=d5, x=e4)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4, x=e3)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3, x=e2)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2, x=e1)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        out = self.Conv(d2)

        return out

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the pre-trained AttUNet model
model = AttUNet(img_ch=3, output_ch=1)
model.load_state_dict(torch.load('attention_unet_model/models_building_final.pth', map_location=device))
model.to(device)
model.eval()

# Set input and output file paths
input_image_path = 'test_area_no_flip/test_area.tif'
output_image_path = 'test_area_no_flip/attention_unet/test_area_predict.tif'
intermediate_folder = 'test_area_no_flip/attention_unet/intermediate'

# Create intermediate results folder (if not exists)
if not os.path.exists(intermediate_folder):
    os.makedirs(intermediate_folder)

# Open input image
rsdataset = gdal.Open(input_image_path)

# Read the first three bands and stack them into a numpy array
images = np.stack([rsdataset.GetRasterBand(i).ReadAsArray() for i in range(1, 4)], axis=0)

# Normalize to the range [0, 1]
images = images / 255.0

# Function: Overlap cropping
def sliding_window(image, step_size, window_size):
    for y in range(0, image.shape[1] - window_size[1] + 1, step_size[1]):
        for x in range(0, image.shape[2] - window_size[0] + 1, step_size[0]):
            yield (x, y, image[:, y:y + window_size[1], x:x + window_size[0]])

# Initialize stitched image
stitched_image = np.zeros((images.shape[1], images.shape[2]), dtype=np.uint8)

# Overlap cropping and prediction
window_size = (512, 512)
step_size = (256, 256)
batch_size = 8
batch_windows = []
batch_coords = []

for (x, y, window) in sliding_window(images, step_size, window_size):
    # Resize window to 256x256
    resized_window = cv2.resize(window.transpose(1, 2, 0), (256, 256)).transpose(2, 0, 1)
    batch_windows.append(resized_window)
    batch_coords.append((x, y))

    # If batch size is reached, perform prediction
    if len(batch_windows) == batch_size:
        batch_windows_tensor = torch.tensor(batch_windows).float().to(device)
        with torch.no_grad():
            outputs = model(batch_windows_tensor)

        for i in range(batch_size):
            output = (outputs[i] > 0.8).float().cpu().numpy()
            prediction = cv2.resize(output.squeeze(), window_size)
            prediction = (prediction > 0.5).astype(np.uint8) * 255

            # Update stitched image
            x, y = batch_coords[i]
            stitched_image[y:y + window_size[1], x:x + window_size[0]] = np.maximum(
                stitched_image[y:y + window_size[1], x:x + window_size[0]], prediction
            )

        batch_windows = []
        batch_coords = []

# Process remaining windows
if batch_windows:
    batch_windows_tensor = torch.tensor(batch_windows).float().to(device)
    with torch.no_grad():
        outputs = model(batch_windows_tensor)

    for i in range(len(batch_windows)):
        output = (outputs[i] > 0.8).float().cpu().numpy()
        prediction = cv2.resize(output.squeeze(), window_size)
        prediction = (prediction > 0.5).astype(np.uint8) * 255

        # Update stitched image
        x, y = batch_coords[i]
        stitched_image[y:y + window_size[1], x:x + window_size[0]] = np.maximum(
            stitched_image[y:y + window_size[1], x:x + window_size[0]], prediction
        )

# Create output image file
driver = gdal.GetDriverByName('GTiff')
out_raster = driver.Create(output_image_path, rsdataset.RasterXSize, rsdataset.RasterYSize, 1, gdal.GDT_Byte)

# Set georeferencing and projection information
out_raster.SetGeoTransform(rsdataset.GetGeoTransform())
out_raster.SetProjection(rsdataset.GetProjectionRef())

# Write data to output image
out_raster.GetRasterBand(1).WriteArray(stitched_image)

# Close output image file
out_raster.FlushCache()
out_raster = None

print(f"Processed large image and saved prediction to {output_image_path}")


  batch_windows_tensor = torch.tensor(batch_windows).float().to(device)


Processed large image and saved prediction to test_area_no_flip/attention_unet/test_area_predict.tif
