In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install git+https://github.com/ChaoningZhang/MobileSAM.git

In [None]:
!mkdir -p weights
!wget -nc https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt -P ./weights/

In [None]:
import torch
import torch.nn as nn
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt

from torchvision import models
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

In [None]:
from mobile_sam import sam_model_registry, SamPredictor

## Evaluation Dataset class

Since MobileSAM is a promptable segmentation model, we need to pass some prompts (bounding box in this case) into the model

In [None]:
# Standard ImageNet Normalization
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

In [None]:
class EvaluationPromptableSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, img_size=(1024, 1024), mean=IMAGENET_MEAN, std=IMAGENET_STD):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.images = sorted(os.listdir(image_dir))
        self.masks = sorted(os.listdir(mask_dir))
        self.img_size = img_size
        self.mean = np.array(mean)
        self.std = np.array(std)

    def __len__(self):
        return len(self.images)

    def get_bounding_boxes_from_mask(self, mask, padding_factor=0.1, min_area_threshold=5):
        """Get one or multiple bounding boxes from binary mask."""
        _, binary_mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
        num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary_mask, 8, cv2.CV_32S)

        boxes = []
        for i in range(1, num_labels):
            x, y, w, h, area = (
                stats[i, cv2.CC_STAT_LEFT],
                stats[i, cv2.CC_STAT_TOP],
                stats[i, cv2.CC_STAT_WIDTH],
                stats[i, cv2.CC_STAT_HEIGHT],
                stats[i, cv2.CC_STAT_AREA],
            )
            if area < min_area_threshold:
                continue

            pad = int(max(w, h) * padding_factor)
            x_min, y_min = max(0, x - pad), max(0, y - pad)
            x_max, y_max = min(mask.shape[1], x + w + pad), min(mask.shape[0], y + h + pad)
            boxes.append([x_min, y_min, x_max, y_max])

        return boxes

    def __getitem__(self, idx):
        img_name = self.images[idx]
        mask_name = self.masks[idx]

        image_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, mask_name)

        # Load image and mask
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        original_H, original_W = mask.shape
        boxes = self.get_bounding_boxes_from_mask(mask)

        # Resize
        if image.shape[:2] != self.img_size:
            image = cv2.resize(image, self.img_size)
        if mask.shape[:2] != self.img_size:
            mask = cv2.resize(mask, self.img_size, interpolation=cv2.INTER_NEAREST)

        target_H, target_W = self.img_size
        scale_x = target_W / original_W
        scale_y = target_H / original_H

        boxes_rescaled = []
        for box in boxes:
            x_min, y_min, x_max, y_max = box
            boxes_rescaled.append([
                x_min * scale_x,
                y_min * scale_y,
                x_max * scale_x,
                y_max * scale_y
            ])

        # Image normalization
        image = image.astype("float32") / 255.0
        image = (image - self.mean) / self.std
        image_tensor = torch.from_numpy(image).permute(2, 0, 1)

        mask_tensor = torch.from_numpy((mask > 0).astype(np.float32)).unsqueeze(0)

        # Return all boxes, not just one (SAM can handle multiple)
        return {
            "image": image_tensor,
            "mask": mask_tensor,
            "bboxes": torch.tensor(boxes_rescaled, dtype=torch.float32),
            "image_name": img_name,
        }

**Note:** Please replace your own test set paths here.

In [None]:
test_sets_base_path = "/content/drive/MyDrive/FYP/Datasets/"
test_set_1_images_path = os.path.join(test_sets_base_path, "FUSeg/test/images")
test_set_1_masks_path = os.path.join(test_sets_base_path, "FUSeg/test/labels")

test_set_2_images_path = os.path.join(test_sets_base_path, "DFUC2022/test/images")
test_set_2_masks_path = os.path.join(test_sets_base_path, "DFUC2022/test/masks")

In [None]:
test_sets = {
    "test_set_FUSeg": EvaluationPromptableSegmentationDataset(test_set_1_images_path, test_set_1_masks_path),
    "test_set_DFUC2022": EvaluationPromptableSegmentationDataset(test_set_2_images_path, test_set_2_masks_path),
}

### Helper function to get the evaluation metrics

