In [None]:
import os
import timm
import torch
import torch.nn as nn
# Define some useful paths.
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
       '/storage'
))
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
%load_ext yamlmagic

# Part 1: Load our dataset and look at examples.

In [None]:
%%yaml train_dataset_cfg

task: "WBC/CV/EM/0"
# task: "ACDC/Challenge2017/MRI/2"
split: "train"
return_data_id: True
label_threshold: 0.5
slicing: "midslice"
resolution: 128 
label: 0

In [None]:
%%yaml val_dataset_cfg

task: "WBC/CV/EM/0"
# task: "ACDC/Challenge2017/MRI/2"
split: "val"
return_data_id: True
label_threshold: 0.5
slicing: "midslice"
resolution: 128 
label: 0

In [None]:
from ese.datasets import Segment2D

# Load the training and validation datasets
train_dataset = Segment2D(**train_dataset_cfg)
val_dataset = Segment2D(**val_dataset_cfg)

In [None]:
print("Num train examples:", len(train_dataset))
print("Num val examples:", len(val_dataset))

In [None]:
train_im = train_dataset[0]['img'][None]
val_im = val_dataset[0]['img'][None]

In [None]:
# Visualize the images
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(train_im.squeeze(), cmap='gray')
plt.title("Training Image")
plt.subplot(1, 2, 2)
plt.imshow(val_im.squeeze(), cmap='gray')
plt.title("Validation Image")
plt.show()

In [None]:
import torchvision.transforms as T

def preprocess_image(image_tensors, target_size, apply_clip_norm, num_channels=3):
    assert len(image_tensors.shape) == 4, "Input must be a 4D tensor"
    """
    Preprocess image tensors for CLIP.

    Args:
        image_tensors (torch.Tensor): Tensor of shape (B, C, H, W).
        target_size (int): Desired image size for CLIP.
        num_channels (int): Number of channels expected by CLIP (usually 3).

    Returns:
        torch.Tensor: Preprocessed image tensors.
    """
    # If images are grayscale, repeat the channel to make it 3-channel
    if image_tensors.shape[1] == 1 and num_channels == 3:
        image_tensors = image_tensors.repeat(1, 3, 1, 1)
    
    if apply_clip_norm:
        # Define the preprocessing pipeline
        preprocess_transform = T.Compose([
            T.Resize(target_size, interpolation=T.InterpolationMode.BICUBIC),
            T.CenterCrop(target_size),
            T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                        std=(0.26862954, 0.26130258, 0.27577711))
        ])
    else:
        # Define the preprocessing pipeline
        preprocess_transform = T.Compose([
            T.Resize(target_size, interpolation=T.InterpolationMode.BICUBIC),
        ])
    
    # Apply preprocessing
    return preprocess_transform(image_tensors)

In [None]:
# Visualize the images
import matplotlib.pyplot as plt

# Preprocess the images before we pass into our embedding module.
proc_train_im = preprocess_image(train_im, apply_clip_norm=True, target_size=224)
proc_val_im = preprocess_image(val_im, apply_clip_norm=True, target_size=224)

# move the channel dim to the end
vis_train_im = proc_train_im.squeeze().permute(1, 2, 0)
vis_val_im = proc_val_im.squeeze().permute(1, 2, 0)
# Print the range of pixel values
print("Training Image Pixel Range:", vis_train_im.min(), vis_train_im.max())
print("Validation Image Pixel Range:", vis_val_im.min(), vis_val_im.max())
# visualize the processed images
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(vis_train_im)
plt.title("Training Image")
plt.subplot(1, 2, 2)
plt.imshow(vis_val_im)
plt.title("Validation Image")
plt.colorbar()
plt.show()

# Part 2: Load and prepare the functions for the embedding model.

In [None]:
import clip 
from typing import Any, Optional


def load_model(model_type: str, device: Optional[Any] = None) -> nn.Module:
    """
    Loads a pre-trained DINOv2 model from the timm library.
    """
    try:
        if 'dino' in model_type:
            print("Loading DINOv2 model...")
            # List available DINOv2 models
            available_models = timm.list_models('*dinov2*', pretrained=True)
            if not available_models:
                raise ValueError("No DINOv2 models found in timm. Please ensure timm is updated and supports DINOv2.")
            
            # Select a specific DINOv2 model, e.g., 'dinov2_vitl14'
            model = timm.create_model('vit_large_patch14_dinov2.lvd142m', pretrained=True)
            model.eval()  # Set model to evaluation mode
            if device is not None:
                model.to(device)
            return model
        elif 'clip' in model_type:
            print("Loading CLIP model...")
            model, _ = clip.load("ViT-B/32", device=device)
            return model
        else:
            raise ValueError(f"Model type {model_type} not supported.")
    except Exception as e:
        print(f"Error loading model {model_type}: {e}")
        raise

