In [None]:
import os
import pandas as pd
import timm
import numpy as np
from wildlife_tools.data import WildlifeDataset
from wildlife_tools.features import DeepFeatures
from wildlife_tools.similarity import CosineSimilarity
import torchvision.transforms as T
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont

In [None]:
# Step 1: Prepare dataset
image_folder = '../test/Building_images'
print(os.listdir(image_folder))
image_paths = [os.path.join(image_folder, img) for img in os.listdir(image_folder)]

# Create a DataFrame with correct column name 'path'
metadata = pd.DataFrame({'path': image_paths})

# Define enhanced transformations with data augmentation and black and white conversion
transform = T.Compose([
    T.Resize([384, 384]),
    T.Grayscale(num_output_channels=3),
    T.RandomRotation(10),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    T.ToTensor(),
    T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

# Add a dummy 'identity' column for compatibility
metadata['identity'] = range(len(metadata))

# Create the dataset
dataset = WildlifeDataset(metadata, transform=transform)

In [None]:
# Load the MegaDescriptor-L-384 model
model_name = 'hf-hub:BVRA/MegaDescriptor-L-384'
extractor = DeepFeatures(timm.create_model(model_name, num_classes=0, pretrained=True))
print("model loaded successfully")

# Extract features
features = extractor(dataset)

In [None]:
# Calculate similarity
similarity_function = CosineSimilarity()
similarity_dict = similarity_function(features, features)

# Access the actual similarity values
similarity_matrix = similarity_dict['cosine']

In [None]:
# Function to get top N similar images with similarity scores
def get_top_n_similar_with_scores(similarity_matrix, n=5):
    top_n_indices = np.argsort(-similarity_matrix, axis=1)[:, 1:n+1]
    top_n_scores = -np.sort(-similarity_matrix, axis=1)[:, 1:n+1]
    return top_n_indices, top_n_scores

# Get top 5 similar images for each image
top_5_similar, top_5_scores = get_top_n_similar_with_scores(similarity_matrix, n=5)

# Display results with images and similarity scores
for idx, (similar_indices, scores) in enumerate(zip(top_5_similar, top_5_scores)):
    print(f"Image {idx} ({os.path.basename(metadata['path'][idx])}) is similar to:")
    query_img = Image.open(metadata['path'][idx]).convert("L").convert("RGB").resize((384, 384))  # Convert to black and white and back to RGB
    
    # Create a new image for the results with more space
    result_img = Image.new('RGB', (2340, 484 + 60), (255, 255, 255))  # Adjust width for spacing and height for text
    
    # Draw the query image details with background
    draw = ImageDraw.Draw(result_img)
    font = ImageFont.truetype("arial.ttf", 40)  # Larger font size
    draw.rectangle([(0, 0), (2340, 60)], fill="white")
    draw.text((10, 10), f"Query Image: {os.path.basename(metadata['path'][idx])}", fill="black", font=font)
    result_img.paste(query_img, (10, 70))  # Add padding
    
    for i, (similar_idx, score) in enumerate(zip(similar_indices, scores)):
        similar_img = Image.open(metadata['path'][similar_idx]).convert("L").convert("RGB").resize((384, 384))  # Convert to black and white and back to RGB
        result_img.paste(similar_img, ((i + 1) * 384 + 20 * (i + 1), 70))  # Add more spacing between images
        draw.rectangle([((i + 1) * 384 + 20 * (i + 1), 0), ((i + 2) * 384 + 20 * (i + 1), 60)], fill="white")
        draw.text(((i + 1) * 384 + 20 * (i + 1) + 10, 10), f"Similar Image: {os.path.basename(metadata['path'][similar_idx])}\nConfidence: {score*100:.2f}%", fill="black", font=font)
    
    # Display the aggregated result image
    plt.imshow(result_img)
    plt.axis('off')
    plt.show()