In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Import

In [None]:
!pip install torchinfo

In [None]:
!pip install monai

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt

import cv2
from torchinfo import summary
from torchvision import models
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR # learning rate scheduler
from monai.losses import DiceCELoss

# Augmentation
import albumentations as A

# TensorBoard logging
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid

# Garbage collection
import gc

### MobileSAM's libraries

In [None]:
!pip install segment-anything
!pip install git+https://github.com/ChaoningZhang/MobileSAM.git
!mkdir -p weights
!wget -nc https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt -P ./weights/

In [None]:
from mobile_sam import sam_model_registry # MobileSAM utilities

## Move data to local disk

**Note:** Replace the `zip_train_source_path` and `zip_val_source_path` to your own dataset paths.

In [None]:
zip_train_source_path = "/content/drive/MyDrive/FYP/Datasets/zipped/train.zip"
zip_val_source_path = "/content/drive/MyDrive/FYP/Datasets/zipped/validation.zip"

local_data_dir = "/content/data"

!mkdir -p "$local_data_dir"

print(f"Copying {zip_train_source_path} to {local_data_dir}")
!cp "$zip_train_source_path" "$local_data_dir/"

print(f"Copying {zip_val_source_path} to {local_data_dir}")
!cp "$zip_val_source_path" "$local_data_dir/"

print("Copying complete.")

In [None]:
local_zip_train_path = f"{local_data_dir}/train.zip"
local_zip_val_path = f"{local_data_dir}/validation.zip"

unzip_destination_path = local_data_dir

print(f"Unzipping {local_zip_train_path} to {unzip_destination_path}")
!unzip -q "$local_zip_train_path" -d "$unzip_destination_path"

print(f"Unzipping {local_zip_val_path} to {unzip_destination_path}")
!unzip -q "$local_zip_val_path" -d "$unzip_destination_path"

print("Unzipping complete.")

In [None]:
local_train_image_path = os.path.join(local_data_dir, "train", "images")
local_train_mask_path = os.path.join(local_data_dir, "train", "masks")

# Check if directories exist
if os.path.exists(local_train_image_path):
  num_images = len(os.listdir(local_train_image_path))
  print(f"Number of images in {local_train_image_path}: {num_images}")
else:
  print(f"Directory {local_train_image_path} does not exist.")

if os.path.exists(local_train_mask_path):
  num_masks = len(os.listdir(local_train_mask_path))
  print(f"Number of masks in {local_train_mask_path}: {num_masks}")
else:
  print(f"Directory {local_train_mask_path} does not exist.")

In [None]:
local_val_image_path = os.path.join(local_data_dir, "validation", "images")
local_val_mask_path = os.path.join(local_data_dir, "validation", "masks")

# Check if directories exist
if os.path.exists(local_val_image_path):
  num_images = len(os.listdir(local_val_image_path))
  print(f"Number of images in {local_val_image_path}: {num_images}")
else:
  print(f"Directory {local_val_image_path} does not exist.")

if os.path.exists(local_val_mask_path):
  num_masks = len(os.listdir(local_val_mask_path))
  print(f"Number of masks in {local_val_mask_path}: {num_masks}")
else:
  print(f"Directory {local_val_mask_path} does not exist.")

## Data augmentation pipeline

In [None]:
# Standard ImageNet normalization (since MobileSAM image encoder was pre-trained on ImageNet)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

In [None]:
train_transform = A.Compose([
    # Geometric transforms
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.3),
    A.RandomRotate90(p=0.3),
    A.Affine(
        scale=(0.95, 1.05),
        translate_percent=(-0.05, 0.05),
        rotate=(-15, 15),
        border_mode=cv2.BORDER_REFLECT,
        p=0.5
    ),

    # Photometric transforms (brightness & contrast)
    A.RandomBrightnessContrast(
        brightness_limit=0.15,
        contrast_limit=0.15,
        p=0.5
    ),

    # Noise & blur
    A.OneOf([
        A.GaussNoise(std_range=(0.1,0.15)),  # Lower range, reduce blur
        A.MedianBlur(blur_limit=3),
    ], p=0.3),

    # Occlusion
    A.CoarseDropout(
        num_holes_range = (1, 4),
        p=0.4
    ),
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['category_ids']))


## Custom SegmentationDataset

