## Setting up the CLIP features model classes 

In [1]:
###############################################
#               Testing the model             #
###############################################

import os
import clip
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import re
from typing import List
from clip.simple_tokenizer import SimpleTokenizer
import torchvision.transforms as T
from matplotlib.backends.backend_pdf import PdfPages
import random
import matplotlib.pyplot as plt

clip_tokenizer = SimpleTokenizer()

#############################################
#                Dataset Class              #
#############################################

import os
import re
import clip
import time
import csv
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, Subset
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from clip.simple_tokenizer import SimpleTokenizer

clip_tokenizer = SimpleTokenizer()

class SegmentationDatasetWithText(Dataset):
    """
    Dataset for segmentation with text prompts.
    
    Assumes:
      - Images are in root_dir/color.
      - Masks are in root_dir/label.
      - Each mask is a color-coded image where:
          • Background: Black [0,0,0] → 0
          • Cat: Orange [255,165,0] → 1
          • Dog: Cyan [0,255,255] → 2
      - A text prompt is generated from the filename, e.g.:
            "cat_Siamese_27.png" → "a photo of a Siamese cat"
    """
    def __init__(self, root_dir, transform_img, transform_label=None):
        self.root_dir = root_dir
        self.transform_img = transform_img
        self.transform_label = transform_label
        
        self.image_dir = os.path.join(root_dir, 'color')
        self.mask_dir = os.path.join(root_dir, 'label')
        
        self.image_paths = sorted([
            fname for fname in os.listdir(self.image_dir)
            if fname.lower().endswith(('.jpg', '.png', '.jpeg'))
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image_filename = self.image_paths[index]
        image_path = os.path.join(self.image_dir, image_filename)
        # Use the same stem and force a .png extension for the mask.
        mask_path = os.path.join(self.mask_dir, Path(image_filename).stem + ".png")
        
        # Load and transform the image.
        image = Image.open(image_path).convert("RGB")
        if self.transform_img:
            image = self.transform_img(image)
        
        # Load the mask as RGB and map colors to class indices.
        mask_img = Image.open(mask_path).convert("RGB")
        mask_np = np.array(mask_img, dtype=np.uint8)
        label_new = np.zeros((mask_np.shape[0], mask_np.shape[1]), dtype=np.int64)
        
        # Map background: black → 0.
        background_mask = (mask_np[..., 0] == 0) & (mask_np[..., 1] == 0) & (mask_np[..., 2] == 0)
        label_new[background_mask] = 0
        
        # Map cat: orange [255,165,0] → 1.
        cat_mask = (mask_np[..., 0] == 255) & (mask_np[..., 1] == 165) & (mask_np[..., 2] == 0)
        label_new[cat_mask] = 1
        
        # Map dog: cyan [0,255,255] → 2.
        dog_mask = (mask_np[..., 0] == 0) & (mask_np[..., 1] == 255) & (mask_np[..., 2] == 255)
        label_new[dog_mask] = 2
        
        mask = torch.from_numpy(label_new).long()
        if self.transform_label:
            if mask.ndim == 2:
                mask = mask.unsqueeze(0)
            mask = self.transform_label(mask)
            mask = mask.squeeze(0)
        
        # Generate text prompt based on filename.
        name, _ = os.path.splitext(image_filename)
        match = re.match(r"^(cat|dog)_([A-Za-z_]+)_(\d+)", name)
        if match:
            animal_type = match.group(1)  # "cat" or "dog"
            breed = match.group(2).replace("_", " ")
            breed = breed.title()
            text_prompt = f"a photo of a {breed} {animal_type}"
        else:
            text_prompt = "a photo of an animal"
        
        token_ids = clip.tokenize([text_prompt]).squeeze(0)
        
        # (Optional) Decode the tokens for debugging.
        decoded_text = clip_tokenizer.decode(token_ids.tolist())
        # print(f"Original: {text_prompt} → Tokenized: {decoded_text}")
        
        return image, mask, token_ids

def denormalize(image_tensor):
    """Convert a normalized tensor image back to a standard (0,1) range for visualization."""
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(3, 1, 1)
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(3, 1, 1)
    image_tensor = image_tensor * std + mean  # Undo normalization
    image_tensor = torch.clamp(image_tensor, 0, 1)  # Ensure values are valid
    return image_tensor

#############################################
#             Segmentation Head             #
#############################################

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction, in_channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class ImprovedCLIPSegmentationHead(nn.Module):
    def __init__(self, in_channels, text_dim, num_classes, dropout_prob=0.25, use_attention=True):
        super(ImprovedCLIPSegmentationHead, self).__init__()
        # Project text features to match visual feature channels.
        self.text_proj = nn.Linear(text_dim, in_channels)
        self.fuse_conv1 = nn.Sequential(
            nn.Conv2d(in_channels * 2, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_prob)
        )
        self.residual_block = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256)
        )
        self.relu = nn.ReLU(inplace=True)
        self.attention = ChannelAttention(256) if use_attention else nn.Identity()
        self.fuse_conv2 = nn.Conv2d(256, num_classes, kernel_size=1)

    def forward(self, img_features, text_features):
        B, C, H, W = img_features.shape  # Expected C = 768
        img_features = img_features.float()
        text_features = text_features.float()

        projected_text = self.text_proj(text_features)  # [B, 768]
        projected_text = projected_text.view(B, C, 1, 1).expand(B, C, H, W)
        fused = torch.cat([img_features, projected_text], dim=1)
        x = self.fuse_conv1(fused)
        res = self.residual_block(x)
        x = self.relu(x + res)
        x = self.attention(x)
        logits = self.fuse_conv2(x)
        return logits

