## **Evaluation of Image Retrieval Using DINO Embeddings**

In this notebook, we will:
- Load the embeddings generated using the DINO model.
- Use FAISS for similarity search to perform image retrieval.
- Evaluate the retrieval performance using appropriate metrics.
- Visualize the retrieval results.
- Create an evaluation DataFrame summarizing the results.

---

## **Table of Contents**
---

1. Import Libraries
2. Load Embeddings and Image Files
3. Load Augmented Image Mapping and Captions
4. Prepare the Dataset
5. Set Up Similarity Search (FAISS)
6. Define Functions for Retrieval and Evaluation
7. Perform Image Retrieval
8. Cluster Embeddings
9. Compute Unsupervised Evaluation Metrics
10. Visualize Sample Images from Each Cluster
11. Create Evaluation DataFrame
12. Conclusion

---
### **Step 1: Import Libraries**

In [None]:
import os
import re
import gc
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import faiss
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

---
### **Step 2: Load Embeddings and Image Files**

In [None]:
# Load embeddings and image files
embeddings = np.load('embeddings_dino.npy')
image_files = np.load('image_files_dino.npy')

print(f'Embeddings shape: {embeddings.shape}')
print(f'Number of images: {len(image_files)}')

---
### **Step 3: Load Augmented Image Mapping and Captions**

In [None]:
# Load augmented image mapping
augmented_df = pd.read_csv('augmented_image_mapping.csv')

# Load captions from pubmed_set
captions_file = './pubmed_set/captions.json'
with open(captions_file, 'r') as f:
    captions_data = json.load(f)

# Create a mapping from UUID to caption
uuid_to_caption = {entry['uuid']: entry['caption'] for entry in captions_data.values()}

# Map filenames to captions for pubmed_set images
pubmed_image_captions = {}
pubmed_images_dir = './pubmed_set/images/'
for filename in os.listdir(pubmed_images_dir):
    uuid = os.path.splitext(filename)[0]
    if uuid in uuid_to_caption:
        pubmed_image_captions[filename] = uuid_to_caption[uuid]

# Convert the pubmed captions mapping into a DataFrame
captions_df = pd.DataFrame(
    list(pubmed_image_captions.items()), columns=['image_file', 'caption']
)

---
### **Step 4: Prepare the Dataset**

We create a DataFrame that includes image file names, their source dataset, and any associated captions.

In [None]:
# Create a DataFrame for image files
df = pd.DataFrame({
    'image_file': image_files,
    'embedding_index': range(len(image_files))
})

# Create a mapping from augmented to original images
augmented_mapping = augmented_df.set_index('augmented_image')['original_image'].to_dict()

# Apply the mapping to determine the original image
df['original_image'] = df['image_file'].apply(lambda x: augmented_mapping.get(x, x))

# Determine the source dataset for each image
def determine_source(image_name):
    # Matches the format XXX_XX.jpg where XXX and XX are digits
    if re.match(r'^\d{3}_\d{2}\.jpg$', image_name):
        return 'matches'
    # Matches the format XXXX_XX.jpg where XXXX and XX are digits
    elif re.match(r'^\d{4}_\d{2}\.jpg$', image_name):
        return 'matches'
    else:
        return 'pubmed'

# Apply the source determination function
df['source'] = df['image_file'].apply(determine_source)

# Merge captions for pubmed images
df = df.merge(captions_df, on='image_file', how='left')

---
### **Step 5: Set Up Similarity Search with FAISS**

FAISS (Facebook AI Similarity Search) is a library for efficient similarity search and clustering of dense vectors.

In [None]:
# Define the dimension of the embeddings
d = embeddings.shape[1]

# Build the index
index = faiss.IndexFlatL2(d)

# Add embeddings to the index
index.add(embeddings)

print(f'Number of vectors in the index: {index.ntotal}')

---
### **Step 6: Define Functions for Retrieval and Evaluation**

Function to Display Images with Captions

In [None]:
def display_images(image_paths, titles=None, captions=None, cols=1, figsize=(60, 32)):
    rows = len(image_paths) // cols + int(len(image_paths) % cols != 0)
    plt.figure(figsize=figsize)
    for idx, img_path in enumerate(image_paths):
        plt.subplot(rows, cols, idx+1)
        img = Image.open(img_path)
        plt.imshow(img)
        plt.axis('off')
        title = titles[idx] if titles else ''
        caption = captions[idx] if captions else ''
        full_title = f"{title}\n{caption}"
        plt.title(full_title, fontsize=8)
    plt.tight_layout()
    plt.show()