In [None]:
class SegmentationDataset(Dataset):
  def __init__(self, image_dir, mask_dir, transform=None, img_size=(1024, 1024), mean=IMAGENET_MEAN, std=IMAGENET_STD):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.transform = transform
    self.img_size = img_size
    self.mean = np.array(mean, dtype=np.float32)
    self.std = np.array(std, dtype=np.float32)

    # Pre-calculate samples (1 image -> N samples if N wounds)
    # This handles the 1-to-Many relationship cleanly before training starts
    self.samples = self._prepare_samples()

  def _prepare_samples(self):
    """Pre-process all images to find all bounding boxes and create a flat list of samples"""
    image_files = sorted(os.listdir(self.image_dir))
    mask_files = sorted(os.listdir(self.mask_dir))
    all_samples = []

    for img_name, mask_name in zip(image_files, mask_files):
      mask_path = os.path.join(self.mask_dir, mask_name)

      # Load the original mask once for bounding box calculation
      # We assume the boxes are calculated on the original image size
      # and then they will be scaled to the training size in __getitem__
      mask_original = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
      if mask_original is None: continue

      # Get bounding boxes on the ORIGINAL image dimensions
      original_boxes = self.get_bounding_boxes_from_mask(mask_original, padding_factor=0.1, min_area_threshold=5)

      for bbox in original_boxes:
        all_samples.append({
            'image_name': img_name,
            'mask_name': mask_name,
            'bbox_original': np.array(bbox, dtype=np.int32) # [xmin, ymin, xmax, ymax] on original size
        })

    return all_samples

  # We dont need 'self' data, only input mask
  @staticmethod
  def get_bounding_boxes_from_mask(mask, padding_factor=0.1, min_area_threshold=5):
    """
    Converts a binary segmentation mask with multiple disconnected regions
    into a list of bounding boxes (xyxy), one for each region.
    """

    if mask is None: return[]

    H, W = mask.shape

    # Ensure the mask is binary(0 or 255)
    # Note: if the mask is read in as 0/1, change 1 to 255
    _, binary_mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)

    # Connected Component Analysis
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary_mask, 8, cv2.CV_32S)

    bounding_boxes = []

    # Iterate through each component (starting from 1 to skip background)
    for i in range(1, num_labels):
      x = stats[i, cv2.CC_STAT_LEFT]
      y = stats[i, cv2.CC_STAT_TOP]
      w = stats[i, cv2.CC_STAT_WIDTH]
      h = stats[i, cv2.CC_STAT_HEIGHT]
      area = stats[i, cv2.CC_STAT_AREA]

      if area >= min_area_threshold:
        # Determine the padding amount
        padding_pixels = int(max(w, h) * padding_factor)

        # Calculate padded coordinates
        x_min_padded = x - padding_pixels
        y_min_padded = y - padding_pixels
        x_max_padded = x + w + padding_pixels
        y_max_padded = y + h + padding_pixels

        # Constrain to Image Boundaries
        x_min_final = max(0, x_min_padded)
        y_min_final = max(0, y_min_padded)
        x_max_final = min(W, x_max_padded)
        y_max_final = min(H, y_max_padded)

        # Store as integers
        bbox = [int(x_min_final), int(y_min_final), int(x_max_final), int(y_max_final)]
        bounding_boxes.append(bbox)

    return bounding_boxes

  def __len__(self):
    # The length is the total number of (image, mask, bbox) triples
    return len(self.samples)

  @staticmethod
  def _bbox_visible(bbox, mask, min_iou=0.1):
      """Check if bbox overlaps with mask area."""
      x_min, y_min, x_max, y_max = np.clip(np.array(bbox, dtype=np.int32), 0, mask.shape[1]-1)
      box_mask = np.zeros_like(mask, dtype=np.uint8)
      cv2.rectangle(box_mask, (x_min, y_min), (x_max, y_max), 1, -1)
      intersection = np.logical_and(box_mask, mask > 0).sum()
      union = (mask > 0).sum() + box_mask.sum() - intersection
      iou = intersection / union if union > 0 else 0
      return iou >= min_iou

  def __getitem__(self, idx):

    # Load raw data
    sample = self.samples[idx]

    img_name = sample['image_name']
    mask_name = sample['mask_name']
    bbox_original = sample['bbox_original'] # bounding box on original image size

    image_path = os.path.join(self.image_dir, img_name)
    mask_path = os.path.join(self.mask_dir, mask_name)

    # Load image and mask
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

    bbox_transformed = bbox_original.astype(np.float32)

    # Apply transformations (including geometric)
    if self.transform:
      try:
        augmented = self.transform(
            image=image,
            mask=mask,
            bboxes=[bbox_original],
            category_ids=[0]
        )
        image = augmented['image']
        mask = augmented['mask']
        bboxes = augmented['bboxes']

        # Check 1: Box lost? (e.g. cropped out)
        if len(bboxes) == 0:
          raise ValueError("Box lost")

        # Check 2: Mask empty? (Wound rotated out of frame but box edge remains)
        if np.max(mask) == 0:
          raise ValueError("Mask empty")

        bbox_transformed = np.array(bboxes[0], dtype=np.float32)
      except ValueError:
          # If augmentation fails (box lost/ invalid), try next sample recursively
          return self.__getitem__((idx + 1) % len(self))

    # Check bounding box visibility (IoU with mask)
    if not self._bbox_visible(bbox_transformed, mask, min_iou=0.1):
        return self.__getitem__((idx + 1) % len(self.samples))

    current_H, current_W = image.shape[:2]
    target_H, target_W = self.img_size

    # Manual Resize
    if current_H != target_H or current_W != target_W:
      image = cv2.resize(image, self.img_size)
      mask = cv2.resize(mask, self.img_size, interpolation=cv2.INTER_NEAREST)

    # Calculate scale factors based on the change from CURRENT size to TARGET size
    scale_x = target_W / current_W
    scale_y = target_H / current_H

    bbox_resized = np.array([
        bbox_transformed[0] * scale_x,
        bbox_transformed[1] * scale_y,
        bbox_transformed[2] * scale_x,
        bbox_transformed[3] * scale_y
    ], dtype=np.float32)

    x_min, y_min, x_max, y_max = bbox_resized.astype(np.int32)

    # Isolate ground truth mask
    mask_np_binary = (mask > 0).astype(np.uint8)

    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_np_binary, 8, cv2.CV_32S)

    isolated_mask = np.zeros_like(mask_np_binary, dtype=np.float32)

    # Iterate through components (start from 1 to skip background)
    for i in range(1, num_labels):

        # Use the component's center point for a simple, robust check:
        # Check if the center of the component is inside the bounding box prompt.
        center_x = stats[i, cv2.CC_STAT_LEFT] + stats[i, cv2.CC_STAT_WIDTH] // 2
        center_y = stats[i, cv2.CC_STAT_TOP] + stats[i, cv2.CC_STAT_HEIGHT] // 2

        # Check containment: [x_min, y_min, x_max, y_max)
        if (x_min <= center_x < x_max) and (y_min <= center_y < y_max):
            # This component belongs to the prompted region.
            # Add its pixels to the isolated mask.
            component_mask = (labels == i).astype(np.uint8)
            isolated_mask += component_mask

    # Ensure the final isolated mask is still binary [0 or 1]
    isolated_mask = np.clip(isolated_mask, 0, 1).astype(np.float32)

    # Standardization and Tensor Conversion
    image = image.astype("float32") / 255.0
    image = (image - self.mean) / self.std
    image_tensor = torch.from_numpy(image).permute(2, 0, 1) #HWC -> CHW

    mask_tensor = torch.from_numpy(isolated_mask).unsqueeze(0) #HW -> 1HW

    bbox_tensor = torch.from_numpy(bbox_resized)


    return image_tensor, mask_tensor, bbox_tensor