def get_dino_embedding(model, image_tensor, device):
    """
    Passes the image tensor through the model to obtain the embedding.
    """
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        # Forward pass
        # Depending on the model architecture, you might need to access a specific layer.
        # For simplicity, we'll assume the model's forward method returns the desired embedding.
        embedding = model.forward_features(image_tensor)
        # Flatten the embedding if necessary
        embedding = embedding.view(embedding.size(0), -1)
        # Normalize the embedding
        embedding = nn.functional.normalize(embedding, p=2, dim=1)
        return embedding.cpu()

def get_clip_embedding(model, image_tensor, device):
    """
    Passes the image tensor through the model to obtain the embedding.
    """
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        # Forward pass
        # Depending on the model architecture, you might need to access a specific layer.
        # For simplicity, we'll assume the model's forward method returns the desired embedding.
        embedding = model.encode_image(image_tensor)
        # Normalize the embedding
        embedding /= embedding.norm(dim=-1, keepdim=True)
        return embedding.cpu()

In [None]:
# Check for CUDA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load model
# model_type = "dino"
model_type = "clip"

# Load the model
model = load_model(model_type, device=device)

In [None]:
# We want to build dictionaries that map from the image ID to the image
train_dataid_to_im = {train_dataset[i]['data_id']: train_dataset[i]['img'][None] for i in range(len(train_dataset))}
val_dataid_to_im = {val_dataset[i]['data_id']: val_dataset[i]['img'][None] for i in range(len(val_dataset))}

In [None]:
import pandas as pd
from tqdm import tqdm

embedding_fn = get_dino_embedding if model_type == "dino" else get_clip_embedding

records = []
for val_data_id, val_image in tqdm(val_dataid_to_im.items(), desc='Computing similiarity over validation images'):
    # Compute the embedding
    if model_type == "dino":
        proc_val_im = preprocess_image(val_image, apply_clip_norm=False, target_size=518)
    elif model_type == "clip":
        proc_val_im = preprocess_image(val_image, apply_clip_norm=True, target_size=224)
    # Compute the similarity between the validation image and each training image
    embedding_val = embedding_fn(model, proc_val_im, device)
    # Compute the similarity between the validation image and each training image
    for train_data_id, train_image in tqdm(train_dataid_to_im.items(), desc='Computing similarity over training images', leave=False):
        # Compute the embedding
        if model_type == "dino":
            proc_train_im = preprocess_image(train_image, apply_clip_norm=False, target_size=518)
        elif model_type == "clip":
            proc_train_im = preprocess_image(train_image, apply_clip_norm=True, target_size=224)
        # Compute the embedding on training
        embedding_train = embedding_fn(model, proc_train_im, device)
        # Compute the cosine similarity between the embeddings
        similarity = torch.nn.functional.cosine_similarity(embedding_val, embedding_train).item()
        # Store the similarity in a dictionary
        records.append({
            'val_data_id': val_data_id,
            'train_data_id': train_data_id,
            'similarity': similarity
        })

# Convert the records to a DataFrame
similarity_df = pd.DataFrame(records)

In [None]:
similarity_df

In [None]:
num_train_images = 8
# For each validation image, find the 5 most similar training images and plot them side by side
for val_id in similarity_df['val_data_id'].unique():
    val_id_df = similarity_df[similarity_df['val_data_id'] == val_id]
    # Get the 5 most similar training images by the similarity score
    most_similar_train_ids = val_id_df.sort_values(by='similarity', ascending=False)['train_data_id'].values[:num_train_images]
    # Get the validation image
    val_image = val_dataid_to_im[val_id].squeeze()
    # Get the most similar training images
    train_images = [train_dataid_to_im[train_id].squeeze() for train_id in most_similar_train_ids]
    # Plot the images
    plt.figure(figsize=(3*(num_train_images + 1), 3))
    plt.subplot(1, num_train_images + 1, 1)
    plt.imshow(val_image, cmap='gray')
    plt.title("Validation Image")
    for i, train_image in enumerate(train_images):
        plt.subplot(1, num_train_images + 1, i+2)
        plt.imshow(train_image, cmap='gray')
        plt.title(f"Train Image {i+1}: Sim {val_id_df[val_id_df['train_data_id'] == most_similar_train_ids[i]]['similarity'].values[0]:.2f}")
    plt.show()