In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install git+https://github.com/ChaoningZhang/MobileSAM.git
!mkdir -p weights
!wget -nc https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt -P ./weights/

In [None]:
!pip install peft

In [None]:
import os
import cv2
import ast
import pandas as pd
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset


from peft import LoraConfig, get_peft_model, PeftModel
from mobile_sam import sam_model_registry

## Custom Dataset classes

There are two custom dataset classes defined here. Please choose the one which is suitable for your use case.

In [None]:
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

### SegmentationDataset class

This class is for the dataset where it contains the ground truth segmentation masks for bounding box(es) extraction. Just need to pass the paths that contain the images and ground truth masks. It will automatically prepare the bounding box(es) coordinates from the ground truth masks.

In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, img_size=(1024, 1024), mean=IMAGENET_MEAN, std=IMAGENET_STD):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.images = sorted(os.listdir(image_dir))
        self.masks = sorted(os.listdir(mask_dir))
        self.img_size = img_size
        self.mean = np.array(mean)
        self.std = np.array(std)

    def __len__(self):
        return len(self.images)

    def get_bounding_boxes_from_mask(self, mask, padding_factor=0.1, min_area_threshold=5):
        """Get one or multiple bounding boxes from binary mask."""
        _, binary_mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
        num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary_mask, 8, cv2.CV_32S)

        boxes = []
        for i in range(1, num_labels):
            x, y, w, h, area = (
                stats[i, cv2.CC_STAT_LEFT],
                stats[i, cv2.CC_STAT_TOP],
                stats[i, cv2.CC_STAT_WIDTH],
                stats[i, cv2.CC_STAT_HEIGHT],
                stats[i, cv2.CC_STAT_AREA],
            )
            if area < min_area_threshold:
                continue

            pad = int(max(w, h) * padding_factor)
            x_min, y_min = max(0, x - pad), max(0, y - pad)
            x_max, y_max = min(mask.shape[1], x + w + pad), min(mask.shape[0], y + h + pad)
            boxes.append([x_min, y_min, x_max, y_max])

        return boxes

    def __getitem__(self, idx):
        img_name = self.images[idx]
        mask_name = self.masks[idx]

        image_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, mask_name)

        # Load image and mask
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        original_H, original_W = mask.shape
        boxes = self.get_bounding_boxes_from_mask(mask)

        # Resize
        if image.shape[:2] != self.img_size:
            image = cv2.resize(image, self.img_size)
        if mask.shape[:2] != self.img_size:
            mask = cv2.resize(mask, self.img_size, interpolation=cv2.INTER_NEAREST)

        target_H, target_W = self.img_size
        scale_x = target_W / original_W
        scale_y = target_H / original_H

        boxes_rescaled = []
        for box in boxes:
            x_min, y_min, x_max, y_max = box
            boxes_rescaled.append([
                x_min * scale_x,
                y_min * scale_y,
                x_max * scale_x,
                y_max * scale_y
            ])

        # Image normalization
        image = image.astype("float32") / 255.0
        image = (image - self.mean) / self.std
        image_tensor = torch.from_numpy(image).permute(2, 0, 1)

        mask_tensor = torch.from_numpy((mask > 0).astype(np.float32)).unsqueeze(0)

        # Return all boxes, not just one (SAM can handle multiple)
        return {
            "image": image_tensor,
            "mask": mask_tensor,
            "bboxes": torch.tensor(boxes_rescaled, dtype=torch.float32),
            "image_name": img_name,
        }

**Note:** Please modify the paths below to your own paths that contain the wound images and their ground truth masks

In [None]:
dataset_images_path = "/content/drive/MyDrive/FYP/Datasets/test_inference/images/"
dataset_masks_path = "/content/drive/MyDrive/FYP/Datasets/test_inference/masks/"

In [None]:
dataset = SegmentationDataset(dataset_images_path, dataset_masks_path)

### RealWorldInferenceDataset class
This class is for real world collected dataset that has no ground truth segmentation masks for bounding box(es) extraction. Therefore, we need to manually annotate the bounding box(es) of the wound image in PASCAL format (x1, y1, x2, y2) and record them in a csv file.

About the csv file, it must contain two columns: **image_name** and **bbox**. If you have different column names, you can also modify the code below to match the column names in your csv file. Please ensure your image_name is matched and the bounding box is in the correct format: [[x1, y1, x2, y2]]. If there are two bounding boxes, the format is [[x1, y1, x2, y2], [x1, y1, x2, y2]].

