# Drone Images Segmentation Using SegNet

In [None]:
!pip install torchmetrics

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms import v2
from torchvision import tv_tensors
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.functional import InterpolationMode

from sklearn.model_selection import train_test_split

import os
import time
import seaborn as sns
from PIL import Image
import matplotlib.pyplot as plt


# Dataset Class

In [None]:
class DroneImagesSegmentationDataset(Dataset):
    def __init__(self, images_dir, masks_dir, filenames, joint_transforms=None, image_transforms=None):
        self.images_dir = images_dir
        self.masks_dir = images_dir
        self.filenames = filenames
        self.joint_transforms = joint_transforms
        self.image_transforms = image_transforms

    def __len__(self):
        return len(self.images_dir)
    
    def __getitem__(self, index):
        image_path = os.path.join(self.images_dir, self.filenames[index])
        mask_path = os.path.join(self.masks_dir, self.filenames[index])
        
        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("RGB")

        image = tv_tensors(image)
        mask = tv_tensors(mask)

        if self.joint_transforms:
            image, mask = self.joint_transforms(image, mask)
        
        if self.image_transforms:
            image = self.image_transforms(image)
        
        mask = (mask > 0).to(torch.float32)

        return image, mask

# Helper Functions

In [None]:
def get_transforms(is_train=True):
    
    if is_train: 
        # Joint Spatial Transforms
        joint_transforms = v2.Compose([
            v2.Resize((256, 256), interpolation=InterpolationMode.NEAREST),
            
            v2.RandomHorizontalFlip(p=0.5),
            v2.RandomVerticalFlip(p=0.5),

            v2.RandomRotation(
                degrees=45,
                interpolation=InterpolationMode.NEAREST
            ),

            v2.RandomResizedCrop(
                size=(256, 256),
                scale=(0.8, 1.0),
                interpolation=InterpolationMode.NEAREST
            ),
            
            v2.ToImage(), # Image to tensor (.ToTensor() Alternative)
            v2.ToDtype(torch.float32, scale=True)
        ])

        # Image Only Transforms
        image_transforms = v2.Compose([
            v2.ColorJitter(
                brightness=0.2,
                contrast=0.2,
                saturation=0.2
            ),
            v2.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
    else:
        joint_transforms = v2.Compose([
            v2.Resize((256, 256), interpolation=InterpolationMode.NEAREST),
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
        ])

        image_transforms = v2.Compose([
            v2.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
    
    return joint_transforms, image_transforms

In [None]:
def get_loaders(images_dir, masks_dir, batch_size=8):
    
    train_joint_transforms, train_image_transforms = get_transforms(is_train=True)
    val_test_joint_transforms, val_test_image_transforms = get_transforms(is_train=False)

    all_images = set(os.listdir(images_dir))
    all_masks = set(os.listdir(masks_dir))
    paired_filenames = sorted(list(all_images & all_masks))

    # Split filenames before creating datasets
    train_files, val_test_files = train_test_split(
        paired_filenames, test_size=0.2, random_state=42
    )

    val_files, test_files = train_test_split(
        val_test_files, test_size=0.5, random_state=42
    )

    train_dataset = DroneImagesSegmentationDataset(
        images_dir, masks_dir, train_files, train_joint_transforms, train_image_transforms
    )

    val_dataset = DroneImagesSegmentationDataset(
        images_dir, masks_dir, val_files, val_test_joint_transforms, val_test_image_transforms
    )

    test_dataset = DroneImagesSegmentationDataset(
        images_dir, masks_dir, test_files, val_test_joint_transforms, val_test_image_transforms
    )

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True
    )

    return train_loader, val_loader, test_loader

# Model Architecture