In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install git+https://github.com/ChaoningZhang/MobileSAM.git
!mkdir -p weights
!wget -nc https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt -P ./weights/

In [None]:
!pip install peft

In [None]:
import os
import cv2
import ast
import pandas as pd
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset


from peft import LoraConfig, get_peft_model, PeftModel
from mobile_sam import sam_model_registry

# Tissue Classification
from typing import Any
from sklearn.cluster import MiniBatchKMeans
from sklearn.metrics import silhouette_score
from scipy.spatial.distance import cdist

## Custom Dataset classes

There are two custom dataset classes defined here. Please choose the one which is suitable for your use case.

In [None]:
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

### SegmentationDataset class

This class is for the dataset where it contains the ground truth segmentation masks for bounding box(es) extraction. Just need to pass the paths that contain the images and ground truth masks. It will automatically prepare the bounding box(es) coordinates from the ground truth masks.

In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, img_size=(1024, 1024), mean=IMAGENET_MEAN, std=IMAGENET_STD):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.images = sorted(os.listdir(image_dir))
        self.masks = sorted(os.listdir(mask_dir))
        self.img_size = img_size
        self.mean = np.array(mean)
        self.std = np.array(std)

    def __len__(self):
        return len(self.images)

    def get_bounding_boxes_from_mask(self, mask, padding_factor=0.1, min_area_threshold=5):
        """Get one or multiple bounding boxes from binary mask."""
        _, binary_mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
        num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary_mask, 8, cv2.CV_32S)

        boxes = []
        for i in range(1, num_labels):
            x, y, w, h, area = (
                stats[i, cv2.CC_STAT_LEFT],
                stats[i, cv2.CC_STAT_TOP],
                stats[i, cv2.CC_STAT_WIDTH],
                stats[i, cv2.CC_STAT_HEIGHT],
                stats[i, cv2.CC_STAT_AREA],
            )
            if area < min_area_threshold:
                continue

            pad = int(max(w, h) * padding_factor)
            x_min, y_min = max(0, x - pad), max(0, y - pad)
            x_max, y_max = min(mask.shape[1], x + w + pad), min(mask.shape[0], y + h + pad)
            boxes.append([x_min, y_min, x_max, y_max])

        return boxes

    def __getitem__(self, idx):
        img_name = self.images[idx]
        mask_name = self.masks[idx]

        image_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, mask_name)

        # Load image and mask
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        original_H, original_W = mask.shape
        boxes = self.get_bounding_boxes_from_mask(mask)

        # Resize
        if image.shape[:2] != self.img_size:
            image = cv2.resize(image, self.img_size)
        if mask.shape[:2] != self.img_size:
            mask = cv2.resize(mask, self.img_size, interpolation=cv2.INTER_NEAREST)

        target_H, target_W = self.img_size
        scale_x = target_W / original_W
        scale_y = target_H / original_H

        boxes_rescaled = []
        for box in boxes:
            x_min, y_min, x_max, y_max = box
            boxes_rescaled.append([
                x_min * scale_x,
                y_min * scale_y,
                x_max * scale_x,
                y_max * scale_y
            ])

        # Image normalization
        image = image.astype("float32") / 255.0
        image = (image - self.mean) / self.std
        image_tensor = torch.from_numpy(image).permute(2, 0, 1)

        mask_tensor = torch.from_numpy((mask > 0).astype(np.float32)).unsqueeze(0)

        # Return all boxes, not just one (SAM can handle multiple)
        return {
            "image": image_tensor,
            "mask": mask_tensor,
            "bboxes": torch.tensor(boxes_rescaled, dtype=torch.float32),
            "image_name": img_name,
        }

**Note:** Please modify the paths below to your own paths that contain the wound images and their ground truth masks

In [None]:
dataset_images_path = "/content/drive/MyDrive/FYP/Datasets/test_inference/images/"
dataset_masks_path = "/content/drive/MyDrive/FYP/Datasets/test_inference/masks/"

In [None]:
dataset = SegmentationDataset(dataset_images_path, dataset_masks_path)

### RealWorldInferenceDataset class
This class is for real world collected dataset that has no ground truth segmentation masks for bounding box(es) extraction. Therefore, we need to manually annotate the bounding box(es) of the wound image in PASCAL format (x1, y1, x2, y2) and record them in a csv file.

