In [4]:
#############################################
#          Preprocessing Definition         #
#############################################

import os
import sys
import csv
import cv2
import time
import torch
import random
import shutil
import argparse
import numpy as np
from tqdm import tqdm
from pathlib import Path
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter
from sklearn.model_selection import train_test_split


# --------------------------
# Resizing and Padding Classes
# --------------------------

class ResizeWithPadding:
    """
    For color images. Uses bilinear interpolation.
    """
    def __init__(self, target_size=512, padding_mode='mean', force_resize=True, resize_dims=(128, 128)):
        self.target_size = target_size
        self.force_resize = force_resize
        self.resize_dims = resize_dims
        assert padding_mode in ['mean', 'reflect', 'hybrid'], \
            "Padding mode must be 'mean', 'reflect', or 'hybrid'"
        self.padding_mode = padding_mode

    def resize_image(self, image):
        h, w = image.shape[:2]
        scale = self.resize_dims[0] / max(h, w)
        new_w, new_h = int(w * scale), int(h * scale)
        resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
        return resized_image

    def pad_image(self, image):
        h, w = image.shape[:2]
        # Determine target dimensions
        target_h, target_w = self.resize_dims if self.force_resize else (self.target_size, self.target_size)
        delta_w = max(0, target_w - w)
        delta_h = max(0, target_h - h)
        top = delta_h // 2
        bottom = delta_h - top
        left = delta_w // 2
        right = delta_w - left

        if self.padding_mode == 'mean':
            mean_pixel = np.mean(image, axis=(0, 1), dtype=int)
            padded_image = cv2.copyMakeBorder(image, top, bottom, left, right,
                                              cv2.BORDER_CONSTANT, value=mean_pixel.tolist())
        elif self.padding_mode == 'reflect':
            padded_image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_REFLECT)
        else:  # hybrid
            reflected = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_REFLECT)
            mean_pixel = np.mean(image, axis=(0, 1), dtype=int)
            padded_image = cv2.copyMakeBorder(reflected, top, bottom, left, right,
                                              cv2.BORDER_CONSTANT, value=mean_pixel.tolist())
        return padded_image

    def __call__(self, img):
        # Convert PIL image to numpy array, process, and convert back
        arr = np.array(img)
        resized = self.resize_image(arr)
        padded = self.pad_image(resized)
        return Image.fromarray(padded)

class ResizeWithPaddingLabel:
    """
    For labels/masks. Uses nearest-neighbor interpolation.
    Assumes labels are single-channel images.
    """
    def __init__(self, target_size=512, force_resize=False, resize_dims=(256, 256)):
        self.target_size = target_size
        self.force_resize = force_resize
        self.resize_dims = resize_dims

    def resize_image(self, image):
        h, w = image.shape[:2]
        scale = self.resize_dims[0] / max(h, w)
        new_w, new_h = int(w * scale), int(h * scale)
        resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
        return resized_image

    def pad_image(self, image):
        h, w = image.shape[:2]
        target_h, target_w = self.resize_dims if self.force_resize else (self.target_size, self.target_size)
        delta_w = max(0, target_w - w)
        delta_h = max(0, target_h - h)
        top = delta_h // 2
        bottom = delta_h - top
        left = delta_w // 2
        right = delta_w - left
        # For labels, use a constant value (e.g., 0 for background)
        padded_image = cv2.copyMakeBorder(image, top, bottom, left, right,
                                          cv2.BORDER_CONSTANT, value=0)
        return padded_image

    def __call__(self, label_img):
        arr = np.array(label_img)
        resized = self.resize_image(arr)
        padded = self.pad_image(resized)
        return Image.fromarray(padded)

# --------------------------
# Augmentation Function for Synchronized Transforms
# --------------------------

def elastic_transform_pair(image, label, alpha=34, sigma=4):
    """Apply elastic transformation to both image and label."""
    np_img = np.array(image)
    np_label = np.array(label)
    random_state = np.random.RandomState(None)
    shape = np_img.shape[:2]
    dx = (random_state.rand(*shape) * 2 - 1)
    dy = (random_state.rand(*shape) * 2 - 1)
    dx = cv2.GaussianBlur(dx, (17, 17), sigma) * alpha
    dy = cv2.GaussianBlur(dy, (17, 17), sigma) * alpha
    x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
    map_x = (x + dx).astype(np.float32)
    map_y = (y + dy).astype(np.float32)
    transformed_img = cv2.remap(np_img, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
    transformed_label = cv2.remap(np_label, map_x, map_y, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REFLECT)
    return Image.fromarray(transformed_img), Image.fromarray(transformed_label)

def augment_pair(image, label, size=(256, 256), apply_color=True, **params):
    """
    Applies a random subset of augmentations in random order.
    Each augmentation is stored as a function with its own probability.

    To disable an augmentation, set its probability to 0.
    To guarantee an augmentation, set its probability to 1.
    Modify parameters (e.g., rotation_angle, translate_factor) to adjust intensity.
    """
    # Build list of augmentation functions with their associated probability.
    augmentations = []

    # Horizontal flip
    augmentations.append((
        lambda img, lbl: (F.hflip(img), F.hflip(lbl)),
        params.get("flip_prob", 0.5)  # default: 50% chance
    ))

    # Rotation (you can control intensity via "rotation_angle")
    def rotate_aug(img, lbl):
        angle = random.uniform(-params.get("rotation_angle", 5), params.get("rotation_angle", 45))
        return F.rotate(img, angle, interpolation=Image.BILINEAR), F.rotate(lbl, angle, interpolation=Image.NEAREST)
    augmentations.append((
        rotate_aug,
        params.get("rotation_prob", 0.25)  # default: always apply rotation
    ))

    # Translation (intensity via "translate_factor")
    def translate_aug(img, lbl):
        max_translate = params.get("translate_factor", 0.05) * size[0]
        tx = int(random.uniform(-max_translate, max_translate))
        ty = int(random.uniform(-max_translate, max_translate))
        return (F.affine(img, angle=0, translate=(tx, ty), scale=1.0, shear=0, interpolation=Image.BILINEAR),
                F.affine(lbl, angle=0, translate=(tx, ty), scale=1.0, shear=0, interpolation=Image.NEAREST))
    augmentations.append((
        translate_aug,
        params.get("translate_prob", 0.05)  # default: always apply translation
    ))

    # Random Resized Crop (controls framing; intensity via crop_scale_range and crop_ratio_range)
    def crop_aug(img, lbl):
        i, j, h, w = transforms.RandomResizedCrop.get_params(
            img,
            scale=params.get("crop_scale_range", (0.9, 1.0)),
            ratio=params.get("crop_ratio_range", (1.0, 1.0))
        )
        return (F.resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR),
                F.resized_crop(lbl, i, j, h, w, size, interpolation=Image.NEAREST))
    augmentations.append((
        crop_aug,
        params.get("crop_prob", 0.01)  # default: always apply crop
    ))

    # Elastic Transformation (intensity via "elastic_alpha" and "elastic_sigma")
    def elastic_aug(img, lbl):
        return elastic_transform_pair(
            img, lbl,
            alpha=params.get("elastic_alpha", 15),
            sigma=params.get("elastic_sigma", 2)
        )
    augmentations.append((
        elastic_aug,
        params.get("elastic_prob", 0.01)  # default: always apply elastic transformation
    ))

    # Random Scaling (intensity via "scaling_range")
    def scaling_aug(img, lbl):
        scale_factor = random.uniform(*params.get("scaling_range", (0.0, 0.0)))
        new_size = (int(size[0] * scale_factor), int(size[1] * scale_factor))
        return (F.resize(img, new_size, interpolation=Image.BILINEAR),
                F.resize(lbl, new_size, interpolation=Image.NEAREST))
    augmentations.append((
        scaling_aug,
        params.get("scaling_prob", 0.0)  # default: always apply scaling
    ))

    # Color augmentations: Gaussian blur and Color jitter.
    def color_aug(img, lbl):
        if random.random() < params.get("blur_prob", 0.0):
            img = img.filter(
                ImageFilter.GaussianBlur(
                    radius=random.uniform(*params.get("blur_radius_range", (0.5, 1.5)))
                )
            )
        color_jitter = transforms.ColorJitter(**params.get("color_jitter_params", {
            'brightness': 0.2, 'contrast': 0.2, 'saturation': 0.2, 'hue': 0.1
        }))
        img = color_jitter(img)
        return img, lbl
    augmentations.append((
        color_aug,
        params.get("color_prob", 0.25)  # default: always apply color adjustments if apply_color is True
    ))

    # Optionally randomize the order of augmentations.
    random.shuffle(augmentations)

    # Apply each augmentation based on its probability.
    for func, prob in augmentations:
        if random.random() < prob:
            image, label = func(image, label)

    # Optionally, enforce a final center crop to ensure the image is the desired size.
    image = F.center_crop(image, size)
    label = F.center_crop(label, size)

    return image, label