In [None]:
# Create the dataset using the on-the-fly transformations
train_dataset = SegmentationDataset(local_train_image_path, local_train_mask_path, transform=train_transform)
val_dataset = SegmentationDataset(local_val_image_path, local_val_mask_path)

### Helper functions

In [None]:
def denormalize_image(image_tensor, mean, std):
    """Reverses the normalization process for visualization."""
    image_np = image_tensor.permute(1, 2, 0).cpu().numpy() # CHW -> HWC
    image_np = std * image_np + mean
    return np.clip(image_np, 0, 1) # Clip values to be in the [0, 1] range

In [None]:
def overlay_bbox_on_image(image_tensor, bbox_tensor, mean, std):
    """
    Convert normalized tensor to RGB image and draw bbox.
    image_tensor: [3, H, W] torch.Tensor
    bbox_tensor: [4] torch.Tensor (x_min, y_min, x_max, y_max)
    """
    # Denormalize
    img = denormalize_image(image_tensor, mean, std)  # returns HWC, values [0,1]
    img = (img * 255).astype(np.uint8).copy()          # convert to uint8

    # Draw bounding box
    bbox = bbox_tensor.cpu().numpy().astype(int)
    cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color=(255, 0, 0), thickness=2)

    return img  # HWC, RGB uint8

In [None]:
def get_confusion_matrix_components(y_true, y_pred):
    """
    Calculates the confusion matrix components (TP, FP, FN) for a batch.

    Args:
        y_true (torch.Tensor): Ground truth masks, a tensor of 0s and 1s.
        y_pred (torch.Tensor): Binary tensor after applying the sigmoid function and threshold.

    Returns:
        tuple: A tuple containing True Positives, False Positives, False Negatives, and True Negatives.
    """

    # Flatten tensors for easier calculation
    y_true_flat = y_true.view(-1)
    y_pred_flat = y_pred.view(-1)

    # Calculate confusion matrix components
    true_positives = ((y_pred_flat == 1) & (y_true_flat == 1)).sum().item()
    false_positives = ((y_pred_flat == 1) & (y_true_flat == 0)).sum().item()
    false_negatives = ((y_pred_flat == 0) & (y_true_flat == 1)).sum().item()
    true_negatives = ((y_pred_flat == 0) & (y_true_flat == 0)).sum().item()

    return true_positives, false_positives, false_negatives, true_negatives


