## Test CLIP

In [None]:
###############################################
#               Testing the model             #
###############################################

import os
import clip
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import re
from typing import List
from clip.simple_tokenizer import SimpleTokenizer
import torchvision.transforms as T
from matplotlib.backends.backend_pdf import PdfPages
import random
import matplotlib.pyplot as plt

clip_tokenizer = SimpleTokenizer()

#############################################
#                Dataset Class              #
#############################################

class SegmentationDatasetWithText(Dataset):
    def __init__(self, root_dir, transform_img, transform_label=None):
        self.root_dir = root_dir
        self.transform_img = transform_img
        self.transform_label = transform_label

        # Assume images are in `root_dir/color` and masks are in `root_dir/masks`
        self.image_dir = os.path.join(root_dir, 'color')
        self.mask_dir = os.path.join(root_dir, 'label')

        # Collect all image filenames
        self.image_paths = sorted([
            fname for fname in os.listdir(self.image_dir)
            if fname.lower().endswith(('.jpg', '.png', '.jpeg'))
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image_filename = self.image_paths[index]
        image_path = os.path.join(self.image_dir, image_filename)
        mask_path = os.path.join(self.mask_dir, image_filename.replace('.jpg', '.png'))

        # Load the image and apply image transform.
        image = Image.open(image_path).convert("RGB")
        if self.transform_img:
            image = self.transform_img(image)

        # Load the mask, convert to numpy, then to tensor.
        mask = Image.open(mask_path).convert("L")
        mask = torch.from_numpy(np.array(mask, dtype=np.int64))

        # If a label transform is provided, first unsqueeze to add a channel dimension.
        if self.transform_label:
            if mask.ndim == 2:  # mask shape [H, W]
                mask = mask.unsqueeze(0)  # now shape [1, H, W]
            mask = self.transform_label(mask)
            mask = mask.squeeze(0)  # back to shape [H, W]

        # Generate a text prompt.
        name, _ = os.path.splitext(image_filename)
        match = re.match(r"^(cat|dog)_([A-Za-z_]+)_(\d+)", name)
        if match:
            animal_type = match.group(1)           # 'cat' or 'dog'
            breed = match.group(2).replace("_", " ") # Replace underscores with spaces.
            # Optionally, you can title-case the breed:
            breed = breed.title()
            text_prompt = f"a photo of a {breed} {animal_type}"
        else:
            # Provide a default text prompt if the regex does not match
            print(f"Warning: Filename {image_filename} does not match expected format.")
            text_prompt = "a photo of an animal"

        token_ids = clip.tokenize([text_prompt]).squeeze(0)

        decoded_text = clip_tokenizer.decode(token_ids.tolist())
        #print(f"Original: {text_prompt} → Tokenized: {decoded_text}")

        return image, mask, token_ids

class EncodeMask:
    def __call__(self, mask):
        # Do not convert to float; work with the mask as is (which should be Long)
        mask = torch.where(mask < 36, torch.tensor(0, dtype=mask.dtype, device=mask.device), mask)
        mask = torch.where((mask >= 36) & (mask < 192), torch.tensor(1, dtype=mask.dtype, device=mask.device), mask)
        mask = torch.where(mask >= 192, torch.tensor(2, dtype=mask.dtype, device=mask.device), mask)
        mask = torch.clamp(mask, min=0, max=2)
        # print("Mask dtype:", mask.dtype)  # Should be torch.int64
        # print("Mask unique values:", torch.unique(mask))  # Should only contain [0, 1, 2]
        return mask  # Already Long

def denormalize(image_tensor):
    """Convert a normalized tensor image back to a standard (0,1) range for visualization."""
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(3, 1, 1)
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(3, 1, 1)
    image_tensor = image_tensor * std + mean  # Undo normalization
    image_tensor = torch.clamp(image_tensor, 0, 1)  # Ensure values are valid
    return image_tensor

#############################################
#             Segmentation Head             #
#############################################

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction, in_channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class ImprovedCLIPSegmentationHead(nn.Module):
    def __init__(self, in_channels, text_dim, num_classes, dropout_prob=0.25, use_attention=True):
        super(ImprovedCLIPSegmentationHead, self).__init__()
        print(f'text_dim: {text_dim}')
        print(f'in_channels: {in_channels}')
        self.text_proj = nn.Linear(text_dim, in_channels)
        self.fuse_conv1 = nn.Sequential(
            nn.Conv2d(in_channels * 2, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_prob)
        )
        self.residual_block = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256)
        )
        self.relu = nn.ReLU(inplace=True)
        if use_attention:
            self.attention = ChannelAttention(256)
        else:
            self.attention = nn.Identity()
        self.fuse_conv2 = nn.Conv2d(256, num_classes, kernel_size=1)

    def forward(self, img_features, text_features):
        B, C, H, W = img_features.shape  # should be [B, 768, H, W]

        # Ensure both inputs are float32 to avoid mismatched precision issues
        img_features = img_features.float()
        text_features = text_features.float()

        projected_text = self.text_proj(text_features)  # => [B, 768]
        projected_text = projected_text.view(B, C, 1, 1)  # => [B, 768, 1, 1]
        projected_text = projected_text.expand(B, C, H, W)  # => [B, 768, H, W]
        fused = torch.cat([img_features, projected_text], dim=1)
        x = self.fuse_conv1(fused)
        res = self.residual_block(x)
        x = self.relu(x + res)
        x = self.attention(x)
        logits = self.fuse_conv2(x)
        return logits

