# CLIP-DINOiser visualization demo 🖼️

In [None]:
from hydra.core.global_hydra import GlobalHydra
import os
from models.builder import build_model
from helpers.visualization import mask2rgb
from segmentation.datasets import PascalVOCDataset
from hydra import compose, initialize
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms as T
import torch
import warnings
warnings.filterwarnings('ignore')
GlobalHydra.instance().clear()
initialize(config_path="configs", version_base=None)

def visualize_per_image(file_path, support_files, palette, model, class_names):
    # Assert that the main image exists
    assert os.path.isfile(file_path), f"No such file: {file_path}"
    
    # print(f"Dataset classes: {dataset_classes}")

    # Open and preprocess the main image
    img = Image.open(file_path).convert('RGB')
    img_tens = T.PILToTensor()(img).unsqueeze(0).to(device) / 255.

    # Load and preprocess support images
    support_images = []
    for support_file in support_files:
        assert os.path.isfile(support_file), f"No such file: {support_file}"
        support_img = Image.open(support_file).convert('RGB')

        # Resize the support images to match the size of the main image
        support_img_resized = support_img.resize(img.size, Image.BILINEAR)

        # Convert to tensor and normalize
        support_img_tens = T.PILToTensor()(support_img_resized).unsqueeze(0).to(device) / 255.
        support_images.append(support_img_tens)
    
    # Stack the support images into a batch
    support_images = torch.cat(support_images, dim=0).to(device)  
    # for spt in support_images:
    #     print(spt.shape)
    
    # Get the original height and width of the image
    h, w = img_tens.shape[-2:]
    merged = torch.cat((img_tens, support_images), dim=0)
    # print(f'merged: {merged.shape}')
    
    # Run the model for segmentation using both the main image and the support images
    output = model(merged).cpu()  # Ensure your model can accept both inputs
    output = F.interpolate(output, scale_factor=model.vit_patch_size, mode="bilinear", align_corners=False)[..., :h, :w]
    
    # Visualizza le probabilità per ogni classe
    # output[0] contiene le probabilità per ciascun pixel e ciascuna classe
    # output[0].shape sarà [C, H, W], dove C è il numero di classi
    
    # Itera su tutte le classi per stampare le probabilità
    # C, H, W = output[0].shape
    # for class_idx in range(C):
    #     class_probs = output[0][class_idx]
    #     print(f"Classe {class_idx}:")
    #     print(f"Probabilità media per classe {class_idx}: {class_probs.mean().item()}")
    #     print("-" * 50)
    
    output = output[0].argmax(dim=0)  # Get the most likely class for each pixel
    
    # Convert the output to an RGB mask using the provided palette
    mask = mask2rgb(output, palette)

    # Extract unique class indices from the output mask and map to the palette
    detected_classes = np.unique(output).tolist()
    # print(f"Detected classes (indices): {detected_classes}")

    # Filter out the classes that were not detected and print their names
    detected_class_names = [class_names[idx] for idx in detected_classes if idx < len(class_names)]
    # print(f"Detected class names: {detected_class_names}")

    # Create the visualization: segmented mask and original image
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
    
    # Blending the original image with the mask for visualization
    alpha = 0.5
    blend = (alpha) * np.array(img) / 255. + (1 - alpha) * mask / 255.
    ax[0].imshow(blend)
    ax[0].axis('off')
    
    ax[1].imshow(mask)
    ax[1].axis('off')
    
    # Visualization of class colors along with support image file names as labels
    class_colors = np.array([palette[class_idx] for class_idx in detected_classes if class_idx < len(palette)])
    plt.figure(figsize=(6, 1))
    plt.imshow(class_colors.reshape(1, -1, 3))
    plt.xticks(np.arange(len(detected_class_names)), detected_class_names, rotation=45)
    plt.yticks([])

    return mask, fig, img

check_path = './checkpoints/last.pt'
check = torch.load(check_path, map_location='cpu')
dinoclip_cfg = "clip_dinoiser.yaml"
cfg = compose(config_name=dinoclip_cfg)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = build_model(cfg.model, class_names=PascalVOCDataset.CLASSES).to(device)
model.clip_backbone.decode_head.use_templates = False  # switching off the imagenet templates for fast inference
model.load_state_dict(check['model_state_dict'], strict=False)
model = model.eval()
# TEST WITH TWO SUPPORT IMAGES
file = 'assets/airplane.jpg'
support_files = ['assets/air2.jpg']  # Two support images

PALETTE = [(0, 0, 0), (156, 143, 189), (79, 158, 101)]

# Run segmentation with two support images, no text prompts required
model.apply_found = True  # assuming this flag is still relevant for your setup

# Lista dei nomi delle classi (questo esempio è per 3 classi)
class_names = ['background', 'aeroplane', 'bicycle']
# Run segmentation
mask, ticks, img = visualize_per_image(file, support_files, PALETTE, model, class_names)

# Evaluation

In [19]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from datasets_custom import PascalVOCDataset # Import the PascalVOCDataset class

dinoclip_cfg = "clip_dinoiser.yaml"
cfg = compose(config_name=dinoclip_cfg)

def calculate_confusion_matrix(pred, target, num_classes):
    """
    Calculate the confusion matrix for a single batch.

    Args:
        pred (Tensor): Predicted segmentation map.
        target (Tensor): Ground truth segmentation map.
        num_classes (int): Number of classes.

    Returns:
        np.array: Confusion matrix for the batch.
    """
    pred = pred.view(-1)
    target = target.view(-1).long()  # Convert target to integers
    mask = (target >= 0) & (target < num_classes)
    hist = np.bincount(
        num_classes * target[mask].cpu().numpy() + pred[mask].cpu().numpy(),
        minlength=num_classes ** 2
    ).reshape(num_classes, num_classes)
    return hist

