In [1]:
# Import packages
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os

In [2]:
# Create DoubleConv class for UNet
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        #  Apply 2D convolution to extract spatial features, normalisation to speed up training and reduce over fitting, ReLU to learn complex functions
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    # Define data flow
    def forward(self, x):
        return self.conv(x)

In [3]:
# Create DownSample class for UNet
class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    # Define data flow
    def forward(self, x):
        skip = self.conv(x)
        x = self.pool(skip)
        return x, skip

In [4]:
# Create UpSample class for UNet
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = DoubleConv(out_channels * 2, out_channels)  # Concatenate skip connection

    # Define data flow
    def forward(self, x, skip):
        x = self.upconv(x)
        # Align the skip connection's shape with the upsampled x
        x = torch.cat([x, skip], dim=1)  # Concatenate along the channel dimension
        x = self.conv(x)
        return x

In [5]:
# Create UNet class
class UNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        
        # Downsampling
        self.downs = nn.ModuleList([
            DownSample(in_channels, 64),
            DownSample(64, 128),
            DownSample(128, 256),
            DownSample(256, 512)
        ])
        
        # Bottleneck
        self.bottleneck = DoubleConv(512, 1024)

        # Upsampling
        self.ups = nn.ModuleList([
            UpSample(1024, 512),
            UpSample(512, 256),
            UpSample(256, 128),
            UpSample(128, 64)
        ])

        # Final output layer
        self.out = nn.Conv2d(64, num_classes, kernel_size=1)

    # Define data flow
    def forward(self, x):
        # Downsampling path
        skip_connections = []
        for down_conv in self.downs:
            x, skip = down_conv(x)
            skip_connections.append(skip)
        
        # Bottleneck
        x = self.bottleneck(x)
        
        # Upsampling path with skip connections
        for up_conv, skip in zip(self.ups, reversed(skip_connections)):
            x = up_conv(x, skip)
        
        # Final output layer
        out = self.out(x)
        return out

In [6]:
# Data Loader
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_files = sorted([f for f in os.listdir(image_dir) if f.endswith('.pt')])
        self.mask_files = sorted([f for f in os.listdir(mask_dir) if f.endswith('.pt')])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image = torch.load(os.path.join(self.image_dir, self.image_files[idx]))
        mask = torch.load(os.path.join(self.mask_dir, self.mask_files[idx]))

        # Fix image shape (C, H, W)
        image = image.squeeze(0).permute(2, 0, 1)
        image = image.float()

        # Fix mask shape (C, H, W)
        mask = (mask > 0).float()

        return image, mask

In [7]:
# Train paths
train_image_dir = r'E:\processed_tiles\train\images'
train_mask_dir = r'E:\processed_tiles\train\masks'

train_dataset = SegmentationDataset(train_image_dir, train_mask_dir)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# Test paths
test_image_dir = r'E:\processed_tiles\test\images'
test_mask_dir = r'E:\processed_tiles\test\masks'

test_dataset = SegmentationDataset(test_image_dir, test_mask_dir)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [8]:
# Confirm shape
for images, masks in train_loader:
    print("Image batch shape:", images.shape)
    print("Mask batch shape:", masks.shape)
    break

Image batch shape: torch.Size([8, 3, 128, 128])
Mask batch shape: torch.Size([8, 1, 128, 128])


This is how we want the data to be formated for the UNet.

In [9]:
# Confirm between 0 and 1
for images, masks in train_loader:
    print("Image min/max:", images.min().item(), images.max().item())
    print("Mask unique values:", torch.unique(masks))
    break

Image min/max: 0.003921568859368563 0.5411764979362488
Mask unique values: tensor([0., 1.])


Max is a bit lower than 1 but, it should still work well.

In [13]:
# Initialise UNet model
model = UNet(in_channels=3, num_classes=1).cuda()  # Use GPU

# Define loss function and optimiser
criterion = torch.nn.BCEWithLogitsLoss()  # For binary segmentation
optimiser = optim.Adam(model.parameters(), lr=1e-4)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for images, masks in train_loader:
        images, masks = images.cuda(), masks.cuda()  # Move data to GPU if available
        
        optimiser.zero_grad()  # Zero out the gradients
        
        # Forward pass
        outputs = model(images)
        
        # Calculate loss
        loss = criterion(outputs, masks)
        
        # Backward pass
        loss.backward()
        
        # Update weights
        optimiser.step()
        
        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader)}")

Epoch [1/10], Loss: 0.3031165373445211
Epoch [2/10], Loss: 0.14142448758839254
Epoch [3/10], Loss: 0.08348095979561972
Epoch [4/10], Loss: 0.05804115713623368
Epoch [5/10], Loss: 0.045095034434002104
Epoch [6/10], Loss: 0.03731173060330438
Epoch [7/10], Loss: 0.03338015533581894
Epoch [8/10], Loss: 0.028566366440396913
Epoch [9/10], Loss: 0.025337434817323465
Epoch [10/10], Loss: 0.02500653194290819