About the csv file, it must contain two columns: **image_name** and **bbox**. If you have different column names, you can also modify the code below to match the column names in your csv file. Please ensure your image_name is matched and the bounding box is in the correct format: [[x1, y1, x2, y2]]. If there are two bounding boxes, the format is [[x1, y1, x2, y2], [x1, y1, x2, y2]].

In [None]:
class RealWorldInferenceDataset(Dataset):
    def __init__(self, image_dir, csv_path, img_size=(1024, 1024), mean=IMAGENET_MEAN, std=IMAGENET_STD):
        self.image_dir = image_dir
        self.img_size = img_size
        self.mean = np.array(mean)
        self.std = np.array(std)

        # 1. Load CSV
        # We assume the CSV has columns: 'image_name' and 'bbox'
        df = pd.read_csv(csv_path)

        # Create a dictionary for fast lookup:
        # {'image_01.jpg': "[[10, 10, 100, 100]]", ...}
        self.bbox_map = dict(zip(df['image_name'], df['bbox']))

        # 2. Filter images
        # Only include images that exist in BOTH the folder AND the CSV
        available_files = set(os.listdir(image_dir))
        self.images = [img for img in df['image_name'] if img in available_files]

        print(f"Found {len(self.images)} images with matching bounding boxes in CSV.")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        image_path = os.path.join(self.image_dir, img_name)

        # Load Image
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Get original dimensions for scaling calculations
        original_H, original_W = image.shape[:2]

        # Get Bounding Boxes
        # The CSV likely stores them as strings, so we parse them into lists
        boxes_raw = self.bbox_map[img_name]
        if isinstance(boxes_raw, str):
            boxes_raw = ast.literal_eval(boxes_raw)

        # Resize Image to model input size (usually 1024 for SAM)
        if image.shape[:2] != self.img_size:
            image = cv2.resize(image, self.img_size)

        # --- CRITICAL: Rescale Bounding Boxes ---
        # If we resize the image, we MUST resize the box coordinates too
        target_H, target_W = self.img_size
        scale_x = target_W / original_W
        scale_y = target_H / original_H

        boxes_rescaled = []
        for box in boxes_raw:
            x_min, y_min, x_max, y_max = box
            boxes_rescaled.append([
                x_min * scale_x,
                y_min * scale_y,
                x_max * scale_x,
                y_max * scale_y
            ])

        # Normalize Image
        image = image.astype("float32") / 255.0
        image = (image - self.mean) / self.std
        image_tensor = torch.from_numpy(image).permute(2, 0, 1)

        # Return dict (Note: No 'mask' key needed for pure inference)
        return {
            "image": image_tensor,
            "bboxes": torch.tensor(boxes_rescaled, dtype=torch.float32),
            "image_name": img_name,
        }

**Note:** Please modify the paths below to the your own paths that contain real world dataset and the csv file.

In [None]:
real_world_dataset_path = "/content/drive/MyDrive/FYP/Datasets/test_inference/real_world/"
csv_file_path = "/content/drive/MyDrive/FYP/Datasets/test_inference/real_world/real_world_dataset.csv"

In [None]:
dataset = RealWorldInferenceDataset(real_world_dataset_path, csv_file_path)

## Model Initialization

### MobileSAM

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Load the MobileSAM model
model_type = "vit_t"
checkpoint_path = "./weights/mobile_sam.pt"

mobile_sam = sam_model_registry[model_type](checkpoint=checkpoint_path)

In [None]:
# Load trained mask decoder
decoder_path = "/content/drive/MyDrive/FYP/MobileSAM_Finetuning/checkpoints/Run_20251214-093848/mask_decoder.pth"

# Capture the loading message
msg = mobile_sam.mask_decoder.load_state_dict(torch.load(decoder_path, map_location=device))

# Check for issues
if len(msg.missing_keys) == 0 and len(msg.unexpected_keys) == 0:
    print("Mask Decoder: All weights loaded successfully with no mismatches.")
else:
    print("Mask Decoder Load Warning:")
    print(f"  Missing Keys: {msg.missing_keys}")
    print(f"  Unexpected Keys: {msg.unexpected_keys}")

In [None]:
# Load LoRA into the image encoder
lora_path = "/content/drive/MyDrive/FYP/MobileSAM_Finetuning/checkpoints/Run_20251214-093848/lora_image_encoder"
mobile_sam.image_encoder = PeftModel.from_pretrained(mobile_sam.image_encoder, lora_path)

