In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Import

In [None]:
!pip install monai

In [None]:
!pip install torchinfo

In [29]:
import os
import numpy as np
import matplotlib.pyplot as plt

import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import DataLoader, Dataset
from torchinfo import summary
from torchvision import models
from torch.optim.lr_scheduler import CosineAnnealingLR # learning rate scheduler
from monai.losses import DiceCELoss

# Augmentation
import albumentations as A

# TensorBoard logging
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid

# Garbage collection
import gc

# To accelerate 16-bit matrix multiplication
from torch.cuda.amp import autocast, GradScaler

## Move data to local disk

Replace the `zip_train_source_path` and `zip_val_source_path` to your own paths.

In [None]:
zip_train_source_path = "/content/drive/MyDrive/VerdaSense/Dataset/zipped/train.zip"
zip_val_source_path = "/content/drive/MyDrive/VerdaSense/Dataset/zipped/validation.zip"

local_data_dir = "/content/data"

!mkdir -p "$local_data_dir"

print(f"Copying {zip_train_source_path} to {local_data_dir}")
!cp "$zip_train_source_path" "$local_data_dir/"

print(f"Copying {zip_val_source_path} to {local_data_dir}")
!cp "$zip_val_source_path" "$local_data_dir/"

print("Copying complete.")

Copying /content/drive/MyDrive/VerdaSense/Dataset/zipped/train-v2.zip to /content/data
Copying /content/drive/MyDrive/VerdaSense/Dataset/zipped/validation-v2.zip to /content/data
Copying complete.


In [None]:
local_zip_train_path = f"{local_data_dir}/train.zip"
local_zip_val_path = f"{local_data_dir}/validation.zip"

unzip_destination_path = local_data_dir

print(f"Unzipping {local_zip_train_path} to {unzip_destination_path}")
!unzip -q "$local_zip_train_path" -d "$unzip_destination_path"

print(f"Unzipping {local_zip_val_path} to {unzip_destination_path}")
!unzip -q "$local_zip_val_path" -d "$unzip_destination_path"

print("Unzipping complete.")

In [None]:
local_train_image_path = os.path.join(local_data_dir, "train", "images")
local_train_mask_path = os.path.join(local_data_dir, "train", "masks")

# Check if directories exist
if os.path.exists(local_train_image_path):
  num_images = len(os.listdir(local_train_image_path))
  print(f"Number of images in {local_train_image_path}: {num_images}")
else:
  print(f"Directory {local_train_image_path} does not exist.")

if os.path.exists(local_train_mask_path):
  num_masks = len(os.listdir(local_train_mask_path))
  print(f"Number of masks in {local_train_mask_path}: {num_masks}")
else:
  print(f"Directory {local_train_mask_path} does not exist.")

In [None]:
local_val_image_path = os.path.join(local_data_dir, "validation", "images")
local_val_mask_path = os.path.join(local_data_dir, "validation", "masks")

# Check if directories exist
if os.path.exists(local_val_image_path):
  num_images = len(os.listdir(local_val_image_path))
  print(f"Number of images in {local_val_image_path}: {num_images}")
else:
  print(f"Directory {local_val_image_path} does not exist.")

if os.path.exists(local_val_mask_path):
  num_masks = len(os.listdir(local_val_mask_path))
  print(f"Number of masks in {local_val_mask_path}: {num_masks}")
else:
  print(f"Directory {local_val_mask_path} does not exist.")

## Data augmentation pipeline

In [9]:
# Standard ImageNet Normalization
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

In [10]:
train_transform = A.Compose([
    # Geometric transforms
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.3),
    A.RandomRotate90(p=0.3),
    A.Affine(
        scale=(0.95, 1.05),
        translate_percent=(-0.05, 0.05),
        rotate=(-15, 15),
        border_mode=cv2.BORDER_REFLECT,
        p=0.5
    ),

    # Photometric transforms (brightness & contrast)
    A.RandomBrightnessContrast(
        brightness_limit=0.15,
        contrast_limit=0.15,
        p=0.5
    ),

    # Noise & blur
    A.OneOf([
        A.GaussNoise(std_range=(0.1,0.15)),  # Lower range, reduce blur
        A.MedianBlur(blur_limit=3),
    ], p=0.3),

    # Occlusion
    A.CoarseDropout(
        num_holes_range = (1, 4),
        p=0.4
    ),
])


