# Drone Images Segmentation Using SegNet

In [None]:
!pip install torchmetrics

In [None]:
import torch
import torchmetrics
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms import v2
from torchvision import tv_tensors
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.functional import InterpolationMode

from sklearn.model_selection import train_test_split

import os
import time
import random
import numpy as np
import seaborn as sns
from PIL import Image
import matplotlib.pyplot as plt


# Reproducibilty Settings

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Computation Device: ", device)

In [None]:
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

if device == "cuda":
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

# Dataset Class

In [None]:
class DroneImagesSegmentationDataset(Dataset):
    def __init__(self, images_dir, masks_dir, filenames, joint_transforms=None, image_transforms=None):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.filenames = filenames
        self.joint_transforms = joint_transforms
        self.image_transforms = image_transforms

        # Lookup table for the classes present in the mask
        original_classes = [0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 19, 20, 21, 22]
        self.lookup_table = np.zeros(23, dtype=np.int64)
        for new_index, orig_val in enumerate(original_classes):
            self.lookup_table[orig_val] = new_index


    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, index):
        image_path = os.path.join(self.images_dir, self.filenames[index])
        mask_path = os.path.join(self.masks_dir, self.filenames[index])
        
        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        # Remap the mask pixel values before creating the tensor
        mask_array = np.array(mask)
        remapped_masked = self.lookup_table[mask_array]

        image = tv_tensors.Image(image)
        mask = tv_tensors.Mask(remapped_masked)

        if self.joint_transforms:
            image, mask = self.joint_transforms(image, mask)
        
        if self.image_transforms:
            image = self.image_transforms(image)
        
        mask = mask.to(torch.long)

        # Remove channel dimension for Cross Entropy
        if mask.ndim == 3 and mask.shape[0] == 1:
            mask = mask.squeeze(0)

        return image, mask

# Helper Functions