class CLIPSegmentationModel(nn.Module):
    def __init__(self, clip_model, num_classes=3):
        super(CLIPSegmentationModel, self).__init__()
        # Get feature dimension from the CLIP visual model (e.g., 768).
        self.feature_dim = clip_model.visual.output_dim
        print(f'Feature dimension: {self.feature_dim}')  # Expected 768
        
        self.clip_model = clip_model
        # Segmentation head with input channels and text dimension matching CLIP.
        self.seg_head = ImprovedCLIPSegmentationHead(
            in_channels=self.feature_dim,
            text_dim=self.feature_dim,
            num_classes=num_classes,
            dropout_prob=0.25
        )
    
    def get_visual_features(self, image):
        visual = self.clip_model.visual
        image = image.to(dtype=visual.conv1.weight.dtype)
        x = visual.conv1(image)  # [B, width, H', W']
        x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)  # [B, tokens, width]
        cls_tokens = visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
        x = torch.cat([cls_tokens, x], dim=1)
        x = x + visual.positional_embedding.to(x.dtype)
        x = visual.ln_pre(x)
        x = x.permute(1, 0, 2)
        x = visual.transformer(x)
        x = x.permute(1, 0, 2)
        if hasattr(visual, "proj"):
            x = x @ visual.proj
        return x

    def forward(self, image, token_ids):
      image = image.float()
      # Compute CLIP features without tracking gradients.
      with torch.no_grad():
          visual_features = self.get_visual_features(image)  # [B, num_tokens+1, 768]
          tokens = visual_features[:, 1:, :]  # remove the class token
          B, N, D = tokens.shape
          grid_size = int(np.sqrt(N))
          image_features = tokens.reshape(B, grid_size, grid_size, D).permute(0, 3, 1, 2)
          text_features = self.clip_model.encode_text(token_ids).float()
          
      # Compute segmentation head normally (gradients tracked here).
      seg_logits = self.seg_head(image_features, text_features)
      seg_logits = F.interpolate(seg_logits, size=(image.shape[2], image.shape[3]),
                                mode="bilinear", align_corners=False)
      return seg_logits

###############################################
# Training, Evaluation, and Utility Functions #
###############################################
def compute_iou_per_class(preds, targets, num_classes=3):
    """
    Compute Intersection over Union (IoU) for each class.
    Returns a dictionary with IoU for each class.
    """
    iou_per_class = {}

    for cls in range(num_classes):
        pred_inds = (preds == cls)
        target_inds = (targets == cls)
        intersection = (pred_inds & target_inds).sum().item()
        union = (pred_inds | target_inds).sum().item()

        if union == 0:
            iou_per_class[cls] = float('nan')  # Avoid division by zero
        else:
            iou_per_class[cls] = intersection / union

    return iou_per_class