## custom SegmentationDataset

In [11]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, img_size=(1024, 1024), mean=IMAGENET_MEAN, std=IMAGENET_STD):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform # Only for geometric transformations
        self.images = sorted(os.listdir(image_dir))
        self.masks = sorted(os.listdir(mask_dir))
        self.img_size = img_size
        self.mean = np.array(mean)
        self.std = np.array(std)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        mask_name = self.masks[idx]

        image_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, mask_name)

        # Albumentations expects a NumPy array with uint8 data type
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        # Apply the geometric transformations
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Manual steps for resize, standardization, and conversion
        # 1. Resize if necessary
        if (image.shape[1], image.shape[0]) != self.img_size:
            image = cv2.resize(image, self.img_size)
            mask = cv2.resize(mask, self.img_size, interpolation=cv2.INTER_NEAREST)

        # 2. Convert image to float and standardize
        image = image.astype("float32") / 255.0
        image = (image - self.mean) / self.std

        # 3. Convert image to torch tensor and permute (HWC -> CHW)
        image = torch.from_numpy(image).permute(2, 0, 1)

        # 4. Convert mask to float and add a channel dimension
        mask = (mask > 0).astype("float32") # Convert to binary (0.0 or 1.0)
        mask = torch.from_numpy(mask).unsqueeze(0) # HW -> 1HW

        return image, mask

In [12]:
# Create the dataset using the on-the-fly transformations
train_dataset = SegmentationDataset(local_train_image_path, local_train_mask_path, transform=train_transform)
val_dataset = SegmentationDataset(local_val_image_path, local_val_mask_path)

## Helper functions

In [13]:
def denormalize_image(image_tensor, mean=IMAGENET_MEAN, std=IMAGENET_STD):
    """Reverses the normalization process for visualization."""
    image_np = image_tensor.permute(1, 2, 0).cpu().numpy() # CHW -> HWC
    image_np = std * image_np + mean
    return np.clip(image_np, 0, 1) # Clip values to be in the [0, 1] range

In [14]:
def get_confusion_matrix_components(y_true, y_pred):
    """
    Calculates the confusion matrix components (TP, FP, FN) for a batch.

    Args:
        y_true (torch.Tensor): Ground truth masks, a tensor of 0s and 1s.
        y_pred (torch.Tensor): Binary tensor after applying the sigmoid function and threshold.

    Returns:
        tuple: A tuple containing True Positives, False Positives, False Negatives, and True Negatives.
    """

    # Flatten tensors for easier calculation
    y_true_flat = y_true.view(-1)
    y_pred_flat = y_pred.view(-1)

    # Calculate confusion matrix components
    true_positives = ((y_pred_flat == 1) & (y_true_flat == 1)).sum().item()
    false_positives = ((y_pred_flat == 1) & (y_true_flat == 0)).sum().item()
    false_negatives = ((y_pred_flat == 0) & (y_true_flat == 1)).sum().item()
    true_negatives = ((y_pred_flat == 0) & (y_true_flat == 0)).sum().item()

    return true_positives, false_positives, false_negatives, true_negatives


def calculate_final_metrics(tp, fp, fn, tn, smooth=1e-6):
    """
    Calculates final metrics from accumulated confusion matrix components.
    """
    # IoU
    intersection = tp
    union = tp + fp + fn
    iou = intersection / (union + smooth)

    # Recall (Sensitivity)
    recall = tp / (tp + fn + smooth)

    # Precision (Positive Predictive Value)
    precision = tp / (tp + fp + smooth)

    # Dice Coefficient / F1
    dice = (2 * precision * recall) / (precision + recall + smooth)

    # Accuracy
    accuracy = (tp + tn) / (tp + tn + fp + fn + smooth)

    return iou, dice, recall, precision, accuracy