In [None]:
def get_transforms(is_train=True):
    
    if is_train: 
        # Joint Spatial Transforms For Both Images & Masks
        joint_transforms = v2.Compose([
            v2.Resize((256, 256), interpolation=InterpolationMode.NEAREST),
            
            v2.RandomHorizontalFlip(p=0.5),
            v2.RandomVerticalFlip(p=0.5),

            v2.RandomRotation(
                degrees=45,
                interpolation=InterpolationMode.NEAREST
            ),

            v2.RandomResizedCrop(
                size=(256, 256),
                scale=(0.8, 1.0),
                interpolation=InterpolationMode.NEAREST
            ),
            
            v2.ToImage(), # Image to tensor (Modern ToTensor() Alternative)
            v2.ToDtype(torch.float32, scale=True)
        ])

        # Image Only Transforms
        image_transforms = v2.Compose([
            v2.ColorJitter(
                brightness=0.2,
                contrast=0.2,
                saturation=0.2
            ),
            v2.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
    else:
        joint_transforms = v2.Compose([
            v2.Resize((256, 256), interpolation=InterpolationMode.NEAREST),
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
        ])

        image_transforms = v2.Compose([
            v2.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
    
    return joint_transforms, image_transforms

In [None]:
def get_loaders(images_dir, masks_dir, batch_size=8):
    
    # Initialize Transforms
    train_joint_transforms, train_image_transforms = get_transforms(is_train=True)
    val_test_joint_transforms, val_test_image_transforms = get_transforms(is_train=False)

    # Ensure only matching image/mask pairs are used
    all_images = set(os.listdir(images_dir))
    all_masks = set(os.listdir(masks_dir))
    paired_filenames = sorted(list(all_images & all_masks))

    # Split filenames before creating datasets
    train_files, val_test_files = train_test_split(
        paired_filenames, test_size=0.2, random_state=42
    )

    val_files, test_files = train_test_split(
        val_test_files, test_size=0.5, random_state=42
    )

    # Create Datasets
    train_dataset = DroneImagesSegmentationDataset(
        images_dir, masks_dir, train_files, train_joint_transforms, train_image_transforms
    )

    val_dataset = DroneImagesSegmentationDataset(
        images_dir, masks_dir, val_files, val_test_joint_transforms, val_test_image_transforms
    )

    test_dataset = DroneImagesSegmentationDataset(
        images_dir, masks_dir, test_files, val_test_joint_transforms, val_test_image_transforms
    )

    # Create DataLoaders
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True
    )

    return train_loader, val_loader, test_loader

In [None]:
def get_metrics(num_classes, device):
    metrics = torchmetrics.MetricCollection({
        "dice": torchmetrics.segmentation.DiceScore(
            num_classes=num_classes, include_background=True, average=None, input_format="one_hot"
        ),
        "miou": torchmetrics.segmentation.MeanIoU(
            num_classes=num_classes, include_background=True, per_class=True, input_format="one-hot"
        ),
        "pixel_acc": torchmetrics.Accuracy(
            task="multiclass", num_classes=num_classes
        )
    }).to(device)

    return metrics

In [None]:
# TODO: Loss Function

# Model Architecture

In [None]:
class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels, num_conv_layers):
        super(Encoder, self).__init__()

        self.enc_block = nn.ModuleList()

        current_in_channels = in_channels

        for _ in range(num_conv_layers):
            
            layer = nn.Sequential(
                nn.Conv2d(current_in_channels, out_channels, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
            )

            self.enc_block.append(layer)

            current_in_channels = out_channels
        
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)

    def forward(self, x):
        for layer in self.enc_block:
            x = layer(x)
        
        x, indices = self.max_pool(x)
        return x, indices

In [None]:
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels, num_conv_layers):
        super(Decoder, self).__init__()

        self.max_unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)

        self.dec_block = nn.ModuleList()

        current_in_channels = in_channels

        for i in range(num_conv_layers):
            
            if i+1 == num_conv_layers:
                current_out_channels = out_channels
            else:
                current_out_channels = in_channels

            layer = nn.Sequential(
                nn.Conv2d(current_in_channels, current_out_channels, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(current_out_channels),
                nn.ReLU(inplace=True)
            )

            self.dec_block.append(layer)

            current_in_channels = current_out_channels

    
    def forward(self, x, indices):
        x = self.max_unpool(x, indices)

        for layer in self.dec_block:
            x = layer(x)
        
        return x

In [None]:
class SegNet(nn.Module):
    def __init__(self, num_classes):
        super(SegNet, self).__init__()

        # Number of Conv2d layers per encoder block 
        num_conv_layer_enc = [2, 2, 3, 3, 3]

        self.enc1 = Encoder(3, 64, num_conv_layer_enc[0])
        self.enc2 = Encoder(64, 128, num_conv_layer_enc[1])
        self.enc3 = Encoder(128, 256, num_conv_layer_enc[2])
        self.enc4 = Encoder(256, 512, num_conv_layer_enc[3])
        self.enc5 = Encoder(512, 512, num_conv_layer_enc[4])


        # Number of Conv2d layers per decoder block 
        num_conv_layer_dec = [3, 3, 3, 2, 2]

        self.dec5 = Decoder(512, 512, num_conv_layer_dec[0])
        self.dec4 = Decoder(512, 256, num_conv_layer_dec[1])
        self.dec3 = Decoder(256, 128, num_conv_layer_dec[2])
        self.dec2 = Decoder(128, 64, num_conv_layer_dec[3])
        self.dec1 = Decoder(64, num_classes, num_conv_layer_dec[4])

    def forward(self, x):
        x, ind1 = self.enc1(x)
        x, ind2 = self.enc2(x)
        x, ind3 = self.enc3(x)
        x, ind4 = self.enc4(x)
        x, ind5 = self.enc5(x)

        x = self.dec5(x, ind5)
        x = self.dec4(x, ind4)
        x = self.dec3(x, ind3)
        x = self.dec2(x, ind2)
        x = self.dec1(x, ind1)

        return x