## Test the robustness of our CLIP features model

In [None]:
import torch
import clip
import numpy as np
import random
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision.transforms as T
from skimage.util import random_noise

# --- Provided perturbation functions (unchanged) ---
def add_salt_and_pepper_noise(image, amount):
    image_float = image.astype(np.float32) / 255.0
    noisy = random_noise(image_float, mode='s&p', amount=amount)
    noisy = np.clip(noisy * 255.0, 0, 255).astype(np.uint8)
    return noisy

def occlude_image(image, square_edge):
    perturbed = image.copy()
    if square_edge > 0:
        h, w, _ = perturbed.shape
        max_x = w - square_edge
        max_y = h - square_edge
        if max_x < 0 or max_y < 0:
            perturbed[:] = 0
        else:
            x = np.random.randint(0, max_x + 1)
            y = np.random.randint(0, max_y + 1)
            perturbed[y:y+square_edge, x:x+square_edge, :] = 0
    return perturbed

def decrease_brightness(image, offset):
    perturbed = image.astype(np.int32) - offset
    perturbed = np.clip(perturbed, 0, 255).astype(np.uint8)
    return perturbed

def increase_brightness(image, offset):
    perturbed = image.astype(np.int32) + offset
    perturbed = np.clip(perturbed, 0, 255).astype(np.uint8)
    return perturbed

def decrease_contrast(image, factor):
    perturbed = image.astype(np.float32) * factor
    perturbed = np.clip(perturbed, 0, 255).astype(np.uint8)
    return perturbed

def increase_contrast(image, factor):
    perturbed = image.astype(np.float32) * factor
    perturbed = np.clip(perturbed, 0, 255).astype(np.uint8)
    return perturbed

def add_gaussian_noise(image, sigma):
    noise = np.random.normal(0, sigma, image.shape)
    perturbed = image.astype(np.float32) + noise
    perturbed = np.clip(perturbed, 0, 255).astype(np.uint8)
    return perturbed

def add_gaussian_blur(image, iterations):
    import cv2
    kernel = (1/16) * np.array([[1, 2, 1],
                                [2, 4, 2],
                                [1, 2, 1]], dtype=np.float32)
    perturbed = image.copy()
    for _ in range(iterations):
        perturbed = cv2.filter2D(perturbed, ddepth=-1, kernel=kernel)
    return perturbed

