##### Install Dependencies

In [3]:
!pip install opencv-python torch torchvision torchmetrics albumentations matplotlib



##### Required dependencies

In [1]:
import os
import cv2
import torch
from typing import List, Tuple
from torch.utils.data import Dataset
import numpy as np
import torch
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2  # np.array -> torch.tensor
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torchmetrics.segmentation import GeneralizedDiceScore, MeanIoU

##### Dataset preparation

In [14]:
# Download and prepare the dataset
import urllib.request
import tarfile

os.makedirs("/content/data", exist_ok=True)

# Download Oxford-IIIT Pet Dataset
url = "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz"
urllib.request.urlretrieve(url, "/content/data/images.tar.gz")

url_annotations = "https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz"
urllib.request.urlretrieve(url_annotations, "/content/data/annotations.tar.gz")

with tarfile.open("/content/data/images.tar.gz", "r:gz") as tar:
    tar.extractall("/content/data")

with tarfile.open("/content/data/annotations.tar.gz", "r:gz") as tar:
    tar.extractall("/content/data")

root = "/content/data"
saved_directory = "/content/saved"

  tar.extractall("/content/data")
  tar.extractall("/content/data")


In [19]:
!ls /content/data

annotations  annotations.tar.gz  images  images.tar.gz


In [11]:
class OxfordIIIPetDataset(Dataset):
    def __init__(
        self,
        root: str,
        is_train: bool = True,
        transform=None,
    ):
        self.root = root
        self.transform = transform
        self.classes = ["background", "animal"]
        self.image_names: List[str] = []

        if is_train:
            annotations = os.path.join(root, "annotations", "trainval.txt")
        else:
            annotations = os.path.join(root, "annotations", "test.txt")

        # Read the annotation file and extract image names
        with open(annotations, "r") as f:
            self.image_names = [image.split(' ')[0] for image in f.readlines()]

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, item) -> Tuple[torch.Tensor, torch.Tensor]:
        image_name = self.image_names[item]
        image_path = os.path.join(self.root, "images", image_name + ".jpg")
        mask_path = os.path.join(self.root, "annotations", "trimaps", image_name + ".png")

        # Read the image and convert it from BGR to RGB
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # type: ignore

        # Read the mask and adjust its values
        # 0.299 x Red + 0.587 x Green + 0.114 x Blue
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask[mask == 2] = 0  # type: ignore
        mask[mask == 3] = 1  # type: ignore

        # Apply transformations if provided
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]

        return image, mask  # type: ignore

##### Build U-Net model

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class DualConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        mid_channels: int | None = None,
    ):
        super(DualConv, self).__init__()
        if not mid_channels:
            mid_channels = out_channels

        self.sequential = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.sequential(x)


class Down(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super(Down, self).__init__()
        self.sequential = nn.Sequential(
            nn.MaxPool2d(2),
            DualConv(in_channels, out_channels),
        )

    def forward(self, x):
        return self.sequential(x)


class Up(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, bilinear: bool = True):
        super(Up, self).__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
            self.conv = DualConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(
                in_channels,
                in_channels // 2,
                kernel_size=2,
                stride=2,
            )
            self.conv = DualConv(in_channels, out_channels)

    def forward(self, x1, x2):
        """
        x1: from the previous layer - decoder
        x2: from the skip connection - encoder
        """
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]  # height
        diffX = x2.size()[3] - x1.size()[3]  # width

        # pad function: (L, R, T, B)
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])

        # Concatenate along the channels axis
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class UNetBaseline(nn.Module):
    def __init__(self, in_channels: int, num_classes: int):
        super(UNetBaseline, self).__init__()

        # Encoder
        self.inc = DualConv(in_channels, 32)
        self.down1 = Down(32, 64)
        self.down2 = Down(64, 128)
        self.down3 = Down(128, 256)

        # Bottleneck
        self.down4 = Down(256, 512)

        # Decoder
        self.up1 = Up(512, 256, bilinear=False)
        self.up2 = Up(256, 128, bilinear=False)
        self.up3 = Up(128, 64, bilinear=False)
        self.up4 = Up(64, 32, bilinear=False)

        # Output layer
        self.outc = nn.Conv2d(32, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder with skip connections
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)  # Bottleneck

        # Decoder with skip connections
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)

        return x

In [13]:
LEARNING_RATE = 0.0001
BATCH_SIZE = 10
EPOCHS = 50
NUM_WORKERS = 4

ROOT = "/Users/hinsun/Workspace/ComputerScience/UNetWithBraTS"
ROOT_DATASET = ROOT + "/data/OxfordIIITPet/oxford-iiit-pet"
MODEL_CHECK_POINT = ROOT + "/checkpoints/oxford_iiit_pet"
if not os.path.exists(MODEL_CHECK_POINT):
    os.makedirs(MODEL_CHECK_POINT)