def calculate_final_metrics(tp, fp, fn, tn, smooth=1e-6):
    """
    Calculates final metrics from accumulated confusion matrix components.
    """
    # IoU
    intersection = tp
    union = tp + fp + fn
    iou = intersection / (union + smooth)

    # Recall (Sensitivity)
    recall = tp / (tp + fn + smooth)

    # Precision (Positive Predictive Value)
    precision = tp / (tp + fp + smooth)

    # Dice Coefficient / F1
    dice = (2 * precision * recall) / (precision + recall + smooth)

    # Accuracy
    accuracy = (tp + tn) / (tp + tn + fp + fn + smooth)

    return iou, dice, recall, precision, accuracy

In [None]:
# To prepare masks for TensorBoard
def prepare_mask_for_tb(image, mask):
  """
  Converts a binary mask (H, W) into an RGB overlay (3, H, W)
  so it can be displayed in TensorBoard.
  """
  # Normalize image to 0-1 if it isn't already
  if image.max() > 1: image = image / 255.0

  # Ensure mask is 0 or 1
  mask = mask.squeeze()

  # Create a red overlay for the mask
  overlay = torch.zeros_like(image)
  overlay[0, :, :] = mask # Red channel

  # Blend 70% Original Image + 30% Red Mask
  blended = (image * 0.4) + (overlay * 0.6)

  # Return blended image where mask exists, otherwise original image
  return torch.where(mask.unsqueeze(0) > 0, blended, image)

In [None]:
def create_comparison_grid(images, bboxes, labels, preds, mean=IMAGENET_MEAN, std=IMAGENET_STD, max_rows=None):
  """
  Creates a grid: [BBox Input | Ground Truth | Prediction]
  Args:
    images: Batch of normalized images [B, 3, H, W]
    bboxes: Batch of bounding boxes [B, 4]
    labels: Batch of GT masks [B, 1, H, W]
    preds: Batch of predicted masks [B, 1, H, W]
    max_rows: Limit the number of rows
  """

  grid_images = []

  images = images.cpu()
  bboxes = bboxes.cpu()
  labels = labels.cpu()
  preds = preds.cpu()

  # Determine how many rows to process
  batch_size = images.shape[0]
  limit = batch_size if max_rows is None else min(batch_size, max_rows)

  for i in range(limit):

    # Prepare base image (We need a clean [3, H, W] tensor with values 0-1 for the mask overlays)
    img_np_clean = denormalize_image(images[i], mean, std)
    img_tensor_clean = torch.from_numpy(img_np_clean).permute(2, 0, 1).float()

    # Column 1: Image + Bounding box
    img_bbox_np = overlay_bbox_on_image(images[i], bboxes[i], mean, std)
    img_bbox_tensor = torch.from_numpy(img_bbox_np).permute(2, 0, 1).float() / 255.0

    # Column 2: Image + Ground Truth
    img_gt_tensor = prepare_mask_for_tb(img_tensor_clean, labels[i])

    # Column 3: Image + Prediction
    img_pred_tensor = prepare_mask_for_tb(img_tensor_clean, preds[i])

    # Append strictly in this order: Left, Middle, Right
    grid_images.extend([img_bbox_tensor, img_gt_tensor, img_pred_tensor])


  # Create the grid
  # nrow=3 tells TensorBoard to break the line after every 3rd image
  final_grid = make_grid(grid_images, nrow=3, padding=5)

  return final_grid

## MobileSAM model

In [None]:
class DecoderAdapter(nn.Module):
  """
  A bottleneck adapter module for PEFT.
  Inserts a small trainable module into the mask decoder.
  """

  def __init__(self, in_dim: int, adapter_dim: int):
    super().__init__()

    # Down-projection: from model dimension (in_dim) to a smaller adapter_dim
    self.down = nn.Linear(in_dim, adapter_dim)

    # Non-linearity
    self.non_linearity = nn.GELU()

    # Up-projection: from adapter_dim back to model dimension (in_dim)
    self.up = nn.Linear(adapter_dim, in_dim)

    # Initialize to near-zero to start
    nn.init.normal_(self.up.weight, std=1e-4)
    nn.init.zeros_(self.up.bias)

  def forward(self, x):
    # The adapter output is added to the input (residual connection)
    return x + self.up(self.non_linearity(self.down(x)))

In [None]:
def inject_domain_adapter(mask_decoder, adapter_dim=64):
    """
    Inject domain adapters into SAM's mask decoder MLP blocks.
    Properly registers each adapter as a submodule (tracked by .to(device)).
    """
    adapter_idx = 0

    modules = list(mask_decoder.named_modules())

    for name, module in modules:
        if isinstance(module, nn.Linear) and module.out_features == module.in_features:
            in_dim = module.out_features
            adapter = DecoderAdapter(in_dim, adapter_dim)

            # Register adapter properly as a submodule
            adapter_name = f"domain_adapter_{adapter_idx}"
            setattr(mask_decoder, adapter_name, adapter)
            adapter_idx += 1

            # Wrap original forward
            old_forward = module.forward

            def new_forward(x, old_forward=old_forward, adapter=getattr(mask_decoder, adapter_name)):
                return adapter(old_forward(x))

            module.forward = new_forward

    return mask_decoder


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
# Load the model
sam_checkpoint = "./weights/mobile_sam.pt"
model_type = "vit_t" # MobileSAM

mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)

mobile_sam.mask_decoder = inject_domain_adapter(mobile_sam.mask_decoder, adapter_dim=64)

In [None]:
mobile_sam.to(device)

In [None]:
TRAINING_IMG_SIZE = 1024
SAM_EMBEDDING_SIZE = 1024

class MobileSAMFineTuner(nn.Module):
    # __init__ method remains the same
    def __init__(self, sam_model, train_img_size, sam_emb_size):
        super().__init__()
        self.sam = sam_model
        self.train_img_size = train_img_size
        self.sam_emb_size = sam_emb_size
        self.scale_factor = sam_emb_size / train_img_size

        for param in self.sam.parameters():
          param.requires_grad = False  # Freeze all SAM parameters

        # Unfreeze adapters
        for name, param in self.sam.mask_decoder.named_parameters():
            if "adapter" in name:
                param.requires_grad = True

    def forward(self, images: torch.Tensor, bboxes: torch.Tensor):
        # images: [B, 3, H, W], bboxes: [B, 4]
        B, C, H, W = images.shape

        # Scale bboxes
        scaled_bboxes = bboxes * self.scale_factor # [B, 4]

        # --- Image Preprocessing (Iterate & Stack) ---
        preprocessed_images = []
        for i in range(B):
            preprocessed_img = self.sam.preprocess(images[i])
            preprocessed_images.append(preprocessed_img)
        input_images = torch.stack(preprocessed_images, dim=0) # [B, C, 1024, 1024]


        # --- Image Encoding (FAST & FROZEN) ---
        with torch.no_grad():
            image_embeddings = self.sam.image_encoder(input_images) # [B, 1024, 64, 64]

        # --- Prompt/Mask Decoding (SLOW & TRAINABLE - MUST ITERATE) ---
        final_mask_logits = []
        iou_preds = []

        for i in range(B):
            # Get single sample tensors for prompts and image embedding
            image_embedding_i = image_embeddings[i].unsqueeze(0) # [1, 1024, 64, 64]
            box_i = scaled_bboxes[i].unsqueeze(0)                # [1, 4]

            # Prompt Encoding for a single sample
            sparse_embeddings_i, dense_embeddings_i = self.sam.prompt_encoder(
                points=None,
                boxes=box_i,
                masks=None,
            )

            # Mask Decoding for a single sample
            low_res_masks_i, iou_predictions_i = self.sam.mask_decoder(
                image_embeddings=image_embedding_i,  # B=1 image embedding
                image_pe=self.sam.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings_i,
                dense_prompt_embeddings=dense_embeddings_i,
                multimask_output=False,
            )

            final_mask_logits.append(low_res_masks_i)
            iou_preds.append(iou_predictions_i)

        # --- Final Stack and Upsampling ---
        low_res_logits_stacked = torch.cat(final_mask_logits, dim=0)
        iou_preds_stacked = torch.cat(iou_preds, dim=0)

        # Upsample to [B, 1, 1024, 1024]
        final_mask_logits_stacked = F.interpolate(
            low_res_logits_stacked,
            size=(self.train_img_size, self.train_img_size),
            mode='bilinear',
            align_corners=False
        )

        return final_mask_logits_stacked, iou_preds_stacked

## Model Training's Configuration & Preparation

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

### Training configs

In [None]:
config = {
    "batch_size": 4,
    "epochs": 40,
    "initial_lr": 1e-3,
    "weight_decay": 1e-5,
    "min_lr": 1e-6,
    "threshold": 0.5,
    "optimizer": "AdamW",
    "loss_function": "DiceCELoss",
    "lr_scheduler": "CosineAnnealingLR",
}

### Dataloaders

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size = config["batch_size"],
    shuffle = True,
    num_workers = 2,
    pin_memory = True, # Speeds up data transfer to GPU
)

print(f"Number of batches in train_loader: {len(train_loader)}")

In [None]:
val_loader = DataLoader(
    val_dataset,
    batch_size = config["batch_size"],
    shuffle = False,
    num_workers = 2,
    pin_memory = True,
)

print(f"Number of batches in val_loader: {len(val_loader)}")

#### Visualize the data first before passing into the model



In [None]:
train_images, train_labels, train_bboxes = next(iter(train_loader))


# Loop through every sample in the batch
batch_size = train_images.shape[0]

