## **Comparative Analysis of Pre-trained Models for Image Retrieval in Computational Pathology**


In this notebook, we will:
- Aggregate and compare evaluation results from four pre-trained models: PLIP, UNI, CLIP, DINO.
- Visualize and analyze the embeddings and retrieval results.
- Highlight "match sets" in t-SNE plots to assess how models cluster similar images.
- Draw conclusions about the relative performance of each model.

---

## **Table of Contents**
---

1. Import Libraries
2. Load Evaluation Data
3. Create Master Evaluation DataFrame
4. Visualize Model Performance
5. Visualize Embeddings Using t-SNE with Highlighted Matches
6. Compare Retrieval Results Across Models
7. Conclusion

---
### **Step 1: Import Libraries**

We begin by importing the necessary libraries.

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import seaborn as sns
from tqdm import tqdm
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.manifold import TSNE
import faiss
import random
import gc
import re

---
### **Step 2: Load Evaluation Data**

We'll load the embeddings and any evaluation data for each model.

In [None]:
# List of models
models = ['PLIP', 'UNI', 'DINO', 'CLIP']

# Directory paths
embeddings_dir = './'  # Adjust if embeddings are stored elsewhere
evaluation_dir = './'  # Adjust if evaluation results are stored elsewhere
combined_dataset_dir = './combined_dataset/'  # Directory containing images

# Dictionary to hold model data
model_data = {}

for model in models:
    # Load embeddings and image files
    embeddings = np.load(os.path.join(embeddings_dir, f'embeddings_{model.lower()}.npy'))
    image_files = np.load(os.path.join(embeddings_dir, f'image_files_{model.lower()}.npy'))
    
    # Load evaluation results
    evaluation_df = pd.read_csv(os.path.join(evaluation_dir, f'evaluation_results_{model.lower()}.csv'))
    
    # Load cluster labels if available
    cluster_labels = None
    cluster_labels_file = os.path.join(evaluation_dir, f'cluster_labels_{model.lower()}.npy')
    if os.path.exists(cluster_labels_file):
        cluster_labels = np.load(cluster_labels_file)
    
    # Store in the dictionary
    model_data[model] = {
        'embeddings': embeddings,
        'image_files': image_files,
        'evaluation_df': evaluation_df,
        'cluster_labels': cluster_labels
    }
    
    print(f'{model}:')
    print(f'  Embeddings shape: {embeddings.shape}')
    print(f'  Number of images: {len(image_files)}')
    print(f'  Evaluation results loaded with shape: {evaluation_df.shape}\n')

---
### **Step 3: Create Master Evaluation DataFrame**

We will merge the evaluation results from all models into a single DataFrame for comparison.

In [None]:
# Start with the list of queried images and total available matches from one model (assuming they are the same across models)
queried_images = model_data['PLIP']['evaluation_df']['Queried Image']
total_available_matches = model_data['PLIP']['evaluation_df']['Total Available Matches']

# Initialize the master DataFrame
master_df = pd.DataFrame({
    'Queried Image': queried_images,
    'Total Available Matches': total_available_matches
})

# Add match counts and accuracies from each model
for model in models:
    eval_df = model_data[model]['evaluation_df']
    master_df[f'Matches Found ({model})'] = eval_df['Number of Matches Found']
    master_df[f'Match Accuracy ({model})'] = eval_df['Match Accuracy']

# Display the master DataFrame
master_df

In [None]:
import matplotlib.pyplot as plt
import math

# Number of queried images to plot
queried_images = plot_data['Queried Image'].unique()
num_images = len(queried_images)

# Create subplots
num_cols = 3
num_rows = math.ceil(num_images / num_cols)
fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, num_rows * 5), sharey=True)

# Flatten axes for easier indexing
axes = axes.flatten()

# Plot each queried image in a separate subplot
for i, queried_image in enumerate(queried_images):
    subset = plot_data[plot_data['Queried Image'] == queried_image]
    axes[i].bar(subset['Model'], subset['Accuracy'], color='skyblue', alpha=0.8)
    axes[i].set_title(queried_image, fontsize=10)
    axes[i].set_xlabel('Models', fontsize=9)
    axes[i].set_ylabel('Accuracy', fontsize=9)
    axes[i].tick_params(axis='x', rotation=45)

# Remove extra subplots
for j in range(i + 1, len(axes)):
    fig.delaxes(axes[j])

# Adjust layout
plt.tight_layout()
plt.suptitle('Model Accuracies for Top 15 Queried Images', fontsize=16, y=1.02)
plt.show()

---
### **Step 4: Visualize Model Performance**

**4.1. Plot Heatmap of Match Accuracies**

We will visualize the match accuracies of different models for each queried image using a heatmap.

In [None]:
# Prepare data for heatmap
accuracy_columns = [f'Match Accuracy ({model})' for model in models]
heatmap_data = master_df[['Queried Image'] + accuracy_columns].set_index('Queried Image')

# Create a heatmap
plt.figure(figsize=(12, 8))
sns.heatmap(heatmap_data, annot=True, fmt=".2f", cmap='viridis')
plt.title('Match Accuracies Across Models and Queried Images')
plt.ylabel('Queried Image')
plt.xlabel('Model')
plt.show()