# Verification Steps:
# 1. Check if an adapter is active
active_adapters = mobile_sam.image_encoder.active_adapters
print(f"Active LoRA Adapters: {active_adapters}")

# 2. Check for missing keys (PEFT models often warn during .from_pretrained)
# If you want to be 100% sure, check if the lora layers exist in the modules
has_lora = any("lora_" in name for name, _ in mobile_sam.image_encoder.named_modules())
if has_lora:
    print("LoRA layers detected in the Image Encoder.")
else:
    print("ERROR: No LoRA layers found. The adapter was not applied correctly.")

# 3. Print trainable parameters (should be 0 for inference, but confirms structure)
mobile_sam.image_encoder.print_trainable_parameters()

In [None]:
mobile_sam.to(device)

### Finetuner Wrapper Class

In [None]:
class MobileSAMFineTuner(nn.Module):
    def __init__(self, sam_model):
        super().__init__()
        self.sam = sam_model

    def forward(self, images: torch.Tensor, bboxes: list):
        # images: [B, 3, 1024, 1024]
        # bboxes: List of tensors, where bboxes[i] is [N_boxes, 4]

        _, _, H, W = images.shape

        # 1. Compute Image Embeddings (Run once per image)
        image_embeddings = self.sam.image_encoder(images)
        dense_pe = self.sam.prompt_encoder.get_dense_pe()

        # Prepare lists to match the "Previous Wrapper" return format
        final_masks_list = []
        iou_preds_list = []

        B = len(bboxes)

        for i in range(B):
            curr_box = bboxes[i] # Shape [N, 4]

            # Safety check for images with no boxes
            if curr_box.shape[0] == 0:
                 # Return empty tensors so the list index stays aligned
                 final_masks_list.append(torch.zeros(0, 1, H, W, device=images.device))
                 iou_preds_list.append(torch.zeros(0, 1, device=images.device))
                 continue

            curr_embedding = image_embeddings[i].unsqueeze(0)

            # Prompt encoder (Handles N boxes)
            sparse_embeddings, dense_embeddings = self.sam.prompt_encoder(
                points=None,
                boxes=curr_box,
                masks=None,
            )

            # Mask decoder
            low_res_masks, iou_predictions = self.sam.mask_decoder(
                image_embeddings=curr_embedding,
                image_pe=dense_pe,
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
            )
            # low_res_masks shape: [N, 1, 256, 256]

            # Upsample NOW (per image) instead of stacking first
            upsampled_masks = F.interpolate(
                low_res_masks,
                size=(H, W),
                mode="bilinear",
                align_corners=False,
            )
            # upsampled_masks shape: [N, 1, 1024, 1024]

            final_masks_list.append(upsampled_masks)
            iou_preds_list.append(iou_predictions)

        # Return LISTS, not stacked tensors.
        # The evaluation loop will access [0] to get the tensor for the first image.
        return final_masks_list, iou_preds_list

In [None]:
finetuner = MobileSAMFineTuner(sam_model=mobile_sam)

In [None]:
finetuner.to(device)
finetuner.eval()

## Tissue Classification Setup

