In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision import models
from sklearn.neighbors import NearestNeighbors
import numpy as np
from tqdm import tqdm

In [None]:
# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = datasets.ImageFolder(root='Datasets', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)

In [None]:
# Load Pretrained Models
resnet_model = models.resnet101(pretrained=True).to(device).eval()
zfnet_model = models.alexnet(pretrained=True).to(device).eval()  # ZFNet approximated with AlexNet
googlenet_model = models.googlenet(pretrained=True).to(device).eval()

def extract_features(model, dataloader):
    model.eval()
    features = []
    labels = []
    with torch.no_grad():
        for images, label in tqdm(dataloader):
            images = images.to(device)
            output = model(images)
            features.append(output)
            labels.append(label)
    return torch.cat(features), torch.cat(labels)

# Extract features for ResNet, ZFNet, and GoogleNet
resnet_features, resnet_labels = extract_features(resnet_model, dataloader)
zfnet_features, zfnet_labels = extract_features(zfnet_model, dataloader)
googlenet_features, googlenet_labels = extract_features(googlenet_model, dataloader)

In [None]:
resnet_features_np = resnet_features.cpu().numpy()
zfnet_features_np = zfnet_features.cpu().numpy()
googlenet_features_np = googlenet_features.cpu().numpy()

def find_nearest_neighbors(features, query_feature, k=10):
    nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(features)
    distances, indices = nbrs.kneighbors([query_feature])
    return distances, indices

In [None]:
class_sample_indices = {class_id: np.where(resnet_labels.cpu().numpy() == class_id)[0][0] for class_id in np.unique(resnet_labels.cpu().numpy())}
for class_id, img_index in class_sample_indices.items():
    print(f"\nClass {class_id}:")
    query_feature = resnet_features_np[img_index]
    distances, indices = find_nearest_neighbors(resnet_features_np, query_feature, k=10)
    print("ResNet Neighbors:", indices[0])

    # Repeat for ZFNet and GoogleNet for comparison
    query_feature_zf = zfnet_features_np[img_index]
    distances_zf, indices_zf = find_nearest_neighbors(zfnet_features_np, query_feature_zf, k=10)
    print("ZFNet Neighbors:", indices_zf[0])

    query_feature_google = googlenet_features_np[img_index]
    distances_google, indices_google = find_nearest_neighbors(googlenet_features_np, query_feature_google, k=10)
    print("GoogleNet Neighbors:", indices_google[0])