In [None]:
class RealWorldInferenceDataset(Dataset):
    def __init__(self, image_dir, csv_path, img_size=(1024, 1024), mean=IMAGENET_MEAN, std=IMAGENET_STD):
        self.image_dir = image_dir
        self.img_size = img_size
        self.mean = np.array(mean)
        self.std = np.array(std)

        # 1. Load CSV
        # We assume the CSV has columns: 'image_name' and 'bbox'
        df = pd.read_csv(csv_path)

        # Create a dictionary for fast lookup:
        # {'image_01.jpg': "[[10, 10, 100, 100]]", ...}
        self.bbox_map = dict(zip(df['image_name'], df['bbox']))

        # 2. Filter images
        # Only include images that exist in BOTH the folder AND the CSV
        available_files = set(os.listdir(image_dir))
        self.images = [img for img in df['image_name'] if img in available_files]

        print(f"Found {len(self.images)} images with matching bounding boxes in CSV.")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        image_path = os.path.join(self.image_dir, img_name)

        # Load Image
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Get original dimensions for scaling calculations
        original_H, original_W = image.shape[:2]

        # Get Bounding Boxes
        # The CSV likely stores them as strings, so we parse them into lists
        boxes_raw = self.bbox_map[img_name]
        if isinstance(boxes_raw, str):
            boxes_raw = ast.literal_eval(boxes_raw)

        # Resize Image to model input size (usually 1024 for SAM)
        if image.shape[:2] != self.img_size:
            image = cv2.resize(image, self.img_size)

        # --- CRITICAL: Rescale Bounding Boxes ---
        # If we resize the image, we MUST resize the box coordinates too
        target_H, target_W = self.img_size
        scale_x = target_W / original_W
        scale_y = target_H / original_H

        boxes_rescaled = []
        for box in boxes_raw:
            x_min, y_min, x_max, y_max = box
            boxes_rescaled.append([
                x_min * scale_x,
                y_min * scale_y,
                x_max * scale_x,
                y_max * scale_y
            ])

        # Normalize Image
        image = image.astype("float32") / 255.0
        image = (image - self.mean) / self.std
        image_tensor = torch.from_numpy(image).permute(2, 0, 1)

        # Return dict (Note: No 'mask' key needed for pure inference)
        return {
            "image": image_tensor,
            "bboxes": torch.tensor(boxes_rescaled, dtype=torch.float32),
            "image_name": img_name,
        }

**Note:** Please modify the paths below to the your own paths that contain real world dataset and the csv file.

In [None]:
real_world_dataset_path = "/content/drive/MyDrive/FYP/Datasets/test_inference/real_world/"
csv_file_path = "/content/drive/MyDrive/FYP/Datasets/test_inference/real_world/real_world_dataset.csv"

In [None]:
dataset = RealWorldInferenceDataset(real_world_dataset_path, csv_file_path)

## Model Initialization

### MobileSAM

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Load the MobileSAM model
model_type = "vit_t"
checkpoint_path = "./weights/mobile_sam.pt"

mobile_sam = sam_model_registry[model_type](checkpoint=checkpoint_path)

In [None]:
# Load trained mask decoder
decoder_path = "/content/drive/MyDrive/FYP/MobileSAM_Finetuning/checkpoints/Run_20251214-093848/mask_decoder.pth"

# Capture the loading message
msg = mobile_sam.mask_decoder.load_state_dict(torch.load(decoder_path, map_location=device))

# Check for issues
if len(msg.missing_keys) == 0 and len(msg.unexpected_keys) == 0:
    print("Mask Decoder: All weights loaded successfully with no mismatches.")
else:
    print("Mask Decoder Load Warning:")
    print(f"  Missing Keys: {msg.missing_keys}")
    print(f"  Unexpected Keys: {msg.unexpected_keys}")

In [None]:
# Load LoRA into the image encoder
lora_path = "/content/drive/MyDrive/FYP/MobileSAM_Finetuning/checkpoints/Run_20251214-093848/lora_image_encoder"
mobile_sam.image_encoder = PeftModel.from_pretrained(mobile_sam.image_encoder, lora_path)

# Verification Steps:
# 1. Check if an adapter is active
active_adapters = mobile_sam.image_encoder.active_adapters
print(f"Active LoRA Adapters: {active_adapters}")

# 2. Check for missing keys (PEFT models often warn during .from_pretrained)
# If you want to be 100% sure, check if the lora layers exist in the modules
has_lora = any("lora_" in name for name, _ in mobile_sam.image_encoder.named_modules())
if has_lora:
    print("LoRA layers detected in the Image Encoder.")
else:
    print("ERROR: No LoRA layers found. The adapter was not applied correctly.")

# 3. Print trainable parameters (should be 0 for inference, but confirms structure)
mobile_sam.image_encoder.print_trainable_parameters()

In [None]:
mobile_sam.to(device)

### Finetuner Wrapper Class