In [None]:
class KMeansClusterer:
    def __init__(self, max_clusters: int = 3, sample_size: int = 50000, l_weight: float = 0.3):
        self.max_clusters = max_clusters
        self.sample_size = sample_size
        self.l_weight = l_weight

        # Reference Colors (L, A, B)
        self.REF_LAB = np.array([
            [20, 128, 128],    # Necrotic (Idx 0)
            [130, 170, 130],   # Granulation (Idx 1)
            [200, 128, 170]    # Slough (Idx 2)
        ])

        self.REF_NAMES = ['Necrotic', 'Granulation', 'Slough']

        self.TISSUE_MAP = {
          'Slough': {'id': 1, 'color': [0, 255, 255]},      # Yellow (Green + Red)
          'Granulation': {'id': 2, 'color': [0, 0, 255]},   # Red
          'Necrotic': {'id': 3, 'color': [0, 0, 0]},       # Black
        }

    def _preprocess_features(self, lab_pixels):
        # Convert the image data from integers (uint8, 0-255) to floating point numbers (float32)
        pixels = lab_pixels.astype(np.float32)

        # Select all rows for 0th column, then multiply the value in that column by the l_weight
        # l_weight is the shadow suppression factor --> how important the L channel is
        pixels[:, 0] *= self.l_weight
        return pixels

    def cluster(self, wound_image: np.ndarray, mask: np.ndarray):
        print("\n" + "="*40)
        print(" [START] K-Means Clustering Process")
        print("="*40)

        # 1. Preprocessing
        if mask.ndim == 3:
            mask = mask[:, :, 0]    # Select all rows & columns but only the 0th color channel

        blurred_img = cv2.GaussianBlur(wound_image, (5, 5), 0)      # Smooth the sparkles (tiny white dots from flash reflection) out
        lab_image = cv2.cvtColor(blurred_img, cv2.COLOR_RGB2Lab)    # Convert to LAB

        wound_indices = np.where(mask > 0)          # Find the coordinates (y, x) of every white pixel (where the wound is)
        wound_pixels = lab_image[wound_indices]     # Extract the wound region out
        n_pixels = wound_pixels.shape[0]            # Original image: (100, 100, 3), wound_pixels: (1000, 3)

        print(f"[STEP 1] Data Extraction")
        print(f"  > Total Wound Pixels: {n_pixels}")

        if n_pixels == 0:
            return wound_image

        # 2. Weighting
        weighted_pixels = self._preprocess_features(wound_pixels)

        # 3. Sampling
        if n_pixels > self.sample_size:
            indices = np.random.choice(n_pixels, self.sample_size, replace=False)
            training_data = weighted_pixels[indices]
            print(f"  > Sampling: Reduced {n_pixels} -> {self.sample_size} pixels for training.")
        else:
            training_data = weighted_pixels
            print(f"  > Sampling: Using all {n_pixels} pixels.")

        # 4. Model Selection
        best_kmeans = None
        best_score = -1.0
        best_k = 1

        # Variance check on A & B channel
        chromatic_var = np.var(training_data[:, 1:], axis=0).sum()
        print(f"\n[STEP 2] Variance Check")
        print(f"  > Chromatic Variance (A+B): {chromatic_var:.2f}")

        if chromatic_var > 95.0:
            print(f"  > Variance is high enough. Testing K=2 to K={self.max_clusters}...")

            # Finding Best K (2 or 3)
            for k in range(2, self.max_clusters + 1):
                kmeans = MiniBatchKMeans(n_clusters=k, batch_size=256, random_state=42, n_init=3)
                labels = kmeans.fit_predict(training_data)

                try:
                    score = silhouette_score(training_data, labels, sample_size=1000)
                except ValueError:
                    score = 0

                print(f"    [TEST] K={k} -> Silhouette Score: {score:.4f}")

                # SELECTION LOGIC DEBUGGING
                if score > 0.25:
                    if score > best_score:
                        print(f"       -> ACCEPTED. (Reason: {score:.4f} >= {best_score:.4f})")
                        best_score = score
                        best_kmeans = kmeans
                        best_k = k
                    else:
                        print(f"       -> REJECTED. (Reason: {score:.4f} < {best_score:.4f})")
                else:
                    print(f"       -> REJECTED. (Score too low)")

        # Fallback
        if best_kmeans is None:
            print("\n[RESULT] Variance/Score too low. Fallback to K=1.")
            best_kmeans = MiniBatchKMeans(n_clusters=1, random_state=42).fit(training_data)
        else:
            print(f"\n[RESULT] Selected Best Model: K={best_k}")

        # 5. Prediction
        all_labels = best_kmeans.predict(weighted_pixels)       # a list of labels of every wound pixels -> [0, 1, 1, 2, ...]
        weighted_centroids = best_kmeans.cluster_centers_

        # 6. Un-weight
        real_centroids = weighted_centroids.copy()
        real_centroids[:, 0] /= self.l_weight

        # 7. Mapping
        print(f"\n[STEP 3] Mapping Clusters to Tissues")
        mapped_labels_flat = self._map_clusters_to_tissues(all_labels, real_centroids)

        # 8. Reconstruction
        clustered_mask = np.zeros_like(mask, dtype=np.uint8)    # Create a blank canvas the exact same size as the original photo
        clustered_mask[wound_indices] = mapped_labels_flat      # Take the first number from the mapped_labels_flat list into first coord & so on

        clustered_mask = cv2.medianBlur(clustered_mask, 5)                              # smoothing out the noise
        clustered_mask = cv2.bitwise_and(clustered_mask, clustered_mask, mask=mask)     # boundary enforcement

        # Apply the color of the tissues based on their ids
        overlay = wound_image.copy()
        for _, props in self.TISSUE_MAP.items():
            overlay[clustered_mask == props['id']] = props['color']

        # Add transparency (60% original photo, 40% tissues' colors)
        blended = cv2.addWeighted(wound_image, 0.6, overlay, 0.4, 0)
        print("="*40 + "\n")
        return blended

    def _map_clusters_to_tissues(self, labels, centroids):
        # Debug: Show raw centroid data
        for i, c in enumerate(centroids):
            print(f"  > Cluster {i} Centroid (LAB): [{c[0]:.1f}, {c[1]:.1f}, {c[2]:.1f}]")

        # Shrink L by 50% during matching so Color (A&B channel) is 2x more important
        match_weight = np.array([0.5, 1.0, 1.0])
        w_centroids = centroids * match_weight  # Weight the Centroids
        w_refs = self.REF_LAB * match_weight    # Weight the Reference Colors

        # Calculate the minimum distance from every cluster to every reference colour
        dists = cdist(w_centroids, w_refs, metric='euclidean')
        closest_ref_indices = np.argmin(dists, axis=1)

        print(f"  > Initial Assignments (Indices): {closest_ref_indices}")

        # --- CONFLICT RESOLUTION ---
        unique_assignments = np.unique(closest_ref_indices)

        if len(unique_assignments) < len(centroids):
            print("  > [!] CONFLICT DETECTED: Multiple clusters mapped to same tissue.")

            # Case K=2
            if len(centroids) == 2 and closest_ref_indices[0] == closest_ref_indices[1]:
                dup_tissue = self.REF_NAMES[closest_ref_indices[0]]
                print(f"    > Conflict Type: Both K=2 clusters mapped to '{dup_tissue}'")

                # Identify which cluster is Darker/ Lighter
                # Sort the indices based on the lightness of centroid
                if centroids[0][0] < centroids[1][0]:
                    idx_dark = 0
                    idx_light = 1
                else:
                    idx_dark = 1
                    idx_light = 0

                avg_lightness = (centroids[0][0] + centroids[1][0]) / 2.0
                print(f"    > Action: Splitting based on Average Lightness ({avg_lightness})")

                # Dynamic Decision based on Lightness Threshold
                if avg_lightness < 60.0:
                    print(f"    -> Low Lightness detected: Splitting into Necrotic + Granulation")
                    closest_ref_indices[idx_dark] = 0   # Necrotic (Black)
                    closest_ref_indices[idx_light] = 1  # Granulation (Red)
                else:
                    print(f"    -> High Lightness detected. Splitting into Granulation + Slough")
                    closest_ref_indices[idx_dark] = 1   # Granulation (Red)
                    closest_ref_indices[idx_light] = 2  # Slough (Yellow)

            # Case K=3
            elif len(centroids) == 3:
                print("    > Conflict Type: K=3 Overlap. Running Greedy Assignment...")

                # Create 9 tuples (3 clusters x 3 tissues)
                flat_dists = []
                for r in range(3):
                    for c in range(3):
                        flat_dists.append((dists[r,c], r, c))

                # Sort the tuples based on their distance (smallest first)
                flat_dists.sort(key=lambda x: x[0])

                assigned_clusters = set()
                assigned_tissues = set()
                new_indices = [0, 0, 0]

                for d, r, c in flat_dists:
                    if r not in assigned_clusters and c not in assigned_tissues:
                        print(f"      -> Assigning Cluster {r} to {self.REF_NAMES[c]} (Dist={d:.1f})")
                        new_indices[r] = c
                        assigned_clusters.add(r)
                        assigned_tissues.add(c)
                closest_ref_indices = np.array(new_indices)

        # Create the Look-Up Table (LUT)
        print("  > Final Mapping:")

        # Create a small & empty array of zeroes with the clsuter's size
        # E.g. K=3, so labels will be 0, 1, 2
        # np.max(labels) -> 2
        # +1 so the array's length is 3 -> lut = [0, 0, 0]
        lut = np.zeros(int(np.max(labels)) + 1, dtype=np.uint8)

        # E.g. closest_ref_indices -> [1, 2, 0] = Cluster 0 is Reference 1 (Granulation) and so on
        # cluster_idx = 0, ref_idx = 1
        for cluster_idx, ref_idx in enumerate[Any](closest_ref_indices):
            tissue_name = self.REF_NAMES[ref_idx]                       # Get the tissue name from REF_NAMES
            lut[cluster_idx] = self.TISSUE_MAP[tissue_name]['id']       # Get the id of that tissue from TISSUE_MAP
            print(f"    -> Cluster {cluster_idx} ==> {tissue_name}")

        # Replace every number in the labels with the value in the lut
        return lut[labels]