#############################################
#          CLIP Segmentation Model          #
#############################################

class CLIPSegmentationModel(nn.Module):
    def __init__(self, clip_model, num_classes=3):
        super(CLIPSegmentationModel, self).__init__()
        # CLIP ViT-L/14 usually outputs 768-dim features
        self.feature_dim = clip_model.visual.output_dim  # should be 768
        print(f'feature dim: {self.feature_dim}')        # prints 768

        self.clip_model = clip_model

        # Both in_channels and text_dim are 768
        self.seg_head = ImprovedCLIPSegmentationHead(
            in_channels=768,  # matches self.feature_dim
            text_dim=768,     # text encoder dimension
            num_classes=num_classes,
            dropout_prob=0.25
        )

    def get_visual_features(self, image):
        visual = self.clip_model.visual
        # Ensure input image has the same dtype as model weights (e.g., fp16)
        image = image.to(dtype=visual.conv1.weight.dtype)
        # Step 1: Convolution.
        x = visual.conv1(image)  # [B, width, H', W']
        # Flatten the spatial dimensions.
        x = x.reshape(x.shape[0], x.shape[1], -1)  # [B, width, tokens]
        x = x.permute(0, 2, 1)  # [B, tokens, width]
        # Step 2: Prepend the class token.
        cls_tokens = visual.class_embedding.to(x.dtype) + torch.zeros(
            x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
        )
        x = torch.cat([cls_tokens, x], dim=1)  # [B, tokens+1, width]
        # Step 3: Add positional embedding.
        x = x + visual.positional_embedding.to(x.dtype)
        # Step 4: Layer norm pre-transformer.
        x = visual.ln_pre(x)
        # Step 5: Run the transformer.
        x = x.permute(1, 0, 2)  # [tokens+1, B, width]
        x = visual.transformer(x)
        x = x.permute(1, 0, 2)  # [B, tokens+1, transformer_width]
        # **Apply projection (if available) to get final visual features.**
        if hasattr(visual, "proj"):
            x = x @ visual.proj  # now x is [B, tokens+1, output_dim] (e.g. [B, tokens+1, 768])
        return x

    def forward(self, image, token_ids):
        # Convert input images to float32
        image = image.float()

        visual_features = self.get_visual_features(image)  # [B, num_tokens+1, 768]

        tokens = visual_features[:, 1:, :]  # remove CLS token => [B, num_tokens, 768]
        B, N, D = tokens.shape  # D should be 768
        grid_size = int(np.sqrt(N))

        image_features = tokens.reshape(B, grid_size, grid_size, 768).permute(0, 3, 1, 2)  # [B, 768, H, W]

        # Convert text features to float32 before passing
        text_features = self.clip_model.encode_text(token_ids).float()  # Ensure float32

        seg_logits = self.seg_head(image_features, text_features)

        seg_logits = F.interpolate(seg_logits, size=(image.shape[2], image.shape[3]),
                                  mode="bilinear", align_corners=False)
        return seg_logits

###############################################
# Training, Evaluation, and Utility Functions #
###############################################
def compute_iou_per_class(preds, targets, num_classes=3):
    """
    Compute Intersection over Union (IoU) for each class.
    Returns a dictionary with IoU for each class.
    """
    iou_per_class = {}

    for cls in range(num_classes):
        pred_inds = (preds == cls)
        target_inds = (targets == cls)
        intersection = (pred_inds & target_inds).sum().item()
        union = (pred_inds | target_inds).sum().item()

        if union == 0:
            iou_per_class[cls] = float('nan')  # Avoid division by zero
        else:
            iou_per_class[cls] = intersection / union

    return iou_per_class

In [None]:
from skimage.util import random_noise