for index in range(batch_size):
    # Prepare the Image using the Denormalize Helper Function
    image = overlay_bbox_on_image(train_images[index].cpu(), train_bboxes[index], mean=IMAGENET_MEAN, std=IMAGENET_STD)

    # Prepare the Ground Truth Mask (Label)
    # The label is (1, H, W). Convert to (H, W).
    mask = train_labels[index].cpu().squeeze().numpy()

    # Visualization
    plt.figure(figsize=(12, 6))

    # Subplot 1: Image with Bounding Box
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title(f'Sample {index} | Input Image with BBox')
    plt.axis('off')

    # Subplot 2: Ground Truth Mask
    plt.subplot(1, 2, 2)
    # Check for all-black mask issue:
    if np.all(mask == 0):
        plt.title('Sample {index} | Ground Truth Mask (ALL BLACK - FAIL)', color='red')
    else:
        plt.title(f'Sample {index} | Ground Truth Mask (PASS)')

    plt.imshow(mask, cmap='viridis')
    plt.axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
vis_images, vis_labels, vis_bboxes = next(iter(val_loader))


# Loop through every sample in the batch
batch_size = vis_images.shape[0]

for index in range(batch_size):
    # Prepare the Image using the Denormalize Helper Function
    image = overlay_bbox_on_image(vis_images[index].cpu(), vis_bboxes[index], mean=IMAGENET_MEAN, std=IMAGENET_STD)

    # Prepare the Ground Truth Mask (Label)
    # The label is (1, H, W). Convert to (H, W).
    mask = vis_labels[index].cpu().squeeze().numpy()

    # Visualization
    plt.figure(figsize=(12, 6))

    # Subplot 1: Image with Bounding Box
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title(f'Sample {index} | Input Image with BBox')
    plt.axis('off')

    # Subplot 2: Ground Truth Mask
    plt.subplot(1, 2, 2)
    # Check for all-black mask issue:
    if np.all(mask == 0):
        plt.title('Sample {index} | Ground Truth Mask (ALL BLACK - FAIL)', color='red')
    else:
        plt.title(f'Sample {index} | Ground Truth Mask (PASS)')

    plt.imshow(mask, cmap='viridis')
    plt.axis('off')

    plt.tight_layout()
    plt.show()

### Model

In [None]:
# Initialize MobileSAM model
model = MobileSAMFineTuner(mobile_sam, TRAINING_IMG_SIZE, SAM_EMBEDDING_SIZE).to(device)

In [None]:
SUMMARY_BATCH_SIZE = config["batch_size"]
IMAGE_SIZE = 1024

# Create dummy input data with BATCH_SIZE=1
dummy_images = torch.randn(SUMMARY_BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE) # (B, C, H, W)
dummy_bboxes = torch.randint(0, IMAGE_SIZE, (SUMMARY_BATCH_SIZE, 4)).float() # (B, 4)

# Move to device
device = model.sam.pixel_mean.device # Get device from the model itself
dummy_images = dummy_images.to(device)
dummy_bboxes = dummy_bboxes.to(device)

# Call summary
summary(model, input_data=[dummy_images, dummy_bboxes], device=device)

In [None]:
# Initialize criterion, optimizer and lr scheduler here

criterion = DiceCELoss(to_onehot_y=False, sigmoid=True)

trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.AdamW(trainable_params, lr=config["initial_lr"], weight_decay=config["weight_decay"])

scheduler = CosineAnnealingLR(optimizer, T_max=config["epochs"], eta_min=config["min_lr"])

## Model Training

**Note:** You can replace the `log_dir` and `save_dir`.

In [None]:
from pathlib import Path
from datetime import datetime

run_name = "Run_" + datetime.now().strftime("%Y%m%d-%H%M%S")

# Define paths
log_dir = f"/content/drive/MyDrive/FYP/Model_Training/SAM/logs/{run_name}"
save_dir = f"/content/drive/MyDrive/FYP/Model_Training/SAM/checkpoints/{run_name}"

# Create directories safely
Path(log_dir).mkdir(parents=True, exist_ok=True)
Path(save_dir).mkdir(parents=True, exist_ok=True)

# Now initialize your writer and paths
writer = SummaryWriter(log_dir=log_dir)
best_model_path = os.path.join(save_dir, "best_MobileSAMwithAdapter.pth")

print(f"Directories verified/created.")
print(f"TensorBoard logging to: {log_dir}")

In [None]:
best_val_iou = -100
patience = 20
counter = 0
threshold = config["threshold"]

train_losses = []
train_ious = []
train_dices = []
train_recalls = []
train_precisions = []
train_accs = []

val_losses = []
val_ious = []
val_dices = []
val_recalls = []
val_precisions = []
val_accs = []


# We'll use a single batch from the validation loader for visualization
vis_images, vis_labels, vis_bboxes = next(iter(val_loader))