# --------------------------
# OOP Preprocessor Class
# --------------------------

class Preprocessor:
    def __init__(self, raw_color_path, raw_label_path,
                 proc_color_path, proc_label_path,
                 resize_dim=128, do_augmentation=True,
                 is_train=True, max_images=None, aug_count=10, aug_params=None):

        """
        Parameters:
          - raw_color_path: Relative path to raw color images.
          - raw_label_path: Relative path to raw label/mask images.
          - proc_color_path: Relative path to save processed color images.
          - proc_label_path: Relative path to save processed label images.
          - resize_dim: Target dimension for resizing/padding.
          - do_augmentation: Whether to augment (only for training).
          - is_train: True if processing training data.
          - max_images: Process only a subset of images (for testing the pipeline).
          - aug_count: Number of augmentations to create per image (train only).
        """
        self.aug_params = aug_params if aug_params is not None else {}

        self.raw_color_path = Path(raw_color_path)
        self.raw_label_path = Path(raw_label_path)
        self.proc_color_path = Path(proc_color_path)
        self.proc_label_path = Path(proc_label_path)
        self.resize_dim = resize_dim
        self.do_augmentation = do_augmentation and is_train
        self.is_train = is_train
        self.max_images = max_images
        self.aug_count = aug_count

        self.proc_color_path.mkdir(parents=True, exist_ok=True)
        self.proc_label_path.mkdir(parents=True, exist_ok=True)

        # Create transforms for images and labels
        self.transform_img = ResizeWithPadding(target_size=224, padding_mode='mean',
                                               force_resize=True, resize_dims=(resize_dim, resize_dim))
        self.transform_label = ResizeWithPaddingLabel(target_size=224, force_resize=True,
                                                      resize_dims=(resize_dim, resize_dim))

    def process(self):
      # Find all color images (assume image file names match label file names)
      image_extensions = (".jpg", ".jpeg", ".png")
      image_files = [f for f in self.raw_color_path.rglob("*") if f.suffix.lower() in image_extensions]
      if self.max_images is not None:
          image_files = image_files[:self.max_images]

      if not image_files:
          print("❌ No images found in", self.raw_color_path)
          return

      for img_file in image_files:
          label_file = self.raw_label_path / f"{img_file.stem}.png"
          if not label_file.exists():
              print(f"⚠️  Label for {img_file.name} not found, skipping.")
              continue

          # Determine new filename with 'cat_' or 'dog_' prefix
          original_filename = img_file.name
          new_filename = f"cat_{original_filename}" if any(char.isupper() for char in original_filename) else f"dog_{original_filename}"

          # Update label filename accordingly
          new_label_filename = new_filename.replace(img_file.suffix, ".png")

          # Open image and label (assume label is a segmentation mask in grayscale)
          img = Image.open(img_file).convert("RGB")
          label = Image.open(label_file).convert("L")

          # Apply resizing and padding to both image and label
          proc_img = self.transform_img(img)
          proc_label = self.transform_label(label)

          # Save processed (base) image and label with the new filename
          proc_img.save(self.proc_color_path / new_filename)
          proc_label.save(self.proc_label_path / new_label_filename)
          print(f"Processed and Renamed {original_filename} → {new_filename}")

          # If training and augmentation is enabled, create additional augmented pairs
          if self.is_train and self.do_augmentation:
              for i in range(self.aug_count):
                  aug_img, aug_label = augment_pair(proc_img, proc_label, size=(self.resize_dim, self.resize_dim), **self.aug_params)

                  # Augmented filenames
                  aug_img_filename = new_filename.replace(img_file.suffix, f"_aug_{i}{img_file.suffix}")
                  aug_label_filename = new_label_filename.replace(".png", f"_aug_{i}.png")

                  aug_img.save(self.proc_color_path / aug_img_filename)
                  aug_label.save(self.proc_label_path / aug_label_filename)
                  print(f"Augmented {new_filename} → {aug_img_filename}")