Function to Retrieve and Display Results

In [None]:
def retrieve_and_display(query_image_name, top_k=10):
    # Get the index of the query image
    query_indices = df[df['original_image'] == query_image_name]['embedding_index'].tolist()
    if not query_indices:
        print(f"No embeddings found for query image: {query_image_name}")
        return
    
    # Use the first embedding as the query
    query_idx = query_indices[0]
    query_embedding = embeddings[query_idx].reshape(1, -1)
    
    # Perform the search
    distances, indices = index.search(query_embedding, top_k)  # Remove +1 to only retrieve top_k results
    
    # Get retrieved images (including the query image itself if present)
    retrieved_indices = indices[0]
    retrieved_images = df.iloc[retrieved_indices]
    
    # Check if the query image is already included in the retrieval
    if query_idx not in retrieved_indices:
        # Manually add the query image to the top of the results
        query_row = df.iloc[[query_idx]]
        retrieved_images = pd.concat([query_row, retrieved_images]).drop_duplicates().head(top_k)
    
    # Prepare paths and titles
    query_image_file = df.loc[query_idx, 'image_file']
    query_image_path = os.path.join('./combined_dataset/', query_image_file)
    retrieved_image_paths = [os.path.join('./combined_dataset/', img_file) for img_file in retrieved_images['image_file']]
    titles = [f"Rank {i+1}: {img_name}" for i, img_name in enumerate(retrieved_images['image_file'])]
    captions = []
    for _, row in retrieved_images.iterrows():
        if row['source'] == 'pubmed':
            captions.append(row['caption'])
        else:
            captions.append('')
    
    # Combine query image and retrieved images for uniform display
    all_image_paths = [query_image_path] + retrieved_image_paths
    all_titles = [f"Query: {query_image_file}"] + titles
    all_captions = [None] + captions  # Query image may not have a caption
    
    # Dynamically adjust figure size based on number of images
    total_images = len(all_image_paths)
    figsize = (15, total_images * 3)  # Each image gets 3 units of vertical space

    # Display the query and retrieved images together
    print("Displaying Query and Retrieved Images:")
    display_images(all_image_paths, titles=all_titles, captions=all_captions, cols=1, figsize=figsize)
    
    return retrieved_images

Function to Evaluate Retrieval

In [None]:
def evaluate_retrieval(query_image_name, retrieved_images):
    # Get all available matches for the query image
    total_available_matches = len(df[(df['original_image'].str.startswith(query_image_name.split('_')[0])) & (df['source'] == 'matches')])
    
    # Count the number of matches in the retrieved images
    retrieved_matches = retrieved_images[retrieved_images['original_image'].str.startswith(query_image_name.split('_')[0])]
    num_matches_found = len(retrieved_matches[retrieved_matches['source'] == 'matches'])
    
    # Calculate match accuracy
    if total_available_matches > 0:
        match_accuracy = num_matches_found / total_available_matches
    else:
        match_accuracy = 0
    
    return num_matches_found, total_available_matches, match_accuracy

---
### **Step 7: Perform Image Retrieval**

We will perform a k-nearest neighbors search for each image in the dataset.

In [None]:
# Step 7: Perform Image Retrieval

# Get all unique original images from the matches set
matches_images = df[df['source'] == 'matches']['original_image'].unique()

# Display available images
if len(matches_images) > 0:
    print("Available images for querying from the 'matches' set:")
    for idx, image_file in enumerate(matches_images):
        print(f"{idx}: {image_file}")

    # Prompt the user to select an image
    while True:
        try:
            query_idx = int(input(f"Enter the index of the query image (0 to {len(matches_images) - 1}): "))
            if 0 <= query_idx < len(matches_images):
                query_image_name = matches_images[query_idx]
                print(f"Selected query image: {query_image_name}")
                
                # Retrieve and display images
                print("\nRetrieving images...")
                retrieved_images = retrieve_and_display(query_image_name, top_k=10)
                
                if retrieved_images is not None:
                    # Evaluate the retrieval
                    print("\nEvaluating retrieval performance...")
                    num_matches_found, total_available_matches, match_accuracy = evaluate_retrieval(query_image_name, retrieved_images)
                    
                    # Display evaluation results
                    print("\nRetrieval Evaluation Results:")
                    print(f"Total Matches Available: {total_available_matches}")
                    print(f"Matches Found in Retrieved Images: {num_matches_found}")
                    print(f"Match Accuracy: {match_accuracy:.2f}")
                else:
                    print("No images retrieved.")
                
                break  # Exit the loop after processing
            else:
                print(f"Invalid index. Please enter a number between 0 and {len(matches_images) - 1}.")
        except ValueError:
            print("Invalid input. Please enter a valid integer.")