**4.2. Plot Box Plots of Match Accuracies**

We will use box plots to compare the distribution of match accuracies for each model.

In [None]:
# Prepare data for box plots
accuracy_data = pd.melt(master_df, id_vars=['Queried Image'], value_vars=accuracy_columns,
                        var_name='Model', value_name='Match Accuracy')
accuracy_data['Model'] = accuracy_data['Model'].str.extract(r'\((.*?)\)')

# Create box plots
plt.figure(figsize=(10, 6))
sns.boxplot(x='Model', y='Match Accuracy', data=accuracy_data)
plt.title('Distribution of Match Accuracies per Model')
plt.ylabel('Match Accuracy')
plt.xlabel('Model')
plt.show()

**4.3. Plot Line Plots of Match Accuracies**

We will plot match accuracies per model across queried images to observe trends.

In [None]:
# Sort master_df by total available matches
sorted_df = master_df.sort_values('Total Available Matches')

# Plot match accuracies
plt.figure(figsize=(14, 6))
for model in models:
    plt.plot(sorted_df['Queried Image'], sorted_df[f'Match Accuracy ({model})'], marker='o', label=model)
plt.xticks(rotation=90)
plt.xlabel('Queried Image')
plt.ylabel('Match Accuracy')
plt.title('Match Accuracy per Model Across Queried Images')
plt.legend()
plt.tight_layout()
plt.show()

---
### **Step 5:  Visualize Embeddings Using t-SNE with Highlighted Matches**

We will create t-SNE plots for each model, highlighting "match sets" within clusters.

**5.1. Prepare Data for t-SNE Visualization**

In [None]:
# Create a mapping of "match sets"
def get_match_set(image_name):
    # Extract the prefix before the underscore (e.g., '387' from '387_01.jpg')
    match = re.match(r'^(\d+)_\d+\.jpg$', image_name)
    if match:
        return match.group(1)
    else:
        return 'pubmed'

# For each model, prepare the DataFrame with embeddings and labels
for model in models:
    embeddings = model_data[model]['embeddings']
    image_files = model_data[model]['image_files']
    
    df = pd.DataFrame({
        'image_file': image_files,
        'embedding_index': range(len(image_files))
    })
    
    df['match_set'] = df['image_file'].apply(get_match_set)
    model_data[model]['df'] = df

**5.2. Compute t-SNE Embeddings**

In [None]:
# Compute t-SNE embeddings for each model
for model in models:
    print(f'Computing t-SNE for {model}...')
    embeddings = model_data[model]['embeddings']
    
    # Reduce dimensions to 2D
    tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=500)
    embeddings_2d = tsne.fit_transform(embeddings)
    model_data[model]['embeddings_2d'] = embeddings_2d

**5.3. Visualize t-SNE Plots with Highlighted Match Sets**

In [None]:
# First, perform K-Means clustering on the embeddings
num_clusters = 10  # Adjust based on your dataset size
for model in models:
    embeddings = model_data[model]['embeddings']
    kmeans = KMeans(n_clusters=num_clusters, random_state=42)
    cluster_labels = kmeans.fit_predict(embeddings)
    model_data[model]['cluster_labels'] = cluster_labels

# Updated plotting code
import matplotlib.patches as mpatches

def plot_tsne_with_clusters_and_matches(model):
    embeddings_2d = model_data[model]['embeddings_2d']
    df = model_data[model]['df']
    cluster_labels = model_data[model]['cluster_labels']

    plt.figure(figsize=(12, 10))

    # Plot all images colored by their cluster labels
    scatter = plt.scatter(
        embeddings_2d[:, 0], embeddings_2d[:, 1],
        c=cluster_labels, cmap='tab10', s=10, alpha=0.6
    )

    # Highlight 'matches' images
    matches_df = df[df['match_set'] != 'pubmed']
    match_sets = matches_df['match_set'].unique()

    # Assign different markers and edge colors to different match sets
    markers = ['o', 's', '^', 'D', 'v', 'P', '*', 'X', 'h', '+']  # Extend if needed
    edge_colors = ['black', 'blue', 'green', 'red', 'purple', 'orange', 'brown', 'pink', 'gray', 'olive']
    match_set_styles = {}
    for i, match_set in enumerate(match_sets):
        match_set_styles[match_set] = {
            'marker': markers[i % len(markers)],
            'edgecolor': edge_colors[i % len(edge_colors)]
        }

    # Plot 'matches' images with specific markers and edge colors
    for match_set in match_sets:
        indices = matches_df[matches_df['match_set'] == match_set].index
        plt.scatter(
            embeddings_2d[indices, 0], embeddings_2d[indices, 1],
            facecolors='none',
            edgecolors=match_set_styles[match_set]['edgecolor'],
            marker=match_set_styles[match_set]['marker'],
            s=80, linewidths=1.5,
            label=f'Match Set {match_set}'
        )

    plt.title(f'{model} Embeddings t-SNE Plot with Clusters and Highlighted Match Sets')
    plt.xticks([])
    plt.yticks([])

    # Create legend for match sets
    handles = []
    for match_set in match_sets:
        style = match_set_styles[match_set]
        marker = plt.Line2D(
            [], [], color=style['edgecolor'], marker=style['marker'],
            linestyle='None', markersize=10, markerfacecolor='none',
            label=f'Match Set {match_set}'
        )
        handles.append(marker)
    plt.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()

