In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from data_utils import MIT_split_dataset, CustomTransform
import numpy as np
import wandb
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import MiniBatchKMeans
from sklearn.metrics import precision_recall_curve, average_precision_score, PrecisionRecallDisplay
import matplotlib.pyplot as plt
import tqdm
import faiss

%load_ext autoreload
%autoreload 2
wandb.login(key='14a56ed86de5bf43e377d95d05458ca8f15f5017')

config = {
    'IMG_WIDTH': 256,
    'IMG_HEIGHT': 256,
    'TRAINING_DATASET_DIR': '../Week 1/data/MIT_split/train',
    'TEST_DATASET_DIR': '../Week 1/data/MIT_split/test',
    'batch_size': 32,
    'classifier': 'knn',
    'n_neighbors': 5,
    'metric': 'euclidean',
    'num_words': 256,
    'use_bovw': False,
    'voronoi_cells': 64,
    'lookup_cells': 8,
    'device': torch.device("cuda" if torch.cuda.is_available() else "cpu")
}

torch.manual_seed(123) # seed for reproductibility

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mluisgogu2001[0m ([33mc5-g8[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\Luis/.netrc


<torch._C.Generator at 0x1b2b13a2690>

In [9]:
transform_train = CustomTransform(config, mode='train')
transform_test = CustomTransform(config, mode='test')

train_dataset = datasets.ImageFolder(root=config['TRAINING_DATASET_DIR'], transform=transform_train)
test_dataset =  datasets.ImageFolder(root=config['TEST_DATASET_DIR'], transform=transform_test)

total_length = len(train_dataset)
train_size = int(0.8 * total_length)  # e.g., 80% for training
valid_size = total_length - train_size  # remaining 20% for validation

# Split dataset
train_dataset, validation_dataset = random_split(train_dataset, [train_size, valid_size])

dataloader_train = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
dataloader_validation = DataLoader(validation_dataset, batch_size=config['batch_size'], shuffle=True)
dataloader_test = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)

In [5]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', weights='ResNet50_Weights.DEFAULT').eval()
model.fc = nn.Identity()
model = model.to(config['device'])

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to C:\Users\Luis/.cache\torch\hub\v0.10.0.zip
Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to C:\Users\Luis/.cache\torch\hub\checkpoints\resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:09<00:00, 11.3MB/s]


In [6]:
# Define a function to extract features from an image
def extract_features(loader, model, device):
    with torch.no_grad():    
        features, labels = [], []
        for img, label in tqdm.tqdm(loader, total=len(loader), desc='Extracting features'):
            features.append(model(img.to(device)).cpu().detach().numpy())
            labels.append(label)
    return np.concatenate(features).astype('float32'), np.concatenate(labels)

# Bag of Visual Words
class BoVW:
    def __init__(self, config):
        self.features = []

        self.cluster = MiniBatchKMeans(n_clusters=config['num_words'], batch_size=512, n_init='auto')
        self.num_words = config['num_words']

    def fit_transform(self, raw_features):

        print('Fitting the clustering...')
        # Fit the KMeans to the features
        self.cluster.fit(raw_features)
        # Assign each feature to the nearest visual word and create a histogram for each set of features
        visual_words = self.cluster.labels_
        bow_features = np.zeros((len(raw_features), self.num_words), dtype=np.float32)
        for i, word in enumerate(visual_words):
            bow_features[i, word] += 1

        # Normalize the histograms to represent the frequency of each visual word appearing in the image
        bow_features = bow_features / np.linalg.norm(bow_features, axis=1, keepdims=True)
        return bow_features

    def transform(self, raw_features):
        # Assign features to the nearest visual word without re-fitting the KMeans
        visual_words = self.cluster.predict(raw_features)
        bow_features = np.zeros((len(raw_features), self.num_words), dtype=np.float32)
        for i, word in enumerate(visual_words):
            bow_features[i, word] += 1

        # Normalize the histograms as in fit_transform
        bow_features = bow_features / np.linalg.norm(bow_features, axis=1, keepdims=True)
        return bow_features

# k-NN Classifier for Image Retrieval
class ImageRetrievalSystem:
    def __init__(self, model, train_dataloader, test_dataloader, config):
        self.use_bovw = config['use_bovw']
        self.model = model
        self.train_dataloader = train_dataloader
        self.test_dataloader = test_dataloader
        self.device = config['device']

        self.dim = 2048 if not self.use_bovw else config['num_words']
        self.classifier_type = config['classifier']
        self.n_neighbors = config['n_neighbors']

        if self.use_bovw:
            self.bow_processor = BoVW(config)

        if self.classifier_type == 'knn':
            self.classifier = NearestNeighbors(n_neighbors=config['n_neighbors'], metric=config['metric'])
        else:
            self.classifier = faiss.IndexIVFFlat(faiss.IndexFlatL2(self.dim), self.dim, config['voronoi_cells'])
            self.classifier.nprobe = config['lookup_cells']

    def fit(self):

        features, self.train_labels = extract_features(self.train_dataloader, self.model, self.device)

        if self.use_bovw:
            features = self.bow_processor.fit_transform(features)
        
        print('Fitting the classifier...')
        if self.classifier_type == 'knn':
            self.classifier.fit(features, self.train_labels)
        else:
            self.classifier.train(features)
            self.classifier.add(features)

    def retrieve(self):

        features, labels = extract_features(self.test_dataloader, self.model, self.device)
        if self.use_bovw:
            features = self.bow_processor.transform(features)
        
        print('Retrieving images...')
        if self.classifier_type == 'knn':
            distances, predictions = self.classifier.kneighbors(features, return_distance=True)
        else:
            distances, predictions = self.classifier.search(features, self.n_neighbors)
        
        return predictions, labels, distances

In [8]:
def evaluate(predictions, labels):

    # Prec@1
    prec_at_1 = np.mean([1 if predictions[i, 0] == labels[i] else 0 for i in range(predictions.shape[0])])

    # Prec@5
    prec_at_5 = np.mean([np.sum([1 if predictions[i, j] == labels[i] else 0 for j in range(config['n_neighbors'])]) / config['n_neighbors'] for i in range(predictions.shape[0])])

    # Initialize list to store average precision for each query
    average_precisions = []

    # Compute binary relevance arrays and calculate average precision for each query
    for i in range(predictions.shape[0]):
        # Convert true label into binary format for each prediction
        binary_relevance = np.array([1 if label == labels[i] else 0 for label in predictions[i]])
        
        # Ensure there is at least one positive class to avoid division by zero in AP score calculation
        if np.sum(binary_relevance) > 0:
            # Compute the average precision for the current query
            ap_score = average_precision_score(binary_relevance, np.ones_like(binary_relevance))
            average_precisions.append(ap_score)

    # Compute mean Average Precision (mAP) by averaging all the AP scores
    mean_ap = np.mean(average_precisions)

    return prec_at_1, prec_at_5, mean_ap

In [7]:
pipeline = ImageRetrievalSystem(model, dataloader_train, dataloader_test, config)
pipeline.fit()
predictions, labels, distances = pipeline.retrieve()

predictions = pipeline.train_labels[predictions]

evaluate(predictions, labels)

Extracting features: 100%|██████████| 47/47 [09:52<00:00, 12.61s/it]


Fitting the classifier...


Extracting features: 100%|██████████| 26/26 [05:26<00:00, 12.56s/it]


Retrieving images...