def evaluate_model_with_support(model, images_per_class, num_classes):
    """
    Evaluate the model on the given images per class and calculate the mean IoU.

    Args:
        model (torch.nn.Module): The model to evaluate.
        images_per_class (dict): Dictionary of images for each class.
        num_classes (int): Number of classes.

    Returns:
        float: Mean IoU across all images.
    """
    model.eval()
    confusion_matrix = np.zeros((num_classes, num_classes))
    with torch.no_grad():
        for class_name, images in images_per_class.items():
            if class_name != 'aeroplane' or len(images) < 3:
                continue  # Skip if not aeroplane or fewer than 3 images

            for i in range(0, len(images) - 2, 3):
                main_image = images[i]
                support_images = images[i+1:i+3]

                # Resize images to the same size
                main_image = F.interpolate(main_image, size=(512, 512), mode='bilinear', align_corners=False)
                support_images = [F.interpolate(img, size=(512, 512), mode='bilinear', align_corners=False) for img in support_images]

                # Preprocess images
                main_image = main_image.to(device)
                support_images = torch.cat(support_images, dim=0).to(device)

                # Concatenate main image and support images
                merged = torch.cat((main_image, support_images), dim=0)

                # Forward pass through the model
                outputs = model(merged)
                outputs = F.interpolate(outputs, size=main_image.shape[-2:], mode='bilinear', align_corners=False)
                preds = outputs.argmax(dim=1)
                confusion_matrix += calculate_confusion_matrix(preds, main_image, num_classes)
                print(f'Tested on class {class_name}, images {i+1} to {i+3}')

    # Calculate mIoU from the confusion matrix
    intersection = np.diag(confusion_matrix)
    union = np.sum(confusion_matrix, axis=0) + np.sum(confusion_matrix, axis=1) - intersection
    iou = intersection / np.maximum(union, 1)
    miou = np.nanmean(iou)
    return miou

def denormalize(img, mean, std):
    """
    Denormalize an image tensor.

    Args:
        img (Tensor): Normalized image tensor.
        mean (list): Mean values for each channel.
        std (list): Standard deviation values for each channel.

    Returns:
        Tensor: Denormalized image tensor.
    """
    mean = torch.tensor(mean).view(3, 1, 1)
    std = torch.tensor(std).view(3, 1, 1)
    img = img * std + mean
    return img

def show_aeroplane_images(dataset, class_name='aeroplane'):
    """
    Show all images of the specified class from the dataset.

    Args:
        dataset (PascalVOCDataset): The dataset to search.
        class_name (str): The class name to filter images by.
    """
    # Create a DataLoader for the dataset
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)

    # Get the index of the specified class
    classes = PascalVOCDataset.CLASSES
    class_idx = classes.index(class_name)

    # Mean and std used for normalization
    mean = [123.675, 116.28, 103.53]
    std = [58.395, 57.12, 57.375]
    count = 0
    images_per_class = {class_name: []}

    # Iterate over the dataset and filter images for the specified class
    for idx, data in enumerate(dataloader):
        images, targets = data
        if class_idx in targets.unique().tolist():
            count += 1
            images_per_class[class_name].append(images)
            # Denormalize the image
            # img = denormalize(images.squeeze(), mean, std).permute(1, 2, 0).cpu().numpy()
            # img = np.clip(img, 0, 255).astype(np.uint8)

            # Visualize the image
            # plt.imshow(img)
            # plt.title(f'Class: {class_name}')
            # plt.axis('off')
            # plt.show()

    return count, images_per_class

# Define the pipeline
pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=(512, 512), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]

# Path to the Pascal VOC dataset
img_dir = '/Users/micheleverriello/LabelAnything/data/pascal/JPEGImages'
ann_dir = '/Users/micheleverriello/LabelAnything/data/pascal/SegmentationClass'
split_file = '/Users/micheleverriello/LabelAnything/data/pascal/ImageSets/Segmentation/val.txt'

# Load the Pascal VOC dataset with a limit of 10 images
dataset = PascalVOCDataset(split=split_file, img_dir=img_dir, ann_dir=ann_dir, pipeline=pipeline, limit=2000)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)

count, images_per_class = show_aeroplane_images(dataset, class_name='aeroplane')
print(f'Number of images in the "aeroplane" class: {count}')

# Calculate the mIoU
num_classes = len(PascalVOCDataset.CLASSES)
miou = evaluate_model_with_support(model, images_per_class, num_classes)
print(f'mIoU: {miou}')

Number of images in the "aeroplane" class: 90
Tested on class aeroplane, images 1 to 3
Tested on class aeroplane, images 4 to 6
Tested on class aeroplane, images 7 to 9
Tested on class aeroplane, images 10 to 12
Tested on class aeroplane, images 13 to 15
Tested on class aeroplane, images 16 to 18
Tested on class aeroplane, images 19 to 21
Tested on class aeroplane, images 22 to 24
Tested on class aeroplane, images 25 to 27
Tested on class aeroplane, images 28 to 30
Tested on class aeroplane, images 31 to 33
Tested on class aeroplane, images 34 to 36
Tested on class aeroplane, images 37 to 39
Tested on class aeroplane, images 40 to 42
Tested on class aeroplane, images 43 to 45
Tested on class aeroplane, images 46 to 48
Tested on class aeroplane, images 49 to 51
Tested on class aeroplane, images 52 to 54
Tested on class aeroplane, images 55 to 57
Tested on class aeroplane, images 58 to 60
Tested on class aeroplane, images 61 to 63
Tested on class aeroplane, images 64 to 66
Tested on clas