for epoch in range(config["epochs"]):
  # --- Training Phase ---
  model.train()

  # Initialize training loss and confusion matrix components
  train_running_loss = 0.0
  train_total_tp, train_total_fp, train_total_fn, train_total_tn = 0, 0, 0, 0

  # DataLoader unpacking
  for X_batch, y_batch, bbox_batch in train_loader:
    X_batch, y_batch = X_batch.to(device).float(), y_batch.to(device).float()
    bbox_batch = bbox_batch.to(device).float() # Move bbox to device and ensure float

    # Corrected MobileSAM model call (returns mask_logits [B, 1, H, W] and iou_preds)
    mask_logits, _ = model(X_batch, bbox_batch)

    # Loss calculation (MONAI DiceCELoss expects [B, 1, H, W] for both inputs and targets)
    loss = criterion(mask_logits, y_batch)

    # Accumulate the loss (loss * batch_size)
    train_running_loss += loss.item() * X_batch.size(0)

    optimizer.zero_grad() # clear the gradients
    loss.backward() # Backward pass
    optimizer.step() # Update weights

    # Convert to binary predictions (either 0 or 1)
    train_probs = torch.sigmoid(mask_logits)
    # Predictions are already [B, 1, H, W]
    train_preds = (train_probs > threshold).float()

    # Accumulate confusion matrix components
    tp, fp, fn, tn = get_confusion_matrix_components(y_batch, train_preds)
    train_total_tp += tp
    train_total_fp += fp
    train_total_fn += fn
    train_total_tn += tn

  # Calculate final epoch metrics
  train_loss = train_running_loss / len(train_dataset)
  train_iou, train_dice, train_recall, train_precision, train_acc = calculate_final_metrics(train_total_tp, train_total_fp, train_total_fn, train_total_tn)

  train_losses.append(train_loss)
  train_ious.append(train_iou)
  train_dices.append(train_dice)
  train_recalls.append(train_recall)
  train_precisions.append(train_precision)
  train_accs.append(train_acc)

  # --- Validation Phase ---
  model.eval()

  # Initialize validation loss and confusion matrix components
  val_running_loss = 0.0
  val_total_tp, val_total_fp, val_total_fn, val_total_tn = 0, 0, 0, 0

  # Variables to track the best and worst performing batches in this epoch
  min_iou_in_epoch = float('inf')
  max_iou_in_epoch = float('-inf')
  # Add bbox tracking for best/worst case visualization
  worst_batch_data, worst_batch_labels, worst_batch_preds, worst_batch_bboxes = None, None, None, None
  best_batch_data, best_batch_labels, best_batch_preds, best_batch_bboxes = None, None, None, None


  with torch.no_grad():
    # Correct DataLoader unpacking
    for X_val_batch, y_val_batch, bbox_val_batch in val_loader:
      X_val_batch, y_val_batch = X_val_batch.to(device).float(), y_val_batch.to(device).float()
      bbox_val_batch = bbox_val_batch.to(device).float()

      val_mask_logits, _ = model(X_val_batch, bbox_val_batch)

      val_loss = criterion(val_mask_logits, y_val_batch)
      val_running_loss += val_loss.item() * X_val_batch.size(0)

      # Convert to binary tensors
      val_probs = torch.sigmoid(val_mask_logits)
      val_preds = (val_probs > threshold).float() # [B, 1, H, W]

      tp, fp, fn, tn = get_confusion_matrix_components(y_val_batch, val_preds)
      val_total_tp += tp
      val_total_fp += fp
      val_total_fn += fn
      val_total_tn += tn

      # Calculate the IoU for logging the best and worst batches for each epoch
      batch_iou, _, _, _, _ = calculate_final_metrics(tp, fp, fn, tn)

      # Track the best and worst batches in this epoch
      if batch_iou < min_iou_in_epoch:
          min_iou_in_epoch = batch_iou
          # Store worst batch data & Move them to CPU immediately
          worst_batch_data = X_val_batch.cpu()
          worst_batch_labels = y_val_batch.cpu()
          worst_batch_preds = val_preds.cpu()
          worst_batch_bboxes = bbox_val_batch.cpu()


      if batch_iou > max_iou_in_epoch:
          max_iou_in_epoch = batch_iou
          # Store best batch data & Move them to CPU immediately
          best_batch_data = X_val_batch.cpu()
          best_batch_labels = y_val_batch.cpu()
          best_batch_preds = val_preds.cpu()
          best_batch_bboxes = bbox_val_batch.cpu()

  # Calculate final epoch metrics
  val_loss = val_running_loss / len(val_dataset)
  val_iou, val_dice, val_recall, val_precision, val_acc = calculate_final_metrics(val_total_tp, val_total_fp, val_total_fn, val_total_tn)

  val_losses.append(val_loss)
  val_ious.append(val_iou)
  val_dices.append(val_dice)
  val_recalls.append(val_recall)
  val_precisions.append(val_precision)
  val_accs.append(val_acc)

  # Learning rate scheduler (CosineAnnealingLR)
  scheduler.step()

  current_lr = optimizer.param_groups[0]['lr'] # Retreive current LR

  # TensorBoard Logging
  writer.add_scalars('Loss', {'Train': train_loss, 'Validation': val_loss}, epoch)
  writer.add_scalars('IoU', {'Train': train_iou, 'Validation': val_iou}, epoch)
  writer.add_scalars('Dice', {'Train': train_dice, 'Validation': val_dice}, epoch)
  writer.add_scalars('Recall', {'Train': train_recall, 'Validation': val_recall}, epoch)
  writer.add_scalars('Precision', {'Train': train_precision, 'Validation': val_precision}, epoch)
  writer.add_scalars('Accuracy', {'Train': train_acc, 'Validation': val_acc}, epoch)
  writer.add_scalar('Learning Rate', current_lr, epoch)


  # Log the visualization every 5 epochs
  if epoch % 5 == 0:
    # --- Fixed 10 Predictions Visualization---
    # Move copies to GPU for the model
    temp_vis_images = vis_images.to(device)
    temp_vis_bboxes = vis_bboxes.to(device)
    vis_mask_logits, _ = model(temp_vis_images, temp_vis_bboxes)
    val_predictions = torch.sigmoid(vis_mask_logits)
    val_predictions = (val_predictions > threshold).float()


    # Logging the fixed batch
    fixed_grid = create_comparison_grid(
        vis_images, vis_bboxes, vis_labels, val_predictions, IMAGENET_MEAN, IMAGENET_STD, max_rows=vis_images.shape[0]
    )
    writer.add_image("Vis/Fixed", fixed_grid, epoch)


    # Logging the best batch
    best_grid = create_comparison_grid(
        best_batch_data, best_batch_bboxes, best_batch_labels, best_batch_preds, IMAGENET_MEAN, IMAGENET_STD, max_rows=best_batch_data.shape[0]
    )
    writer.add_image("Vis/Best", best_grid, epoch)


    # Logging the worst batch
    worst_grid = create_comparison_grid(
        worst_batch_data, worst_batch_bboxes, worst_batch_labels, worst_batch_preds, IMAGENET_MEAN, IMAGENET_STD, max_rows=worst_batch_data.shape[0]
    )
    writer.add_image("Vis/Worst", worst_grid, epoch)

  print(f"Epoch {epoch}:")
  print(f"  Current LR: {current_lr:.6f}")
  print(f"  Train Metrics: Loss: {train_loss:.4f} | IoU: {train_iou:.4f} | Dice: {train_dice:.4f} | Precision: {train_precision:.4f} | Recall: {train_recall:.4f} | Acc: {train_acc:.4f}")
  print(f"  Val Metrics:   Loss: {val_loss:.4f} | IoU: {val_iou:.4f} | Dice: {val_dice:.4f} | Precision: {val_precision:.4f} | Recall: {val_recall:.4f} | Acc: {val_acc:.4f}")
  print("-" * 100) # This line adds a separator

  # Check validation IoU for improvement
  if val_iou > best_val_iou:
    best_val_iou = val_iou
    counter = 0

    # Save mask decoder weights
    torch.save(model.state_dict(), best_model_path)
    print(f"Saved new best model at IoU: {best_val_iou: .4f} to {best_model_path}")
  else:
    counter += 1
    print(f"No improvement in Validation IoU for {counter} epoch(s)")

  if counter >= patience:
    print(f"Early stopping at epoch {epoch}. Best Validation IoU: {best_val_iou}")
    break

  gc.collect()
  torch.cuda.empty_cache()

