In [18]:
import os
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights

In [3]:
base_dir = "CityscapesDataset"

In [4]:
# Create custom dataset class
class CityscapesDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None, target_transform=None):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.target_transform = target_transform

        # Paths for images and labels
        self.img_dir = os.path.join(root_dir, split, "img")
        self.label_dir = os.path.join(root_dir, split, "label")

        # List of image files
        self.img_filenames = sorted(os.listdir(self.img_dir))
        self.label_filenames = sorted(os.listdir(self.label_dir))

    def __len__(self):
        return len(self.img_filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_filenames[idx])
        label_path = os.path.join(self.label_dir, self.label_filenames[idx])

        # Load image and label
        image = Image.open(img_path).convert("RGB")
        label = Image.open(label_path)

        # Apply transforms
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)

        return image, label

In [10]:
# Define transformations. Since the images are pretty small (256x96), random cropping would remove
# too much valuable information. If blurring the image, a smaller kernel should be applied. Color jitter
# should be applied conservatively to avoid excessive distortion. Scaling also should be limited since 
# the images are small.
train_image_transform = transforms.Compose([
    transforms.Resize((256, 96)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),  # Color jitter
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_label_transform = transforms.Compose([
    transforms.Resize((256, 96)),
    transforms.RandomHorizontalFlip(p=0.5),  # Ensure masks get the same flip
    transforms.ToTensor()
])

val_image_transform = transforms.Compose([
    transforms.Resize((256, 96)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_label_transform = transforms.Compose([
    transforms.Resize((256, 96)),
    transforms.ToTensor()
])

In [11]:
# Create CityscapesDataset objects for the training and validation datasets
train_dataset = CityscapesDataset(
    root_dir=base_dir,
    split="train",
    transform=train_image_transform,
    target_transform=train_label_transform)

val_dataset = CityscapesDataset(
    root_dir=base_dir,
    split="val",
    transform=val_image_transform,
    target_transform=val_label_transform)

In [12]:
# Create dataloaders for the training and validation datasets
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)

In [19]:
# Load the DeepLabV3 model with the ResNet50 backbone
model = models.deeplabv3_resnet50(weights=DeepLabV3_ResNet50_Weights.DEFAULT)

Downloading: "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth" to /home/abey/.cache/torch/hub/checkpoints/deeplabv3_resnet50_coco-cd0a2569.pth
100%|████████████████████████████████████████| 161M/161M [00:29<00:00, 5.63MB/s]