# --- Updated Perturbation Evaluation Function ---
def test_model_on_perturbations(model_path, test_dataset, batch_size=1, num_classes=3, perturbation_mode='gaussian_noise'):
    """
    Evaluates the CLIP features segmentation model on perturbed test data.
    
    For each perturbation level, each test sample is:
      1. Denormalized to [0,255] (uint8),
      2. Perturbed using the specified function,
      3. Re-normalized with CLIP’s mean and std,
      4. Passed through the segmentation model (with its text tokens),
      5. Compared with the ground truth after setting all boundary pixels (i.e. pixels not 0,1,2) to ignore (255).
    
    The function computes the mean Dice score (ignoring boundaries) at each perturbation level,
    plots mean Dice versus perturbation level, and returns the list of mean Dice scores.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # CLIP normalization parameters.
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(3,1,1).to(device)
    std  = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(3,1,1).to(device)
    
    # Load CLIP model.
    clip_model, _ = clip.load("ViT-L/14", device=device)
    clip_model.eval()
    for param in clip_model.parameters():
        param.requires_grad = False

    # Load your CLIP segmentation model.
    model = CLIPSegmentationModel(clip_model=clip_model, num_classes=num_classes).to(device)
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict)
    model.eval()
    
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    # Set up perturbation levels and function.
    if perturbation_mode == 'gaussian_noise':
        levels = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
        perturb_fn = add_gaussian_noise
    elif perturbation_mode == 'gaussian_blur':
        levels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        perturb_fn = add_gaussian_blur
    elif perturbation_mode == 'contrast_increase':
        levels = [1.0, 1.01, 1.02, 1.03, 1.04, 1.05, 1.1, 1.15, 1.2, 1.25]
        perturb_fn = increase_contrast
    elif perturbation_mode == 'contrast_decrease':
        levels = [1.0, 0.95, 0.90, 0.85, 0.80, 0.60, 0.40, 0.30, 0.20, 0.10]
        perturb_fn = decrease_contrast
    elif perturbation_mode == 'brightness_increase':
        levels = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45]
        perturb_fn = increase_brightness
    elif perturbation_mode == 'brightness_decrease':
        levels = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45]
        perturb_fn = decrease_brightness
    elif perturbation_mode == 'occlusion_increase':
        levels = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45]
        perturb_fn = occlude_image
    elif perturbation_mode == 'salt_and_pepper_noise':
        levels = [0.00, 0.02, 0.04, 0.06, 0.08, 0.10, 0.12, 0.14, 0.16, 0.18]
        perturb_fn = add_salt_and_pepper_noise
    else:
        raise ValueError(f"Perturbation mode '{perturbation_mode}' not implemented.")
    
    dice_scores_per_level = []
    
    # Define a helper function to compute Dice ignoring boundary pixels.
    def dice_ignore(pred, target, num_classes=3, epsilon=1e-6):
        # target: 2D tensor; set any pixel not in {0,1,2} to ignore (255)
        target_metric = target.clone()
        target_metric[~((target_metric == 0) | (target_metric == 1) | (target_metric == 2))] = 255
        valid = (target_metric != 255)
        if valid.sum() == 0:
            return float('nan')
        dice_per_class = []
        for cls in range(num_classes):
            pred_cls = (pred[valid] == cls).float()
            target_cls = (target_metric[valid] == cls).float()
            intersection = (pred_cls * target_cls).sum()
            union = pred_cls.sum() + target_cls.sum()
            if union == 0:
                dice_per_class.append(float('nan'))
            else:
                dice_per_class.append((2 * intersection + epsilon) / (union + epsilon))
        return np.nanmean(dice_per_class)
    
    # Loop over each perturbation level.
    for level in levels:
        level_dice_scores = []
        # Process each sample in the test set.
        for images, masks, token_ids in test_loader:
            # Process each sample in the batch.
            for i in range(images.size(0)):
                # Extract a single sample.
                image = images[i]         # (C, H, W)
                mask = masks[i]           # (H, W)
                token_id = token_ids[i]   # token IDs tensor
                
                # Denormalize image to [0,255] (uint8).
                denorm_img = denormalize(image)
                np_img = (denorm_img * 255.0).clamp(0, 255).cpu().numpy().transpose(1,2,0).astype(np.uint8)
                
                # Apply perturbation.
                perturbed_np = perturb_fn(np_img, level)
                
                # Convert back to tensor and re-normalize.
                perturbed_tensor = torch.from_numpy(perturbed_np).permute(2, 0, 1).float() / 255.0
                perturbed_tensor = perturbed_tensor.to(device)  # move to the same device as mean and std
                perturbed_tensor = (perturbed_tensor - mean) / std
                perturbed_tensor = perturbed_tensor.unsqueeze(0)  
                
                # Process token IDs and mask for this sample.
                # Process token IDs and mask for this sample.
                token_id = token_id.unsqueeze(0).to(device)
                mask = mask.unsqueeze(0).to(device).cpu()  # move mask to CPU
                # Run through model.
                with torch.no_grad():
                    output = model(perturbed_tensor, token_id)
                pred = torch.argmax(output, dim=1).squeeze(0).cpu()

                # Now both pred and mask are on CPU.
                dice_val = dice_ignore(pred, mask.squeeze(0), num_classes=num_classes)
                level_dice_scores.append(dice_val)
        
        mean_dice = np.nanmean(level_dice_scores)
        dice_scores_per_level.append(mean_dice)
        print(f"Perturbation level {level} ({perturbation_mode}): Mean Dice Score = {mean_dice:.4f}")
    
    # Plot mean Dice vs. perturbation level.
    plt.figure(figsize=(8,6))
    plt.plot(levels, dice_scores_per_level, marker='o')
    plt.xlabel(f"{perturbation_mode} level")
    plt.ylabel("Mean Dice Score")
    plt.title(f"Robustness Evaluation: Dice Score vs {perturbation_mode}")
    plt.grid(True)
    plt.savefig(f"{perturbation_mode}.pdf", format="pdf", bbox_inches="tight")
    plt.show()
    
    return dice_scores_per_level

# Set device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using {device} device')
# Load CLIP's preprocessing transform.
clip_model, clip_preprocess = clip.load("ViT-L/14", device=device)
test_transform = clip_preprocess

# Create your test dataset (adjust the root_dir path as needed).
test_dataset = SegmentationDatasetWithText(
    root_dir="./Dataset/processed/Test",
    transform_img=test_transform,
    transform_label=T.Resize((224, 224), interpolation=T.InterpolationMode.NEAREST)
)

# Specify the path to the saved model.
model_path = "clip_OPENAI_segmentation_best.pth"

dice_scores = test_model_on_perturbations(model_path, test_dataset, perturbation_mode='gaussian_noise')
dice_scores = test_model_on_perturbations(model_path, test_dataset, perturbation_mode='gaussian_blur')
dice_scores = test_model_on_perturbations(model_path, test_dataset, perturbation_mode='contrast_increase')
dice_scores = test_model_on_perturbations(model_path, test_dataset, perturbation_mode='contrast_decrease')
dice_scores = test_model_on_perturbations(model_path, test_dataset, perturbation_mode='brightness_increase')
dice_scores = test_model_on_perturbations(model_path, test_dataset, perturbation_mode='brightness_decrease')
dice_scores = test_model_on_perturbations(model_path, test_dataset, perturbation_mode='occlusion_increase')
dice_scores = test_model_on_perturbations(model_path, test_dataset, perturbation_mode='salt_and_pepper_noise')

Feature dimension: 768
Perturbation level 0 (gaussian_noise): Mean Dice Score = 0.9388


KeyboardInterrupt: 

In [None]:
from skimage.util import random_noise

def add_salt_and_pepper_noise(image, amount):
    """
    image: numpy array of shape (H, W, C) with dtype=np.uint8, values in 0-255.
    amount: float value representing the proportion of pixels to be replaced with noise.
            For example: 0.00, 0.02, 0.04, 0.06, 0.08, 0.10, 0.12, 0.14, 0.16, 0.18.
    Returns:
        perturbed_image: numpy array (uint8, values in 0-255) with salt and pepper noise added.
    """
    # Convert image to float in range [0,1] for skimage
    image_float = image.astype(np.float32) / 255.0
    noisy = random_noise(image_float, mode='s&p', amount=amount)
    # Convert back to [0,255] uint8
    noisy = np.clip(noisy * 255.0, 0, 255).astype(np.uint8)
    return noisy

def occlude_image(image, square_edge):
    """
    image: numpy array of shape (H, W, C) with dtype=np.uint8, values in 0-255.
    square_edge: integer, the edge length of the square to occlude.
                 For example: 0, 5, 10, 15, 20, 25, 30, 35, 40, 45.
    Returns:
        perturbed_image: numpy array (uint8, values in 0-255) with a randomly placed square region occluded.
    """
    perturbed = image.copy()
    if square_edge > 0:
        h, w, _ = perturbed.shape
        # Ensure the square fits within the image
        max_x = w - square_edge
        max_y = h - square_edge
        if max_x < 0 or max_y < 0:
            # If the square is larger than the image, occlude the whole image
            perturbed[:] = 0
        else:
            # Randomly select the top-left corner for occlusion.
            x = np.random.randint(0, max_x + 1)
            y = np.random.randint(0, max_y + 1)
            perturbed[y:y+square_edge, x:x+square_edge, :] = 0
    return perturbed

def decrease_brightness(image, offset):
    """
    image: numpy array of shape (H, W, C) with dtype=np.uint8, values in 0-255.
    offset: integer offset to subtract from each pixel. For example: 0, 5, 10, 15, 20, 25, 30, 35, 40, 45.
    Returns:
        perturbed_image: numpy array (uint8, values in 0-255) with decreased brightness.
    """
    # Convert to a type that supports negative values.
    perturbed = image.astype(np.int32) - offset
    # Clip values so that they do not fall below 0.
    perturbed = np.clip(perturbed, 0, 255).astype(np.uint8)
    return perturbed

def increase_brightness(image, offset):
    """
    image: numpy array of shape (H, W, C) with dtype=np.uint8, values in 0-255.
    offset: integer offset to add to each pixel. For example: 0, 5, 10, 15, 20, 25, 30, 35, 40, 45.
    Returns:
        perturbed_image: numpy array (uint8, values in 0-255) with increased brightness.
    """
    # Convert to an integer type that can hold values > 255
    perturbed = image.astype(np.int32) + offset
    # Clip values to the valid range and convert back to uint8
    perturbed = np.clip(perturbed, 0, 255).astype(np.uint8)
    return perturbed

def decrease_contrast(image, factor):
    """
    image: numpy array of shape (H, W, C) with dtype=np.uint8, values in 0-255.
    factor: multiplicative factor for pixel values.
            For example: 1.0, 0.95, 0.90, 0.85, 0.80, 0.60, 0.40, 0.30, 0.20, 0.10.
    Returns:
        perturbed_image: numpy array (uint8, values in 0-255) with decreased contrast.
    """
    perturbed = image.astype(np.float32) * factor
    perturbed = np.clip(perturbed, 0, 255).astype(np.uint8)
    return perturbed

def increase_contrast(image, factor):
    """
    image: numpy array of shape (H, W, C) with dtype=np.uint8, values in 0-255.
    factor: multiplicative factor for pixel values.
            For example: 1.0, 1.01, 1.02, 1.03, 1.04, 1.05, 1.1, 1.15, 1.2, 1.25.
    Returns:
        perturbed_image: numpy array (uint8, values in 0-255) with increased contrast.
    """
    perturbed = image.astype(np.float32) * factor
    perturbed = np.clip(perturbed, 0, 255).astype(np.uint8)
    return perturbed

def add_gaussian_noise(image, sigma):
    """
    image: numpy array of shape (H, W, C) with dtype=np.uint8, values in 0-255.
    sigma: standard deviation of the Gaussian noise.
    Returns:
        perturbed_image: numpy array (uint8, values in 0-255)
    """
    noise = np.random.normal(0, sigma, image.shape)
    perturbed = image.astype(np.float32) + noise
    perturbed = np.clip(perturbed, 0, 255).astype(np.uint8)
    return perturbed

def add_gaussian_blur(image, iterations):
    """
    image: numpy array of shape (H, W, C) with dtype=np.uint8, values in 0-255.
    iterations: integer, the number of times to convolve the image with a 3x3 Gaussian kernel.
                Use 0 for no blurring, 1 for a single pass, up to 9 for a heavy blur.
    Returns:
        perturbed_image: numpy array (uint8, values in 0-255) that has been blurred.
    """
    import cv2  # OpenCV is used for the convolution.
    # Define a 3x3 Gaussian kernel.
    kernel = (1/16) * np.array([[1, 2, 1],
                                [2, 4, 2],
                                [1, 2, 1]], dtype=np.float32)
    
    perturbed = image.copy()
    for _ in range(iterations):
        perturbed = cv2.filter2D(perturbed, ddepth=-1, kernel=kernel)
    return perturbed


def dice_score_multiclass(pred, target, num_classes=3, epsilon=1e-6):
    """
    Compute the mean Dice score over num_classes.
    pred and target should be 2D tensors of shape [H, W] containing class indices.
    """
    dice_per_class = []
    for cls in range(num_classes):
        pred_cls = (pred == cls).float()
        target_cls = (target == cls).float()
        intersection = (pred_cls * target_cls).sum()
        union = pred_cls.sum() + target_cls.sum()
        dice_cls = (2 * intersection + epsilon) / (union + epsilon)
        dice_per_class.append(dice_cls.item())
    return np.mean(dice_per_class)

def test_best_model_on_perturbations(best_model_path, test_dataset, batch_size=1, num_classes=3, perturbation_mode='gaussian_noise'):
    """
    Evaluates the best segmentation model on perturbed test data.
    For each perturbation level, the function:
      - Denormalizes the input image to [0,255]
      - Applies the chosen perturbation
      - Re-normalizes the image using CLIP's mean/std
      - Feeds the perturbed image (with its token_ids) into the model
      - Computes the multi-class Dice score over the test set.
    Plots the mean Dice score versus perturbation level.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss().to(device)

    # Load CLIP model and set to evaluation.
    clip_model, _ = clip.load("ViT-L/14", device=device)
    clip_model.eval()
    for param in clip_model.parameters():
        param.requires_grad = False

    # Load your CLIP segmentation model.
    model = CLIPSegmentationModel(clip_model=clip_model, num_classes=num_classes).to(device)
    state_dict = torch.load(best_model_path, map_location=device)
    model.load_state_dict(state_dict)
    model.eval()

    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

    # Define CLIP normalization parameters.
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(3, 1, 1)
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(3, 1, 1)

    # Set up perturbation parameters.
    if perturbation_mode == 'gaussian_noise':
        levels = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
        perturb_fn = add_gaussian_noise
    elif perturbation_mode == 'gaussian_blur':
        levels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        perturb_fn = add_gaussian_blur
    elif perturbation_mode == 'contrast_increase':
        levels = [1.0, 1.01, 1.02, 1.03, 1.04, 1.05, 1.1, 1.15, 1.2, 1.25]
        perturb_fn = increase_contrast
    elif perturbation_mode == 'contrast_decrease':
        levels = [1.0, 0.95, 0.90, 0.85, 0.80, 0.60, 0.40, 0.30, 0.20, 0.10]
        perturb_fn = decrease_contrast
    elif perturbation_mode == 'brightness_increase':
        levels = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45]
        perturb_fn = increase_brightness
    elif perturbation_mode == 'brightness_decrease':
        levels = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45]
        perturb_fn = decrease_brightness
    elif perturbation_mode == 'occlusion_increase':
        levels = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45]
        perturb_fn = occlude_image
    elif perturbation_mode == 'salt_and_pepper_noise':
        levels = [0.00, 0.02, 0.04, 0.06, 0.08, 0.10, 0.12, 0.14, 0.16, 0.18]
        perturb_fn = add_salt_and_pepper_noise
    else:
        raise ValueError(f"Perturbation mode '{perturbation_mode}' not implemented.")

    dice_scores_per_level = []

    # Loop over each perturbation level.
    for level in levels:
        level_dice_scores = []
        for images, masks, token_ids in test_loader:
            # Denormalize the image to [0,1] then scale to [0,255].
            images_denorm = denormalize(images)
            images_pixels = (images_denorm * 255.0).clamp(0, 255).cpu().numpy().astype(np.uint8)
            # Process first image in batch.
            img_np = np.transpose(images_pixels[0], (1, 2, 0))
            
            # Apply the perturbation.
            perturbed_np = perturb_fn(img_np, level)
            
            # Convert perturbed image back to a tensor.
            perturbed_tensor = torch.from_numpy(perturbed_np).permute(2, 0, 1).unsqueeze(0).float()
            perturbed_tensor = perturbed_tensor / 255.0
            perturbed_tensor = (perturbed_tensor - mean) / std
            perturbed_tensor = perturbed_tensor.to(device)

            token_ids = token_ids.to(device)
            masks = masks.to(device)

            with torch.no_grad():
                output = model(perturbed_tensor, token_ids)
            pred_mask = torch.argmax(output, dim=1).squeeze(0)

            # Compute multi-class Dice score.
            dice = dice_score_multiclass(pred_mask, masks, num_classes=num_classes)
            level_dice_scores.append(dice)
        mean_dice = np.mean(level_dice_scores)
        dice_scores_per_level.append(mean_dice)
        print(f"Perturbation Level {level} ({perturbation_mode}): Mean Dice Score = {mean_dice:.4f}")

    # Plot mean Dice score vs. perturbation level.
    plt.figure(figsize=(8,6))
    plt.plot(levels, dice_scores_per_level, marker='o')
    plt.xlabel(f"{perturbation_mode} level")
    plt.ylabel("Mean Dice Score")
    plt.title(f"Robustness Evaluation: Dice Score vs {perturbation_mode}")
    plt.grid(True)
    plt.savefig(f"{perturbation_mode}.pdf", format="pdf", bbox_inches="tight")
    plt.show()

    return dice_scores_per_level