def add_salt_and_pepper_noise(image, amount):
    """
    image: numpy array of shape (H, W, C) with dtype=np.uint8, values in 0-255.
    amount: float value representing the proportion of pixels to be replaced with noise.
            For example: 0.00, 0.02, 0.04, 0.06, 0.08, 0.10, 0.12, 0.14, 0.16, 0.18.
    Returns:
        perturbed_image: numpy array (uint8, values in 0-255) with salt and pepper noise added.
    """
    # Convert image to float in range [0,1] for skimage
    image_float = image.astype(np.float32) / 255.0
    noisy = random_noise(image_float, mode='s&p', amount=amount)
    # Convert back to [0,255] uint8
    noisy = np.clip(noisy * 255.0, 0, 255).astype(np.uint8)
    return noisy

def occlude_image(image, square_edge):
    """
    image: numpy array of shape (H, W, C) with dtype=np.uint8, values in 0-255.
    square_edge: integer, the edge length of the square to occlude.
                 For example: 0, 5, 10, 15, 20, 25, 30, 35, 40, 45.
    Returns:
        perturbed_image: numpy array (uint8, values in 0-255) with a randomly placed square region occluded.
    """
    perturbed = image.copy()
    if square_edge > 0:
        h, w, _ = perturbed.shape
        # Ensure the square fits within the image
        max_x = w - square_edge
        max_y = h - square_edge
        if max_x < 0 or max_y < 0:
            # If the square is larger than the image, occlude the whole image
            perturbed[:] = 0
        else:
            # Randomly select the top-left corner for occlusion.
            x = np.random.randint(0, max_x + 1)
            y = np.random.randint(0, max_y + 1)
            perturbed[y:y+square_edge, x:x+square_edge, :] = 0
    return perturbed

def decrease_brightness(image, offset):
    """
    image: numpy array of shape (H, W, C) with dtype=np.uint8, values in 0-255.
    offset: integer offset to subtract from each pixel. For example: 0, 5, 10, 15, 20, 25, 30, 35, 40, 45.
    Returns:
        perturbed_image: numpy array (uint8, values in 0-255) with decreased brightness.
    """
    # Convert to a type that supports negative values.
    perturbed = image.astype(np.int32) - offset
    # Clip values so that they do not fall below 0.
    perturbed = np.clip(perturbed, 0, 255).astype(np.uint8)
    return perturbed

def increase_brightness(image, offset):
    """
    image: numpy array of shape (H, W, C) with dtype=np.uint8, values in 0-255.
    offset: integer offset to add to each pixel. For example: 0, 5, 10, 15, 20, 25, 30, 35, 40, 45.
    Returns:
        perturbed_image: numpy array (uint8, values in 0-255) with increased brightness.
    """
    # Convert to an integer type that can hold values > 255
    perturbed = image.astype(np.int32) + offset
    # Clip values to the valid range and convert back to uint8
    perturbed = np.clip(perturbed, 0, 255).astype(np.uint8)
    return perturbed

def decrease_contrast(image, factor):
    """
    image: numpy array of shape (H, W, C) with dtype=np.uint8, values in 0-255.
    factor: multiplicative factor for pixel values.
            For example: 1.0, 0.95, 0.90, 0.85, 0.80, 0.60, 0.40, 0.30, 0.20, 0.10.
    Returns:
        perturbed_image: numpy array (uint8, values in 0-255) with decreased contrast.
    """
    perturbed = image.astype(np.float32) * factor
    perturbed = np.clip(perturbed, 0, 255).astype(np.uint8)
    return perturbed

def increase_contrast(image, factor):
    """
    image: numpy array of shape (H, W, C) with dtype=np.uint8, values in 0-255.
    factor: multiplicative factor for pixel values.
            For example: 1.0, 1.01, 1.02, 1.03, 1.04, 1.05, 1.1, 1.15, 1.2, 1.25.
    Returns:
        perturbed_image: numpy array (uint8, values in 0-255) with increased contrast.
    """
    perturbed = image.astype(np.float32) * factor
    perturbed = np.clip(perturbed, 0, 255).astype(np.uint8)
    return perturbed

def add_gaussian_noise(image, sigma):
    """
    image: numpy array of shape (H, W, C) with dtype=np.uint8, values in 0-255.
    sigma: standard deviation of the Gaussian noise.
    Returns:
        perturbed_image: numpy array (uint8, values in 0-255)
    """
    noise = np.random.normal(0, sigma, image.shape)
    perturbed = image.astype(np.float32) + noise
    perturbed = np.clip(perturbed, 0, 255).astype(np.uint8)
    return perturbed

