## Data Loader

In [1]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

class DisasterDataset(Dataset):
    def __init__(self, input_dir, target_dir):
        self.input_dir = input_dir
        self.target_dir = target_dir
        self.filenames = [f for f in os.listdir(input_dir) if f.endswith('.png')]
        # Define the transforms
        self.transform = transforms.Compose([
            transforms.Resize((1024, 1024)),  # Resize if not already 1024x1024
            transforms.ToTensor(),  # Convert images to tensors
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        input_path = os.path.join(self.input_dir, self.filenames[idx])
        target_path = os.path.join(self.target_dir, self.filenames[idx].replace('.png', '_target.png'))

        input_image = Image.open(input_path).convert('RGB')
        target_image = Image.open(target_path).convert('RGB')

        if self.transform:
            input_image = self.transform(input_image)
            target_image = self.transform(target_image)

        return input_image, target_image

# Paths to data storagement
train_input_folder_path = '/Users/willianribeiro/Documents/GitHub/disaster-responder-satellite-images/data/processed/building_identification/train/input_images'
train_output_folder_path = '/Users/willianribeiro/Documents/GitHub/disaster-responder-satellite-images/data/processed/building_identification/train/output_images'

test_input_folder_path = '/Users/willianribeiro/Documents/GitHub/disaster-responder-satellite-images/data/processed/building_identification/validation/input_images'
test_output_folder_path = '/Users/willianribeiro/Documents/GitHub/disaster-responder-satellite-images/data/processed/building_identification/validation/output_images'

dataset = DisasterDataset(train_input_folder_path, train_output_folder_path)
loader = DataLoader(dataset, batch_size=4, shuffle=True)

val_dataset = DisasterDataset(test_input_folder_path, test_output_folder_path)
val_loader = DataLoader(dataset, batch_size=4, shuffle=True)


## Model using pytorch U-net model

1. Loss function: Since this is an image-to-image translation, a combination of Mean Squared Error (MSE) and a perceptual loss like VGG-based loss could be effective.
2. Optimizer: Commonly, Adam is used.

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x, skip_x):
        x = self.up(x)
        # Input is CHW
        diffY = skip_x.size()[2] - x.size()[2]
        diffX = skip_x.size()[3] - x.size()[3]

        x = F.pad(x, [diffX // 2, diffX - diffX // 2,
                      diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://discuss.pytorch.org/t/semantic-segmentation-unet-padding/2971
        x = torch.cat([skip_x, x], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

# Example usage
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(n_channels=3, n_classes=3).to(device)


In [3]:
import torch
from torch import nn
from torch.optim import Adam
from tqdm import tqdm

criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=0.001)

best_loss = float('inf')
patience = 5
patience_counter = 0

num_epochs = 50  # Start with more epochs, let early stopping handle the cutoff

for epoch in range(num_epochs):
    with tqdm(loader, unit="batch") as tepoch:
        for inputs, targets in tepoch:
            tepoch.set_description(f"Epoch {epoch+1}")
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            tepoch.set_postfix(loss=loss.item())

    # Validation step
    model.eval()
    with torch.no_grad():
        val_loss = 0
        for val_inputs, val_targets in val_loader:  # Assume you have a validation loader
            val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)
            val_outputs = model(val_inputs)
            val_loss += criterion(val_outputs, val_targets).item()
        val_loss /= len(val_loader)

    print(f'Epoch {epoch+1}, Training Loss: {loss.item()}, Validation Loss: {val_loss}')

    # Early stopping logic
    if val_loss < best_loss:
        best_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), 'best_model.pth')  # Save the best model
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break


Epoch 1:   0%|          | 0/700 [00:00<?, ?batch/s]