The Rellis-3D dataset is split into:
1) Full Image (11GB) -> the RGB inputs.
2) Full Image Annotations(94MB, ID Format) -> segmentation labels where each pixel corresponds to a class Id.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from PIL import Image

We create a custom PyTorch dataset class called `Rellis3DDataset` that loads images and segmentation masks from the Rellis-3D dataset.

 In the `__init__` method, we define the root directory of the dataset, collect sorted file paths for images and labels, set the target resize dimensions, and optionally accept a transform function for preprocessing.

  The `__len__` method simply returns the number of available images, while the `__getitem__` method handles loading a single sample. For each index, it loads the corresponding RGB image and segmentation mask, resizes the image using bilinear interpolation (to preserve smoothness) and the mask using nearest neighbor interpolation (to keep class IDs intact), applies any defined transforms to the image (such as converting to a tensor and normalizing with ImageNet mean and standard deviation), and converts the label to a tensor of class IDs.
  
  Finally, it returns the image and its label as a pair. The transform pipeline consists of converting the image to a tensor and normalizing it, which is standard practice when using pretrained models like UNet encoders.
  
   We then create a dataset instance pointing to the validation folder of Rellis-3D and wrap it with a DataLoader, which batches the data (batch size of 4 in this example) and handles iteration. With this setup, we now have a clean way to load, preprocess, and batch Rellis-3D samples for training or evaluation.

In [None]:
class Rellis3DDataset(Dataset):
    def __init__(self, root_dir, image_size=(512, 512), transform=None):
        self.root_dir = Path(root_dir)
        self.image_paths = sorted((self.root_dir / "images").glob("*.png"))
        self.label_paths = sorted((self.root_dir / "labels").glob("*.png"))
        self.transform = transform
        self.image_size = image_size

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        label = Image.open(self.label_paths[idx])


        img = img.resize(self.image_size, Image.BILINEAR)
        label = label.resize(self.image_size, Image.NEAREST)

        if self.transform:
            img = self.transform(img)

        label = torch.from_numpy(np.array(label)).long()
        return img, label



transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])


dataset = Rellis3DDataset("path/to/Rellis-3D/val", transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)


# UNet Model Explanation (PyTorch)

This document explains the structure and working of a **UNet model** implemented in PyTorch for image segmentation.

---

## 1. DoubleConv Block

- **Purpose:** Performs feature extraction through **two consecutive convolution layers**.
- **Operations:**
  - **Conv2d:** 2D convolution extracts features from input images.
  - **BatchNorm2d:** Normalizes outputs of the convolution layer for faster and stable training.
  - **ReLU:** Introduces non-linearity.
- **Padding:** `padding=1` ensures the output spatial dimensions are the same as the input.
- **Benefit:** Two convolutions per block allow the network to learn richer feature representations at each level.

---

## 2. UNet Architecture

UNet is an **encoder-decoder network** with skip connections, specifically designed for image segmentation.

### Encoder (Downsampling Path)

- Comprised of **DoubleConv blocks** followed by **Max Pooling**.
- Reduces spatial dimensions while increasing the number of feature channels.
- Example progression: `3 → 64 → 128 → 256 → 512` channels.
- **MaxPool2d(2):** Reduces width and height by half at each step.

### Bottleneck

- The **deepest layer** in the network.
- Uses a **DoubleConv block** to capture high-level features with maximum channels (`512 → 1024`).
- Acts as a bridge between encoder and decoder.

### Decoder (Upsampling Path)

- Upsamples the feature maps using **ConvTranspose2d**, increasing spatial resolution.
- **Skip Connections:** Each upsampled feature map is concatenated with the corresponding encoder output.
  - Helps the network preserve fine-grained spatial information.
- Reduces channel dimension progressively: `1024 → 512 → 256 → 128 → 64`.

### Output Layer

- **Final Conv2d:** Reduces channels to the number of classes (`n_classes`), producing the segmentation map.

---

## 3. Forward Pass Workflow

1. **Encoding:** Input passes through encoder blocks, producing feature maps (`e1, e2, e3, e4`).
2. **Pooling:** Max pooling downsamples feature maps between encoder layers.
3. **Bottleneck:** Deepest layer extracts high-level features.
4. **Decoding:**
   - Upsample bottleneck features.
   - Concatenate with encoder features (skip connections).
   - Refine features using DoubleConv blocks.
5. **Final Output:** Last convolution layer outputs the segmentation map with `n_classes` channels.

---

## 4. Model Initialization and Pretrained Weights

- `n_classes = 19`: Number of segmentation classes.
- **Loading Weights:** Pretrained weights are loaded using `load_state_dict`.
- **Evaluation Mode:** `eval()` disables training-specific layers like dropout and fixes batch normalization.
- **GPU Usage:** `cuda()` moves the model to GPU for faster inference.

---

## 5. Key Points

- **Skip Connections:** Preserve spatial details lost during downsampling.
- **DoubleConv Blocks:** Improve feature extraction at each level.
- **Encoder-Decoder Symmetry:** Ensures feature information is combined effectively during reconstruction.
- **UNet Use Case:** Commonly used in semantic segmentation tasks such as satellite imagery, medical imaging, and autonomous driving.


In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)