# --------------------------
# Command Line Interface
# --------------------------

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Preprocess train/test images with labels/masks using relative paths."
    )
    parser.add_argument("--raw_color", type=str, required=True, default='raw/TrainVal/color',
                        help="Relative path to raw color images (e.g., Dataset/raw/TrainVal/color).")
    parser.add_argument("--raw_label", type=str, required=True, default='raw/TrainVal/label',
                        help="Relative path to raw label images (e.g., Dataset/raw/TrainVal/label).")
    parser.add_argument("--proc_color", type=str, required=True, default='Dataset/processed/TrainVal/color',
                        help="Relative path to save processed color images (e.g., Dataset/processed/TrainVal/color).")
    parser.add_argument("--proc_label", type=str, required=True, default='Dataset/processed/TrainVal/label',
                        help="Relative path to save processed label images (e.g., Dataset/processed/TrainVal/label).")
    parser.add_argument("--resize_dim", type=int, default=128,
                        help="Output dimension for resizing (e.g., 128 or 256).")
    parser.add_argument("--no_augment", action="store_true",
                        help="Disable augmentation (for test sets).")
    parser.add_argument("--max_images", type=int, default=None,
                        help="Process only a subset of images (default: all).")
    parser.add_argument("--aug_count", type=int, default=10,
                        help="Number of augmentations per image (train only).")
    parser.add_argument("--set_type", type=str, choices=["TrainVal", "Test"], default="TrainVal",
                        help="Dataset type: TrainVal or Test.")
    parser.add_argument("--flip_prob", type=float, default=0.5, help="Probability for horizontal flip.")
    parser.add_argument("--rotation_angle", type=float, default=5, help="Maximum rotation angle (in degrees).")
    parser.add_argument("--translate_factor", type=float, default=0.05, help="Translation factor (fraction of image size).")
    parser.add_argument("--crop_scale_min", type=float, default=0.9, help="Minimum scale for random resized crop.")
    parser.add_argument("--crop_scale_max", type=float, default=1.0, help="Maximum scale for random resized crop.")
    parser.add_argument("--crop_ratio_min", type=float, default=1.0, help="Minimum aspect ratio for random resized crop.")
    parser.add_argument("--crop_ratio_max", type=float, default=1.0, help="Maximum aspect ratio for random resized crop.")
    parser.add_argument("--elastic_alpha", type=float, default=15, help="Elastic transformation alpha parameter.")
    parser.add_argument("--elastic_sigma", type=float, default=2, help="Elastic transformation sigma parameter.")
    parser.add_argument("--scaling_min", type=float, default=0.0, help="Minimum scaling factor for random scaling.")
    parser.add_argument("--scaling_max", type=float, default=0.0, help="Maximum scaling factor for random scaling.")
    parser.add_argument("--blur_prob", type=float, default=0.3, help="Probability of applying Gaussian blur.")
    parser.add_argument("--blur_radius_min", type=float, default=0.5, help="Minimum radius for Gaussian blur.")
    parser.add_argument("--blur_radius_max", type=float, default=1.5, help="Maximum radius for Gaussian blur.")
    parser.add_argument("--color_jitter_brightness", type=float, default=0.2, help="Brightness for color jitter.")
    parser.add_argument("--color_jitter_contrast", type=float, default=0.2, help="Contrast for color jitter.")
    parser.add_argument("--color_jitter_saturation", type=float, default=0.2, help="Saturation for color jitter.")
    parser.add_argument("--color_jitter_hue", type=float, default=0.1, help="Hue for color jitter.")

    ## Test processing (resizing)
    '''sys.argv = ['preprocessing.py', '--raw_color', '/raw/Test/color',
                '--raw_label', '/raw/Test/label',
                '--proc_color', '/Dataset/processed/Test/color',
                '--proc_label', '/Dataset/processed/Test/label',
                '--resize_dim', '128',
                '--no_augment',
                '--set_type', 'Test']'''

    ## TrainVal augmentation
    sys.argv = ['preprocessing.py',
                '--raw_color', '/raw/TrainVal/color',
                '--raw_label', '/raw/TrainVal/label',
                '--proc_color', '/Dataset/processed/TrainVal/color',
                '--proc_label', '/Dataset/processed/TrainVal/label',
                '--resize_dim', '128',
                '--aug_count', '10',
                '--set_type', 'TrainVal',
                '--translate_factor', '0.02',
                '--elastic_alpha', '0.0',
                '--elastic_sigma', '0.05',
                '--blur_radius_min', '0.0',
                '--blur_radius_max', '0.05']

    args = parser.parse_args()

    # Define relative paths based on set type
    if args.set_type == "TrainVal":
        raw_color = Path("/notebooks/raw_Dataset/TrainVal/color")
        raw_label = Path("/notebooks/raw_Dataset/TrainVal/label")
        proc_color = Path("/notebooks/Dataset/processed2/TrainVal/color")
        proc_label = Path("/notebooks/Dataset/processed2/TrainVal/label")
    else:  # Test set
        raw_color = Path("/notebooks/raw_Dataset/Test/color")
        raw_label = Path("/notebooks/raw_Dataset/Test/label")
        proc_color = Path("/notebooks/Dataset/processed2/Test/color")
        proc_label = Path("/notebooks/Dataset/processed2/Test/label")

    aug_params = {
        "flip_prob": args.flip_prob,  # Already exists.
        "rotation_angle": args.rotation_angle,
        "rotation_prob": 0.25,         # Always apply rotation, for instance.
        "translate_factor": args.translate_factor,
        "translate_prob": 0.05,        # Always apply translation.
        "crop_scale_range": (args.crop_scale_min, args.crop_scale_max),
        "crop_ratio_range": (args.crop_ratio_min, args.crop_ratio_max),
        "crop_prob": 0.01,             # Always apply crop.
        "elastic_alpha": args.elastic_alpha,
        "elastic_sigma": args.elastic_sigma,
        "elastic_prob": 0.01,          # Always apply elastic transform.
        "scaling_range": (0.8, 1.2),
        "scaling_prob": 0.05,          # Always apply scaling.
        "blur_prob": args.blur_prob,
        "blur_radius_range": (args.blur_radius_min, args.blur_radius_max),
        "color_jitter_params": {
            "brightness": args.color_jitter_brightness,
            "contrast": args.color_jitter_contrast,
            "saturation": args.color_jitter_saturation,
            "hue": args.color_jitter_hue,
        },
        "color_prob": 1.0            # Always apply color adjustments if enabled.
    }

    preprocessor = Preprocessor(raw_color, raw_label,
                                proc_color, proc_label,
                                resize_dim=args.resize_dim,
                                do_augmentation=not args.no_augment,
                                is_train=(args.set_type == "TrainVal"),
                                max_images=args.max_images,
                                aug_count=args.aug_count,
                                aug_params=aug_params)

    preprocessor.process()

Processed and Renamed newfoundland_183.jpg → dog_newfoundland_183.jpg
Augmented dog_newfoundland_183.jpg → dog_newfoundland_183_aug_0.jpg
Augmented dog_newfoundland_183.jpg → dog_newfoundland_183_aug_1.jpg
Augmented dog_newfoundland_183.jpg → dog_newfoundland_183_aug_2.jpg
Augmented dog_newfoundland_183.jpg → dog_newfoundland_183_aug_3.jpg
Augmented dog_newfoundland_183.jpg → dog_newfoundland_183_aug_4.jpg
Augmented dog_newfoundland_183.jpg → dog_newfoundland_183_aug_5.jpg
Augmented dog_newfoundland_183.jpg → dog_newfoundland_183_aug_6.jpg
Augmented dog_newfoundland_183.jpg → dog_newfoundland_183_aug_7.jpg
Augmented dog_newfoundland_183.jpg → dog_newfoundland_183_aug_8.jpg
Augmented dog_newfoundland_183.jpg → dog_newfoundland_183_aug_9.jpg
Processed and Renamed pomeranian_102.jpg → dog_pomeranian_102.jpg
Augmented dog_pomeranian_102.jpg → dog_pomeranian_102_aug_0.jpg
Augmented dog_pomeranian_102.jpg → dog_pomeranian_102_aug_1.jpg
Augmented dog_pomeranian_102.jpg → dog_pomeranian_102_au

In [1]:
# Core dependencies
import os
import re
import cv2
import random
import numpy as np

import time, sys                                
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt

# PyTorch related
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, Subset
from torch.utils.checkpoint import checkpoint as ckpt
import torch.optim.lr_scheduler as lr_scheduler


# CLIP related
# !pip install ftfy regex tqdm clip
# pip uninstall clip
%pip install git+https://github.com/openai/CLIP.git
import clip
from clip.simple_tokenizer import SimpleTokenizer

# Sklearn for dataset splitting
from sklearn.model_selection import train_test_split

torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)
random.seed(42)

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /private/var/folders/w7/6rbmdvg163x5kscfm1zh1t8h0000gn/T/pip-req-build-r9veja5h
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /private/var/folders/w7/6rbmdvg163x5kscfm1zh1t8h0000gn/T/pip-req-build-r9veja5h
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25ldone
Note: you may need to restart the kernel to use updated packages.


In [2]:
# If you have a GPU, set device accordingly:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Simple CLIP tokenizer for optional debugging:
clip_tokenizer = SimpleTokenizer()

def denormalize(image_tensor):
    """
    Undo CLIP's normalization for visualization.
    CLIP uses mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711].
    """
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(3, 1, 1).to(image_tensor.device)
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(3, 1, 1).to(image_tensor.device)
    return torch.clamp(image_tensor * std + mean, 0, 1)

def compute_iou_per_class(preds, targets, num_classes=3):
    """
    Compute Intersection over Union (IoU) for each class individually.
    Returns a dict with IoU for each class index and the mean IoU overall.
    """
    iou_per_class = {}
    for cls in range(num_classes):
        pred_inds = (preds == cls)
        target_inds = (targets == cls)
        intersection = (pred_inds & target_inds).sum().item()
        union = (pred_inds | target_inds).sum().item()
        if union == 0:
            iou_per_class[cls] = float('nan')
        else:
            iou_per_class[cls] = intersection / union
    mean_iou = np.nanmean(list(iou_per_class.values()))
    return iou_per_class, mean_iou

Using device: cpu


In [3]:
class PromptBasedSegmentationDataset(Dataset):
    """
    Each item will be (image, mask, text_token_ids, prompt_heatmap).
    The prompt_heatmap is a single-channel 2D map where the location
    of the user’s point is highlighted.
    """
    def __init__(
        self,
        root_dir,
        clip_preprocess,     # The CLIP image transform (resize, normalize, etc.)
        mask_transform=None, # e.g. resize to 224, encode classes
        use_text=True,
        gaussian_radius=5
    ):
        """
        Args:
            root_dir: directory with subfolders "color" and "label"
            clip_preprocess: CLIP transform for the image
            mask_transform: transform for the mask (e.g. resize, int->long)
            use_text: if True, we also generate a text prompt from the filename
            gaussian_radius: size of the Gaussian kernel around the prompt point
        """
        self.root_dir = Path(root_dir)
        self.image_dir = self.root_dir / 'color'
        self.mask_dir  = self.root_dir / 'label'
        
        self.image_paths = sorted([
            f for f in os.listdir(self.image_dir)
            if f.lower().endswith(('.jpg', '.png', '.jpeg'))
        ])
        
        self.clip_preprocess = clip_preprocess
        self.mask_transform  = mask_transform
        self.use_text = use_text
        self.gaussian_radius = gaussian_radius
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_name = self.image_paths[idx]
        
        # Load the image
        img_path = self.image_dir / img_name
        image_pil = Image.open(img_path).convert("RGB")
        # Transform the image for CLIP (224x224 with CLIP normalization)
        image = self.clip_preprocess(image_pil)
        
        # Load the mask
        mask_path = self.mask_dir / img_name.replace('.jpg', '.png')
        mask_pil = Image.open(mask_path).convert("L")
        mask = torch.from_numpy(np.array(mask_pil, dtype=np.int64))
        
        # Apply a transform if specified (e.g. resizing to 224, class-encoding)
        if self.mask_transform is not None:
            # Some transforms expect shape [C,H,W], so add a channel if necessary
            if mask.ndim == 2:
                mask = mask.unsqueeze(0)
            mask = self.mask_transform(mask)
            # Now shape is back to [H, W]
            if mask.ndim == 3:
                mask = mask.squeeze(0)
        
        # Optionally generate a text prompt (the same approach as your prior code):
        text_prompt = "a photo of an animal"  # fallback
        # We can parse filename if it has the form "cat_breed_1.jpg", etc.
        base = os.path.splitext(img_name)[0]
        match = re.match(r"^(cat|dog)_([A-Za-z_]+)_(\d+)", base)
        if match and self.use_text:
            animal_type = match.group(1)
            breed = match.group(2).replace("_", " ").title()
            text_prompt = f"a photo of a {breed} {animal_type}"
        # Convert the text prompt into CLIP token IDs:
        token_ids = clip.tokenize([text_prompt])[0] if self.use_text else torch.zeros(77, dtype=torch.long)
        
        # -------------- Prompt Point Generation -------------
        # We'll pick a random pixel from the *foreground* region of the mask 
        # (where mask != 0) so the prompt is somewhere on the object.
        # If the mask is empty, fallback to anywhere in the image.
        # mask shape: [H, W] after transforms. 
        mask_np = mask.cpu().numpy()
        h, w = mask_np.shape
        foreground_indices = np.where(mask_np != 0)
        # We'll randomly choose between point, box, or scribble prompts
        prompt_type = random.choice(["point", "box", "scribble"])
        prompt_map = np.zeros((h, w), dtype=np.float32)

        if prompt_type == "point":
            if len(foreground_indices[0]) == 0:
                py = random.randint(0, h-1)
                px = random.randint(0, w-1)
            else:
                idx_rand = random.randint(0, len(foreground_indices[0]) - 1)
                py = foreground_indices[0][idx_rand]
                px = foreground_indices[1][idx_rand]
            for gy in range(-self.gaussian_radius, self.gaussian_radius+1):
                for gx in range(-self.gaussian_radius, self.gaussian_radius+1):
                    ny = py + gy
                    nx = px + gx
                    if 0 <= ny < h and 0 <= nx < w:
                        dist_sq = gx*gx + gy*gy
                        sigma = self.gaussian_radius / 2.0
                        val = np.exp(-dist_sq/(2*sigma*sigma))
                        prompt_map[ny, nx] = max(prompt_map[ny, nx], val)

        elif prompt_type == "box":
            ys, xs = np.where(mask_np != 0)
            if len(ys) > 0:
                y1, y2 = ys.min(), ys.max()
                x1, x2 = xs.min(), xs.max()
                prompt_map[y1:y2+1, x1:x2+1] = 1.0
            else:
                prompt_map[...] = 1.0

        elif prompt_type == "scribble":
            ys, xs = np.where(mask_np != 0)
            if len(ys) > 0:
                num_points = 3
                scribble_pts = []
                for _ in range(num_points):
                    i = random.randint(0, len(ys)-1)
                    scribble_pts.append((ys[i], xs[i]))
                for i in range(len(scribble_pts)-1):
                    (yA, xA) = scribble_pts[i]
                    (yB, xB) = scribble_pts[i+1]
                    steps = 20
                    for s in range(steps+1):
                        t = s / steps
                        cy = int(yA + t*(yB-yA))
                        cx = int(xA + t*(xB-xA))
                        for rry in range(-2, 3):
                            for rrx in range(-2, 3):
                                ny = cy + rry
                                nx = cx + rrx
                                if 0 <= ny < h and 0 <= nx < w:
                                    prompt_map[ny, nx] = 1.0
        # Convert to tensor:
        prompt_map = torch.from_numpy(prompt_map).unsqueeze(0)
        
        return image, mask, token_ids, prompt_map

