In [None]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
# Set the device to CUDA if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.get_device_name(device=None)

In [None]:
# Setup transformations for the images
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
# Load the dataset from the class-wise folders
dataset = ImageFolder(root='Datasets', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
# Load pre-trained models
resnet = models.resnet101(pretrained=True)
resnet.eval()  # Set to evaluation mode
googlenet = models.googlenet(pretrained=True)
googlenet.eval()
zfnet = models.alexnet(pretrained=True)
zfnet.eval()

In [None]:
# Function to extract features
def extract_features(model, dataloader):
    features = []
    model.eval()
    with torch.no_grad():
        for images, _ in dataloader:
            images = images.to('cuda' if torch.cuda.is_available() else 'cpu')
            outputs = model(images)
            features.extend(outputs.cpu().numpy())
    return np.array(features)

In [None]:
# Move models to the same device as the input tensor
resnet.to(device)
googlenet.to(device)
zfnet.to(device)

# Extract features using ResNet-101, Google Net and ZFNet
resnet_features = extract_features(resnet, dataloader)
googlenet_features = extract_features(googlenet, dataloader)
zfnet_features = extract_features(zfnet, dataloader)

In [None]:
import matplotlib.pyplot as plt

# Calculate Nearest Neighbors
def find_nearest_neighbors(features, n_neighbors=10):
    similarity_matrix = cosine_similarity(features)
    neighbors = np.argsort(similarity_matrix, axis=1)[:, -n_neighbors-1:-1]
    return neighbors

resnet_neighbors = find_nearest_neighbors(resnet_features)
zfnet_neighbors = find_nearest_neighbors(zfnet_features)
googlenet_neighbors = find_nearest_neighbors(googlenet_features)

# Visualize Nearest Neighbors
def plot_neighbors(image_idx, neighbors, model_name):
    fig, axs = plt.subplots(1, 11, figsize=(15, 2))
    fig.suptitle(f"{model_name} Nearest Neighbors for Image {image_idx}")

    # Show query image
    query_image, _ = dataset[image_idx]
    axs[0].imshow(query_image.permute(1, 2, 0).numpy())
    axs[0].set_title("Query")
    axs[0].axis('off')

    # Show neighbors
    for i, neighbor_idx in enumerate(neighbors[image_idx]):
        neighbor_image, _ = dataset[neighbor_idx]
        axs[i + 1].imshow(neighbor_image.permute(1, 2, 0).numpy())
        axs[i + 1].axis('off')

    plt.show()

# Example: Display neighbors for a sample image index using each model
image_idx = 2500 # Change this to visualize different images
plot_neighbors(image_idx, resnet_neighbors, "ResNet-101")
plot_neighbors(image_idx, zfnet_neighbors, "ZFNet")
plot_neighbors(image_idx, googlenet_neighbors, "GoogleNet")

In [None]:
# Function to find nearest neighbors
def find_nearest_neighbors(features, index, num_neighbors=10):
    similarities = cosine_similarity([features[index]], features)[0]
    nearest_indices = np.argsort(-similarities)[1:num_neighbors+1]  # Top 10 excluding self
    return nearest_indices

# Example usage for one image per class (assuming balanced classes for simplicity)
num_classes = len(dataset.classes)
class_indices = {i: [] for i in range(num_classes)}
for idx, (_, label) in enumerate(dataset):
    class_indices[label].append(idx)

# Find and print nearest neighbors for one image from each class


In [None]:
for label, indices in class_indices.items():
    representative_idx = indices[0]  # Just taking the first image for simplicity
    neighbors = find_nearest_neighbors(resnet_features, representative_idx)
    print(f"Class {label} representative image at index {representative_idx} has neighbors indices: {neighbors}")


In [None]:
for label, indices in class_indices.items():
    representative_idx = indices[0]  # Just taking the first image for simplicity
    neighbors = find_nearest_neighbors(googlenet_features, representative_idx)
    print(f"Class {label} representative image at index {representative_idx} has neighbors indices: {neighbors}")


In [None]:
# for label, indices in class_indices.items():
#     representative_idx = indices[0]  # Just taking the first image for simplicity
#     neighbors = find_nearest_neighbors(zfnet_features, representative_idx)
#     print(f"Class {label} representative image at index {representative_idx} has neighbors indices: {neighbors}")
