In [23]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

In [24]:
# Set the device to CUDA if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [25]:
# Setup transformations for the images
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [26]:
# Load the dataset from the class-wise folders
dataset = ImageFolder(root='Datasets', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [27]:
# Load pre-trained models
resnet = models.resnet101(pretrained=True)
resnet.eval()  # Set to evaluation mode
googlenet = models.googlenet(pretrained=True)
googlenet.eval()
zfnet = models.alexnet(pretrained=True)
zfnet.eval()



AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [28]:
# Function to extract features
def extract_features(model, dataloader):
    features = []
    model.eval()
    with torch.no_grad():
        for images, _ in dataloader:
            images = images.to('cuda' if torch.cuda.is_available() else 'cpu')
            outputs = model(images)
            features.extend(outputs.cpu().numpy())
    return np.array(features)

In [31]:
# Move models to the same device as the input tensor
resnet.to(device)
googlenet.to(device)
zfnet.to(device)

# Extract features using ResNet-101, Google Net and ZFNet
resnet_features = extract_features(resnet, dataloader)
googlenet_features = extract_features(googlenet, dataloader)
zfnet_features = extract_features(zfnet, dataloader)

In [32]:
# Function to find nearest neighbors
def find_nearest_neighbors(features, index, num_neighbors=10):
    similarities = cosine_similarity([features[index]], features)[0]
    nearest_indices = np.argsort(-similarities)[1:num_neighbors+1]  # Top 10 excluding self
    return nearest_indices

# Example usage for one image per class (assuming balanced classes for simplicity)
num_classes = len(dataset.classes)
class_indices = {i: [] for i in range(num_classes)}
for idx, (_, label) in enumerate(dataset):
    class_indices[label].append(idx)

# Find and print nearest neighbors for one image from each class


In [33]:
for label, indices in class_indices.items():
    representative_idx = indices[0]  # Just taking the first image for simplicity
    neighbors = find_nearest_neighbors(resnet_features, representative_idx)
    print(f"Class {label} representative image at index {representative_idx} has neighbors indices: {neighbors}")


Class 0 representative image at index 0 has neighbors indices: [1247 1400  126  364  778   29 2016 2452 2141  903]
Class 1 representative image at index 435 has neighbors indices: [2394  472 2286 2661  880 1032 1711  579  870 1132]
Class 2 representative image at index 635 has neighbors indices: [2485  155 2691 2582 1097 2747 1588  512 1942  175]
Class 3 representative image at index 1433 has neighbors indices: [ 859 1904 1879  870  253 1132 1747  472 2286 1552]
Class 4 representative image at index 2233 has neighbors indices: [ 931  765 2997 2258 2765 2445  449 1115  424 1349]
Class 5 representative image at index 2356 has neighbors indices: [1456 1878 1073  496 1738 2617  890 1586 2316  647]
Class 6 representative image at index 2463 has neighbors indices: [2872 2916  504  334  335 1821 1231  476 2130 2404]
Class 7 representative image at index 2562 has neighbors indices: [ 132 1960 2660 2204  375  328  947  690 2685  493]
Class 8 representative image at index 2662 has neighbors indi

In [34]:
for label, indices in class_indices.items():
    representative_idx = indices[0]  # Just taking the first image for simplicity
    neighbors = find_nearest_neighbors(googlenet_features, representative_idx)
    print(f"Class {label} representative image at index {representative_idx} has neighbors indices: {neighbors}")


Class 0 representative image at index 0 has neighbors indices: [2498 2992 2873 2500 1275 2099 1946  829 1645 2670]
Class 1 representative image at index 435 has neighbors indices: [  43  785  157 2910 1583  822 1151  369  811 2692]
Class 2 representative image at index 635 has neighbors indices: [1473 2415  899 2128 1995   84 1567 2260  731  115]
Class 3 representative image at index 1433 has neighbors indices: [ 442  766 2485 1471  859 1509  890 2177  368  136]
Class 4 representative image at index 2233 has neighbors indices: [ 760  729  229 2185  507  386 1307  446  255 2282]
Class 5 representative image at index 2356 has neighbors indices: [ 674 1256  193 1891  179 1334 2955 2328 1415  295]
Class 6 representative image at index 2463 has neighbors indices: [2468 1026 1018 1366 1865 3014  625 1693  544 2858]
Class 7 representative image at index 2562 has neighbors indices: [  50 1976  209  793  869 2307  527  318  407 2699]
Class 8 representative image at index 2662 has neighbors indi

In [35]:
# for label, indices in class_indices.items():
#     representative_idx = indices[0]  # Just taking the first image for simplicity
#     neighbors = find_nearest_neighbors(zfnet_features, representative_idx)
#     print(f"Class {label} representative image at index {representative_idx} has neighbors indices: {neighbors}")
