# CREATE A COPY FIRST!

# **Question 3: Multi-Class Segmentation for Animal Parts**

In this question, you should finetune a **pretrained U-Net** model for **multi-class segmentation** of animal parts. Your segmentation model will classify each pixel into **one of 5 classes** (Tail, Body, Legs, Head, Background).

Complete the code cells below.

---

In [1]:
# # Please uncomment this out when you are running this lab on google colab!
import os

# # Set KaggleHub cache to a directory inside /content/
os.environ["KAGGLEHUB_CACHE"] = "/content/data"

In [2]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("mohammad2012191/segmentation")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/mohammad2012191/segmentation?dataset_version_number=1...


100%|██████████| 192M/192M [00:09<00:00, 21.0MB/s]

Extracting files...





Path to dataset files: /content/data/datasets/mohammad2012191/segmentation/versions/1


## **TASK 1: Dataset Class**
- Build a custom dataset class to load images and masks.
- Prepare your train and validation dataloaders.
- Display some images and their corresponding masks.



In [37]:

import os
import glob
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
import random


class SEG(Dataset):
    def __init__(self, root_dir, split ,  transform=None, target_transform = None):
        self.split = split
        self.root_dir = root_dir  # main Dataset path
        self.transform = transform  # Transformations
        self.masks = glob.glob(f"{self.root_dir}/{self.split}/masks/*.png")
        self.image = glob.glob(f"{self.root_dir}/{self.split}/images/*.jpg")
        self.target_transform = target_transform


    def __len__(self):
        return len(self.image)  # Total number of images

    def __getitem__(self, idx):
        image_path = self.image[idx]  # Get image path
        mask = self.masks[idx]  # Get mask
        image = Image.open(image_path)
        mask = Image.open(mask)

        # Apply transformations (if any)
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)


        return image, mask





transform = transforms.Compose([
    transforms.Resize([128, 128]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5,0.5], std =[0.5, 0.5,0.5])
])
target_transform = transforms.Compose([
    transforms.Resize([128, 128]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5,], std = [0.5,])

])

train_dataset = SEG(path,  "train", transform=transform,target_transform=target_transform)
test_dataset = SEG(path, "val",  transform=transform,target_transform=target_transform, )
train_loader = DataLoader(train_dataset, shuffle=True, batch_size= 32)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=32)



## **TASK 2: Model Class**
- **Use a pretrained U-Net** (from `segmentation_models_pytorch`) with "efficientnet-b0" as an encoder.

In [41]:
import torch
import segmentation_models_pytorch as smp

# Define U-Net Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = smp.Unet(
    encoder_name="efficientnet-b0",  # Pretrained encoder (backbone)
    encoder_weights="imagenet",  # Use ImageNet weights
    in_channels=3,  # RGB images
    classes=5,  # Binary segmentation (1 output channel)
    activation="softmax"  # Apply Sigmoid activation directly in the model
).to(device)

In [7]:
!pip install -q segmentation_models_pytorch

  Preparing metadata (setup.py) ... [?25l[?25hdone
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/58.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.3/121.3 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m108.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m83.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m53.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

## **TASK 3: Training and Validation Loops**
- Define the training and validation loops.

In [20]:
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm

# 🔹 Training Loop
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0

    for images, masks in tqdm(dataloader):
        images, masks = images.to(device), masks.to(device).squeeze(dim=1).to(torch.long)  # mask shape becomes [N, H, W]

        outputs = model(images)
        loss = criterion(outputs.squeeze(1).float(), masks.squeeze(1).float())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

# 🔹 Validation Loop
def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for images, masks in dataloader:
            images, masks = images.to(device), masks.to(device).squeeze(dim=1).to(torch.long)    # mask shape becomes [N, H, W]

            outputs = model(images)  # Now [N, H, W]
            loss = criterion(outputs.squeeze(1).float(), masks.squeeze(1).float())
            total_loss += loss.item()

    return total_loss / len(dataloader)

## **TASK 4: Running Training**
- Define the loss and the optimizer.
- Train the model for 10 epochs.
- Print the training and validation losses.
- Plot loss curve.

In [40]:
import torch
from torch import nn
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss
optimizer = optim.AdamW(model.parameters(), lr=0.0001)

num_epochs = 4 # Define number of epochs
train_losses = []
val_losses = []

# Training Loop
for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss = validate(model, test_loader, criterion, device)

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    print(f"Epoch {epoch+1}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")


  0%|          | 0/55 [00:00<?, ?it/s]


RuntimeError: Given groups=1, weight of size [32, 1, 3, 3], expected input[32, 3, 129, 129] to have 1 channels, but got 3 channels instead

## **TASK 5: Visualizing Predictions**
- Visualize your model's predictions against the ground truth for several images.

# **BONUS Task: Using Dice Loss & Dice Coefficient**

**Dice loss** and **Dice coefficient** are widely used metrics for evaluating segmentation models. We typically use **Dice loss** during training because it is **differentiable**, and then calculate the **Dice coefficient** as a metric to measure performance—similar to how we use cross-entropy loss for training and accuracy for evaluation.


## **Your Bonus Tasks:**
- **Retrain** your previously built segmentation model using the **Dice loss** provided below.
- **Modify the validation function** to evaluate your retrained model using the **Dice coefficient** metric provided below.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Dice Loss for multiclass segmentation (Lower is better)
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-5):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        inputs = F.softmax(inputs, dim=1)
        targets = F.one_hot(targets, num_classes=inputs.shape[1]).permute(0, 3, 1, 2).float()

        intersection = (inputs * targets).sum(dim=(2, 3))
        total = inputs.sum(dim=(2, 3)) + targets.sum(dim=(2, 3))

        dice_score = (2 * intersection + self.smooth) / (total + self.smooth)
        dice_loss = 1 - dice_score.mean()

        return dice_loss

# Dice Coefficient Metric (Higher is better)
def dice_coefficient(inputs, targets, smooth=1e-5):
    inputs = F.softmax(inputs, dim=1)
    targets = targets.squeeze(1)
    targets = F.one_hot(targets, num_classes=inputs.shape[1]).permute(0, 3, 1, 2).float()

    intersection = (inputs * targets).sum(dim=(2, 3))
    total = inputs.sum(dim=(2, 3)) + targets.sum(dim=(2, 3))

    dice_score = (2 * intersection + smooth) / (total + smooth)

    return dice_score.mean().item()


In [None]:
# TO DO