class UNet(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.enc1 = DoubleConv(3, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)

        self.pool = nn.MaxPool2d(2)

        self.bottleneck = DoubleConv(512, 1024)

        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(128, 64)

        self.final = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        b = self.bottleneck(self.pool(e4))

        d4 = self.up4(b)
        d4 = self.dec4(torch.cat([d4, e4], dim=1))
        d3 = self.up3(d4)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))

        return self.final(d1)


n_classes = 19
model = UNet(n_classes=n_classes)
model.load_state_dict(torch.load("unet_rellis3d.pth", map_location="cuda"))
model.eval().cuda()


# Evaluation of UNet Model for Segmentation

This document explains the evaluation functions used to measure the performance of a **UNet segmentation model** using metrics like **mIoU** and **Pixel Accuracy**.

---

## 1. Confusion Matrix

**Function:** `compute_confusion_matrix(pred, label, num_classes)`

- **Purpose:** Builds a confusion matrix for a single prediction-label pair.
- **Process:**
  - **Masking:** Ensures only valid labels (0 ≤ label < num_classes) are considered.
  - **Indexing:** Each pair of `(label, pred)` is mapped to a unique index: `num_classes * label + pred`.
  - **Bincount:** Counts occurrences of each `(label, pred)` pair.
  - **Reshape:** Converts the flat count vector into a `(num_classes x num_classes)` confusion matrix.

**Key Concept:**

- **Rows:** Ground truth classes.
- **Columns:** Predicted classes.
- **Diagonal:** Correct predictions.

---

## 2. Mean Intersection over Union (mIoU)

**Function:** `compute_mIoU(confusion_matrix)`

- **Purpose:** Measures the overlap between predicted and true segmentation regions.
- **Computation:**
  - **Intersection:** Diagonal of the confusion matrix → correctly predicted pixels for each class.
  - **Union:** Sum of row + sum of column - intersection → total pixels covered by either prediction or ground truth for each class.
  - **IoU per class:** `intersection / union`.
  - **Mean IoU:** Average across all classes.

**Significance:** mIoU is a standard metric for semantic segmentation, reflecting per-class accuracy while accounting for class imbalance.

---

## 3. Evaluation Function

**Function:** `evaluate(model, dataloader, num_classes=19)`

- **Purpose:** Computes **mIoU** and **Pixel Accuracy** over the entire dataset.
- **Steps:**
  1. Initialize a zeroed confusion matrix.
  2. Disable gradient computation with `torch.no_grad()` (faster evaluation).
  3. Iterate over batches from `dataloader`:
     - Move images and labels to GPU.
     - Run the model to get predictions.
     - Convert raw outputs to class predictions using `argmax`.
     - Update the confusion matrix for each image in the batch.
  4. Compute **mIoU** using `compute_mIoU`.
  5. Compute **Pixel Accuracy**: total correctly predicted pixels divided by total pixels.

**Outputs:**

- **mIoU:** Mean Intersection over Union (range: 0–1)
- **Pixel Accuracy:** Overall fraction of correctly classified pixels (range: 0–1)

---

## 4. Example Usage

- Call the evaluation function with the model and dataloader:

```python
mIoU, pixel_acc = evaluate(model, dataloader, num_classes=19)
print(f"mIoU: {mIoU:.4f}, Pixel Accuracy: {pixel_acc:.4f}")


In [None]:
def compute_confusion_matrix(pred, label, num_classes):
    mask = (label >= 0) & (label < num_classes)
    hist = torch.bincount(
        num_classes * label[mask] + pred[mask],
        minlength=num_classes ** 2
    ).reshape(num_classes, num_classes)
    return hist

def compute_mIoU(confusion_matrix):
    intersection = torch.diag(confusion_matrix)
    union = confusion_matrix.sum(1) + confusion_matrix.sum(0) - intersection
    IoU = intersection / union
    return IoU.mean().item()

def evaluate(model, dataloader, num_classes=19):
    confusion_matrix = torch.zeros((num_classes, num_classes), dtype=torch.int64)

    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs, labels = imgs.cuda(), labels.cuda()
            outputs = model(imgs)
            preds = torch.argmax(outputs, dim=1)

            for p, l in zip(preds, labels):
                confusion_matrix += compute_confusion_matrix(p.view(-1), l.view(-1), num_classes)

    mIoU = compute_mIoU(confusion_matrix)
    pixel_acc = torch.diag(confusion_matrix).sum().item() / confusion_matrix.sum().item()
    return mIoU, pixel_acc


mIoU, pixel_acc = evaluate(model, dataloader, num_classes=19)
print(f"mIoU: {mIoU:.4f}, Pixel Accuracy: {pixel_acc:.4f}")


### Visualization

In [None]:
def visualize(model, dataset, idx=0):
    model.eval()
    img, label = dataset[idx]
    img_input = img.unsqueeze(0).cuda()

    with torch.no_grad():
        output = model(img_input)
        pred = torch.argmax(output, dim=1).squeeze().cpu().numpy()

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(img.permute(1,2,0).cpu().numpy() * 0.229 + 0.485)  # approx unnormalize
    axes[0].set_title("Input Image")
    axes[1].imshow(label.numpy())
    axes[1].set_title("Ground Truth")
    axes[2].imshow(pred)
    axes[2].set_title("Prediction")
    plt.show()



visualize(model, dataset, idx=5)