In [None]:
class MobileSAMFineTuner(nn.Module):
    def __init__(self, sam_model):
        super().__init__()
        self.sam = sam_model

    def forward(self, images: torch.Tensor, bboxes: list):
        # images: [B, 3, 1024, 1024]
        # bboxes: List of tensors, where bboxes[i] is [N_boxes, 4]

        _, _, H, W = images.shape

        # 1. Compute Image Embeddings (Run once per image)
        image_embeddings = self.sam.image_encoder(images)
        dense_pe = self.sam.prompt_encoder.get_dense_pe()

        # Prepare lists to match the "Previous Wrapper" return format
        final_masks_list = []
        iou_preds_list = []

        B = len(bboxes)

        for i in range(B):
            curr_box = bboxes[i] # Shape [N, 4]

            # Safety check for images with no boxes
            if curr_box.shape[0] == 0:
                 # Return empty tensors so the list index stays aligned
                 final_masks_list.append(torch.zeros(0, 1, H, W, device=images.device))
                 iou_preds_list.append(torch.zeros(0, 1, device=images.device))
                 continue

            curr_embedding = image_embeddings[i].unsqueeze(0)

            # Prompt encoder (Handles N boxes)
            sparse_embeddings, dense_embeddings = self.sam.prompt_encoder(
                points=None,
                boxes=curr_box,
                masks=None,
            )

            # Mask decoder
            low_res_masks, iou_predictions = self.sam.mask_decoder(
                image_embeddings=curr_embedding,
                image_pe=dense_pe,
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
            )
            # low_res_masks shape: [N, 1, 256, 256]

            # Upsample NOW (per image) instead of stacking first
            upsampled_masks = F.interpolate(
                low_res_masks,
                size=(H, W),
                mode="bilinear",
                align_corners=False,
            )
            # upsampled_masks shape: [N, 1, 1024, 1024]

            final_masks_list.append(upsampled_masks)
            iou_preds_list.append(iou_predictions)

        # Return LISTS, not stacked tensors.
        # The evaluation loop will access [0] to get the tensor for the first image.
        return final_masks_list, iou_preds_list

In [None]:
finetuner = MobileSAMFineTuner(sam_model=mobile_sam)

In [None]:
finetuner.to(device)
finetuner.eval()

## Masks Generation & Save

**Note:** Please modify the `output_dir` below to your own desired path.

In [None]:
# Define where to save the masks
output_dir = "/content/drive/MyDrive/FYP/Datasets/test_inference/inference_masks/"
os.makedirs(output_dir, exist_ok=True)

print(f"Saving to {output_dir}")

**Note:**
- There is also some padding added (`k_size` of 20 approximately enlarged the wound mask by 10 pixels). Increase `k_size` if you want more padding.
- The mask is resized to (224, 224) already

In [None]:
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

with torch.no_grad():
  # Wrap loader with tqdm for a progress bar
  for batch in tqdm(dataloader):
      images = batch["image"].to(device).float()
      bboxes = [batch["bboxes"][0].to(device).float()]
      image_name = batch["image_name"][0]

      # Skip if no bounding boxes are detected (SAM needs a prompt)
      if bboxes[0].numel() == 0:
          print(f"Skipping {image_name}: No bounding boxes detected.")
          # Optional: Save a black mask instead of skipping?
          # blank_mask = np.zeros((images.shape[2], images.shape[3]), dtype=np.uint8)
          # cv2.imwrite(save_path, blank_mask)
          continue

      # --- Forward pass (Generate Mask) ---
      pred_masks_list, _ = finetuner(images, bboxes)

      # Get the mask for the first image in batch
      pred_masks = pred_masks_list[0]

      # Combine multiple masks (if multiple boxes) into one binary mask
      # sigmoid -> threshold at 0.5 -> sum across boxes -> clip to 0/1
      combined_pred_mask = (torch.sigmoid(pred_masks) > 0.5).float().sum(dim=0, keepdim=True)
      combined_pred_mask = (combined_pred_mask > 0).float()

      # Squeeze to shape [H, W]
      mask_tensor = combined_pred_mask.squeeze()

      # --- Convert to Image Format ---
      # Convert tensor to numpy and scale to 0-255
      mask_np = mask_tensor.cpu().numpy().astype(np.uint8) * 255

      # --- Add Padding ---
      k_size = 20
      kernel = np.ones((k_size, k_size), np.uint8)
      mask_dilated = cv2.dilate(mask_np, kernel, iterations=1)

      # --- Resize to 224x224 ---
      mask_final = cv2.resize(mask_dilated, (224, 224), interpolation=cv2.INTER_NEAREST)

      # --- Save the Mask ---
      # We change extension to .png to avoid JPEG compression artifacts on masks
      filename_no_ext = os.path.splitext(image_name)[0]
      save_path = os.path.join(output_dir, f"{filename_no_ext}.png")

      cv2.imwrite(save_path, mask_final)
      print(f"Successfully saved the mask to {save_path}\n")