model_path = MODEL_CHECK_POINT

train_transform = A.Compose([
    A.Resize(width=224, height=224),
    A.HorizontalFlip(),
    A.RandomBrightnessContrast(),
    A.Blur(),
    A.Sharpen(),
    A.RGBShift(),
    ToTensorV2(),
])

train_dataset = OxfordIIIPetDataset(
    root=ROOT_DATASET,
    is_train=True,
    transform=train_transform
)

val_dataset = OxfordIIIPetDataset(
    root=ROOT_DATASET,
    is_train=False,
    transform=train_transform
)

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=True,
    drop_last=True,
)

val_dataloader = DataLoader(
    dataset=val_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=False,
    drop_last=True,
)

##### Training the model

In [14]:
model = UNetBaseline(in_channels=3, num_classes=1)
device = torch.device("mps" if torch.mps.is_available() else "cpu")

# Initialize optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss()

# Metrics
miou_metric = MeanIoU(num_classes=2)
dice_metric = GeneralizedDiceScore(num_classes=2)

# Best validation IoU for saving the best model
best_predict = -1
current_epoch = 0

# Training loop
for epoch in range(EPOCHS):
    # Training Phase
    model.train()
    train_progress = tqdm(train_dataloader, colour="cyan")

    for idx, img_mask in enumerate(train_progress):
        # B, C, H, W
        img = img_mask[0].float().to(device)  # type: ignore
        # B, H, W
        mask = img_mask[1].float().to(device)

        y_pred = model(img)  # B, 1, H, W
        y_pred = y_pred.squeeze()  # B, H, W
        optimizer.zero_grad()

        # Calculate Loss
        loss = criterion(y_pred, mask)

        # Backpropagation
        loss.backward()
        optimizer.step()
        train_progress.set_description("TRAIN| Epoch: {}/{}| Loss: {:0.4f}".format(epoch, EPOCHS, loss))

    # Validation Phase
    model.eval()

    all_losses = []
    all_ious = []
    all_dices = []

    with torch.no_grad():
        for idx, img_mask in enumerate(val_dataloader):
            img = img_mask[0].float().to(device)  # type: ignore
            mask = img_mask[1].float().to(device)  # B W H

            y_pred = model(img)
            y_pred = y_pred.squeeze()  # B H W

            loss = criterion(y_pred, mask)

            mask = mask.long().cpu()
            y_pred[y_pred > 0] = 1  # BWH
            y_pred[y_pred < 0] = 0  # BWH
            y_pred = y_pred.long().cpu()

            miou = miou_metric(y_pred, mask)
            dice = dice_metric(y_pred, mask)

            all_losses.append(loss.cpu().item())
            all_ious.append(miou.cpu().item())
            all_dices.append(dice.cpu().item())

            if idx == 40: break

    # Compute mean IoU for the epoch
    loss = np.mean(all_losses)
    miou = np.mean(all_ious)
    dice = np.mean(all_dices)

    print("VAL| Loss: {:0.4f} | mIOU: {:0.4f} | Dice: {:0.4f}".format(loss, miou, dice))

    checkpoint = {
        "model_state_dict": model.state_dict(),
        "epoch": epoch,
        "optimizer_state_dict": optimizer.state_dict(),
        "miou": miou
    }

    # Save Last Checkpoint
    torch.save(checkpoint, os.path.join(model_path, "last.h5"))

    # Save best checkpoint based on IoU
    if miou > best_predict:
        torch.save(checkpoint, os.path.join(model_path, "best.pth"))
        best_predict = miou

  0%|[36m          [0m| 0/368 [00:00<?, ?it/s]Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File [35m"<string>"[0m, line [35m1[0m, in [35m<module>[0m
    from multiprocessing.spawn import spawn_main; [31mspawn_main[0m[1;31m(tracker_fd=95, pipe_handle=119)[0m
                                                  [31m~~~~~~~~~~[0m[1;31m^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
  File [35m"/Users/hinsun/Workspace/ComputerScience/UNetWithBraTS/.venv/lib/python3.13/multiprocessing/spawn.py"[0m, line [35m122[0m, in [35mspawn_main[0m
    exitcode = _main(fd, parent_sentinel)
  File [35m"/Users/hinsun/Workspace/ComputerScience/UNetWithBraTS/.venv/lib/python3.13/multiprocessing/spawn.py"[0m, line [35m132[0m, in [35m_main[0m
    self = reduction.pickle.load(from_parent)
[1;35mAttributeError[0m: [35mCan't get attribute 'OxfordIIIPetDataset' on <module '__main__' (<class '_frozen_importlib.BuiltinImporter'>)>[0m
  F

RuntimeError: DataLoader worker (pid(s) 12704, 12705) exited unexpectedly