In [None]:
def get_confusion_matrix_components(y_true, y_pred, threshold=0.5):
    """
    Calculates the confusion matrix components (TP, FP, FN) for a batch.

    Args:
        y_true (torch.Tensor): Ground truth masks, a tensor of 0s and 1s.
        y_pred (torch.Tensor): Predicted masks from the model, a tensor of continuous values.
        threshold (float): The binarization threshold.

    Returns:
        tuple: A tuple containing True Positives, False Positives, False Negatives, and True Negatives.
    """
    # Binarize predictions using the specified threshold
    y_pred = (y_pred > threshold).float()

    # Flatten tensors for easier calculation
    y_true_flat = y_true.view(-1)
    y_pred_flat = y_pred.view(-1)

    # Calculate confusion matrix components
    true_positives = ((y_pred_flat == 1) & (y_true_flat == 1)).sum().item()
    false_positives = ((y_pred_flat == 1) & (y_true_flat == 0)).sum().item()
    false_negatives = ((y_pred_flat == 0) & (y_true_flat == 1)).sum().item()
    true_negatives = ((y_pred_flat == 0) & (y_true_flat == 0)).sum().item()

    return true_positives, false_positives, false_negatives, true_negatives

# Final metrics calculation function
def calculate_final_metrics(tp, fp, fn, tn, smooth=1e-6):
    """
    Calculates final metrics from accumulated confusion matrix components.
    """
    # IoU
    intersection = tp
    union = tp + fp + fn
    iou = intersection / (union + smooth)

    # Dice Coefficient
    dice = (2 * tp) / (2 * tp + fp + fn + smooth)

    # Recall (Sensitivity)
    recall = tp / (tp + fn + smooth)

    # Precision (Positive Predictive Value)
    precision = tp / (tp + fp + smooth)

    # Accuracy
    accuracy = (tp + tn) / (tp + tn + fp + fn + smooth)

    return iou, dice, recall, precision, accuracy

## MobileSAM + Domain Adapter Module

In [None]:
class DecoderAdapter(nn.Module):
  """
  A bottleneck adapter module for PEFT.
  Inserts a small trainable module into the mask decoder.
  """

  def __init__(self, in_dim: int, adapter_dim: int):
    super().__init__()

    # Down-projection: from model dimension (in_dim) to a smaller adapter_dim
    self.down = nn.Linear(in_dim, adapter_dim)

    # Non-linearity
    self.non_linearity = nn.GELU()

    # Up-projection: from adapter_dim back to model dimension (in_dim)
    self.up = nn.Linear(adapter_dim, in_dim)

    # Initialize to near-zero to start
    nn.init.normal_(self.up.weight, std=1e-4)
    nn.init.zeros_(self.up.bias)

  def forward(self, x):
    # The adapter output is added to the input (residual connection)
    return x + self.up(self.non_linearity(self.down(x)))

In [None]:
def inject_domain_adapter(mask_decoder, adapter_dim=64):
    """
    Inject domain adapters into SAM's mask decoder MLP blocks.
    Properly registers each adapter as a submodule (tracked by .to(device)).
    """
    adapter_idx = 0

    modules = list(mask_decoder.named_modules())

    for name, module in modules:
        if isinstance(module, nn.Linear) and module.out_features == module.in_features:
            in_dim = module.out_features
            adapter = DecoderAdapter(in_dim, adapter_dim)

            # Register adapter properly as a submodule
            adapter_name = f"domain_adapter_{adapter_idx}"
            setattr(mask_decoder, adapter_name, adapter)
            adapter_idx += 1

            # Wrap original forward
            old_forward = module.forward

            def new_forward(x, old_forward=old_forward, adapter=getattr(mask_decoder, adapter_name)):
                return adapter(old_forward(x))

            module.forward = new_forward

    return mask_decoder


In [None]:
MODEL_TYPE = "vit_t"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

**Note:** Please replace your own model path

In [None]:
MODEL_PATH = "/content/drive/MyDrive/FYP/Model_Training/SAM/best_finetunedMobileSAM_DA.pth"

In [None]:
# Load the model
sam_checkpoint = "./weights/mobile_sam.pt"
model_type = "vit_t" # MobileSAM

mobile_sam = sam_model_registry[MODEL_TYPE](checkpoint=sam_checkpoint)
mobile_sam.mask_decoder = inject_domain_adapter(mobile_sam.mask_decoder, adapter_dim=64)

mobile_sam.to(device)

In [None]:
TRAINING_IMG_SIZE = 1024
SAM_EMBEDDING_SIZE = 1024

