In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt

from torchvision import models
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

In [None]:
# Standard ImageNet Normalization
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

In [None]:
class EvaluationSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, img_size=(1024, 1024), mean=IMAGENET_MEAN, std=IMAGENET_STD):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.images = sorted(os.listdir(image_dir))
        self.masks = sorted(os.listdir(mask_dir))
        self.img_size = img_size
        self.mean = np.array(mean)
        self.std = np.array(std)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        mask_name = self.masks[idx]

        image_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, mask_name)

        # Albumentations expects a NumPy array with uint8 data type
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        # Manual steps for resize, standardization, and conversion
        # 1. Resize if necessary
        if image.shape[0] != self.img_size[0] or image.shape[1] != self.img_size[1]:
            image = cv2.resize(image, self.img_size)
        if mask.shape[0] != self.img_size[0] or mask.shape[1] != self.img_size[1]:
            mask = cv2.resize(mask, self.img_size, interpolation=cv2.INTER_NEAREST)

        # 2. Convert image to float and standardize
        image = image.astype("float32") / 255.0
        image = (image - self.mean) / self.std

        # 3. Convert image to torch tensor and permute (HWC -> CHW)
        image = torch.from_numpy(image).permute(2, 0, 1)

        # 4. Convert mask to float and add a channel dimension
        mask = (mask > 0).astype("float32") # Convert to binary (0.0 or 1.0)
        mask = torch.from_numpy(mask).unsqueeze(0) # HW -> 1HW

        return image, mask

**Note:** Replace your own test paths here

In [None]:
test_sets_base_path = "/content/drive/MyDrive/FYP/Datasets/"
test_set_1_images_path = os.path.join(test_sets_base_path, "FUSeg/test/images")
test_set_1_masks_path = os.path.join(test_sets_base_path, "FUSeg/test/labels")

test_set_2_images_path = os.path.join(test_sets_base_path, "DFUC2022/test/images")
test_set_2_masks_path = os.path.join(test_sets_base_path, "DFUC2022/test/masks")

In [None]:
test_sets = {
    "test_set_fuseg": EvaluationSegmentationDataset(test_set_1_images_path, test_set_1_masks_path),
    "test_set_dfuc2022": EvaluationSegmentationDataset(test_set_2_images_path, test_set_2_masks_path),
}

## Evaluation Metrics

In [None]:
def get_confusion_matrix_components(y_true, y_pred, threshold=0.5):
    """
    Calculates the confusion matrix components (TP, FP, FN) for a batch.

    Args:
        y_true (torch.Tensor): Ground truth masks, a tensor of 0s and 1s.
        y_pred (torch.Tensor): Predicted masks from the model, a tensor of continuous values.
        threshold (float): The binarization threshold.

    Returns:
        tuple: A tuple containing True Positives, False Positives, False Negatives, and True Negatives.
    """
    # Binarize predictions using the specified threshold
    y_pred = (y_pred > threshold).float()

    # Flatten tensors for easier calculation
    y_true_flat = y_true.view(-1)
    y_pred_flat = y_pred.view(-1)

    # Calculate confusion matrix components
    true_positives = ((y_pred_flat == 1) & (y_true_flat == 1)).sum().item()
    false_positives = ((y_pred_flat == 1) & (y_true_flat == 0)).sum().item()
    false_negatives = ((y_pred_flat == 0) & (y_true_flat == 1)).sum().item()
    true_negatives = ((y_pred_flat == 0) & (y_true_flat == 0)).sum().item()

    return true_positives, false_positives, false_negatives, true_negatives

# Final metrics calculation function
def calculate_final_metrics(tp, fp, fn, tn, smooth=1e-6):
    """
    Calculates final metrics from accumulated confusion matrix components.
    """
    # IoU
    intersection = tp
    union = tp + fp + fn
    iou = intersection / (union + smooth)

    # Dice Coefficient
    dice = (2 * tp) / (2 * tp + fp + fn + smooth)

    # Recall (Sensitivity)
    recall = tp / (tp + fn + smooth)

    # Precision (Positive Predictive Value)
    precision = tp / (tp + fp + smooth)

    # Accuracy
    accuracy = (tp + tn) / (tp + tn + fp + fn + smooth)

    return iou, dice, recall, precision, accuracy

## DeepLabv3+ with MobileNetV2