In [15]:
# To prepare masks for TensorBoard
def prepare_mask_for_tb(image, mask):
  """
  Converts a binary mask (H, W) into an RGB overlay (3, H, W)
  so it can be displayed in TensorBoard.
  """
  # Normalize image to 0-1 if it isn't already
  if image.max() > 1: image = image / 255.0

  # Ensure mask is 0 or 1
  mask = mask.squeeze()

  # Create a red overlay for the mask
  overlay = torch.zeros_like(image)
  overlay[0, :, :] = mask # Red channel

  # Blend 70% Original Image + 30% Red Mask
  blended = (image * 0.4) + (overlay * 0.6)

  # Return blended image where mask exists, otherwise original image
  return torch.where(mask.unsqueeze(0) > 0, blended, image)

In [16]:
def create_comparison_grid(images, labels, preds, mean=IMAGENET_MEAN, std=IMAGENET_STD, max_rows=None):
  """
  Creates a grid: [Input | Ground Truth | Prediction]
  Args:
    images: Batch of normalized images [B, 3, H, W]
    labels: Batch of GT masks [B, 1, H, W]
    preds: Batch of predicted masks [B, 1, H, W]
    max_rows: Limit the number of rows
  """

  grid_images = []

  images = images.cpu()
  labels = labels.cpu()
  preds = preds.cpu()

  # Determine how many rows to process
  batch_size = images.shape[0]
  limit = batch_size if max_rows is None else min(batch_size, max_rows)

  for i in range(limit):

    # Prepare base image (We need a clean [3, H, W] tensor with values 0-1 for the mask overlays)
    img_np_clean = denormalize_image(images[i], mean, std)
    img_tensor_clean = torch.from_numpy(img_np_clean).permute(2, 0, 1).float()

    # Column 2: Image + Ground Truth
    img_gt_tensor = prepare_mask_for_tb(img_tensor_clean, labels[i])

    # Column 3: Image + Prediction
    img_pred_tensor = prepare_mask_for_tb(img_tensor_clean, preds[i])

    # Append strictly in this order: Left, Middle, Right
    grid_images.extend([img_tensor_clean, img_gt_tensor, img_pred_tensor])


  # Create the grid
  # nrow=3 tells TensorBoard to break the line after every 3rd image
  final_grid = make_grid(grid_images, nrow=3, padding=5)

  return final_grid

## DeepLabv3+ with MobileNetV2