class MobileSAMFineTuner(nn.Module):
    def __init__(self, sam_model, train_img_size, sam_emb_size):
        super().__init__()
        self.sam = sam_model
        self.train_img_size = train_img_size
        self.sam_emb_size = sam_emb_size
        self.scale_factor = sam_emb_size / train_img_size

        # Freeze image encoder (keep prompt + mask decoder trainable)
        for name, param in self.sam.named_parameters():
            if name.startswith('image_encoder'):
                param.requires_grad = False
            else:
                param.requires_grad = True

    def forward(self, images: torch.Tensor, batch_bboxes: list[torch.Tensor]):
        """
        Args:
            images: Tensor [B, 3, H, W]
            batch_bboxes: list of length B, each element is a Tensor [N_i, 4]
                          where N_i = number of boxes for image i
        Returns:
            all_masks: list of Tensors, each of shape [N_i, 1, H, W]
            all_iou_preds: list of Tensors, each of shape [N_i, 1]
        """
        B, C, H, W = images.shape

        # --- Preprocess images (rescale + normalize) ---
        preprocessed_images = []
        for i in range(B):
            preprocessed_img = self.sam.preprocess(images[i])
            preprocessed_images.append(preprocessed_img)
        input_images = torch.stack(preprocessed_images, dim=0)  # [B, 3, 1024, 1024]

        # --- Compute frozen image embeddings ---
        with torch.no_grad():
            image_embeddings = self.sam.image_encoder(input_images)  # [B, 1024, 64, 64]

        all_masks = []
        all_iou_preds = []

        # --- For each image, process all bounding boxes ---
        for i in range(B):
            image_embedding_i = image_embeddings[i].unsqueeze(0)  # [1, 1024, 64, 64]
            boxes_i = batch_bboxes[i]  # [N_i, 4]

            # Scale bounding boxes to match SAM embedding space
            scaled_boxes_i = boxes_i * self.scale_factor  # [N_i, 4]

            # Encode multiple box prompts at once
            sparse_embeddings_i, dense_embeddings_i = self.sam.prompt_encoder(
                points=None,
                boxes=scaled_boxes_i,
                masks=None,
            )

            # Decode all masks for this image
            low_res_masks_i, iou_predictions_i = self.sam.mask_decoder(
                image_embeddings=image_embedding_i,  # [1, 1024, 64, 64]
                image_pe=self.sam.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings_i,  # [N_i, C]
                dense_prompt_embeddings=dense_embeddings_i,    # [1, C, H, W]
                multimask_output=False,
            )

            # Upsample predicted masks to full image size
            upsampled_masks_i = F.interpolate(
                low_res_masks_i,
                size=(self.train_img_size, self.train_img_size),
                mode="bilinear",
                align_corners=False,
            )

            all_masks.append(upsampled_masks_i)      # [N_i, 1, H, W]
            all_iou_preds.append(iou_predictions_i)  # [N_i, 1]

        return all_masks, all_iou_preds


In [None]:
finetuner = MobileSAMFineTuner(sam_model=mobile_sam, train_img_size=1024, sam_emb_size=1024)
finetuner.to(device)

In [None]:
# --- 4. Load the saved model weights ---
state_dict = torch.load(MODEL_PATH, map_location=device)

# --- 5. Load the weights into the model ---
missing, unexpected = finetuner.load_state_dict(state_dict, strict=False)

print("✅ Model loaded successfully.")
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)

In [None]:
finetuner.eval()

## Model Evaluation

In [None]:
def denormalize_image(image_tensor, mean, std):
    """
    Denormalizes a tensor image and converts it to a displayable format (HWC, uint8).

    Args:
        image_tensor (torch.Tensor): A normalized image tensor (C, H, W).
        mean (list or tuple): The mean values used for normalization.
        std (list or tuple): The standard deviation values used for normalization.

    Returns:
        np.ndarray: A denormalized NumPy array in (H, W, C) format with uint8 data type.
    """
    mean = np.array(mean).reshape(1, 1, 3)
    std = np.array(std).reshape(1, 1, 3)

    # Transpose from (C, H, W) to (H, W, C)
    img_np = image_tensor.cpu().numpy().transpose(1, 2, 0)

    # Denormalize
    img_np = (img_np * std) + mean

    # Clip and convert to uint8
    img_np = np.clip(img_np, 0, 1) * 255
    img_np = img_np.astype(np.uint8)
    return img_np