In [None]:
class SeparableConv2d(nn.Module):
    """
    Implements Depthwise Separable Convolution, which is a depthwise convolution
    followed by a pointwise (1x1) convolution.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=False):
        super(SeparableConv2d, self).__init__()

        # Calculate padding to keep spatial size same: p = (d * (k-1)) / 2
        padding = dilation

        # Depthwise convolution: Applies a separate filter to each input channel
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=bias)
        self.bn_depth = nn.BatchNorm2d(in_channels)
        self.relu_depth = nn.ReLU(inplace=True)

        # Pointwise convolution: A 1x1 convolution to mix the channels
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1, bias=bias)
        self.bn_point = nn.BatchNorm2d(out_channels)
        self.relu_point = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.bn_depth(x)
        x = self.relu_depth(x)

        x = self.pointwise(x)
        x = self.bn_point(x)
        x = self.relu_point(x)

        return x

In [None]:
class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ASPP, self).__init__()

        # 1x1 convolution branch (This is always a standard 1x1 convolution)
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        # Atrous separable convolution with rate=6
        self.atrous_block6 = SeparableConv2d(in_channels, out_channels, kernel_size=3, dilation=6)

        # Atrous separable convolution with rate=12
        self.atrous_block12 = SeparableConv2d(in_channels, out_channels, kernel_size=3, dilation=12)

        # Atrous separable convolution with rate=18
        self.atrous_block18 = SeparableConv2d(in_channels, out_channels, kernel_size=3, dilation=18)

        # Global Average Pooling branch
        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(in_channels, out_channels, 1, stride=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        # Final 1x1 convolution to fuse all 5 branches
        self.final_conv = nn.Sequential(
            nn.Conv2d(out_channels * 5, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5)
        )

    def forward(self, x):
        size = x.size()[2:]

        x1 = self.conv1x1(x)
        x2 = self.atrous_block6(x)
        x3 = self.atrous_block12(x)
        x4 = self.atrous_block18(x)
        x5 = self.global_avg_pool(x)

        x5 = F.interpolate(x5, size=size, mode='bilinear', align_corners=False)

        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        x = self.final_conv(x)

        return x

In [None]:
class DeepLabV3Plus(nn.Module):
    def __init__(self, num_classes=1):
        super(DeepLabV3Plus, self).__init__()
        backbone = models.mobilenet_v2(weights="DEFAULT")
        self.backbone = backbone.features  # Get all layers except classifier

        # Modify MobileNetV2 for Output Stride 16
        # Change stride of the 14th block (bottleneck)
        self.backbone[14].conv[1][0].stride = (1, 1)

        # Then all subsequent layers must use dilation=2 to maintain receptive field
        for i in range (14, 19):
          for m in self.backbone[i].modules():
            if isinstance(m, nn.Conv2d):
              # Only apply to 3x3 depthwise convs
              if m.kernel_size == (3, 3):
                m.dilation = (2, 2)
                m.padding = (2, 2)

        # Low-level features come from early layer (for decoder)
        self.low_level_idx = 3
        self.low_level_channels = 24

        # ASPP expects 1280 channels from the last MobileNetV2 layer
        self.aspp = ASPP(in_channels=1280, out_channels=256)

        # Decoder
        self.low_level_project = nn.Sequential(
            nn.Conv2d(self.low_level_channels, 48, kernel_size=1),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )

        self.decoder = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1)
        )

    def forward(self, x):
      input_size = x.size()[2:]

      # Extract low-level and high-level features
      low_level_feat = None
      feat = x
      for i, layer in enumerate(self.backbone):
          feat = layer(feat)
          if i == self.low_level_idx:
              low_level_feat = feat  # Save for decoder

      high_level_feat = feat  # Final output of backbone (usually [B, 1280, H/32, W/32])

      # ASPP on high-level features
      x = self.aspp(high_level_feat)
      x = F.interpolate(x, size=low_level_feat.shape[2:], mode='bilinear', align_corners=False)

      # Decoder
      low_level = self.low_level_project(low_level_feat)
      x = torch.cat([x, low_level], dim=1)
      x = self.decoder(x)
      x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=False)
      return x

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
model = DeepLabV3Plus(num_classes=1)
model = model.to(device)

## Load the trained model

**Note:** Replace your own model checkpoint file

In [None]:
# Define the path to your saved model file
model_path = "/content/drive/MyDrive/FYP/Model_Training/MobileNet/checkpoints/Run_20260117-0531372/best_DeepLabv3PlusModelwithMobileNetV2.pth"

# Load the entire model state dictionary
state_dict = torch.load(model_path, map_location=device)

In [None]:
model.load_state_dict(state_dict)

## Model Evaluation

In [None]:
def denormalize_image(image_tensor, mean, std):
    """
    Denormalizes a tensor image and converts it to a displayable format (HWC, uint8).

    Args:
        image_tensor (torch.Tensor): A normalized image tensor (C, H, W).
        mean (list or tuple): The mean values used for normalization.
        std (list or tuple): The standard deviation values used for normalization.

    Returns:
        np.ndarray: A denormalized NumPy array in (H, W, C) format with uint8 data type.
    """
    mean = np.array(mean).reshape(1, 1, 3)
    std = np.array(std).reshape(1, 1, 3)

    # Transpose from (C, H, W) to (H, W, C)
    img_np = image_tensor.cpu().numpy().transpose(1, 2, 0)

    # Denormalize
    img_np = (img_np * std) + mean

    # Clip and convert to uint8
    img_np = np.clip(img_np, 0, 1) * 255
    img_np = img_np.astype(np.uint8)
    return img_np

In [None]:
# Set the model to evaluation mode
model.eval()

# Evaluate on each test set
for test_set_name, test_set in test_sets.items():
    test_loader = DataLoader(test_set, batch_size=32, shuffle=False)

    total_tp = total_fp = total_fn = total_tn = 0

    with torch.no_grad():
        for X_test_batch, y_test_batch in test_loader:
            X_test_batch, y_test_batch = X_test_batch.to(device).float(), y_test_batch.to(device).float()

            predictions_logits = model(X_test_batch)
            predictions_probs = torch.sigmoid(predictions_logits)

            tp, fp, fn, tn = get_confusion_matrix_components(y_test_batch, predictions_probs, threshold=0.5)

            total_tp += tp
            total_fp += fp
            total_fn += fn
            total_tn += tn

    # Calculate final metrics
    avg_iou, avg_dice, avg_recall, avg_precision, avg_accuracy = calculate_final_metrics(total_tp, total_fp, total_fn, total_tn)

    print(f"\nMetrics for {test_set_name}:")
    print(f"  IoU: {avg_iou:.4f}")
    print(f"  Dice: {avg_dice:.4f}")
    print(f"  Recall: {avg_recall:.4f}")
    print(f"  Precision: {avg_precision:.4f}")
    print(f"  Accuracy: {avg_accuracy:.4f}")

    # --- Visualization (One-by-One Overlay) ---
    try:
        # Grab a single sample batch for visualization
        vis_loader = DataLoader(test_set, batch_size=1, shuffle=True)
        vis_batch = next(iter(vis_loader))
        vis_image, vis_label = vis_batch[0].to(device).float(), vis_batch[1].to(device).float()
    except StopIteration:
        print(f"Test set {test_set_name} is empty.")
        continue

    with torch.no_grad():
        test_pred_logits = model(vis_image)
        test_pred_mask = (torch.sigmoid(test_pred_logits) > 0.5).float()

    # Create the 1x3 display
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    fig.suptitle(f"Sample Prediction - {test_set_name}", fontsize=14)

    # 1. Original image
    original_img = denormalize_image(vis_image[0], IMAGENET_MEAN, IMAGENET_STD)
    axes[0].imshow(original_img)
    axes[0].set_title("Original Image")
    axes[0].axis("off")

    # 2. Ground Truth Overlay (Reds)
    gt_mask = vis_label[0].squeeze().cpu().numpy()
    axes[1].imshow(original_img)
    axes[1].imshow(gt_mask, cmap="Reds", alpha=0.5) # Overlay effect
    axes[1].set_title("Ground Truth Mask")
    axes[1].axis("off")

    # 3. Predicted Mask Overlay (Blues)
    pred_mask = test_pred_mask[0].squeeze().cpu().numpy()
    axes[2].imshow(original_img)
    axes[2].imshow(pred_mask, cmap="Blues", alpha=0.5) # Overlay effect
    axes[2].set_title("Predicted Mask")
    axes[2].axis("off")

    plt.tight_layout()
    plt.show()