In [1]:
import os
import numpy as np
from PIL import Image
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from tqdm.notebook import tqdm

In [None]:
path = Path("./preprocessed_patches") 
batch_size = 2
learning_rate = 0.0001
max_epochs = 25 
rgb_channels = 3 
segmentation_classes = 1 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


In [3]:
class SatellitePatchDataset(Dataset):
    def __init__(self, path, transform=None):
        self.path = Path(path)
        self.image_paths = []
        self.mask_paths = []
        self.transform=transform
        for city_dir in self.path.iterdir():
            if city_dir.is_dir():
                for img_path in city_dir.glob("patch_*.png"):
                    if "_gt.png" not in img_path.name:
                        mask_path = img_path.parent / f"{img_path.stem}_gt.png"
                        if mask_path.exists():
                            self.image_paths.append(img_path)
                            self.mask_paths.append(mask_path)
                        else:
                            print(f"Warning: Corresponding mask not found for {img_path}")

        print(f"Found {len(self.image_paths)} image-mask pairs across all cities.")
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):        
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")
        image = np.array(image, dtype=np.float32) / 255.0 
        mask = np.array(mask, dtype=np.float32) / 255.0 
        image = np.transpose(image, (2, 0, 1))
        mask = np.expand_dims(mask, axis=0)
        image_tensor = torch.from_numpy(image)
        mask_tensor = torch.from_numpy(mask)    
        if self.transform:
            image_tensor = self.transform(image_tensor) 
            mask_tensor = self.transform(mask_tensor)

        return image_tensor, mask_tensor    


In [None]:
# --- 4. U-Net Model Architecture ---

# Double Convolution Block
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

# Downsampling Block
class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

# Upsampling Block
class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
                                    diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1) 
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

# Full U-Net Model
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


In [None]:
dataset = SatellitePatchDataset(path)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)

In [None]:
model = UNet(n_channels=rgb_channels, n_classes=segmentation_classes)
model.to(device) # Move model to GPU 

criterion = nn.BCEWithLogitsLoss() # Binary classification 
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
for epoch in range(max_epochs):
    model.train() 
    running_loss = 0.0
    
    # Use tqdm for a progress bar in Jupyter
    for batch_idx, (images, masks) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{max_epochs}")):
        images = images.to(device)
        masks = masks.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        
        # Calculate loss
        loss = criterion(outputs, masks)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0) # Accumulate batch loss
        
    epoch_loss = running_loss / len(dataloader.dataset)
    print(f"Epoch {epoch+1}/{max_epochs}, Loss: {epoch_loss:.4f}")


print("Training complete!")