In [1]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity


In [2]:

# Setup transformations for the images
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


In [3]:

# Load the dataset from the class-wise folders
dataset = ImageFolder(root='C:/Users/ASUS/Desktop/CSE465/Datasets', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


In [4]:

# Load pre-trained models
resnet = models.resnet101(pretrained=True)
resnet.eval()  # Set to evaluation mode
googlenet = models.googlenet(pretrained=True)
googlenet.eval()
# zfnet = models.zfnet(pretrained=True) 
# zfnet.eval()




GoogLeNet(
  (conv1): BasicConv2d(
    (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (conv2): BasicConv2d(
    (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv3): BasicConv2d(
    (conv): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (inception3a): Inception(
    (branch1): BasicConv2d(
      (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track

In [5]:
# Function to extract features
def extract_features(model, dataloader):
    features = []
    model.eval()
    with torch.no_grad():
        for images, _ in dataloader:
            images = images.to('cuda' if torch.cuda.is_available() else 'cpu')
            outputs = model(images)
            features.extend(outputs.cpu().numpy())
    return np.array(features)


In [6]:

# Extract features using ResNet-101, Google Net and ZFNet
resnet_features = extract_features(resnet, dataloader)
googlenet_features = extract_features(googlenet, dataloader)
# zfnet_features = extract_features(zfnet, dataloader)



In [7]:
# Function to find nearest neighbors
def find_nearest_neighbors(features, index, num_neighbors=10):
    similarities = cosine_similarity([features[index]], features)[0]
    nearest_indices = np.argsort(-similarities)[1:num_neighbors+1]  # Top 10 excluding self
    return nearest_indices

# Example usage for one image per class (assuming balanced classes for simplicity)
num_classes = len(dataset.classes)
class_indices = {i: [] for i in range(num_classes)}
for idx, (_, label) in enumerate(dataset):
    class_indices[label].append(idx)

# Find and print nearest neighbors for one image from each class


In [8]:
for label, indices in class_indices.items():
    representative_idx = indices[0]  # Just taking the first image for simplicity
    neighbors = find_nearest_neighbors(resnet_features, representative_idx)
    print(f"Class {label} representative image at index {representative_idx} has neighbors indices: {neighbors}")


Class 0 representative image at index 0 has neighbors indices: [1076  735 1274 1860  867 1417 2308 1087  447 1819]
Class 1 representative image at index 435 has neighbors indices: [ 802 1495  833 2274 1297 1030 2931 1887  955 1979]
Class 2 representative image at index 635 has neighbors indices: [ 737 1533 1560 2180 1953 1207 2924 2170  956  339]
Class 3 representative image at index 1433 has neighbors indices: [ 393 2390  677 2240 2797 1953 2180 2494  481   98]
Class 4 representative image at index 2233 has neighbors indices: [  53  325 1283 1532  957  851 2538 2375 2719 2993]
Class 5 representative image at index 2356 has neighbors indices: [1449 2397  573 1914  222  914  907 1309 1659 1660]
Class 6 representative image at index 2463 has neighbors indices: [2511 2083 1600  117 1652 1679 2050 1552 2814 1060]
Class 7 representative image at index 2562 has neighbors indices: [2895  202 2826 1115 2801 2369 1836  977 1414 1653]
Class 8 representative image at index 2662 has neighbors indi

In [9]:
for label, indices in class_indices.items():
    representative_idx = indices[0]  # Just taking the first image for simplicity
    neighbors = find_nearest_neighbors(googlenet_features, representative_idx)
    print(f"Class {label} representative image at index {representative_idx} has neighbors indices: {neighbors}")


Class 0 representative image at index 0 has neighbors indices: [1190 2286  656 2673 1436  875 1522 1975 2357 2000]
Class 1 representative image at index 435 has neighbors indices: [1882 2978 1662 1449 2645  155 2272 1147 2192 2499]
Class 2 representative image at index 635 has neighbors indices: [ 872  347  640 1524 1560 1656 1225  993  585 1631]
Class 3 representative image at index 1433 has neighbors indices: [1017 1528 2905 2051 1837  394 1043   49  812 2988]
Class 4 representative image at index 2233 has neighbors indices: [  44  869  982   83 1870  350  597 2814 2445 1481]
Class 5 representative image at index 2356 has neighbors indices: [2185 1045 1759  175 2236 1347  745  267 1031 2220]
Class 6 representative image at index 2463 has neighbors indices: [2904  818 2905 2489 1325 1615 2058 2192 1043  917]
Class 7 representative image at index 2562 has neighbors indices: [2853 1017  475 2141  812 1513 2402 2897  877 2105]
Class 8 representative image at index 2662 has neighbors indi

In [10]:
# for label, indices in class_indices.items():
#     representative_idx = indices[0]  # Just taking the first image for simplicity
#     neighbors = find_nearest_neighbors(zfnet_features, representative_idx)
#     print(f"Class {label} representative image at index {representative_idx} has neighbors indices: {neighbors}")