def add_gaussian_blur(image, iterations):
    """
    image: numpy array of shape (H, W, C) with dtype=np.uint8, values in 0-255.
    iterations: integer, the number of times to convolve the image with a 3x3 Gaussian kernel.
                Use 0 for no blurring, 1 for a single pass, up to 9 for a heavy blur.
    Returns:
        perturbed_image: numpy array (uint8, values in 0-255) that has been blurred.
    """
    import cv2  # OpenCV is used for the convolution.
    # Define a 3x3 Gaussian kernel.
    kernel = (1/16) * np.array([[1, 2, 1],
                                [2, 4, 2],
                                [1, 2, 1]], dtype=np.float32)
    
    perturbed = image.copy()
    for _ in range(iterations):
        perturbed = cv2.filter2D(perturbed, ddepth=-1, kernel=kernel)
    return perturbed


def dice_score_multiclass(pred, target, num_classes=3, epsilon=1e-6):
    """
    Compute the mean Dice score over num_classes.
    pred and target should be 2D tensors of shape [H, W] containing class indices.
    """
    dice_per_class = []
    for cls in range(num_classes):
        pred_cls = (pred == cls).float()
        target_cls = (target == cls).float()
        intersection = (pred_cls * target_cls).sum()
        union = pred_cls.sum() + target_cls.sum()
        dice_cls = (2 * intersection + epsilon) / (union + epsilon)
        dice_per_class.append(dice_cls.item())
    return np.mean(dice_per_class)

def test_best_model_on_perturbations(best_model_path, test_dataset, batch_size=1, num_classes=3, perturbation_mode='gaussian_noise'):
    """
    Evaluates the best segmentation model on perturbed test data.
    
    For perturbation_mode 'gaussian_noise', the function will:
      - Loop over 10 sigma levels: {0, 2, 4, …, 18}.
      - For each sigma level, it denormalizes the test image to [0,255], applies Gaussian noise,
        clips the result to [0,255], re-normalizes using CLIP's mean/std, and then runs the model.
      
    For perturbation_mode 'gaussian_blur', it:
      - Loops over 10 iteration levels: {0, 1, 2, …, 9}.
      - For each level, it applies the 3×3 Gaussian kernel that many times.
    
    In both cases, the function computes the mean multi-class Dice score over the test set,
    plots Dice score vs. perturbation level, saves the figure, and shows a few example perturbed images.
    """
    import matplotlib.pyplot as plt
    from torch.utils.data import DataLoader

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss().to(device)

    # Load CLIP model and set to evaluation.
    clip_model, _ = clip.load("ViT-L/14", device=device)
    clip_model.eval()
    for param in clip_model.parameters():
        param.requires_grad = False

    # Load segmentation model.
    model = CLIPSegmentationModel(clip_model=clip_model, num_classes=num_classes).to(device)
    state_dict = torch.load(best_model_path, map_location=device)
    model.load_state_dict(state_dict)
    model.eval()

    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

    # Define CLIP normalization parameters.
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(3, 1, 1)
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(3, 1, 1)

    # Set up perturbation parameters.
    if perturbation_mode == 'gaussian_noise':
        levels = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
        perturb_fn = add_gaussian_noise
    elif perturbation_mode == 'gaussian_blur':
        levels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        perturb_fn = add_gaussian_blur
    elif perturbation_mode == 'contrast_increase':
        levels = [1.0, 1.01, 1.02, 1.03, 1.04, 1.05, 1.1, 1.15, 1.2, 1.25]
        perturb_fn = increase_contrast
    elif perturbation_mode == 'contrast_decrease':
        levels = [1.0, 0.95, 0.90, 0.85, 0.80, 0.60, 0.40, 0.30, 0.20, 0.10]
        perturb_fn = decrease_contrast
    elif perturbation_mode == 'brightness_increase':
        levels = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45]
        perturb_fn = increase_brightness
    elif perturbation_mode == 'brightness_decrease':
        levels = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45]
        perturb_fn = decrease_brightness
    elif perturbation_mode == 'occlusion_increase':
        levels = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45]
        perturb_fn = occlude_image
    elif perturbation_mode == 'salt_and_pepper_noise':
        levels = [0.00, 0.02, 0.04, 0.06, 0.08, 0.10, 0.12, 0.14, 0.16, 0.18]
        perturb_fn = add_salt_and_pepper_noise
    else:
        raise ValueError(f"Perturbation mode '{perturbation_mode}' not implemented.")

    dice_scores_per_level = []

    # Loop over each perturbation level.
    for level in levels:
        level_dice_scores = []
        for images, masks, token_ids in test_loader:
            # Denormalize the image to get pixel values in [0,1], then scale to [0,255].
            images_denorm = denormalize(images)
            images_pixels = (images_denorm * 255.0).clamp(0, 255).cpu().numpy().astype(np.uint8)
            img_np = np.transpose(images_pixels[0], (1, 2, 0))

            # Apply the selected perturbation.
            perturbed_np = perturb_fn(img_np, level)

            # Convert the perturbed image back to a tensor.
            perturbed_tensor = torch.from_numpy(perturbed_np).permute(2, 0, 1).unsqueeze(0).float()
            perturbed_tensor = perturbed_tensor / 255.0
            perturbed_tensor = (perturbed_tensor - mean) / std
            perturbed_tensor = perturbed_tensor.to(device)

            token_ids = token_ids.to(device)
            masks = masks.to(device)

            with torch.no_grad():
                output = model(perturbed_tensor, token_ids)
            # Get predicted class mask.
            pred_mask = torch.argmax(output, dim=1).squeeze(0)

            # Compute multi-class Dice score.
            dice = dice_score_multiclass(pred_mask, masks, num_classes=num_classes)
            level_dice_scores.append(dice)
        mean_dice = np.mean(level_dice_scores)
        dice_scores_per_level.append(mean_dice)
        print(f"Perturbation Level {level} ({perturbation_mode}): Mean Dice Score = {mean_dice:.4f}")

    # Plot the mean Dice score vs. perturbation level.
    plt.figure(figsize=(8,6))
    plt.plot(levels, dice_scores_per_level, marker='o')
    plt.xlabel(f"{perturbation_mode} level")
    plt.ylabel("Mean Dice Score")
    plt.title(f"Robustness Evaluation: Dice Score vs {perturbation_mode}")
    plt.grid(True)
    plt.savefig(f"{perturbation_mode}.pdf", format="pdf", bbox_inches="tight")
    plt.show()

    '''# Display example perturbations for a few selected levels (e.g., low, mid, high).
    example_levels = [levels[0], levels[len(levels)//2], levels[-1]]
    example_image, _, _ = next(iter(test_loader))
    example_denorm = denormalize(example_image) * 255.0  # [0,255] values
    example_np = np.transpose(example_denorm[0].cpu().numpy().astype(np.uint8), (1, 2, 0))

    plt.figure(figsize=(15, 3))
    for i, lev in enumerate(example_levels):
        if perturbation_mode == 'gaussian_noise':
            perturbed_example = perturb_fn(example_np, lev)
        elif perturbation_mode == 'gaussian_blur':
            perturbed_example = perturb_fn(example_np, lev)
        plt.subplot(1, len(example_levels), i+1)
        plt.imshow(perturbed_example)
        plt.title(f"Level {lev}")
        plt.axis("off")
    plt.suptitle(f"Examples of {perturbation_mode} perturbations")
    plt.show()'''

    return dice_scores_per_level