In [4]:
class ChannelAttention(nn.Module):
    """
    Same channel-attention block from your existing code.
    """
    def __init__(self, in_channels, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction, in_channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class ImprovedCLIPSegmentationHead(nn.Module):
    """
    Slightly modified to handle 3 inputs: 
    - Image features (768 channels) 
    - Text features (768 channels) 
    - Prompt features (768 channels)
    => total 768*3 = 2304 input channels to fuse_conv1.
    """
    def __init__(self, use_attention=True, num_classes=3):
        super(ImprovedCLIPSegmentationHead, self).__init__()
        self.use_attention = use_attention
        
        in_channels = 768 * 3  # image(768) + text(768) + prompt(768)
        mid_channels = 256
        
        self.fuse_conv1 = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.25)
        )
        self.residual_block = nn.Sequential(
            nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels)
        )
        self.relu = nn.ReLU(inplace=True)
        
        if use_attention:
            self.attention = ChannelAttention(mid_channels)
        else:
            self.attention = nn.Identity()
        
        self.fuse_conv2 = nn.Conv2d(mid_channels, num_classes, kernel_size=1)
        
    def forward(self, img_features, txt_features, prm_features):
        """
        Inputs: 
          - img_features: [B, 768, H, W]
          - txt_features: [B, 768, H, W]
          - prm_features: [B, 768, H, W]
        Output:
          - seg_logits:   [B, num_classes, H, W]
        """
        # Concatenate
        fused = torch.cat([img_features, txt_features, prm_features], dim=1)  # [B, 2304, H, W]
        
        x = ckpt(self.fuse_conv1, fused, use_reentrant=False)
        res = ckpt(self.residual_block, x, use_reentrant=False)
        x = self.relu(x + res)
        x = self.attention(x)
        seg_logits = self.fuse_conv2(x)
        return seg_logits

class PromptEncoder(nn.Module):
    """
    A small encoder to transform a 1-channel prompt map into 768-dim feature.
    We keep it simple: a few conv layers, up to 768 channels.
    """
    def __init__(self, out_channels=768):
        super(PromptEncoder, self).__init__()
        # We'll do a small stack of conv->relu->conv->relu->... => 768
        # Each conv can double the channels until we reach 768.
        # Because 768 is large, we do it in steps: 1->32->64->128->256->768
        # You can tweak or replace with a more advanced UNet-like approach.
        
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, out_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )
        
    def forward(self, prompt_map):
        # prompt_map shape: [B, 1, H, W]
        return self.net(prompt_map)  # [B, 768, H, W]

class CLIPPointSegModel(nn.Module):
    """
    Combines:
      - CLIP image backbone
      - CLIP text encoder
      - Additional prompt encoder
      - Segmentation head to fuse all 3.
    """
    def __init__(self, clip_model, num_classes=3):
        super(CLIPPointSegModel, self).__init__()
        
        # The clip_model.visual output dimension is typically 768 for ViT-L/14
        self.clip_model = clip_model
        self.num_classes = num_classes
        
        # Our additional prompt-map encoder
        self.prompt_encoder = PromptEncoder(out_channels=768)
        
        # The final segmentation head
        self.seg_head = ImprovedCLIPSegmentationHead(
            use_attention=True,
            num_classes=self.num_classes
        )
    
    def get_visual_features(self, image):
        """
        Slightly lower-level extraction from CLIP ViT model. 
        We do the same steps your code does: conv1 => flatten => cat CLS => add pos emb => transformer => ...
        """
        visual = self.clip_model.visual
        # Convert to same type as CLIP weights (might be fp16).
        dtype = visual.conv1.weight.dtype
        image = image.to(dtype)
        
        x = visual.conv1(image)  # [B, width, H', W']
        x = x.reshape(x.shape[0], x.shape[1], -1)  # [B, width, tokens]
        x = x.permute(0, 2, 1)   # [B, tokens, width]
        
        cls_tokens = visual.class_embedding.to(dtype) + torch.zeros(
            x.shape[0], 1, x.shape[-1], dtype=dtype, device=x.device
        )
        x = torch.cat([cls_tokens, x], dim=1)  # [B, tokens+1, width]
        
        x = x + visual.positional_embedding.to(dtype)
        x = visual.ln_pre(x)
        x = x.permute(1, 0, 2)  # [tokens+1, B, width]
        x = visual.transformer(x)
        x = x.permute(1, 0, 2)  # [B, tokens+1, transformer_width]
        if hasattr(visual, "proj"):
            x = x @ visual.proj
        
        return x  # shape [B, tokens+1, 768]
    
    def get_text_features(self, token_ids):
        """
        Use CLIP text encoder to get text embedding => shape [B, 768].
        """
        # By default, encode_text returns [B, 768]
        txt_feat = self.clip_model.encode_text(token_ids).float()
        return txt_feat
    
    def forward(self, image, token_ids, prompt_map):
        """
        1) Extract visual features => shape [B, tokens+1, 768]
        2) Discard CLS => reshape => [B, 768, H, W]
        3) Encode text => shape [B,768], expand to [B,768,H,W]
        4) Encode prompt => shape [B,768,H,W]
        5) Seg Head => [B, num_classes, H, W]
        6) (Optionally) upsample to input resolution
        """
        # image shape: [B, 3, 224, 224]
        # prompt_map shape: [B, 1, 224, 224]
        
        B, _, h, w = image.shape
        
        # (1) Visual features from CLIP
        visual_tokens = self.get_visual_features(image)  # [B, tokens+1, 768]
        # (2) Discard CLS
        tokens = visual_tokens[:, 1:, :]                # [B, tokens, 768]
        N = tokens.shape[1]
        grid_size = int(np.sqrt(N))  # for ViT-L/14 => 196 = 14x14
        
        img_features = tokens.reshape(B, grid_size, grid_size, 768).permute(0, 3, 1, 2)
        # => shape [B, 768, 14, 14] (assuming 224x224 input => 14x14 tokens for the patchify)
        
        # (3) Encode text
        txt_feat = self.get_text_features(token_ids)  # [B, 768]
        # expand to [B,768,14,14]
        txt_feat = txt_feat.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, grid_size, grid_size)
        
        # (4) Encode prompt
        prm_feat = ckpt(lambda x: self.prompt_encoder(x), prompt_map, use_reentrant=False)  # [B, 768, 224, 224] if we didn't downsample
        # But note:  The image features are 14x14. We can either:
        # (A) Downsample the prompt_map to 14x14 before prompt_encoder, or 
        # (B) Or we can just do a stride in the prompt_encoder, or
        # (C) Interpolate the result to 14x14.
        #
        # For simplicity, let's downsample after we get 768-ch. We'll do a bilinear down to 14x14:
        prm_feat_14 = F.interpolate(prm_feat, size=(grid_size, grid_size), mode='bilinear', align_corners=False)
        
        # (5) Seg Head
        seg_logits_14 = self.seg_head(img_features, txt_feat, prm_feat_14)
        # => shape [B, num_classes, 14, 14]
        
        # (6) Upsample back to 224x224 so it aligns with the original image
        seg_logits = F.interpolate(seg_logits_14, size=(h, w), mode='bilinear', align_corners=False)
        
        return seg_logits