# Plot for each model
for model in models:
    plot_tsne_with_clusters_and_matches(model)

---
### **Step 6: Compare Retrieval Results Across Models**

We will select a single query image from the "matches" set and compare the retrieval results across different models.

**6.1. Select a Query Image from the "Matches" Set**

In [None]:
# Display available images for querying from the 'matches' set
matches_df = model_data[models[0]]['df']
matches_images = matches_df[matches_df['match_set'] != 'pubmed']['image_file'].unique()

print("Available images for querying from the 'matches' set:")
for idx, image_file in enumerate(matches_images):
    print(f"{idx}: {image_file}")

# Prompt the user to select an image index
try:
    query_idx = int(input("\nEnter the index of the image you want to select: "))
    if query_idx < 0 or query_idx >= len(matches_images):
        raise ValueError("Invalid index selected.")
    query_image_name = matches_images[query_idx]
    print(f"\nSelected query image: {query_image_name}")
except ValueError as e:
    print(f"Error: {e}. Please enter a valid index.")

**6.2. Perform Retrieval and Display Results**

In [None]:
def retrieve_images_for_model(model, query_image_name, top_k=5):
    embeddings = model_data[model]['embeddings']
    image_files = model_data[model]['image_files']
    df = model_data[model]['df']
    
    # Build the FAISS index
    index = faiss.IndexFlatL2(embeddings.shape[1])
    index.add(embeddings)
    
    # Get the index of the query image
    query_indices = df[df['image_file'] == query_image_name]['embedding_index'].tolist()
    if not query_indices:
        print(f"No embeddings found for query image: {query_image_name} in model {model}")
        return None, None
    query_idx = query_indices[0]
    
    # Perform the search
    query_embedding = embeddings[query_idx].reshape(1, -1)
    distances, indices = index.search(query_embedding, top_k + 1)  # +1 to include the query image itself
    
    # Get retrieved images
    retrieved_indices = indices[0][1:]  # Exclude the query image itself
    retrieved_images = [image_files[i] for i in retrieved_indices]
    
    # Evaluate retrieval
    retrieved_df = df.iloc[retrieved_indices]
    query_match_set = df.loc[query_idx, 'match_set']
    correct_matches = retrieved_df[retrieved_df['match_set'] == query_match_set]
    num_matches_found = len(correct_matches)
    total_available_matches = len(df[df['match_set'] == query_match_set]) - 1  # Exclude the query image
    match_accuracy = num_matches_found / total_available_matches if total_available_matches > 0 else 0
    
    return retrieved_images, match_accuracy

# Retrieve images for each model
retrieved_images_dict = {}
match_accuracies = {}
for model in models:
    retrieved_images, match_accuracy = retrieve_images_for_model(model, query_image_name, top_k=15)
    retrieved_images_dict[model] = retrieved_images
    match_accuracies[model] = match_accuracy

**6.3. Display Retrieval Results Side by Side**

In [None]:
def display_retrievals_side_by_side(query_image_path, retrieved_image_paths_dict, match_accuracies):
    num_models = len(retrieved_image_paths_dict)
    num_retrieved = len(next(iter(retrieved_image_paths_dict.values())))
    fig, axes = plt.subplots(num_retrieved + 1, num_models + 1, figsize=(4 * (num_models + 1), 4 * (num_retrieved + 1)))
    
    # Display the query image in the first column
    query_img = Image.open(query_image_path)
    axes[0, 0].imshow(query_img)
    axes[0, 0].axis('off')
    axes[0, 0].set_title('Query Image', fontsize=12)
    
    # Fill the rest of the query column with empty plots
    for i in range(1, num_retrieved + 1):
        axes[i, 0].axis('off')
    
    # Display retrieved images for each model
    for col_idx, model in enumerate(retrieved_image_paths_dict.keys()):
        # Add model name and match accuracy in the header row
        axes[0, col_idx + 1].text(0.5, 0.5, f'{model}\nMatch Accuracy: {match_accuracies[model]:.2f}', 
                                  fontsize=12, ha='center', va='center')
        axes[0, col_idx + 1].axis('off')
        
        # Add retrieved images under the respective column
        retrieved_paths = retrieved_image_paths_dict[model]
        for row_idx, img_file in enumerate(retrieved_paths):
            img_path = os.path.join(combined_dataset_dir, img_file)
            img = Image.open(img_path)
            axes[row_idx + 1, col_idx + 1].imshow(img)
            axes[row_idx + 1, col_idx + 1].axis('off')
    
    plt.tight_layout()
    plt.show()


# Prepare the query image path
query_image_path = os.path.join(combined_dataset_dir, query_image_name)

# Display the retrieval results
display_retrievals_side_by_side(query_image_path, retrieved_images_dict, match_accuracies)

---
### **Step 7: Conclusion**


**[PLACEHOLDER]**