# Example usage in __main__:
if __name__ == "__main__":
    # Load CLIP preprocessing transform.
    clip_model, clip_preprocess = clip.load("ViT-L/14", device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    test_transform = clip_preprocess  # Use the same image preprocessing as for training.

    # Define a label transform that resizes the mask and encodes its values properly.
    transform_label = T.Compose([
        T.Resize((224, 224), interpolation=T.InterpolationMode.NEAREST),
        EncodeMask()
    ])

    test_dataset = SegmentationDatasetWithText(
        root_dir="./processed/Test/",
        transform_img=test_transform,
        transform_label=transform_label
    )

    best_model_path = "./clip_OPENAI_segmentation_best.pth"

    dice_scores = test_best_model_on_perturbations(best_model_path, test_dataset, perturbation_mode='gaussian_noise')
    dice_scores = test_best_model_on_perturbations(best_model_path, test_dataset, perturbation_mode='gaussian_blur')
    dice_scores = test_best_model_on_perturbations(best_model_path, test_dataset, perturbation_mode='contrast_increase')
    dice_scores = test_best_model_on_perturbations(best_model_path, test_dataset, perturbation_mode='contrast_decrease')
    dice_scores = test_best_model_on_perturbations(best_model_path, test_dataset, perturbation_mode='brightness_increase')
    dice_scores = test_best_model_on_perturbations(best_model_path, test_dataset, perturbation_mode='brightness_decrease')
    dice_scores = test_best_model_on_perturbations(best_model_path, test_dataset, perturbation_mode='occlusion_increase')
    dice_scores = test_best_model_on_perturbations(best_model_path, test_dataset, perturbation_mode='salt_and_pepper_noise')