In [5]:
def encode_mask_values(mask):
    """
    Example: you had separate logic that converts mask intensities to [0,1,2].
    E.g., 0-> background, 1-> one class, 2-> another class, etc.
    Adjust as needed for your dataset.
    """
    # We do the same threshold-based approach from your snippet:
    # <36 => 0, [36..192) => 1, >=192 => 2
    mask = torch.where(mask < 36, torch.tensor(0, dtype=mask.dtype), mask)
    mask = torch.where((mask >= 36) & (mask < 192), torch.tensor(1, dtype=mask.dtype), mask)
    mask = torch.where(mask >= 192, torch.tensor(2, dtype=mask.dtype), mask)
    return mask

class MaskEncodeTransform(nn.Module):
    """
    A small wrapper that you can pass as 'mask_transform' to your dataset
    so it resizes to 224x224 then encodes classes.
    """
    def __init__(self):
        super().__init__()
        self.resize = T.Resize((224,224), interpolation=T.InterpolationMode.NEAREST)
    def forward(self, mask):
        mask = self.resize(mask)  # shape [1,224,224]
        mask = encode_mask_values(mask)
        return mask

def train_point_based_model(
    dataset_dir,
    save_path="clip_pointseg_best.pth",
    epochs=2000,
    batch_size=4,
    lr=2e-4,
    resume=False,
    resume_path=None
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    # 1) Load the CLIP model
    clip_model, clip_preprocess = clip.load("ViT-L/14", device=device)
    clip_model.eval()
    # Freeze CLIP parameters
    for param in clip_model.parameters():
        param.requires_grad = False
    
    # 2) Create our prompt-based dataset
    #    We'll do train/test split from the same folder for demonstration.
    #    In practice, you might have separate TrainVal/ Test folders.
    full_dataset = PromptBasedSegmentationDataset(
        root_dir=dataset_dir,
        clip_preprocess=clip_preprocess,
        mask_transform=MaskEncodeTransform(),
        use_text=True,    # generate text from filename
        gaussian_radius=5  # how big the prompt Gaussian
    )
    indices = list(range(len(full_dataset)))
    train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
    train_dataset = Subset(full_dataset, train_idx)
    val_dataset   = Subset(full_dataset, val_idx)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4, pin_memory=True)
    val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    
    # 3) Build our model
    model = CLIPPointSegModel(clip_model, num_classes=3).to(device)
    
    # 4) Define an optimizer that only trains the newly introduced layers:
    trainable_params = list(model.prompt_encoder.parameters()) + list(model.seg_head.parameters())
    optimizer = torch.optim.Adam(trainable_params, lr=lr)
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)
    scaler = torch.cuda.amp.GradScaler() 
    criterion = nn.CrossEntropyLoss()
    
    start_epoch = 1
    best_val_loss = float('inf')
    if resume and resume_path is not None and os.path.isfile(resume_path):
        print(f"Resuming from checkpoint: {resume_path}")
        checkpoint_data = torch.load(resume_path, map_location=device)
        if 'model_state_dict' in checkpoint_data:
           # New format: Full checkpoint with optimizer, epoch, best_val_loss
            model.load_state_dict(checkpoint_data['model_state_dict'])
            optimizer.load_state_dict(checkpoint_data['optimizer_state_dict'])
            start_epoch = checkpoint_data['epoch'] + 1
            best_val_loss = checkpoint_data['best_val_loss']
            print(f"Resumed training from epoch {start_epoch}, best val loss: {best_val_loss:.4f}")
        else:
            # Old format: Only model weights stored
            model.load_state_dict(checkpoint_data)
            print("Warning: Loaded only model weights, optimizer state is not restored!")
    patience = 50
    patience_counter = 0
    
    for epoch in range(start_epoch, epochs+1):
        epoch_start = time.time()
        model.train()
        running_loss = 0.0
        correct_pixels = 0
        total_pixels  = 0
        
        for i, (images, masks, token_ids, prompt_maps) in enumerate(train_loader):
            images = images.to(device)
            masks  = masks.to(device)
            token_ids = token_ids.to(device)
            prompt_maps = prompt_maps.to(device)
            
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                seg_logits = model(images, token_ids, prompt_maps)
            
                loss = criterion(seg_logits, masks)
            scaler.scale(loss).backward()    # 3) Scale the loss before backward
            scaler.step(optimizer)           
            scaler.update()
            
            running_loss += loss.item() * images.size(0)
            
            preds = torch.argmax(seg_logits, dim=1)
            correct_pixels += (preds == masks).sum().item()
            total_pixels   += masks.numel()
                        
            # 4) Live percentage progress + flush (for Jupyter)
            progress = 100.0 * (i + 1) / len(train_loader)
            elapsed_mins = (time.time() - epoch_start) / 60.0
            sys.stdout.write(
                f"\rEpoch {epoch}/{epochs} - {progress:.1f}% done, "
                f"{elapsed_mins:.2f} min elapsed..."
            )
            sys.stdout.flush()
         
        print()
        train_loss = running_loss / len(train_dataset)
        train_acc = correct_pixels / total_pixels
        
        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total   = 0
        with torch.no_grad():
            all_preds = []
            all_targs = []
            for (images, masks, token_ids, prompt_maps) in val_loader:
                images = images.to(device)
                masks  = masks.to(device)
                token_ids = token_ids.to(device)
                prompt_maps = prompt_maps.to(device)
                
                seg_logits = model(images, token_ids, prompt_maps)
                loss = criterion(seg_logits, masks)
                val_loss += loss.item() * images.size(0)
                
                preds = torch.argmax(seg_logits, dim=1)
                val_correct += (preds == masks).sum().item()
                val_total   += masks.numel()
                
                all_preds.append(preds.cpu())
                all_targs.append(masks.cpu())
        
        val_loss /= len(val_dataset)
        val_acc = val_correct / val_total
        
        # Compute IoU
        all_preds = torch.cat(all_preds, dim=0)
        all_targs = torch.cat(all_targs, dim=0)
        iou_dict, mean_iou = compute_iou_per_class(all_preds, all_targs, num_classes=3)
        
        print(f"Epoch {epoch}/{epochs} => Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"                 => Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc:.4f} | Val mIoU: {mean_iou:.4f}")
        
        # Check for best val
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), save_path)
            torch.save({
                    'epoch': epoch,
                    'best_val_loss': best_val_loss,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                    }, save_path)
            print("  [*] Saved best model.")
        else:
            patience_counter += 1
            print(f"  [!] No improvement. patience={patience_counter}/{patience}")
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break
        scheduler.step()
    
    print("Training complete.")
    return model