## Segmentation + Tissue Classification Masks Generation & Save

**Note:** Please modify the `output_dir` below to your own desired paths.

In [None]:
# Define where to save the masks
output_dir_mask = "/content/drive/MyDrive/FYP/Datasets/inference/masks/"
output_dir_tissue = "/content/drive/MyDrive/FYP/Datasets/inference/tissue_classification/"
os.makedirs(output_dir_mask, exist_ok=True)
os.makedirs(output_dir_tissue, exist_ok=True)

print(f"Saving the segmentation masks to {output_dir_mask}")
print(f"Saving the tissue classification overlays to {output_dir_tissue}")

**Note:**
- There is also some padding added (`k_size` of 20 approximately enlarged the wound mask by 10 pixels). Increase `k_size` if you want more padding.
- The mask is resized to (224, 224) already
- The tissue classification overlay is still the original image size and not resized to 224x224 yet

In [None]:
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
clusterer = KMeansClusterer()

with torch.no_grad():
  # Wrap loader with tqdm for a progress bar
  for batch in tqdm(dataloader):
      images = batch["image"].to(device).float()
      bboxes = [batch["bboxes"][0].to(device).float()]
      image_name = batch["image_name"][0]

      # Skip if no bounding boxes are detected (SAM needs a prompt)
      if bboxes[0].numel() == 0:
          print(f"Skipping {image_name}: No bounding boxes detected.")
          # Optional: Save a black mask instead of skipping?
          # blank_mask = np.zeros((images.shape[2], images.shape[3]), dtype=np.uint8)
          # cv2.imwrite(save_path, blank_mask)
          continue


      # ------------------------------------
      # --- Segmentation Mask Generation ---
      # ------------------------------------
      pred_masks_list, _ = finetuner(images, bboxes)

      # Get the mask for the first image in batch
      pred_masks = pred_masks_list[0]

      # Combine multiple masks (if multiple boxes) into one binary mask
      # sigmoid -> threshold at 0.5 -> sum across boxes -> clip to 0/1
      combined_pred_mask = (torch.sigmoid(pred_masks) > 0.5).float().sum(dim=0, keepdim=True)
      combined_pred_mask = (combined_pred_mask > 0).float()

      # Squeeze to shape [H, W]
      mask_tensor = combined_pred_mask.squeeze()

      # --- Convert to Image Format ---
      # Convert tensor to numpy and scale to 0-255
      mask_np = mask_tensor.cpu().numpy().astype(np.uint8) * 255

      # --- Add Padding ---
      k_size = 20
      kernel = np.ones((k_size, k_size), np.uint8)
      mask_dilated = cv2.dilate(mask_np, kernel, iterations=1)

      # --- Resize to 224x224 ---
      mask_final = cv2.resize(mask_dilated, (224, 224), interpolation=cv2.INTER_NEAREST)

      # --- Save the Segmentation Mask ---
      # We change extension to .png to avoid JPEG compression artifacts on masks
      filename_no_ext = os.path.splitext(image_name)[0]
      save_path = os.path.join(output_dir_mask, f"{filename_no_ext}.png")

      cv2.imwrite(save_path, mask_final)

      # ------------------------------------
      # --- Tissue Classification ---
      # ------------------------------------

      binary_mask = (torch.sigmoid(pred_masks) > 0.5).float().sum(dim=0).clamp(0, 1).squeeze().cpu().numpy().astype(np.uint8)

      # Tissue Classification (K-Means)
      # Convert tensor image back to original RGB for K-means
      # Undo normalization: (img * std + mean) * 255
      raw_img = images[0].permute(1, 2, 0).cpu().numpy()
      raw_img = (raw_img * IMAGENET_STD + IMAGENET_MEAN) * 255
      raw_img = np.clip(raw_img, 0, 255).astype(np.uint8)
      raw_img_bgr = cv2.cvtColor(raw_img, cv2.COLOR_RGB2BGR)


      tissue_mask = clusterer.cluster(raw_img_bgr, binary_mask)

      # Save Tissue Classification Results
      if tissue_mask is not None:
        save_path_overlay = os.path.join(output_dir_tissue, f"{filename_no_ext}_overlay.png")
        cv2.imwrite(save_path_overlay, tissue_mask)

print("Inference and Tissue Classification complete.")