###############################################
#         Example Usage in __main__          #
###############################################

if __name__ == "__main__":
    # Load CLIP's preprocessing transform.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    clip_model, clip_preprocess = clip.load("ViT-L/14", device=device)
    test_transform = clip_preprocess  # use CLIP's preprocessing

    # For labels, we now simply resize the mask to 224x224 (no extra encoding needed)
    transform_label = T.Resize((224, 224), interpolation=T.InterpolationMode.NEAREST)

    # Create your test dataset (adjust root_dir as needed).
    test_dataset = SegmentationDatasetWithText(
        root_dir="./processed/Test/",
        transform_img=test_transform,
        transform_label=transform_label
    )

    best_model_path = "./clip_OPENAI_segmentation_best.pth"

    # Evaluate different perturbation modes.
    dice_scores = test_best_model_on_perturbations(best_model_path, test_dataset, perturbation_mode='gaussian_noise')
    dice_scores = test_best_model_on_perturbations(best_model_path, test_dataset, perturbation_mode='gaussian_blur')
    dice_scores = test_best_model_on_perturbations(best_model_path, test_dataset, perturbation_mode='contrast_increase')
    dice_scores = test_best_model_on_perturbations(best_model_path, test_dataset, perturbation_mode='contrast_decrease')
    dice_scores = test_best_model_on_perturbations(best_model_path, test_dataset, perturbation_mode='brightness_increase')
    dice_scores = test_best_model_on_perturbations(best_model_path, test_dataset, perturbation_mode='brightness_decrease')
    dice_scores = test_best_model_on_perturbations(best_model_path, test_dataset, perturbation_mode='occlusion_increase')
    dice_scores = test_best_model_on_perturbations(best_model_path, test_dataset, perturbation_mode='salt_and_pepper_noise')