In [17]:
class SeparableConv2d(nn.Module):
    """
    Implements Depthwise Separable Convolution, which is a depthwise convolution
    followed by a pointwise (1x1) convolution.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=False):
        super(SeparableConv2d, self).__init__()

        # Calculate padding to keep spatial size same: p = (d * (k-1)) / 2
        padding = dilation

        # Depthwise convolution: Applies a separate filter to each input channel
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=bias)
        self.bn_depth = nn.BatchNorm2d(in_channels)
        self.relu_depth = nn.ReLU(inplace=True)

        # Pointwise convolution: A 1x1 convolution to mix the channels
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1, bias=bias)
        self.bn_point = nn.BatchNorm2d(out_channels)
        self.relu_point = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.bn_depth(x)
        x = self.relu_depth(x)

        x = self.pointwise(x)
        x = self.bn_point(x)
        x = self.relu_point(x)

        return x

In [18]:
class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ASPP, self).__init__()

        # 1x1 convolution branch (This is always a standard 1x1 convolution)
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        # Atrous separable convolution with rate=6
        self.atrous_block6 = SeparableConv2d(in_channels, out_channels, kernel_size=3, dilation=6)

        # Atrous separable convolution with rate=12
        self.atrous_block12 = SeparableConv2d(in_channels, out_channels, kernel_size=3, dilation=12)

        # Atrous separable convolution with rate=18
        self.atrous_block18 = SeparableConv2d(in_channels, out_channels, kernel_size=3, dilation=18)

        # Global Average Pooling branch
        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(in_channels, out_channels, 1, stride=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        # Final 1x1 convolution to fuse all 5 branches
        self.final_conv = nn.Sequential(
            nn.Conv2d(out_channels * 5, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5)
        )

    def forward(self, x):
        size = x.size()[2:]

        x1 = self.conv1x1(x)
        x2 = self.atrous_block6(x)
        x3 = self.atrous_block12(x)
        x4 = self.atrous_block18(x)
        x5 = self.global_avg_pool(x)

        x5 = F.interpolate(x5, size=size, mode='bilinear', align_corners=False)

        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        x = self.final_conv(x)

        return x

In [19]:
class DeepLabV3Plus(nn.Module):
    def __init__(self, num_classes=1):
        super(DeepLabV3Plus, self).__init__()
        backbone = models.mobilenet_v2(weights="DEFAULT")
        self.backbone = backbone.features  # Get all layers except classifier

        # Modify MobileNetV2 for Output Stride 16
        # Change stride of the 14th block (bottleneck)
        self.backbone[14].conv[1][0].stride = (1, 1)

        # Then all subsequent layers must use dilation=2 to maintain receptive field
        for i in range (14, 19):
          for m in self.backbone[i].modules():
            if isinstance(m, nn.Conv2d):
              # Only apply to 3x3 depthwise convs
              if m.kernel_size == (3, 3):
                m.dilation = (2, 2)
                m.padding = (2, 2)

        # Low-level features come from early layer (for decoder)
        self.low_level_idx = 3
        self.low_level_channels = 24

        # ASPP expects 1280 channels from the last MobileNetV2 layer
        self.aspp = ASPP(in_channels=1280, out_channels=256)

        # Decoder
        self.low_level_project = nn.Sequential(
            nn.Conv2d(self.low_level_channels, 48, kernel_size=1),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )

        self.decoder = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1)
        )

    def forward(self, x):
      input_size = x.size()[2:]

      # Extract low-level and high-level features
      low_level_feat = None
      feat = x
      for i, layer in enumerate(self.backbone):
          feat = layer(feat)
          if i == self.low_level_idx:
              low_level_feat = feat  # Save for decoder

      high_level_feat = feat  # Final output of backbone (usually [B, 1280, H/32, W/32])

      # ASPP on high-level features
      x = self.aspp(high_level_feat)
      x = F.interpolate(x, size=low_level_feat.shape[2:], mode='bilinear', align_corners=False)

      # Decoder
      low_level = self.low_level_project(low_level_feat)
      x = torch.cat([x, low_level], dim=1)
      x = self.decoder(x)
      x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=False)
      return x

## Model Training's Configuration & Preparation

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

### Training configs

In [32]:
config = {
    "batch_size": 2,
    "epochs": 40,
    "initial_lr": 2e-4,
    "weight_decay": 1e-5,
    "min_lr": 1e-6,
    "threshold": 0.5,
    "optimizer": "AdamW",
    "loss_function": "DiceCELoss",
    "lr_scheduler": "CosineAnnealingLR",
}

### Dataloaders

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size = config["batch_size"],
    shuffle = True,
    num_workers = 2,
    pin_memory = True, # Speeds up data transfer to GPU
    drop_last = True, # Drop the last batch if it's smaller than batch_size
)

print(f"Number of batches in train_loader: {len(train_loader)}")

In [None]:
val_loader = DataLoader(
    val_dataset,
    batch_size = config["batch_size"],
    shuffle = False,
    num_workers = 2,
    pin_memory = True,
)

print(f"Number of batches in val_loader: {len(val_loader)}")

### Visualizing the data first before passing into the model

In [None]:
train_images, train_labels= next(iter(train_loader))


# Loop through every sample in the batch
batch_size = train_images.shape[0]

for index in range(batch_size):
    # Prepare the Image using the Denormalize Helper Function
    image = denormalize_image(train_images[index].cpu(), mean=IMAGENET_MEAN, std=IMAGENET_STD)

    # Prepare the Ground Truth Mask (Label)
    # The label is (1, H, W). Convert to (H, W).
    mask = train_labels[index].cpu().squeeze().numpy()

    # Visualization
    plt.figure(figsize=(12, 6))

    # Subplot 1: Image
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title(f'Sample {index} | Input Image')
    plt.axis('off')

    # Subplot 2: Ground Truth Mask
    plt.subplot(1, 2, 2)
    # Check for all-black mask issue:
    if np.all(mask == 0):
        plt.title(f'Sample {index} | Ground Truth Mask (ALL BLACK - FAIL)', color='red')
    else:
        plt.title(f'Sample {index} | Ground Truth Mask (PASS)')

    plt.imshow(mask, cmap='viridis')
    plt.axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
vis_images, vis_labels= next(iter(val_loader))


# Loop through every sample in the batch
batch_size = vis_images.shape[0]

for index in range(batch_size):
    # Prepare the Image using the Denormalize Helper Function
    image = denormalize_image(vis_images[index].cpu(), mean=IMAGENET_MEAN, std=IMAGENET_STD)

    # Prepare the Ground Truth Mask (Label)
    # The label is (1, H, W). Convert to (H, W).
    mask = vis_labels[index].cpu().squeeze().numpy()

    # Visualization
    plt.figure(figsize=(12, 6))

    # Subplot 1: Image
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title(f'Sample {index} | Input Image')
    plt.axis('off')

    # Subplot 2: Ground Truth Mask
    plt.subplot(1, 2, 2)
    # Check for all-black mask issue:
    if np.all(mask == 0):
        plt.title(f'Sample {index} | Ground Truth Mask (ALL BLACK - FAIL)', color='red')
    else:
        plt.title(f'Sample {index} | Ground Truth Mask (PASS)')

    plt.imshow(mask, cmap='viridis')
    plt.axis('off')

    plt.tight_layout()
    plt.show()

### DeepLabV3+ with MobileNetV2 backbone

In [46]:
model = DeepLabV3Plus(num_classes=1)
model = model.to(device)

In [None]:
summary(model, input_size=(config["batch_size"], 3, 1024, 1024))

In [48]:
# Initialize criterion, optimizer and lr scheduler here
criterion = DiceCELoss(to_onehot_y=False, sigmoid=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=config["initial_lr"], weight_decay=config["weight_decay"])

scheduler = CosineAnnealingLR(optimizer, T_max=config["epochs"], eta_min=config["min_lr"])

## Model Training

You can define your own desired paths where you want to save the checkpoints

In [None]:
from pathlib import Path
from datetime import datetime


run_name = "Run_" + datetime.now().strftime("%Y%m%d-%H%M%S")

# Define paths
log_dir = f"/content/drive/MyDrive/FYP/Model_Training/MobileNet/logs/{run_name}"
save_dir = f"/content/drive/MyDrive/FYP/Model_Training/MobileNet/checkpoints/{run_name}"

# Create directories safely
Path(log_dir).mkdir(parents=True, exist_ok=True)
Path(save_dir).mkdir(parents=True, exist_ok=True)

# Now initialize your writer and paths
writer = SummaryWriter(log_dir=log_dir)
best_model_path = os.path.join(save_dir, "best_deeplabv3_mobilenetv2.pth")

print(f"Directories verified/created.")
print(f"TensorBoard logging to: {log_dir}")
print(f"Best DeepLabV3+ with MobileNetV2 model's weights will be saved to: {best_model_path}")

In [None]:
accumulation_steps = 2

best_val_iou = -100
patience = 20
counter = 0
threshold = config["threshold"]

train_losses = []
train_ious = []
train_dices = []
train_recalls = []
train_precisions = []
train_accs = []

val_losses = []
val_ious = []
val_dices = []
val_recalls = []
val_precisions = []
val_accs = []


# We'll use a single batch from the validation loader for visualization
vis_images, vis_labels = next(iter(val_loader))

optimizer.zero_grad()

for epoch in range(config["epochs"]):
  # --- Training Phase ---
  model.train()

  # Initialize training loss and confusion matrix components
  train_running_loss = 0.0
  train_total_tp, train_total_fp, train_total_fn, train_total_tn = 0, 0, 0, 0

  # DataLoader unpacking
  for i, (X_batch, y_batch) in enumerate(train_loader):
    X_batch, y_batch = X_batch.to(device).float(), y_batch.to(device).float()


    mask_logits = model(X_batch)
    train_prob = torch.sigmoid(mask_logits)
    # Predictions are already [B, 1, H, W]
    train_preds = (train_prob > threshold).float()

    # Loss calculation (MONAI DiceCELoss expects [B, 1, H, W] for both inputs and targets)
    loss = criterion(mask_logits, y_batch) / accumulation_steps
    loss.backward()

    # Accumulate the loss (loss * batch_size)
    train_running_loss += (loss.item() * accumulation_steps) * X_batch.size(0)

    # Update the weights
    if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
      optimizer.step()
      optimizer.zero_grad()

    # Accumulate confusion matrix components
    tp, fp, fn, tn = get_confusion_matrix_components(y_batch, train_preds)
    train_total_tp += tp
    train_total_fp += fp
    train_total_fn += fn
    train_total_tn += tn

  # Calculate final epoch metrics
  train_loss = train_running_loss / len(train_dataset)
  train_iou, train_dice, train_recall, train_precision, train_acc = calculate_final_metrics(train_total_tp, train_total_fp, train_total_fn, train_total_tn)

  train_losses.append(train_loss)
  train_ious.append(train_iou)
  train_dices.append(train_dice)
  train_recalls.append(train_recall)
  train_precisions.append(train_precision)
  train_accs.append(train_acc)

  # --- Validation Phase ---
  model.eval()

  # Initialize validation loss and confusion matrix components
  val_running_loss = 0.0
  val_total_tp, val_total_fp, val_total_fn, val_total_tn = 0, 0, 0, 0

  # Variables to track the best and worst performing batches in this epoch
  min_iou_in_epoch = float('inf')
  max_iou_in_epoch = float('-inf')
  # Add bbox tracking for best/worst case visualization
  worst_batch_data, worst_batch_labels, worst_batch_preds= None, None, None
  best_batch_data, best_batch_labels, best_batch_preds= None, None, None


  with torch.no_grad():
    # Correct DataLoader unpacking
    for X_val_batch, y_val_batch in val_loader:
      X_val_batch, y_val_batch = X_val_batch.to(device).float(), y_val_batch.to(device).float()


      val_mask_logits = model(X_val_batch)
      val_probs = torch.sigmoid(val_mask_logits)

      val_loss = criterion(val_mask_logits, y_val_batch)
      val_running_loss += val_loss.item() * X_val_batch.size(0)

      # Convert to binary tensors
      val_preds = (val_probs > threshold).float() # [B, 1, H, W]

      tp, fp, fn, tn = get_confusion_matrix_components(y_val_batch, val_preds)
      val_total_tp += tp
      val_total_fp += fp
      val_total_fn += fn
      val_total_tn += tn

      # Calculate the IoU for logging the best and worst batches for each epoch
      batch_iou, _, _, _, _ = calculate_final_metrics(tp, fp, fn, tn)

      # Track the best and worst batches in this epoch
      if batch_iou < min_iou_in_epoch:
          min_iou_in_epoch = batch_iou
          # Store worst batch data & Move them to CPU immediately
          worst_batch_data = X_val_batch.cpu()
          worst_batch_labels = y_val_batch.cpu()
          worst_batch_preds = val_preds.cpu()


      if batch_iou > max_iou_in_epoch:
          max_iou_in_epoch = batch_iou
          # Store best batch data & Move them to CPU immediately
          best_batch_data = X_val_batch.cpu()
          best_batch_labels = y_val_batch.cpu()
          best_batch_preds = val_preds.cpu()

  # Calculate final epoch metrics
  val_loss = val_running_loss / len(val_dataset)
  val_iou, val_dice, val_recall, val_precision, val_acc = calculate_final_metrics(val_total_tp, val_total_fp, val_total_fn, val_total_tn)

  val_losses.append(val_loss)
  val_ious.append(val_iou)
  val_dices.append(val_dice)
  val_recalls.append(val_recall)
  val_precisions.append(val_precision)
  val_accs.append(val_acc)

  # Learning rate scheduler (CosineAnnealingLR)
  scheduler.step()

  current_lr = optimizer.param_groups[0]['lr'] # Retreive current LR

  # TensorBoard Logging
  writer.add_scalars('Loss', {'Train': train_loss, 'Validation': val_loss}, epoch)
  writer.add_scalars('IoU', {'Train': train_iou, 'Validation': val_iou}, epoch)
  writer.add_scalars('Dice', {'Train': train_dice, 'Validation': val_dice}, epoch)
  writer.add_scalars('Recall', {'Train': train_recall, 'Validation': val_recall}, epoch)
  writer.add_scalars('Precision', {'Train': train_precision, 'Validation': val_precision}, epoch)
  writer.add_scalars('Accuracy', {'Train': train_acc, 'Validation': val_acc}, epoch)
  writer.add_scalar('Learning Rate', current_lr, epoch)


  # Log the visualization every 5 epochs
  if epoch % 5 == 0:
    # --- Fixed 10 Predictions Visualization---
    # Move copies to GPU for the model
    temp_vis_images = vis_images.to(device).float()

    with torch.no_grad():
        vis_mask_logits = model(temp_vis_images)

    val_probs = torch.sigmoid(vis_mask_logits)
    val_predictions = (val_probs > threshold).float()


    # Logging the fixed batch
    fixed_grid = create_comparison_grid(
        vis_images, vis_labels, val_predictions, IMAGENET_MEAN, IMAGENET_STD, max_rows=vis_images.shape[0]
    )
    writer.add_image("Vis/Fixed", fixed_grid, epoch)


    # Logging the best batch
    best_grid = create_comparison_grid(
        best_batch_data, best_batch_labels, best_batch_preds, IMAGENET_MEAN, IMAGENET_STD, max_rows=best_batch_data.shape[0]
    )
    writer.add_image("Vis/Best", best_grid, epoch)


    # Logging the worst batch
    worst_grid = create_comparison_grid(
        worst_batch_data, worst_batch_labels, worst_batch_preds, IMAGENET_MEAN, IMAGENET_STD, max_rows=worst_batch_data.shape[0]
    )
    writer.add_image("Vis/Worst", worst_grid, epoch)

  print(f"Epoch {epoch}:")
  print(f"  Current LR: {current_lr:.6f}")
  print(f"  Train Metrics: Loss: {train_loss:.4f} | IoU: {train_iou:.4f} | Dice: {train_dice:.4f} | Precision: {train_precision:.4f} | Recall: {train_recall:.4f} | Acc: {train_acc:.4f}")
  print(f"  Val Metrics:   Loss: {val_loss:.4f} | IoU: {val_iou:.4f} | Dice: {val_dice:.4f} | Precision: {val_precision:.4f} | Recall: {val_recall:.4f} | Acc: {val_acc:.4f}")
  print("-" * 100) # This line adds a separator

  # Check validation IoU for improvement
  if val_iou > best_val_iou:
    best_val_iou = val_iou
    counter = 0

    # Save the best DeepLabV3+ with MobileNetV2 model's weight
    torch.save(model.state_dict(), best_model_path)
    print(f"Saved new best model at IoU: {best_val_iou: .4f} to {best_model_path}")
  else:
    counter += 1
    print(f"No improvement in Validation IoU for {counter} epoch(s)")

  if counter >= patience:
    print(f"Early stopping at epoch {epoch}. Best Validation IoU: {best_val_iou}")
    break

  gc.collect()
  torch.cuda.empty_cache()

# End of run
writer.close()

In [None]:
# Plotting
epochs_range = range(1, len(train_losses)+1)
plt.figure(figsize=(15, 10)) # Increased figure size for better readability with 4 plots

# Plot 1: Loss
plt.subplot(2, 2, 1) # Changed to 2x2 grid, first plot
plt.plot(epochs_range, train_losses, label='Train Loss')
plt.plot(epochs_range, val_losses, label='Validation Loss') # Added validation loss
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend() # Added legend

# Plot 2: IoU
plt.subplot(2, 2, 2) # Second plot
plt.plot(epochs_range, train_ious, label='Train IoU')
plt.plot(epochs_range, val_ious, label='Validation IoU') # Added validation IoU
plt.title('Training and Validation IoU')
plt.xlabel('Epoch')
plt.ylabel('IoU')
plt.legend() # Added legend

# Plot 3: Dice Coefficient
plt.subplot(2, 2, 3) # Third plot
plt.plot(epochs_range, train_dices, label='Train Dice Coef')
plt.plot(epochs_range, val_dices, label='Validation Dice Coef') # Added validation Dice
plt.title('Training and Validation Dice Coefficient')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.legend() # Added legend

# Plot 4: Accuracy
plt.subplot(2, 2, 4) # Fourth plot
plt.plot(epochs_range, train_accs, label='Train Accuracy')
plt.plot(epochs_range, val_accs, label='Validation Accuracy') # Added validation Accuracy
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.legend() # Added legend

plt.tight_layout() # Adjusts subplot params for a tight layout
plt.show()