# CNN Image Segmentation Tutorial

## Table of Contents
- [CNN Image Segmentation Tutorial](#cnn-image-segmentation-tutorial)
  - [Table of Contents](#table-of-contents)
  - [Introduction](#introduction)
    - [What is Image Segmentation?](#what-is-image-segmentation)
    - [Types of Segmentation](#types-of-segmentation)
    - [Real-world Applications](#real-world-applications)
  - [Prerequisites](#prerequisites)
    - [Required Libraries](#required-libraries)
    - [Dataset Overview](#dataset-overview)
  - [Data Preparation](#data-preparation)
    - [Dataset Organization](#dataset-organization)
    - [Custom Dataset Classes](#custom-dataset-classes)
    - [Data Preprocessing](#data-preprocessing)
    - [Data Augmentation](#data-augmentation)
    - [DataLoader Creation](#dataloader-creation)
  - [Model Architecture](#model-architecture)
    - [U-Net with MobileNetV2](#u-net-with-mobilenetv2)
      - [Architecture Components:](#architecture-components)
    - [Training Components](#training-components)
      - [Loss Function](#loss-function)
      - [Optimizer](#optimizer)
      - [Learning Rate Scheduler](#learning-rate-scheduler)
    - [Model Configuration](#model-configuration)
    - [Device Management](#device-management)
  - [Metrics (`metrics.py`)](#metrics-metricspy)
    - [Pixel Accuracy](#pixel-accuracy)
    - [Mean IoU (Intersection over Union)](#mean-iou-intersection-over-union)
  - [Training and Inference (`model.py`)](#training-and-inference-modelpy)
    - [Training Pipeline](#training-pipeline)
    - [Inference Pipeline](#inference-pipeline)
  - [Visualization (`viz.py`)](#visualization-vizpy)
    - [Training Progress Visualization](#training-progress-visualization)
    - [Prediction Visualization](#prediction-visualization)
  - [Downloading the Dataset](#downloading-the-dataset)
  - [Main Training Script](#main-training-script)
  - [Main Inference Script](#main-inference-script)
  - [Exercises](#exercises)

## Introduction
Image segmentation is a fundamental computer vision task that involves dividing an
image into multiple segments or regions, where each pixel in the image is assigned
a class label. Unlike classification, which provides a single label for an entire
image, or object detection, which identifies object locations with bounding boxes,
segmentation provides pixel-level understanding of the image content.

In this tutorial, we'll be working with semantic segmentation, where we assign
each pixel to a predefined class category. For example, in our drone imagery
dataset, pixels might belong to classes such as 'building', 'vegetation', 'road',
or 'vehicle'.

### Types of Segmentation
There are several types of image segmentation, each serving different purposes:

1. **Semantic Segmentation**
   - Assigns each pixel to a class category
   - Doesn't distinguish between instances of the same class
   - Example: All 'car' pixels get the same label, regardless of how many cars
     are present

2. **Instance Segmentation**
   - Identifies individual instances of objects
   - Distinguishes between different instances of the same class
   - Example: Each car in the image gets a unique instance ID

3. **Panoptic Segmentation**
   - Combines semantic segmentation for background classes
   - Adds instance segmentation for countable objects
   - Provides a unified view of scene understanding

### Real-world Applications
Image segmentation has numerous practical applications across various fields:

1. **Aerial and Satellite Imagery**
   - Urban planning and development
   - Agricultural monitoring
   - Disaster response and damage assessment
   - Environmental monitoring

2. **Medical Imaging**
   - Tumor detection and measurement
   - Organ segmentation
   - Cell counting and analysis
   - Surgical planning

3. **Autonomous Vehicles**
   - Road and lane detection
   - Obstacle identification
   - Pedestrian segmentation
   - Traffic analysis

4. **Industrial Applications**
   - Quality control and inspection


In this tutorial, we'll focus on semantic segmentation of aerial drone imagery,
which has important applications in urban planning, mapping, and environmental
monitoring. Our implementation uses a U-Net architecture with a MobileNetV2
backbone, providing an efficient and effective solution for real-world
segmentation tasks.



### Dataset Overview

We'll be working with the Semantic Drone Dataset, which consists of:
- Aerial imagery captured by drones
- High-resolution RGB images (6000x4000px)
- Pixel-wise semantic segmentation masks
- 23 different class categories including:
  - Buildings
  - Roads
  - Vegetation
  - Vehicles
  - People
  - Other urban features

![Semantic Drone Dataset Example](https://www.kaggleusercontent.com/kf/55737032/eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0..x_-exPHMZnlgjHP-2ibtug.GSbJPnhFB0nT3LkxG2ODuIkwkAO44uhpRSDSvSDWsX_JVxZML6hT1QcRVhiknaMFNUgzfshyJ0oLQdRlSP92On2e9Nzv8Zxo2BNJuWNvfINFHXb_1R1QJbVsK8HPJgF9tgEUOaudVM9fvJCDvC8Id2lw7vBiXGPdWRzF8r7xDDMzk-jX-vL5SVOBTXc4UgYltiwEbRatnykFZ8LEq03j2vb7hb5XhkgCXAPl0FOkrpJGuxKwKHSSzyviZqO3g9dP2_kUk4_bfzIbKn-hek3k3MNHxjZhEfrEmcNkgLBzoLuhbxVDRYd7tyheAyga3UA3qggzBE3YDIDJUJgWE3Qf3uzoPL5o2DzB2PirGapRoE387CjR8tZzxnrC3H1c6nYNS7UcByrOnpLgszotGjNoAwFT_l9UpQMblXw_ln2MP1Cqw3o0Ab1JhasCIS7kbY9VurenuulbpwEJH6QJDeliLO9Q3tUvvqPITYjyvjxTgOb8kSF9X8Q93KpJKbRFGO_KyelKe0Yb1ak8fjyP-m0HiXcfnjXEb0Z0aGhFQ4jTv1CwpiBSwxtoPGNVgzm_xb50lyUgXRFxCsTSPUR-BDT9w8q4KephJSoDuZy-Jb43pZaHH2r1fTh6gWnbFAJJVXC_IW0v5a1t8ZLbaAqL8C-E7sh3tJUJwGfmdXLNdVJ3kgQ.6S4uje2IxtNyDT9Db6YPww/__results___files/__results___33_0.png)

The dataset structure:
```
data/
├── dataset/
│   └── semantic_drone_dataset/
│       ├── original_images/     # RGB images
│       └── label_images_semantic/  # Segmentation masks
```



## Prerequisites

Before diving into the implementation, let's review the essential libraries and tools we'll be using in this project.

### Required Libraries

1. **PyTorch**
   - Deep learning framework
   - Provides neural network building blocks and GPU acceleration
   - Version requirement: >= 1.7.0

2. **OpenCV (cv2)**
   - Image loading and preprocessing
   - Color space conversions
   - Basic image operations
   ```python
   pip install opencv-python
   ```

3. **Albumentations**
   - High-performance data augmentation
   - Implements various image transformations
   - Specially designed for segmentation tasks
   ```python
   pip install albumentations
   ```

4. **Segmentation Models PyTorch (smp)**
   - Pre-implemented segmentation architectures
   - Provides pre-trained encoders
   - Easy-to-use interface for segmentation models
   ```python
   pip install segmentation-models-pytorch
   ```

5. **Supporting Libraries**
   - NumPy: Numerical operations and array manipulation
   - Pandas: Data organization and splitting
   - Matplotlib: Visualization and result plotting
   - tqdm: Progress bar for training loops



## Data Preparation

The data preparation pipeline is implemented in `cnn_data.py` and handles dataset organization, loading, and preprocessing.

### Dataset Organization

First, we define the paths to our image and mask directories:

```python
IMAGE_PATH = "data/dataset/semantic_drone_dataset/original_images/"
MASK_PATH = "data/dataset/semantic_drone_dataset/label_images_semantic/"
```

We create a DataFrame to organize our data and split it into training, validation, and test sets:

```python
def create_df():
    name = []
    for dirname, _, filenames in os.walk(IMAGE_PATH):
        for filename in filenames:
            name.append(filename.split(".")[0])
    return pd.DataFrame({"id": name}, index=np.arange(0, len(name)))

def get_data_splits():
    df = create_df()
    # Split: 75% train, 15% validation, 10% test
    X_trainval, X_test = train_test_split(df["id"].values, test_size=0.1, random_state=19)
    X_train, X_val = train_test_split(X_trainval, test_size=0.15, random_state=19)
    return X_train, X_val, X_test
```

### Custom Dataset Classes

We implement two custom PyTorch Dataset classes:

1. **DroneDataset**: For training and validation data
```python
class DroneDataset(Dataset):
    def __init__(self, img_path, mask_path, X, mean, std, transform=None, patch=False):
        self.img_path = img_path
        self.mask_path = mask_path
        self.X = X
        self.transform = transform
        self.patches = patch
        self.mean = mean
        self.std = std

    def __getitem__(self, idx):
        # Load image and mask
        img = cv2.imread(self.img_path + self.X[idx] + ".jpg")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_path + self.X[idx] + ".png",
                         cv2.IMREAD_GRAYSCALE)
```

2. **DroneTestDataset**: Specialized for test data with simplified processing

### Data Preprocessing

Our preprocessing pipeline includes:

1. **Image Loading**
   - Read images using OpenCV
   - Convert from BGR to RGB color space
   - Load masks as grayscale images

2. **Normalization**
   - Standardize images using ImageNet statistics:
   ```python
   mean = [0.485, 0.456, 0.406]
   std = [0.229, 0.224, 0.225]
   ```

3. **Tensor Conversion**
   - Convert images to PyTorch tensors
   - Convert masks to long tensors for classification

### Data Augmentation

We use the Albumentations library for efficient data augmentation:

```python
t_train = A.Compose([
    A.Resize(704, 1056, interpolation=cv2.INTER_NEAREST),
    A.HorizontalFlip(),
    A.VerticalFlip(),
    A.GridDistortion(p=0.2),
    A.RandomBrightnessContrast((0, 0.5), (0, 0.5)),
    A.GaussNoise(),
])

t_val = A.Compose([
    A.Resize(704, 1056, interpolation=cv2.INTER_NEAREST),
    A.HorizontalFlip(),
    A.GridDistortion(p=0.2),
])
```

The augmentation pipeline includes:
- Resizing to a consistent size
- Random horizontal and vertical flips
- Grid distortion for geometric variety
- Brightness and contrast adjustments
- Gaussian noise for robustness

### DataLoader Creation

Finally, we create PyTorch DataLoaders for efficient batch processing:

```python
def get_data_loaders(batch_size=16):
    X_train, X_val, X_test = get_data_splits()

    train_set = DroneDataset(IMAGE_PATH, MASK_PATH, X_train,
                            mean, std, t_train, patch=False)
    val_set = DroneDataset(IMAGE_PATH, MASK_PATH, X_val,
                          mean, std, t_val, patch=False)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)

    return train_loader, val_loader, test_set
```

This setup provides:
- Batched data loading
- Automatic shuffling
- Parallel data loading capabilities
- Memory-efficient data handling

In [4]:
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
import cv2
import albumentations as A
from PIL import Image

IMAGE_PATH = "data/dataset/semantic_drone_dataset/original_images/"
MASK_PATH = "data/dataset/semantic_drone_dataset/label_images_semantic/"


def create_df():
    name = []
    for dirname, _, filenames in os.walk(IMAGE_PATH):
        for filename in filenames:
            name.append(filename.split(".")[0])
    return pd.DataFrame({"id": name}, index=np.arange(0, len(name)))


def get_data_splits():
    df = create_df()
    X_trainval, X_test = train_test_split(
        df["id"].values, test_size=0.1, random_state=19
    )
    X_train, X_val = train_test_split(X_trainval, test_size=0.15, random_state=19)
    return X_train, X_val, X_test


class DroneDataset(Dataset):
    def __init__(self, img_path, mask_path, X, mean, std, transform=None, patch=False):
        self.img_path = img_path
        self.mask_path = mask_path
        self.X = X
        self.transform = transform
        self.patches = patch
        self.mean = mean
        self.std = std

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        img = cv2.imread(self.img_path + self.X[idx] + ".jpg")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_path + self.X[idx] + ".png", cv2.IMREAD_GRAYSCALE)

        if self.transform is not None:
            aug = self.transform(image=img, mask=mask)
            img = Image.fromarray(aug["image"])
            mask = aug["mask"]

        if self.transform is None:
            img = Image.fromarray(img)

        t = T.Compose([T.ToTensor(), T.Normalize(self.mean, self.std)])
        img = t(img)
        mask = torch.from_numpy(mask).long()

        if self.patches:
            img, mask = self.tiles(img, mask)

        return img, mask

    def tiles(self, img, mask):
        img_patches = img.unfold(1, 512, 512).unfold(2, 768, 768)
        img_patches = img_patches.contiguous().view(3, -1, 512, 768)
        img_patches = img_patches.permute(1, 0, 2, 3)

        mask_patches = mask.unfold(0, 512, 512).unfold(1, 768, 768)
        mask_patches = mask_patches.contiguous().view(-1, 512, 768)

        return img_patches, mask_patches


class DroneTestDataset(Dataset):
    def __init__(self, img_path, mask_path, X, transform=None):
        self.img_path = img_path
        self.mask_path = mask_path
        self.X = X
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        img = cv2.imread(self.img_path + self.X[idx] + ".jpg")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_path + self.X[idx] + ".png", cv2.IMREAD_GRAYSCALE)

        if self.transform is not None:
            aug = self.transform(image=img, mask=mask)
            img = Image.fromarray(aug["image"])
            mask = aug["mask"]

        if self.transform is None:
            img = Image.fromarray(img)

        mask = torch.from_numpy(mask).long()

        return img, mask


def get_data_loaders(batch_size=16):
    X_train, X_val, X_test = get_data_splits()

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    t_train = A.Compose(
        [
            A.Resize(704, 1056, interpolation=cv2.INTER_NEAREST),
            A.HorizontalFlip(),
            A.VerticalFlip(),
            A.GridDistortion(p=0.2),
            A.RandomBrightnessContrast((0, 0.5), (0, 0.5)),
            A.GaussNoise(),
        ]
    )

    t_val = A.Compose(
        [
            A.Resize(704, 1056, interpolation=cv2.INTER_NEAREST),
            A.HorizontalFlip(),
            A.GridDistortion(p=0.2),
        ]
    )

    train_set = DroneDataset(
        IMAGE_PATH, MASK_PATH, X_train, mean, std, t_train, patch=False
    )
    val_set = DroneDataset(IMAGE_PATH, MASK_PATH, X_val, mean, std, t_val, patch=False)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)

    t_test = A.Resize(768, 1152, interpolation=cv2.INTER_NEAREST)
    test_set = DroneTestDataset(IMAGE_PATH, MASK_PATH, X_test, transform=t_test)

    return train_loader, val_loader, test_set


  check_for_updates()


## Metrics (`metrics.py`)

Our implementation uses two key metrics to evaluate segmentation performance:

### Pixel Accuracy
```python
def pixel_accuracy(output, mask):
    with torch.no_grad():
        output = torch.argmax(F.softmax(output, dim=1), dim=1)
        correct = torch.eq(output, mask).int()
        accuracy = float(correct.sum()) / float(correct.numel())
    return accuracy
```

- Measures percentage of correctly classified pixels
- Simple but potentially misleading for imbalanced classes
- Range: [0, 1], where 1 is perfect classification

### Mean IoU (Intersection over Union)
```python
def mIoU(pred_mask, mask, smooth=1e-10, n_classes=23):
    with torch.no_grad():
        pred_mask = F.softmax(pred_mask, dim=1)
        pred_mask = torch.argmax(pred_mask, dim=1)
        pred_mask = pred_mask.contiguous().view(-1)
        mask = mask.contiguous().view(-1)

        iou_per_class = []
        for clas in range(0, n_classes):
            true_class = pred_mask == clas
            true_label = mask == clas

            if true_label.long().sum().item() == 0:
                iou_per_class.append(np.nan)
            else:
                intersect = torch.logical_and(true_class, true_label).sum().float().item()
                union = torch.logical_or(true_class, true_label).sum().float().item()

                iou = (intersect + smooth) / (union + smooth)
                iou_per_class.append(iou)
        return np.nanmean(iou_per_class)
```

- Calculates intersection over union for each class
- Better metric for imbalanced datasets
- Handles missing classes with NaN values
- Range: [0, 1], where 1 is perfect segmentation

In [2]:
def pixel_accuracy(output, mask):
    with torch.no_grad():
        output = torch.argmax(F.softmax(output, dim=1), dim=1)
        correct = torch.eq(output, mask).int()
        accuracy = float(correct.sum()) / float(correct.numel())
    return accuracy


def mIoU(pred_mask, mask, smooth=1e-10, n_classes=23):
    with torch.no_grad():
        pred_mask = F.softmax(pred_mask, dim=1)
        pred_mask = torch.argmax(pred_mask, dim=1)
        pred_mask = pred_mask.contiguous().view(-1)
        mask = mask.contiguous().view(-1)

        iou_per_class = []
        for clas in range(0, n_classes):  # loop per pixel class
            true_class = pred_mask == clas
            true_label = mask == clas

            if true_label.long().sum().item() == 0:  # no exist label in this loop
                iou_per_class.append(np.nan)
            else:
                intersect = (
                    torch.logical_and(true_class, true_label).sum().float().item()
                )
                union = torch.logical_or(true_class, true_label).sum().float().item()

                iou = (intersect + smooth) / (union + smooth)
                iou_per_class.append(iou)
        return np.nanmean(iou_per_class)

## Model Architecture

Our implementation uses a U-Net architecture with a MobileNetV2 backbone, leveraging the segmentation-models-pytorch (smp) library for efficient implementation.

### U-Net with MobileNetV2

The model is created using the following configuration:

```python
def create_model():
    return smp.Unet(
        "mobilenet_v2",          # Encoder backbone
        encoder_weights="imagenet",  # Pre-trained weights
        classes=23,              # Number of output classes
        activation=None,         # No activation (handled by loss function)
        encoder_depth=5,         # Number of encoder stages
        decoder_channels=[256, 128, 64, 32, 16]  # Decoder channel sizes
    )
```

#### Architecture Components:

1. **Encoder (MobileNetV2)**
   - Pre-trained on ImageNet
   - Efficient mobile-first architecture
   - Feature extraction at multiple scales
   - 5 stages of downsampling

2. **Decoder**
   - Progressive upsampling path
   - Channel sizes: [256, 128, 64, 32, 16]
   - Skip connections from encoder
   - Feature refinement at each stage

3. **Output Layer**
   - 23 output channels (one per class)
   - No activation (handled by CrossEntropyLoss)

### Training Components

#### Loss Function
```python
criterion = nn.CrossEntropyLoss()
```
- Combines LogSoftmax and NLLLoss
- Handles multi-class segmentation
- Automatically normalizes class predictions

#### Optimizer
```python
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=max_lr,
    weight_decay=weight_decay
)
```
- AdamW optimizer for better generalization
- Weight decay for regularization
- Adaptive learning rate adjustments

#### Learning Rate Scheduler
```python
sched = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr,
    epochs=epochs,
    steps_per_epoch=len(train_loader)
)
```
- One Cycle learning rate policy
- Gradual warmup and cooldown
- Helps prevent overfitting
- Improves training stability

### Model Configuration

Key hyperparameters for the model:
```python
max_lr = 1e-3        # Maximum learning rate
epochs = 30          # Number of training epochs
weight_decay = 1e-4  # L2 regularization factor
batch_size = 16      # Images per batch
```

### Device Management
```python
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)  # Move model to GPU if available
```

This architecture provides:
- Efficient feature extraction through MobileNetV2
- High-resolution detail through skip connections
- Memory-efficient training
- Fast inference capabilities
- Good balance of accuracy and computational cost



### Model Architecture (U-Net) More Details

U-Net Architecture for Image Segmentation

![U-Net Architecture](https://media.geeksforgeeks.org/wp-content/uploads/20220614121231/Group14.jpg)

The U-Net architecture is a powerful neural network design specifically crafted for precise image segmentation tasks. Its distinctive U-shaped structure consists of two main paths:

Contracting Path (Encoder):
- Begins with the input image and progressively reduces spatial dimensions
- Uses consecutive convolutional layers followed by max pooling operations
- Each downsampling step doubles the number of feature channels
- Captures increasingly abstract features and broader contextual information
- Final encoded representation contains high-level semantic information

Expanding Path (Decoder) with Up-Convolutions:
- Up-convolutions (transposed convolutions) gradually restore spatial dimensions
- Each up-convolution halves the number of feature channels
- Transforms low-resolution, abstract features back into detailed spatial information
- Enables the network to generate full-resolution segmentation maps
- Learns to reconstruct spatial details from compressed representations

Skip Connections - The Critical Bridge:
- Direct connections between corresponding encoder and decoder layers
- Concatenates high-resolution features from encoder with upsampled features
- Preserves fine spatial details that would otherwise be lost in compression
- Helps combat the vanishing gradient problem during training
- Enables precise boundary detection in segmentation masks

The architecture excels at segmentation because:
1. Multi-scale feature processing captures both fine details and global context
2. Skip connections maintain spatial precision throughout the network
3. Gradual upsampling allows the network to learn optimal feature reconstruction
4. The symmetric structure balances feature extraction and reconstruction

In our implementation with MobileNetV2:
- The efficient MobileNetV2 backbone serves as the encoder
- Custom decoder layers mirror the encoder's structure
- Multiple skip connections at different scales preserve spatial information
- Final output layer produces per-pixel class predictions
- Architecture optimized for both accuracy and computational efficiency

### what is upconvolution?

Upconvolution (also called transposed convolution or deconvolution) is essentially the reverse operation of traditional convolution.

While regular convolution reduces spatial dimensions by sliding a kernel over the input:
  e.g., 4x4 input -> 2x2 output with 3x3 kernel and stride 2

Upconvolution increases spatial dimensions by:
1. Inserting zeros between input elements (based on stride)
2. Performing regular convolution with a learnable kernel
3. Producing a larger output
  e.g., 2x2 input -> 4x4 output with 3x3 kernel and stride 2

Key differences from regular convolution:
- Increases spatial dimensions instead of reducing them
- Still uses learnable kernels but applies them differently
- Often used in decoders to restore resolution lost in encoding

The name "deconvolution" is technically incorrect since it's not truly inverting convolution,
but rather learning an upsampling transformation that best reconstructs the desired output resolution.




## Training and Inference (`model.py`)

### Training Pipeline
```python
def train_model(train_loader, val_loader, epochs=30, checkpoint_dir="checkpoints"):
    model = create_model()
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs,
                                                  steps_per_epoch=len(train_loader))

    history = fit(epochs, model, train_loader, val_loader, criterion,
                 optimizer, scheduler, checkpoint_dir=checkpoint_dir)
    return model, history
```

Key components:
- Model creation with MobileNetV2 backbone
- CrossEntropyLoss for multi-class segmentation
- AdamW optimizer with OneCycleLR scheduler
- Checkpoint management for model saving/loading

### Inference Pipeline
```python
def inference(model_path, image_path):
    model = create_model()
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    with torch.no_grad():
        image = load_image(image_path)
        prediction = model(image.unsqueeze(0))
        prediction = torch.argmax(prediction.squeeze(), dim=0)
    return prediction
```

Features:
- Model loading from checkpoint
- Single image prediction
- No gradient computation during inference
- Returns class predictions per pixel

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import segmentation_models_pytorch as smp

from tqdm.notebook import tqdm

import time
import numpy as np
from metrics import mIoU, pixel_accuracy
import os
from torchvision import transforms  # Change this line

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def create_model():
    return smp.Unet(
        "mobilenet_v2",
        encoder_weights="imagenet",
        classes=23,
        activation=None,
        encoder_depth=5,
        decoder_channels=[256, 128, 64, 32, 16],
    )


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]


def fit(
    epochs,
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    scheduler,
    patch=False,
    checkpoint_dir="checkpoints",
):
    torch.cuda.empty_cache()
    train_losses = []
    test_losses = []
    val_iou = []
    val_acc = []
    train_iou = []
    train_acc = []
    lrs = []
    min_loss = np.inf
    decrease = 1
    not_improve = 0

    # Create checkpoint directory if it doesn't exist
    os.makedirs(checkpoint_dir, exist_ok=True)

    model.to(device)
    fit_time = time.time()
    for e in range(epochs):
        since = time.time()
        running_loss = 0
        iou_score = 0
        accuracy = 0
        # training loop
        model.train()
        for i, data in enumerate(tqdm(train_loader)):
            # training phase
            image_tiles, mask_tiles = data
            if patch:
                bs, n_tiles, c, h, w = image_tiles.size()
                image_tiles = image_tiles.view(-1, c, h, w)
                mask_tiles = mask_tiles.view(-1, h, w)

            image = image_tiles.to(device)
            mask = mask_tiles.to(device)
            # forward
            output = model(image)
            loss = criterion(output, mask)
            # evaluation metrics
            iou_score += mIoU(output, mask)
            accuracy += pixel_accuracy(output, mask)
            # backward
            loss.backward()
            optimizer.step()  # update weight
            optimizer.zero_grad()  # reset gradient

            # step the learning rate
            lrs.append(get_lr(optimizer))
            scheduler.step()

            running_loss += loss.item()

        else:
            model.eval()
            test_loss = 0
            test_accuracy = 0
            val_iou_score = 0
            # validation loop
            with torch.no_grad():
                for i, data in enumerate(tqdm(val_loader)):
                    # reshape to 9 patches from single image, delete batch size
                    image_tiles, mask_tiles = data

                    if patch:
                        bs, n_tiles, c, h, w = image_tiles.size()
                        image_tiles = image_tiles.view(-1, c, h, w)
                        mask_tiles = mask_tiles.view(-1, h, w)

                    image = image_tiles.to(device)
                    mask = mask_tiles.to(device)
                    output = model(image)
                    # evaluation metrics
                    val_iou_score += mIoU(output, mask)
                    test_accuracy += pixel_accuracy(output, mask)
                    # loss
                    loss = criterion(output, mask)
                    test_loss += loss.item()

            # calculation mean for each batch
            train_losses.append(running_loss / len(train_loader))
            test_losses.append(test_loss / len(val_loader))

            if min_loss > (test_loss / len(val_loader)):
                print(
                    "Loss Decreasing.. {:.3f} >> {:.3f} ".format(
                        min_loss, (test_loss / len(val_loader))
                    )
                )
                min_loss = test_loss / len(val_loader)
                decrease += 1
                if decrease % 5 == 0:
                    print("saving model...")
                    torch.save(
                        model,
                        "Unet-Mobilenet_v2_mIoU-{:.3f}.pt".format(
                            val_iou_score / len(val_loader)
                        ),
                    )

            if (test_loss / len(val_loader)) > min_loss:
                not_improve += 1
                min_loss = test_loss / len(val_loader)
                print(f"Loss Not Decrease for {not_improve} time")
                if not_improve == 7:
                    print("Loss not decrease for 7 times, Stop Training")
                    break

            # iou
            val_iou.append(val_iou_score / len(val_loader))
            train_iou.append(iou_score / len(train_loader))
            train_acc.append(accuracy / len(train_loader))
            val_acc.append(test_accuracy / len(val_loader))
            print(
                "Epoch:{}/{}..".format(e + 1, epochs),
                "Train Loss: {:.3f}..".format(running_loss / len(train_loader)),
                "Val Loss: {:.3f}..".format(test_loss / len(val_loader)),
                "Train mIoU:{:.3f}..".format(iou_score / len(train_loader)),
                "Val mIoU: {:.3f}..".format(val_iou_score / len(val_loader)),
                "Train Acc:{:.3f}..".format(accuracy / len(train_loader)),
                "Val Acc:{:.3f}..".format(test_accuracy / len(val_loader)),
                "Time: {:.2f}m".format((time.time() - since) / 60),
            )

        # Save checkpoint
        history = {
            "train_loss": train_losses,
            "val_loss": test_losses,
            "train_miou": train_iou,
            "val_miou": val_iou,
            "train_acc": train_acc,
            "val_acc": val_acc,
            "lrs": lrs,
        }
        save_checkpoint(
            model, optimizer, scheduler, e + 1, min_loss, history, checkpoint_dir
        )

    print("Total time: {:.2f} m".format((time.time() - fit_time) / 60))
    return history


def predict_image_mask_miou(
    model, image, mask, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
):
    model.eval()
    t = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize(mean, std)]
    )  # Change this line
    image = t(image)
    model.to(device)
    image = image.to(device)
    mask = mask.to(device)
    with torch.no_grad():
        image = image.unsqueeze(0)
        mask = mask.unsqueeze(0)
        output = model(image)
        score = mIoU(output, mask)
        masked = torch.argmax(output, dim=1)
        masked = masked.cpu().squeeze(0)
    return masked, score


def predict_image_mask_pixel(
    model, image, mask, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
):
    model.eval()
    t = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize(mean, std)]
    )  # Change this line
    image = t(image)
    model.to(device)
    image = image.to(device)
    mask = mask.to(device)
    with torch.no_grad():
        image = image.unsqueeze(0)
        mask = mask.unsqueeze(0)
        output = model(image)
        acc = pixel_accuracy(output, mask)
        masked = torch.argmax(output, dim=1)
        masked = masked.cpu().squeeze(0)
    return masked, acc


def evaluate_model(model, test_set):
    score_iou = []
    accuracy = []
    for i in tqdm(range(len(test_set))):
        img, mask = test_set[i]
        pred_mask, score = predict_image_mask_miou(model, img, mask)
        score_iou.append(score)
        _, acc = predict_image_mask_pixel(model, img, mask)
        accuracy.append(acc)
    return np.mean(score_iou), np.mean(accuracy)


def load_checkpoint(model, optimizer, scheduler, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    epoch = checkpoint["epoch"]
    loss = checkpoint["loss"]
    history = {
        "train_loss": checkpoint["train_loss"],
        "val_loss": checkpoint["val_loss"],
        "train_miou": checkpoint["train_miou"],
        "val_miou": checkpoint["val_miou"],
        "train_acc": checkpoint["train_acc"],
        "val_acc": checkpoint["val_acc"],
        "lrs": checkpoint["lrs"],
    }
    return model, optimizer, scheduler, epoch, loss, history


def get_latest_checkpoint(checkpoint_dir):
    checkpoints = [
        f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_epoch_")
    ]
    if not checkpoints:
        return None
    latest_checkpoint = max(
        checkpoints, key=lambda x: int(x.split("_")[-1].split(".")[0])
    )
    return os.path.join(checkpoint_dir, latest_checkpoint)


def resume_from_checkpoint(model, optimizer, scheduler, checkpoint_dir):
    latest_checkpoint = get_latest_checkpoint(checkpoint_dir)
    if latest_checkpoint:
        model, optimizer, scheduler, start_epoch, min_loss, history = load_checkpoint(
            model, optimizer, scheduler, latest_checkpoint
        )
        print(f"Resuming training from epoch {start_epoch}")
        return model, optimizer, scheduler, start_epoch, min_loss, history
    else:
        print("No checkpoint found. Starting from scratch.")
        return model, optimizer, scheduler, 0, float("inf"), {}


def save_checkpoint(model, optimizer, scheduler, epoch, loss, history, checkpoint_dir):
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "loss": loss,
        "train_loss": history.get("train_loss", []),
        "val_loss": history.get("val_loss", []),
        "train_miou": history.get("train_miou", []),
        "val_miou": history.get("val_miou", []),
        "train_acc": history.get("train_acc", []),
        "val_acc": history.get("val_acc", []),
        "lrs": history.get("lrs", []),
    }
    torch.save(
        checkpoint, os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pth")
    )



## Visualization (`viz.py`)

### Training Progress Visualization
```python
def plot_training_progress(history):
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Loss curves
    axes[0,0].plot(history['train_loss'], label='Training Loss')
    axes[0,0].plot(history['val_loss'], label='Validation Loss')
    axes[0,0].set_title('Loss vs Epochs')

    # IoU curves
    axes[0,1].plot(history['train_miou'], label='Training mIoU')
    axes[0,1].plot(history['val_miou'], label='Validation mIoU')
    axes[0,1].set_title('mIoU vs Epochs')

    # Accuracy curves
    axes[1,0].plot(history['train_acc'], label='Training Accuracy')
    axes[1,0].plot(history['val_acc'], label='Validation Accuracy')
    axes[1,0].set_title('Accuracy vs Epochs')

    # Learning rate
    axes[1,1].plot(history['lrs'], label='Learning Rate')
    axes[1,1].set_title('Learning Rate vs Steps')

    plt.tight_layout()
    plt.show()
```

### Prediction Visualization
```python
def visualize_prediction(image, mask, prediction):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    axes[0].imshow(image)
    axes[0].set_title('Original Image')

    axes[1].imshow(mask)
    axes[1].set_title('Ground Truth')

    axes[2].imshow(prediction)
    axes[2].set_title('Prediction')

    plt.tight_layout()
    plt.show()
```

Features:
- Training metrics visualization
  - Loss curves
  - IoU progress
  - Accuracy tracking
  - Learning rate schedule
- Prediction visualization
  - Side-by-side comparison
  - Original image
  - Ground truth mask
  - Model prediction

In [2]:
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import numpy as np
import torch
from torchvision import transforms as T


def plot_loss(history):
    plt.figure(figsize=(10, 5))
    plt.plot(history["val_loss"], label="val", marker="o")
    plt.plot(history["train_loss"], label="train", marker="o")
    plt.title("Loss per epoch")
    plt.ylabel("loss")
    plt.xlabel("epoch")
    plt.legend(), plt.grid()
    plt.show()


def plot_score(history):
    plt.figure(figsize=(10, 5))
    plt.plot(history["train_miou"], label="train_mIoU", marker="*")
    plt.plot(history["val_miou"], label="val_mIoU", marker="*")
    plt.title("Score per epoch")
    plt.ylabel("mean IoU")
    plt.xlabel("epoch")
    plt.legend(), plt.grid()
    plt.show()


def plot_acc(history):
    plt.figure(figsize=(10, 5))
    plt.plot(history["train_acc"], label="train_accuracy", marker="*")
    plt.plot(history["val_acc"], label="val_accuracy", marker="*")
    plt.title("Accuracy per epoch")
    plt.ylabel("Accuracy")
    plt.xlabel("epoch")
    plt.legend(), plt.grid()
    plt.show()


def visualize_predictions(
    model,
    test_set,
    output_pdf,
    num_classes=23,
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225],
):
    model.eval()
    device = next(model.parameters()).device

    # Create a colormap for the segmentation mask
    cmap = plt.get_cmap("tab20")
    colors = [cmap(i) for i in np.linspace(0, 1, num_classes)]

    with PdfPages(output_pdf) as pdf:
        for i, (img, mask) in enumerate(test_set):
            # Prepare the image
            img_tensor = T.Compose([T.ToTensor(), T.Normalize(mean, std)])(img)
            img_tensor = img_tensor.unsqueeze(0).to(device)

            # Get the prediction
            with torch.no_grad():
                output = model(img_tensor)
                pred_mask = torch.argmax(output, dim=1).squeeze().cpu().numpy()

            # Create a figure with three subplots
            fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
            fig.suptitle(f"Test Image {i+1}")

            # Plot original image
            ax1.imshow(img)
            ax1.set_title("Original Image")
            ax1.axis("off")

            # Plot ground truth mask
            ax2.imshow(mask, cmap=cmap, vmin=0, vmax=num_classes - 1)
            ax2.set_title("Ground Truth")
            ax2.axis("off")

            # Plot predicted mask
            ax3.imshow(pred_mask, cmap=cmap, vmin=0, vmax=num_classes - 1)
            ax3.set_title("Prediction")
            ax3.axis("off")

            # Add the plot to the PDF
            pdf.savefig(fig)
            plt.close(fig)

        print(f"Visualizations saved to {output_pdf}")



def visualize_sample(dataloader, num_classes=23):
    """Visualize a single sample image and its segmentation mask from a dataloader."""
    # Get a single batch from the dataloader
    images, masks = next(iter(dataloader))

    # Take the first image and mask from the batch
    img = images[0].permute(1, 2, 0).numpy()  # Convert from CxHxW to HxWxC
    mask = masks[0].numpy()

    # Denormalize the image
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    img = np.clip(img, 0, 1)

    # Create a colormap for the segmentation mask
    cmap = plt.get_cmap("tab20")

    # Create a figure with three subplots
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

    # Plot original image
    ax1.imshow(img)
    ax1.set_title("Sample Image")
    ax1.axis("off")

    # Plot segmentation mask
    ax2.imshow(mask, cmap=cmap, vmin=0, vmax=num_classes - 1)
    ax2.set_title("Segmentation Mask")
    ax2.axis("off")

    # Plot overlay
    ax3.imshow(img)
    ax3.imshow(mask, cmap=cmap, vmin=0, vmax=num_classes - 1, alpha=0.5)
    ax3.set_title("Overlay")
    ax3.axis("off")

    plt.tight_layout()
    plt.show()


## Download the Dataset

The next cell downloads a drone dataset from Kaggle and extracts it into a data directory.
If the data directory already exists, it skips the download.
The dataset is downloaded as a zip file and then extracted.



In [5]:
import requests
import zipfile
from pathlib import Path

# Setup path to data folder
# data_path = Path("data/")
# image_path = data_path / "pizza_steak_sushi"

DATA_PATH = Path("data/")

# If the image folder doesn't exist, download it and prepare it...
if Path(DATA_PATH).is_dir():
    print(f"{DATA_PATH} directory exists.")
else:
    print(f"Did not find {DATA_PATH} directory, creating one...")
    Path(DATA_PATH).mkdir(parents=True, exist_ok=True)

    # Download pizza, steak, sushi data
    with open(DATA_PATH / "archive.zip", "wb") as f:
        # request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
        request = requests.get("https://storage.googleapis.com/kaggle-data-sets/333968/1834160/bundle/archive.zip?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20241023%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20241023T113757Z&X-Goog-Expires=259200&X-Goog-SignedHeaders=host&X-Goog-Signature=3e3bbcf7cb80c007d26471d7f7be115d075367ab5a6e241b83823607ac7683cb813a1757c5e71ee5052498b967758686f92be595272f684bf1a90bd8a21681ba7ba7a34e074464ac5e2d3b944af4ebf34d425d50281034b3fd3c17f1f15320f27eaf578cfbead4c6e40b721f1209333e55c6185b157001d9afd3762fd3f6eadb67ee4841ba059b999775c14615537f31e44b0f3e2cea010e3c13b612d18d952cf22c7d101962cdefe0da4d4e6a03345f9d3ceb14048de01e987345e318361b9d2f8cea7c9fb749de9c78eea4795da2e71ae5d8e065206627970bebb1eb523d7cf03d413978eb542f3f0500538b53bda10f198ca97f5f85d267bbe40b269487dd")
        print("Downloading drone dataset ...")
        f.write(request.content)

    # Unzip pizza, steak, sushi data
    with zipfile.ZipFile(DATA_PATH / "archive.zip", "r") as zip_ref:
        print("Unzipping drone dataset ...")
        zip_ref.extractall(DATA_PATH)

data directory exists.


## Main Training Script

In [6]:
import torch
import torch.nn as nn
# from cnn_data import get_data_loaders
# from model import create_model, fit, evaluate_model, resume_from_checkpoint
# from viz import plot_loss, plot_score, plot_acc, visualize_predictions
from tqdm import tqdm
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inference_only = True

# Get data loaders
train_loader, val_loader, test_set = get_data_loaders(batch_size=16)

# Create model
model = create_model()

# Training parameters
max_lr = 1e-3
epochs = 4
weight_decay = 1e-4

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)
sched = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr, epochs=epochs, steps_per_epoch=len(train_loader)
)


In [None]:

if not inference_only:
    # Check for existing checkpoints and resume training if possible
    checkpoint_dir = "checkpoints"
    os.makedirs(checkpoint_dir, exist_ok=True)
    model, optimizer, sched, start_epoch, min_loss, history = resume_from_checkpoint(
        model, optimizer, sched, checkpoint_dir
    )

    # Train the model
    history = fit(
        epochs - start_epoch,
        model,
        train_loader,
        val_loader,
        criterion,
        optimizer,
        sched,
        checkpoint_dir=checkpoint_dir,
    )

    # Save the final model
    torch.save(model, "Unet-Mobilenet.pt")

    # Plot training results
    plot_loss(history)
    plot_score(history)
    plot_acc(history)

## Main Inference Script

In [None]:
if inference_only:
    # load model from .pt file
    model = torch.load("Unet-Mobilenet.pt")

    # Evaluate on test set
    test_miou, test_accuracy = evaluate_model(model, test_set)
    print("Test Set mIoU:", test_miou)
    print("Test Set Pixel Accuracy:", test_accuracy)

    # Visualize predictions
    visualize_predictions(model, test_set, "test_predictions.pdf")


# Exercise Problems

## Problem 1

In [None]:
# Problem 1


## Problem 2

In [None]:
# Problem 2

## Problem 3

In [None]:
# Problem 3

## Problem 4

In [None]:
# Problem 4

## Problem 5

In [None]:
# Problem 5

## Problem 6

In [None]:
# Problem 6

## Problem 7

In [None]:
# Problem 7

## Problem 8

In [None]:
# Problem 8

## Problem 9

In [None]:
# Problem 9