In [6]:
import fitz  # PyMuPDF
from PIL import Image

def merge_pdfs_to_grid(pdf_list, output_path, grid=(2, 4), dpi=150):
    """
    Render the first page of each PDF as an image, arrange them in a grid,
    and save the result as a single-page PDF.
    
    Arguments:
      pdf_list: list of PDF file paths (should have 8 PDFs for a 2x4 grid)
      output_path: output PDF file path.
      grid: tuple (rows, columns) for the layout; default is (2, 4).
      dpi: rendering resolution.
    """
    images = []
    for pdf in pdf_list:
        doc = fitz.open(pdf)
        page = doc.load_page(0)
        # Compute zoom factor (default resolution is 72 dpi)
        zoom = dpi / 72.0
        mat = fitz.Matrix(zoom, zoom)
        pix = page.get_pixmap(matrix=mat)
        # Convert pixmap to a PIL Image (RGB)
        img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
        images.append(img)
        doc.close()
    
    # Determine grid dimensions.
    rows, cols = grid
    if len(images) != rows * cols:
        raise ValueError(f"Expected {rows*cols} images, but got {len(images)}.")

    # Assume all images have similar sizes; use the maximum width and height.
    widths, heights = zip(*(im.size for im in images))
    max_width = max(widths)
    max_height = max(heights)
    
    # Create a new blank image with white background.
    merged_width = cols * max_width
    merged_height = rows * max_height
    merged_img = Image.new('RGB', (merged_width, merged_height), color=(255, 255, 255))
    
    # Paste each image into its grid cell.
    for idx, im in enumerate(images):
        row = idx // cols
        col = idx % cols
        x_offset = col * max_width
        y_offset = row * max_height
        merged_img.paste(im, (x_offset, y_offset))
    
    # Save the merged image as a single-page PDF.
    merged_img.save(output_path, "PDF", resolution=dpi)
    print(f"Merged PDF saved as {output_path}")

if __name__ == "__main__":
    # List your 8 PDF figure files here.
    pdf_files = [
        'brightness_decrease.pdf', 'brightness_increase.pdf', 'contrast_decrease.pdf', 'contrast_increase.pdf',
        'gaussian_blur.pdf', 'gaussian_noise.pdf', 'occlusion_increase.pdf', 'salt_and_pepper_noise.pdf'
    ]
    output_pdf = 'merged_grid.pdf'
    merge_pdfs_to_grid(pdf_files, output_pdf, grid=(2, 4), dpi=150)

Merged PDF saved as merged_grid.pdf