In [6]:
def test_point_based_model(model_path, test_dataset_dir, num_samples_to_show=5):
    # 1) Load CLIP
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    clip_model, clip_preprocess = clip.load("ViT-L/14", device=device)
    clip_model.eval()
    for p in clip_model.parameters():
        p.requires_grad = False
    
    # 2) Build model
    model = CLIPPointSegModel(clip_model, num_classes=3).to(device)
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict)
    model.eval()
    
    # 3) Create test dataset & loader
    test_dataset = PromptBasedSegmentationDataset(
        root_dir=test_dataset_dir,
        clip_preprocess=clip_preprocess,
        mask_transform=MaskEncodeTransform(),
        use_text=True
    )
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    
    # 4) Evaluate
    criterion = nn.CrossEntropyLoss()
    running_loss, running_correct, total_pixels = 0.0, 0, 0
    all_preds, all_masks = [], []
    
    with torch.no_grad():
        for (images, masks, token_ids, prompt_maps) in test_loader:
            images = images.to(device)
            masks  = masks.to(device)
            token_ids = token_ids.to(device)
            prompt_maps = prompt_maps.to(device)
            
            logits = model(images, token_ids, prompt_maps)
            loss   = criterion(logits, masks)
            running_loss += loss.item() * images.size(0)
            
            preds = torch.argmax(logits, dim=1)
            running_correct += (preds == masks).sum().item()
            total_pixels    += masks.numel()
            
            all_preds.append(preds.cpu())
            all_masks.append(masks.cpu())
    
    test_loss = running_loss / len(test_dataset)
    test_acc  = running_correct / total_pixels
    
    all_preds = torch.cat(all_preds, dim=0)
    all_masks = torch.cat(all_masks, dim=0)
    iou_dict, mean_iou = compute_iou_per_class(all_preds, all_masks, num_classes=3)
    
    print(f"Test Results => Loss: {test_loss:.4f}, Acc: {test_acc:.4f}, Mean IoU: {mean_iou:.4f}")
    print(f"IoU per class: {iou_dict}")
    
    # 5) Visualize a few predictions
    print(f"Showing {num_samples_to_show} random samples from test set:")
    indices = random.sample(range(len(test_dataset)), num_samples_to_show)
    for idx in indices:
        image, mask, token_ids, prompt_map = test_dataset[idx]
        
        image_in = image.unsqueeze(0).to(device)
        token_in = token_ids.unsqueeze(0).to(device)
        prompt_in= prompt_map.unsqueeze(0).to(device)
        
        with torch.no_grad():
            out = model(image_in, token_in, prompt_in)
            pred = torch.argmax(out, dim=1).squeeze(0).cpu().numpy()
        
        # Undo CLIP norm for display
        image_den = denormalize(image_in[0].cpu()).permute(1,2,0).numpy()
        mask_np   = mask.numpy()
        pm_np     = prompt_map[0].numpy()
        
        # Display
        fig, axes = plt.subplots(1, 4, figsize=(16,4))
        axes[0].imshow(image_den)
        axes[0].set_title("Input Image")
        axes[1].imshow(mask_np, cmap='gray', vmin=0, vmax=2)
        axes[1].set_title("Ground Truth Mask")
        axes[2].imshow(pm_np, cmap='jet')
        axes[2].set_title("Prompt Heatmap")
        axes[3].imshow(pred, cmap='gray', vmin=0, vmax=2)
        axes[3].set_title("Prediction")
        
        for ax in axes:
            ax.axis('off')
        plt.show()

In [None]:
# Example usage:
import torch
torch.cuda.empty_cache()

# 1) Train
trained_model = train_point_based_model(
    dataset_dir="/notebooks/Dataset/processed2/TrainVal",
    save_path="clip_pointseg_samlike_4.pth",
    epochs=1980,
    batch_size=4,
    lr=2e-4,
    resume=True,
    resume_path="clip_pointseg_samlike_3.pth"
    
)

# # 2) Test
# test_point_based_model(
#     model_path="clip_pointseg_samlike.pth",
#     test_dataset_dir="/notebooks/Dataset/processed2/Test",
#     num_samples_to_show=5
# )

Using device: cuda
Resuming from checkpoint: clip_pointseg_samlike_3.pth
Resumed training from epoch 25, best val loss: 0.1373
Epoch 25/1980 - 27.8% done, 6.80 min elapsed...

In [None]:
import clip
print(hasattr(clip, "load"))

In [17]:
import torch

checkpoint = torch.load("clip_pointseg_samlike.pth", map_location="cuda")
print(checkpoint.keys())  # This will print all keys stored in the checkpoint