else:
    print("No images available in the 'matches' set.")

---
### **Step 8: Cluster Embeddings**

We can use clustering algorithms like K-Means to group similar images and evaluate the cohesiveness of these clusters.

In [None]:
# Choose the number of clusters (dynamic based on dataset size)
num_clusters = min(10, len(embeddings) // 100)  # For example, 1 cluster per 100 images or at least 10 clusters

# Initialize K-Means
kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init=10)

# Fit K-Means on the embeddings
cluster_labels = kmeans.fit_predict(embeddings)

# Add cluster labels to the DataFrame
df['cluster_label'] = cluster_labels

print(f"Clustering completed with {num_clusters} clusters.")

---
### **Step 9: Compute Unsupervised Evaluation Metrics**

**Silhouette Score:** measures how similar an object is to its own cluster compared to other clusters. It ranges from -1 to 1, with higher values indicating better clustering.

In [None]:
# Compute Silhouette Score
if num_clusters > 1:  # Silhouette Score is valid only if there are >1 clusters
    sil_score = silhouette_score(embeddings, cluster_labels)
    print(f'Silhouette Score: {sil_score:.4f}')
else:
    print("Silhouette Score cannot be computed with a single cluster.")

---
### **Step 10: Visualize Sample Images from Each Cluster**

Visualize Sample Images

In [None]:
# Function to display images from clusters
def display_cluster_images(cluster_num, images_per_cluster=5):
    """
    Display a sample of images from a specified cluster.

    Parameters:
    - cluster_num: The cluster number to display images from.
    - images_per_cluster: The number of images to sample and display.
    """
    cluster_images = df[df['cluster_label'] == cluster_num]['image_file'].values
    if len(cluster_images) == 0:
        print(f"Cluster {cluster_num} is empty.")
        return
    sample_images = np.random.choice(cluster_images, size=min(images_per_cluster, len(cluster_images)), replace=False)
    sample_image_paths = [os.path.join('./combined_dataset/', img_file) for img_file in sample_images]
    titles = [f"Cluster {cluster_num}: {img_file}" for img_file in sample_images]
    display_images(sample_image_paths, titles=titles, figsize=(30, 16), cols=3)

# Visualize images from each cluster
for cluster in range(num_clusters):
    print(f"\nCluster {cluster} (Total Images: {len(df[df['cluster_label'] == cluster])}):")
    display_cluster_images(cluster)

---
### **Step 11: Create Evaluation DataFrame**

Evaluate Over All Images in the Matches Set

In [None]:
# Prepare a list to store evaluation results
evaluation_results = []

# Loop through each query image and evaluate
for query_image_name in tqdm(matches_images, desc='Evaluating Queries'):
    # Retrieve images without displaying
    query_indices = df[df['original_image'] == query_image_name]['embedding_index'].tolist()
    if not query_indices:
        print(f"No embeddings found for query image: {query_image_name}")
        continue

    # Use the first embedding as the query
    query_idx = query_indices[0]
    query_embedding = embeddings[query_idx].reshape(1, -1)
    
    # Perform the search
    distances, indices = index.search(query_embedding, 10)  # Retrieve top 10 results
    retrieved_indices = indices[0]
    retrieved_images = df.iloc[retrieved_indices]
    
    # Evaluate retrieval
    num_matches_found, total_available_matches, match_accuracy = evaluate_retrieval(query_image_name, retrieved_images)
    
    # Store the results
    evaluation_results.append({
        'Queried Image': query_image_name,
        'Number of Matches Found': num_matches_found,
        'Total Available Matches': total_available_matches,
        'Match Accuracy': match_accuracy
    })
    
    # Clear memory
    gc.collect()

# Convert results to a DataFrame
evaluation_df = pd.DataFrame(evaluation_results)

# Display the evaluation DataFrame
evaluation_df

Create DataFrame and Save Results

In [None]:
# Save to CSV
evaluation_df.to_csv('evaluation_results_dino.csv', index=False)

print('Evaluation results saved to evaluation_results_dino.csv')

### **Step 12: Conclusion**

We have successfully:

- Loaded the embeddings generated using the PLIP model.
- Performed image retrieval using FAISS without labels.
- Visualized retrieval results to qualitatively assess similarity.
- Applied K-Means clustering on the embeddings.
- Computed the Silhouette Score to evaluate clustering quality.
- Visualized sample images from each cluster.
- Created an evaluation DataFrame summarizing the results for all query images.