In [None]:
# Evaluate on each test set
for test_set_name, test_set in test_sets.items():

    test_loader = DataLoader(test_set, batch_size=1, shuffle=False)
    # ⚠️ Use batch_size=1 for promptable models (different number of boxes per image)

    total_tp = total_fp = total_fn = total_tn = 0

    with torch.no_grad():
        for batch in test_loader:
            images = batch["image"].to(device).float()        # [1, 3, H, W]
            masks_gt = batch["mask"].to(device).float()       # [1, 1, H, W]
            bboxes = [batch["bboxes"][0].to(device).float()]  # list([N_i, 4])
            image_name = batch["image_name"][0]

            # Skip if no bounding boxes are detected for this image
            if bboxes[0].numel() == 0:
                print(f"Skipping {image_name} from {test_set_name}: No bounding boxes detected.")
                continue

            # Forward pass
            pred_masks_list, iou_preds_list = finetuner(images, bboxes)

            # Each pred_masks_list[i] has shape [N_i, 1, H, W]
            pred_masks = pred_masks_list[0]                   # [N_i, 1, H, W]
            mask_gt = masks_gt.squeeze(0)                     # [1, H, W]

            # Combine all predicted masks into one binary mask
            combined_pred_mask = (torch.sigmoid(pred_masks) > 0.5).float().sum(dim=0, keepdim=True)
            combined_pred_mask = (combined_pred_mask > 0).float()  # Union of all object masks

            # Compute confusion matrix components for this image
            tp, fp, fn, tn = get_confusion_matrix_components(mask_gt.unsqueeze(0), combined_pred_mask.unsqueeze(0))
            total_tp += tp
            total_fp += fp
            total_fn += fn
            total_tn += tn

    # --- Final metrics for the test set ---
    avg_iou, avg_dice, avg_recall, avg_precision, avg_accuracy = calculate_final_metrics(
        total_tp, total_fp, total_fn, total_tn
    )

    print(f"\nMetrics for {test_set_name}:")
    print(f"  IoU: {avg_iou:.4f}")
    print(f"  Dice: {avg_dice:.4f}")
    print(f"  Recall: {avg_recall:.4f}")
    print(f"  Precision: {avg_precision:.4f}")
    print(f"  Accuracy: {avg_accuracy:.4f}")

    # --- Visualization Section ---
    # Get a sample batch for visualization (ensure it has bounding boxes)
    vis_batch = None
    for batch in DataLoader(test_set, batch_size=1, shuffle=True):
        if batch["bboxes"][0].numel() > 0:
            vis_batch = batch
            break

    if vis_batch is None:
        print(f"Could not find a sample with bounding boxes in {test_set_name} for visualization.")
        continue

    vis_images = vis_batch["image"].to(device).float()
    vis_masks_gt = vis_batch["mask"].to(device).float()
    vis_bboxes = [vis_batch["bboxes"][0].to(device).float()]
    vis_image_name = vis_batch["image_name"]


    with torch.no_grad():
        vis_pred_masks_list, _ = finetuner(vis_images, vis_bboxes)
        vis_pred_masks = vis_pred_masks_list[0]  # [N, 1, H, W]
        vis_pred_combined = (torch.sigmoid(vis_pred_masks) > 0.5).float().sum(dim=0, keepdim=True)
        vis_pred_combined = (vis_pred_combined > 0).float()

    # --- Display visualization ---
    fig, axes = plt.subplots(1, 3, figsize=(12, 6))
    fig.suptitle(f"Sample Predictions - {test_set_name}", fontsize=14)

    # Original image
    original_img = denormalize_image(vis_images[0], IMAGENET_MEAN, IMAGENET_STD)
    axes[0].imshow(original_img)
    axes[0].set_title("Original Image")
    axes[0].axis("off")

    # Ground Truth
    gt_mask = vis_masks_gt[0].squeeze().cpu().numpy()
    axes[1].imshow(original_img)
    axes[1].imshow(gt_mask, cmap="Reds", alpha=0.5)
    axes[1].set_title("Ground Truth Mask")
    axes[1].axis("off")

    # Predicted Mask
    pred_mask = vis_pred_combined[0].squeeze().cpu().numpy()
    axes[2].imshow(original_img)
    axes[2].imshow(pred_mask, cmap="Blues", alpha=0.5)
    axes[2].set_title("Predicted Mask")
    axes[2].axis("off")

    plt.tight_layout()
    plt.show()