# End of run
writer.close()

## Plotting

In [None]:
# Plotting
epochs_range = range(1, len(train_losses)+1)
plt.figure(figsize=(15, 10)) # Increased figure size for better readability with 4 plots

# Plot 1: Loss
plt.subplot(2, 2, 1) # Changed to 2x2 grid, first plot
plt.plot(epochs_range, train_losses, label='Train Loss')
plt.plot(epochs_range, val_losses, label='Validation Loss') # Added validation loss
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend() # Added legend

# Plot 2: IoU
plt.subplot(2, 2, 2) # Second plot
plt.plot(epochs_range, train_ious, label='Train IoU')
plt.plot(epochs_range, val_ious, label='Validation IoU') # Added validation IoU
plt.title('Training and Validation IoU')
plt.xlabel('Epoch')
plt.ylabel('IoU')
plt.legend() # Added legend

# Plot 3: Dice Coefficient
plt.subplot(2, 2, 3) # Third plot
plt.plot(epochs_range, train_dices, label='Train Dice Coef')
plt.plot(epochs_range, val_dices, label='Validation Dice Coef') # Added validation Dice
plt.title('Training and Validation Dice Coefficient')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.legend() # Added legend

# Plot 4: Accuracy
plt.subplot(2, 2, 4) # Fourth plot
plt.plot(epochs_range, train_accs, label='Train Accuracy')
plt.plot(epochs_range, val_accs, label='Validation Accuracy') # Added validation Accuracy
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.legend() # Added legend

plt.tight_layout() # Adjusts subplot params for a tight layout
plt.show()