odict_keys(['clip_model.positional_embedding', 'clip_model.text_projection', 'clip_model.logit_scale', 'clip_model.visual.class_embedding', 'clip_model.visual.positional_embedding', 'clip_model.visual.proj', 'clip_model.visual.conv1.weight', 'clip_model.visual.ln_pre.weight', 'clip_model.visual.ln_pre.bias', 'clip_model.visual.transformer.resblocks.0.attn.in_proj_weight', 'clip_model.visual.transformer.resblocks.0.attn.in_proj_bias', 'clip_model.visual.transformer.resblocks.0.attn.out_proj.weight', 'clip_model.visual.transformer.resblocks.0.attn.out_proj.bias', 'clip_model.visual.transformer.resblocks.0.ln_1.weight', 'clip_model.visual.transformer.resblocks.0.ln_1.bias', 'clip_model.visual.transformer.resblocks.0.mlp.c_fc.weight', 'clip_model.visual.transformer.resblocks.0.mlp.c_fc.bias', 'clip_model.visual.transformer.resblocks.0.mlp.c_proj.weight', 'clip_model.visual.transformer.resblocks.0.mlp.c_proj.bias', 'clip_model.visual.transformer.resblocks.0.ln_2.weight', 'clip_model.visual.

In [50]:
print(type(self.prompt_encoder))

NameError: name 'self' is not defined

In [9]:
for name, p in model.named_parameters():
    if p.requires_grad:
        print(f"{name} is trainable.")

NameError: name 'model' is not defined

In [9]:
import gradio as gr
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
from PIL import Image
import random, os, sys
import matplotlib.pyplot as plt
from pathlib import Path

import clip  # from openai/clip
from torchvision import transforms as T

# ---------------------------------------------------------------------------
# 1) Utility: Undo CLIP normalization for display
# ---------------------------------------------------------------------------
def denormalize(image_tensor):
    """
    Undo CLIP's normalization for visualization.
    CLIP uses mean=[0.48145466, 0.4578275, 0.40821073],
              std=[0.26862954, 0.26130258, 0.27577711].
    """
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(3, 1, 1).to(image_tensor.device)
    std  = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(3, 1, 1).to(image_tensor.device)
    return torch.clamp(image_tensor * std + mean, 0, 1)

# ---------------------------------------------------------------------------
# 2) Utility: Generate a prompt heatmap around user click
# ---------------------------------------------------------------------------
def generate_prompt_heatmap(click_x, click_y, height=224, width=224, radius=5):
    """
    If click_x < 0 or click_y < 0 => no click (all zeros).
    Otherwise, create a Gaussian around (click_x, click_y).
    """
    prompt_map = np.zeros((height, width), dtype=np.float32)
    # If user gave -1, skip
    if click_x < 0 or click_y < 0:
        return prompt_map[None, ...]  # shape (1, H, W)
    
    # Clamp to image bounds
    cx = int(np.clip(click_x, 0, width - 1))
    cy = int(np.clip(click_y, 0, height - 1))
    
    for gy in range(-radius, radius + 1):
        for gx in range(-radius, radius + 1):
            ny = cy + gy
            nx = cx + gx
            if 0 <= ny < height and 0 <= nx < width:
                dist_sq = gx * gx + gy * gy
                sigma   = radius / 2.0
                val     = np.exp(-dist_sq / (2 * sigma * sigma))
                prompt_map[ny, nx] = max(prompt_map[ny, nx], val)
    
    return prompt_map[None, ...]  # shape (1, H, W)

# ---------------------------------------------------------------------------
# 3) Utility: Convert predicted mask to a color overlay (e.g., class 1=Red, class 2=Green)
# ---------------------------------------------------------------------------
def apply_segmentation_overlay(image_np, seg_mask, alpha=0.5):
    """
    image_np: [H, W, 3] in [0,1]
    seg_mask: [H, W] with class indices (0,1,2)
    Overlays classes 1 and 2 in red/green at alpha=0.5
    """
    h_img, w_img, _ = image_np.shape
    h_mask, w_mask  = seg_mask.shape
    if (h_img != h_mask) or (w_img != w_mask):
        raise ValueError("Mismatch between image and seg_mask shapes.")
    
    base_img = (image_np * 255).astype(np.uint8)   # to uint8
    overlay  = np.zeros((h_img, w_img, 3), dtype=np.uint8)
    overlay[seg_mask == 1] = (255, 0, 0)  # red
    overlay[seg_mask == 2] = (0, 255, 0)  # green
    
    blended = (alpha * overlay + (1 - alpha) * base_img).astype(np.uint8)
    return Image.fromarray(blended)

# ---------------------------------------------------------------------------
# 4) Load the trained model
# ---------------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Adjust your checkpoint path:
MODEL_PATH = "/Users/jakedugan/Documents/UniversityofEdinburgh/CV/cw1/clip_pointseg_samlike_4.pth"

# Load base CLIP
clip_model, clip_preprocess = clip.load("ViT-L/14", device=device)
clip_model.eval()
for p in clip_model.parameters():
    p.requires_grad = False

# Build your custom segmentation model class (already defined in your code):
# e.g.:
#   model = CLIPPointSegModel(clip_model=clip_model, num_classes=3).to(device)
model = CLIPPointSegModel(clip_model, num_classes=3).to(device)

# Load the checkpoint dictionary
checkpoint_dict = torch.load(MODEL_PATH, map_location=device)

# If you stored a dictionary with 'model_state_dict', load that:

if 'model_state_dict' in checkpoint_dict:
    model.load_state_dict(checkpoint_dict["model_state_dict"])
    print("Loaded from 'model_state_dict' in checkpoint.")
else:
    # fallback if it was saved as raw state_dict
    model.load_state_dict(checkpoint_dict)
    print("Loaded from raw state dict (no 'model_state_dict' key found).")

model.eval()
print(f"Model loaded from: {MODEL_PATH}")

# ---------------------------------------------------------------------------
# 5) The inference (Gradio) function
# ---------------------------------------------------------------------------
def segment_interactive(input_image, text_prompt, click_x, click_y, radius=5,
                        x1_box=-1, y1_box=-1, x2_box=-1, y2_box=-1, scribble=None):
    if input_image is None:
        return None

    image_pil = input_image.convert("RGB")
    image_tensor = clip_preprocess(image_pil).unsqueeze(0).to(device)

    # Build prompt_map_np as a 2D array
    prompt_map_np = np.zeros((224, 224), dtype=np.float32)

    if (x1_box >= 0 and x2_box > x1_box and y1_box >= 0 and y2_box > y1_box):
        prompt_map_np[y1_box:y2_box+1, x1_box:x2_box+1] = 1.0
    elif scribble is not None:
        # Check if scribble is a dict with a composite key
        if isinstance(scribble, dict) and scribble.get("composite") is not None:
            if isinstance(scribble["composite"], np.ndarray):
                scribble_img = Image.fromarray(scribble["composite"])
            else:
                scribble_img = scribble["composite"]
            scribble_np = np.array(scribble_img.convert("L").resize((224,224)))
            prompt_map_np[scribble_np > 10] = 1.0
        else:
            # if scribble is directly a PIL image
            scribble_np = np.array(scribble.convert("L").resize((224,224)))
            prompt_map_np[scribble_np > 10] = 1.0
    else:
        prompt_map_np = generate_prompt_heatmap(click_x, click_y, 224, 224, radius=radius)[0]

    # Ensure prompt_map has shape (1, 1, 224, 224)
    prompt_map = torch.from_numpy(prompt_map_np).unsqueeze(0).unsqueeze(0).float().to(device)

    if len(text_prompt.strip()) == 0:
        text_prompt = "a photo of an animal"
    token_ids = clip.tokenize([text_prompt]).to(device)

    with torch.no_grad():
        seg_logits = model(image_tensor, token_ids, prompt_map)
        seg_mask = torch.argmax(seg_logits, dim=1).squeeze(0).cpu().numpy()

    image_denorm = denormalize(image_tensor[0]).permute(1,2,0).cpu().numpy()
    overlay_pil = apply_segmentation_overlay(image_denorm, seg_mask, alpha=0.5)
    return overlay_pil

# ---------------------------------------------------------------------------
# 6) Build the Gradio UI
# ---------------------------------------------------------------------------
def create_ui():
    with gr.Blocks() as demo:
        gr.Markdown("## CLIP-based Prompt Segmentation Demo")
        gr.Markdown(
            "Upload an image and provide a prompt. You can click a point, define a bounding box, or draw a scribble. "
            "If no bounding box or scribble is provided, the model uses the click point."
        )
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Input Image (resized to 224x224)", type="pil")
                text_prompt = gr.Textbox(label="Text Prompt (optional)", placeholder="e.g., 'a photo of a Persian cat'")
                click_x = gr.Number(label="Click X (0-223) or -1 for no click", value=-1, precision=0)
                click_y = gr.Number(label="Click Y (0-223) or -1 for no click", value=-1, precision=0)
                radius = gr.Number(label="Click Gaussian Radius", value=5, precision=0)
                x1_box = gr.Number(label="Box X1 (optional)", value=-1, precision=0)
                y1_box = gr.Number(label="Box Y1 (optional)", value=-1, precision=0)
                x2_box = gr.Number(label="Box X2 (optional)", value=-1, precision=0)
                y2_box = gr.Number(label="Box Y2 (optional)", value=-1, precision=0)
                # Use the new ImageEditor component for scribbles
                scribble = gr.ImageEditor(sources=(), brush=gr.Brush(colors=["#000000"], color_mode="fixed"),
                                          label="Scribble Prompt (optional)", type="pil")
                submit_btn = gr.Button("Segment!")
            with gr.Column():
                output_image = gr.Image(label="Output Overlay", type="pil")
        submit_btn.click(
            fn=segment_interactive,
            inputs=[input_image, text_prompt, click_x, click_y, radius, x1_box, y1_box, x2_box, y2_box, scribble],
            outputs=[output_image]
        )
    return demo

Using device: cpu
Loaded from 'model_state_dict' in checkpoint.
Model loaded from: /Users/jakedugan/Documents/UniversityofEdinburgh/CV/cw1/clip_pointseg_samlike_4.pth


In [11]:
# ---------------------------------------------------------------------------
# 7) Launch the Gradio App
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    ui = create_ui()
    ui.launch(server_name="0.0.0.0", server_port=7861)

* Running on local URL:  http://0.0.0.0:7861

To create a public link, set `share=True` in `launch()`.
