## Imports:

In [None]:
from fastai.vision.all import *
from fastbook import *
import cv2
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.manifold import TSNE
from sklearn import preprocessing
import random
from IPython.display import clear_output
from torchvision import transforms
from sklearn.decomposition import PCA
import glob
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.optim as optim
import torchvision
from typing import Tuple
from pytorch_metric_learning import losses
import timm
from torch_lr_finder import LRFinder
import faiss
from tqdm import tqdm
import scipy.sparse as sparse
import scipy.sparse.linalg as linalg
import joblib
from joblib import Parallel, delayed

## Checking GPU:

In [None]:
# checking cuda stats and avialability
device = "cuda" if torch.cuda.is_available() else "cpu"

print("Using", device)

if device == "cuda":
    print('__CUDNN VERSION:', torch.backends.cudnn.version())
    print('__Number CUDA Devices:', torch.cuda.device_count())
    print('__CUDA Device Name:',torch.cuda.get_device_name(0))
    print('__CUDA Device Total Memory [GB]:',torch.cuda.get_device_properties(0).total_memory/1e9)

## Data Functions:

In [None]:
# load in data from train and query folders
def load_data(train_path, query_path, names):
    train_image_paths = []
    train_classes = []

    # save path to image and save class names as numbers (train)
    for data_path in glob.glob(train_path + '/*'):
        name = data_path.split('/')[-1].split("-")[0]
        idx = names.index(name)
        train_classes.append(idx) 
        train_image_paths.append(data_path)

    # save path to image and save class names as numbers (query)
    valid_image_paths = []
    valid_classes = []
    for data_path in glob.glob(query_path + '/*'):
        name = data_path.split('/')[-1].split("-")[0]
        idx = names.index(name)
        valid_classes.append(idx) 
        valid_image_paths.append(data_path)

    print("Train Images: {} | Query Images: {}".format(len(train_image_paths), len(valid_image_paths)))
    return train_image_paths, train_classes, valid_image_paths, valid_classes

In [None]:
# transform data function
data_transforms = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

In [None]:
# dataset function for loading images and classes in
class Dataset(Dataset):
    def __init__(self, image_paths, labels, transform=None, train=False):
        self.is_train = train
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_filepath = self.image_paths[idx]
        anchor_image = Image.open(image_filepath)
        anchor_label = self.labels[idx]
        
        if self.is_train:
            if self.transform is not None:
                anchor_image = self.transform(anchor_image)
        else:
            if self.transform is not None:
                anchor_image = self.transform(anchor_image)
        return anchor_image, anchor_label

In [None]:
# display images function:
def imshow(img, title=None):
     # unnormalize img and display
    img = img / 2 + 0.5
    img = np.transpose(img.numpy(), (1, 2, 0))
    img = cv2.normalize(img, None, alpha = 0, beta = 255, norm_type = cv2.NORM_MINMAX, dtype = cv2.CV_32F)
    img = img.astype(np.uint8)
    plt.figure(figsize=(10, 10), dpi=80)
    plt.imshow(img)
    plt.title(title)
    plt.axis("off")
    plt.show()

def show_batch(dataloader, n_samples=5):
    # get some random training images
    dataiter = iter(dataloader)
    anchors, label = dataiter.next()
    # show anchor, pos, neg, images
    imshow(torchvision.utils.make_grid(anchors[:n_samples]), "Anchor Images")

    # print labels
    print("Labels")
    print([names[idx] for idx in label[:n_samples]])

## Fine Tuning functions:

In [None]:
# fine tuning model parameters
def fine_tune(dataloader, model, epochs=1, lr=0.00005):

    criterion = losses.CircleLoss(m=0.4, gamma=80)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=5e-6)
    
    iteration_loss = []
    epoch_loss = []
    maps = []
    
    # method to update weights in given model 
    for epoch in range(epochs):  # loop over the dataset multiple times
        running_loss = []
        for i, data in enumerate(dataloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            anchor, anchor_label = data[0].to(device), data[1].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            anchor_out = model(anchor)
                        
            loss = criterion(anchor_out, anchor_label)
            loss.backward()
            optimizer.step()

            anchor_out.detach()
                
            # print statistics
            running_loss.append(loss.cpu().detach().numpy())
            iteration_loss.append(loss.cpu().detach().numpy())
        epoch_loss.append(np.mean(running_loss))
        
        # setting a model checkpoint to compute mAP:
        torch.save({
            'model_state_dict': model.state_dict(),
            }, "checkpoint")

        model.eval()
        trn_features, trn_names, val_features, val_names, AP, precisionsatk, mAP = compute_map(model, 'oxford', 'easy', 0, 0, False)
        maps.append(mAP)
        
        checkpoint = torch.load("checkpoint")
        model.load_state_dict(checkpoint['model_state_dict'])    

        print("Epoch: {}/{} - Loss: {:.4f} - mAP: {:.4f}".format(epoch+1, epochs, np.mean(running_loss), maps[-1]))

    print('Finished Training')
    return iteration_loss, epoch_loss, maps

In [None]:
def plot_loss(iteration_loss, epoch_loss, maps):
    plt.plot(iteration_loss)
    plt.title("Loss over each iteration")
    plt.show()

    plt.plot(epoch_loss) 
    plt.title("Loss over each epoch")
    plt.show()

    plt.plot(maps)
    plt.title("mAP over each epoch")
    plt.show()

## Models:

In [None]:
def load_model(mod):
    if mod == 'vgg':
        out_model = vgg16_bn(pretrained=True)
        out_model.classifier = out_model.classifier[:4] 
    elif mod == 'resnet':
        out_model = resnet50(pretrained=True)
        out_model.fc = nn.Identity()
    elif mod == 'swin':
        out_model = timm.create_model('swin_large_patch4_window7_224_in22k', pretrained=True)
        out_model.head = nn.Identity()
    elif mod == 'vit':
        out_model = timm.create_model('vit_large_patch16_224_in21k', pretrained=True)
        out_model.head = nn.Identity()
    elif mod == 'mae_vit':
        out_model = timm.create_model('vit_large_patch16_224_in21k')
        mae_pretrianed = torch.load('Saved_models/mae_pretrain_vit_large.pth')
        out_model.load_state_dict(mae_pretrianed['model'], strict=False)
    else:
        print("Unknown Model: Try 'vgg' or 'resnet' or 'vit' or 'swin' or 'mae_vit'")
        out_model = None
    return out_model.to(device)

In [None]:
def load_ft(m, mod, dataset):
    checkpoint = torch.load("Saved_models/{}-{}-model".format(mod, dataset))
    m.load_state_dict(checkpoint['model_state_dict'])  
    return m

## Loading ROxford5k:

In [None]:
# easy data
ox_train_path = "roxford5k/easy"
ox_query_path = "roxford5k/query"
ox_names = ['radcliffe_camera','hertford','all_souls','bodleian','balliol','magdalen','christ_church','pitt_rivers','ashmolean','keble','cornmarket']
ox_easy_image_paths, ox_easy_classes, ox_query_image_paths, ox_query_classes = load_data(ox_train_path, ox_query_path, ox_names)

ox_easy_dataset = Dataset(ox_easy_image_paths, ox_easy_classes, data_transforms, True)
ox_query_dataset = Dataset(ox_query_image_paths, ox_query_classes, data_transforms, False)

ox_easy_loader = DataLoader(ox_easy_dataset , batch_size=32, shuffle=True)
ox_query_loader = DataLoader(ox_query_dataset, batch_size=32, shuffle=True)

In [None]:
# hard data
ox_train_path = "roxford5k/hard"
ox_hard_image_paths, ox_hard_classes, ox_query_image_paths, ox_query_classes = load_data(ox_train_path, ox_query_path, ox_names)

ox_hard_dataset = Dataset(ox_hard_image_paths, ox_hard_classes, data_transforms, True)

ox_hard_loader = DataLoader(ox_hard_dataset , batch_size=32, shuffle=True)

In [None]:
# medium data
ox_medium_image_paths = ox_hard_image_paths + ox_easy_image_paths
ox_medium_classes = ox_hard_classes + ox_easy_classes

ox_medium_dataset = Dataset(ox_medium_image_paths, ox_medium_classes, data_transforms, True)

ox_medium_loader = DataLoader(ox_medium_dataset , batch_size=32, shuffle=True)

In [None]:
# load junk images in:
ox_train_path = "roxford5k/junk"
ox_junk_image_paths, ox_junk_classes, ox_query_image_paths, ox_query_classes = load_data(ox_train_path, ox_query_path, ox_names)

# combine all images (medium with junk):
ox_all_image_paths = ox_medium_image_paths + ox_junk_image_paths
ox_all_classes = ox_medium_classes + ox_junk_classes

# make full dataset and loader:
ox_all_dataset = Dataset(ox_all_image_paths, ox_all_classes, data_transforms, True)
ox_all_loader = DataLoader(ox_all_dataset , batch_size=32, shuffle=True)

print("Train Images: {} | Query Images: {}".format(len(ox_all_image_paths), len(ox_query_image_paths)))

## Loading RParis6k:

In [None]:
# easy data
par_train_path = "rparis6k/easy"
par_query_path = "rparis6k/query"
par_names = ["defense", "eiffel", "invalides","louvre","moulinrouge","museedorsay","notredame","pantheon","pompidou","sacrecoeur","triomphe",]
par_easy_image_paths, par_easy_classes, par_query_image_paths, par_query_classes = load_data(par_train_path, par_query_path, par_names)

par_easy_dataset = Dataset(par_easy_image_paths, par_easy_classes, data_transforms, True)
par_query_dataset = Dataset(par_query_image_paths, par_query_classes, data_transforms, False)

par_easy_loader = DataLoader(par_easy_dataset , batch_size=32, shuffle=True)
par_query_loader = DataLoader(par_query_dataset, batch_size=32, shuffle=True)

In [None]:
# hard data
par_train_path = "rparis6k/hard"
par_hard_image_paths, par_hard_classes, par_query_image_paths, par_query_classes = load_data(par_train_path, par_query_path, par_names)

par_hard_dataset = Dataset(par_hard_image_paths, par_hard_classes, data_transforms, True)

par_hard_loader = DataLoader(par_hard_dataset , batch_size=32, shuffle=True)

In [None]:
# medium data
par_medium_image_paths = par_hard_image_paths + par_easy_image_paths
par_medium_classes = par_hard_classes + par_easy_classes

par_medium_dataset = Dataset(par_medium_image_paths, par_medium_classes, data_transforms, True)

par_medium_loader = DataLoader(par_medium_dataset , batch_size=32, shuffle=True)

In [None]:
# load junk images in:
par_train_path = "rparis6k/junk"
par_junk_image_paths, par_junk_classes, par_query_image_paths, par_query_classes = load_data(par_train_path, par_query_path, par_names)

# combine all images (medium with junk):
par_all_image_paths = par_medium_image_paths + par_junk_image_paths
par_all_classes = par_medium_classes + par_junk_classes

# make full dataset and loader:
par_all_dataset = Dataset(par_all_image_paths, par_all_classes, data_transforms, True)
par_all_loader = DataLoader(par_all_dataset , batch_size=32, shuffle=True)

print("Train Images: {} | Query Images: {}".format(len(par_all_image_paths), len(par_query_image_paths)))

## data computation functions:

In [None]:
def train_data(model, train_dset, print_opt=True):
    train_features = []
    train_images = []
    
    for idx, (img, name) in enumerate(train_dset):
        if print_opt:
            print("Image {} / {}".format(idx+1, len(train_dset)), end="\r")
        
        # unnormalize image and save        
        save_img = img/2 + 0.5
        save_img = np.transpose(save_img.numpy(), (1, 2, 0))
        save_img = cv2.normalize(save_img, None, alpha = 0, beta = 255, norm_type = cv2.NORM_MINMAX, dtype = cv2.CV_32F)
        save_img = save_img.astype(np.uint8)
        train_images.append(save_img)
        
        # compute feature for img
        img = img.unsqueeze(0)
        feature = model(img.to("cuda"))
        train_features.append(Tensor.cpu(feature).detach().numpy())
        
    # converting to numpy arrays
    trn_features = np.array(train_features).reshape((len(train_features),train_features[0].shape[1]))
    trn_images = np.array(train_images)

    # displaying shapes
    if print_opt:
        print("Features shape: {} | Images shape: {}".format(trn_features.shape, trn_images.shape), end="\r")
    return trn_features, trn_images

In [None]:
def valid_data(model, valid_dset, print_opt=True):
    valid_features = []
    valid_images = []

    for idx, (img, name) in enumerate(valid_dset):
        if print_opt:
            print("Image {} / {}".format(idx+1, len(valid_dset)), end="\r")
        
        # unnormalize image and save        
        save_img = img/2 + 0.5
        save_img = np.transpose(save_img.numpy(), (1, 2, 0))
        save_img = cv2.normalize(save_img, None, alpha = 0, beta = 255, norm_type = cv2.NORM_MINMAX, dtype = cv2.CV_32F)
        save_img = save_img.astype(np.uint8)
        valid_images.append(save_img)
        
        # compute feature for img
        img = img.unsqueeze(0)
        feature = model(img.to("cuda"))
        valid_features.append(Tensor.cpu(feature).detach().numpy())

    # converting to numpy arrays
    val_features = np.array(valid_features).reshape((len(valid_features),valid_features[0].shape[1]))
    val_images = np.array(valid_images)

    # displaying shapes
    if print_opt:
        print("Features shape: {} | Images shape: {}".format(val_features.shape, val_images.shape), end="\r")
    return val_features, val_images

In [None]:
def pca(trn_features, val_features, dim=128, print_opt=True):
    # PCA Dimension reduction
    pca = PCA(n_components=dim)
    pca.fit(trn_features)

    # Dimension reduction
    trn_features = pca.transform(trn_features)
    val_features = pca.transform(val_features)

    if print_opt:
        print("Train Features shape: {} | Valid Features shape: {}".format(trn_features.shape, val_features.shape))
    
    return trn_features, val_features

## Diffusion:

In [None]:
def get_offline_result(i):
    ids = trunc_ids[i]
    trunc_lap = lap_alpha[ids][:, ids]
    scores, _ = linalg.cg(trunc_lap, trunc_init, tol=1e-6, maxiter=20)
    return scores

class Diffusion(object):
    """Diffusion class
    """
    def __init__(self, features, cache_dir):
        self.features = features
        self.N = len(self.features)
        self.cache_dir = cache_dir
        # use ANN for large datasets
        self.use_ann = self.N >= 100000
        if self.use_ann:
            self.ann = ANN(self.features, method='cosine')
        self.knn = KNN(self.features, method='cosine')

    def get_offline_results(self, n_trunc, kd=50):
        """Get offline diffusion results for each gallery feature
        """
        global trunc_ids, trunc_init, lap_alpha
        if self.use_ann:
            _, trunc_ids = self.ann.search(self.features, n_trunc)
            sims, ids = self.knn.search(self.features, kd)
            lap_alpha = self.get_laplacian(sims, ids)
        else:
            sims, ids = self.knn.search(self.features, n_trunc)
            trunc_ids = ids
            lap_alpha = self.get_laplacian(sims[:, :kd], ids[:, :kd])
        trunc_init = np.zeros(n_trunc)
        trunc_init[0] = 1

        results = Parallel(n_jobs=-1, prefer='threads')(delayed(get_offline_result)(i)
                                      for i in range(self.N))
        all_scores = np.concatenate(results)

        rows = np.repeat(np.arange(self.N), n_trunc)
        offline = sparse.csr_matrix((all_scores, (rows, trunc_ids.reshape(-1))),
                                    shape=(self.N, self.N),
                                    dtype=np.float32)
        return offline

    def get_laplacian(self, sims, ids, alpha=0.99):
        """Get Laplacian_alpha matrix
        """
        affinity = self.get_affinity(sims, ids)
        num = affinity.shape[0]
        degrees = affinity @ np.ones(num) + 1e-12
        # mat: degree matrix ^ (-1/2)
        mat = sparse.dia_matrix(
            (degrees ** (-0.5), [0]), shape=(num, num), dtype=np.float32)
        stochastic = mat @ affinity @ mat
        sparse_eye = sparse.dia_matrix(
            (np.ones(num), [0]), shape=(num, num), dtype=np.float32)
        lap_alpha = sparse_eye - alpha * stochastic
        return lap_alpha

    def get_affinity(self, sims, ids, gamma=3):
        """Create affinity matrix for the mutual kNN graph of the whole dataset
        Args:
            sims: similarities of kNN
            ids: indexes of kNN
        Returns:
            affinity: affinity matrix
        """
        num = sims.shape[0]
        sims[sims < 0] = 0  # similarity should be non-negative
        sims = sims ** gamma
        # vec_ids: feature vectors' ids
        # mut_ids: mutual (reciprocal) nearest neighbors' ids
        # mut_sims: similarites between feature vectors and their mutual nearest neighbors
        vec_ids, mut_ids, mut_sims = [], [], []
        for i in range(num):
            # check reciprocity: i is in j's kNN and j is in i's kNN when i != j
            ismutual = np.isin(ids[ids[i]], i).any(axis=1)
            ismutual[0] = False
            if ismutual.any():
                vec_ids.append(i * np.ones(ismutual.sum(), dtype=int))
                mut_ids.append(ids[i, ismutual])
                mut_sims.append(sims[i, ismutual])
        vec_ids, mut_ids, mut_sims = map(np.concatenate, [vec_ids, mut_ids, mut_sims])
        affinity = sparse.csc_matrix((mut_sims, (vec_ids, mut_ids)),
                                     shape=(num, num), dtype=np.float32)
        return affinity

In [None]:
class BaseKNN(object):
    """KNN base class"""
    def __init__(self, database, method):
        if database.dtype != np.float32:
            database = database.astype(np.float32)
        self.N = len(database)
        self.D = database[0].shape[-1]
        self.database = database if database.flags['C_CONTIGUOUS'] \
                               else np.ascontiguousarray(database)

    def add(self, batch_size=10000):
        """Add data into index"""
        if self.N <= batch_size:
            self.index.add(self.database)
        else:
            [self.index.add(self.database[i:i+batch_size])
                    for i in tqdm(range(0, len(self.database), batch_size),
                                  desc='[index] add')]

    def search(self, queries, k):
        """Search
        Args:
            queries: query vectors
            k: get top-k results
        Returns:
            sims: similarities of k-NN
            ids: indexes of k-NN
        """
        if not queries.flags['C_CONTIGUOUS']:
            queries = np.ascontiguousarray(queries)
        if queries.dtype != np.float32:
            queries = queries.astype(np.float32)
        sims, ids = self.index.search(queries, k)
        return sims, ids


class KNN(BaseKNN):
    """KNN class
    Args:
        database: feature vectors in database
        method: distance metric
    """
    def __init__(self, database, method):
        super().__init__(database, method)
        self.index = {'cosine': faiss.IndexFlatIP,
                      'euclidean': faiss.IndexFlatL2}[method](self.D)
        if os.environ.get('CUDA_VISIBLE_DEVICES'):
            self.index = faiss.index_cpu_to_all_gpus(self.index)
        self.add()


class ANN(BaseKNN):
    """Approximate nearest neighbor search class
    Args:
        database: feature vectors in database
        method: distance metric
    """
    def __init__(self, database, method, M=128, nbits=8, nlist=316, nprobe=64):
        super().__init__(database, method)
        self.quantizer = {'cosine': faiss.IndexFlatIP,
                          'euclidean': faiss.IndexFlatL2}[method](self.D)
        self.index = faiss.IndexIVFPQ(self.quantizer, self.D, nlist, M, nbits)
        samples = database[np.random.permutation(np.arange(self.N))[:self.N // 5]]
        print("[ANN] train")
        self.index.train(samples)
        self.add()
        self.index.nprobe = nprobe

## Image retrieval functions:

In [None]:
def cosine_sim(train_data, queries):
    D = cosine_similarity(queries, train_data)
    indexes = np.argsort(1-D)
    return indexes

In [None]:
def diffusion(train_data, queries):
    cache_dir = "cache"
    kd = 50
    truncation_size = 586
    
    n_queries = len(queries)
    diffusion = Diffusion(np.vstack([queries, train_data]), cache_dir)
    
    offline = diffusion.get_offline_results(truncation_size, kd)
    
    features = preprocessing.normalize(offline, norm="l2", axis=1)
    
    scores = features[:n_queries] @ features[n_queries:].T
    
    ranks = np.argsort(-scores.toarray())
    
    return ranks

In [None]:
def image_retrieval_k(train_data, test_data, train_names, test_names, train_images, test_images, k=10, view_option=0, border_size=3, print_opt=True, diff=False):
    avg_precisions = []
    precisionsatk = []
    count = 0
    
    # Finding similarity order:
    if diff:
        indexes = diffusion(train_data, test_data)
    else:
        indexes = cosine_sim(train_data, test_data)
        
    for idx, index in enumerate(indexes):
        all_precisions = []
        precisions = []
        
        # Finding the index of the last correct image in the sorted index to iter to
        last_correct_image_idx = 0
        for i in range(len(index)):
            if train_names[index[i]] == test_names[idx]:
                last_correct_image_idx = i
        
        # make sure we iter to k (for precision@k) if all correct images are found before k
        if k > last_correct_image_idx:
            last_correct_image_idx = k+1
        
        # Itering through all images untill we get to k or last correct image to compute AP
        for kk in range(1, last_correct_image_idx+2):
            TP = 0
            FP = 0
            FN = 0
            
            # Finding the correct amount of images in the training set
            correct_count = 0
            for ind in index:
                if train_names[ind] == test_names[idx]:
                    correct_count += 1
            sized_index = index[:kk]
            
            # Find TP FP FN
            for ind in sized_index:
                if train_names[ind] == test_names[idx]:
                    TP += 1
                else:
                    FP += 1
            FN = correct_count - TP
            
            # If the last k image is a correct image we add precision to the list
            if train_names[sized_index[-1]] == test_names[idx]:
                precisions.append(TP/(TP+FP))

            # Adding all precisions and recalls to a seperate list
            all_precisions.append(TP/(TP+FP))
        
        # Solving AP and precision@k
        avg_precisions.append(np.average(precisions))
        precisionsatk.append(all_precisions[k-1])
        
        # display retrieval:
        if view_option == 0:
            count += 1
            if print_opt:
                print("Percentage Complete: {}%".format(round((count/len(test_data))*100),2), end="\r")
        elif view_option == 1:
            display_retrieval(test_data, test_images, idx, train_images, index, test_names, train_names, sized_index, avg_precisions[-1], precisionsatk[-1], border_size, k)
            
    return avg_precisions, precisionsatk

In [None]:
 def display_retrieval(test_data, test_images, idx, train_images, index, test_names, train_names, sized_index, avg_precisions, precisionsatk, border_size, k):
    top_k_images = [test_images[idx]]
    for i in range(0,k):
        top_k_images.append(train_images[index[i]])

    fig, axes = plt.subplots(1, k+1, figsize=(200/k, 200/k))
    for i, (image, ax) in enumerate(zip(top_k_images, axes.ravel())):
        if i == 0:
            query_name = test_names[idx]
            title = "Query: {}".format(query_name)
        else:
            title = train_names[sized_index[i-1]]
            if train_names[sized_index[i-1]] == query_name:
                color = (0, 255, 0)
                image = border(image, color, border_size)
            else:
                color = (255, 0, 0)
                image = border(image, color, border_size)
        # display all set options
        ax.imshow(image, cmap="gray")
        ax.set_title(title)
        ax.axis("off")
    plt.show()
    print("Label: {}".format(test_names[idx]))
    print("Average Precision for query {}: ".format(idx), avg_precisions)
    print("Precision@k for query {}: ".format(idx), precisionsatk)
    print("\n")

In [None]:
def border(img, color, border_size):
    # get dimensions
    h, w = img.shape[:2]

    # make a base slightly bigger than image
    base_size= h+(border_size*2), w+(border_size*2), 3
    base = np.zeros(base_size, dtype=np.uint8)

    # make a boundary of chosen color
    cv2.rectangle(base, (0,0), (w+20,h+20), color, 30)

    # put original image into base
    base[border_size:h+border_size, border_size:w+border_size] = img
    
    return base

In [None]:
def compute_map(model, dataset, difficulty, pca_opt=64, view=0, print_opt=True, diff=False):
    run = True
    # finding correct imports
    if dataset == 'oxford':
        names = ox_names
        query_dataset = ox_query_dataset
        query_classes = ox_query_classes
        if difficulty == 'easy':
            train_dataset = ox_easy_dataset
            train_classes = ox_easy_classes
        elif difficulty == 'medium':
            train_dataset = ox_medium_dataset
            train_classes = ox_medium_classes
        elif difficulty == 'hard':
            train_dataset = ox_hard_dataset
            train_classes = ox_hard_classes
        else:
            print("Unkown mode: Try 'easy', 'medium', or 'hard'")
            run = False
    elif dataset == 'paris':
        names = par_names
        query_dataset = par_query_dataset
        query_classes = par_query_classes
        if difficulty == 'easy':
            train_dataset = par_easy_dataset
            train_classes = par_easy_classes
        elif difficulty == 'medium':
            train_dataset = par_medium_dataset
            train_classes = par_medium_classes
        elif difficulty == 'hard':
            train_dataset = par_hard_dataset
            train_classes = par_hard_classes
        else:
            print("Unkown mode: Try 'easy', 'medium', or 'hard'") 
            run = False
    else:
        print("Unknown dataset: Try 'oxford' or 'paris'")
        run = False
    
    if run:
        # loading in training data
        if print_opt:
            print("Loading Training features...")
        trn_features, trn_images = train_data(model, train_dataset, print_opt)

        # loading in validation data
        if print_opt:
            print("\n\nLoading Query features...")
        val_features, val_images = valid_data(model, query_dataset, print_opt)

        # computing names
        trn_names = np.array([names[idx] for idx in train_classes])
        val_names = np.array([names[idx] for idx in query_classes])

        # normalization
        n_queries = len(val_features)
        features = np.vstack([val_features, trn_features])
        features = preprocessing.normalize(features, norm="l2", axis=0)
        val_features = features[:n_queries]
        trn_features = features[n_queries:]

        # compute PCA for dimension reduction
        if pca_opt != 0:
            if print_opt:
                print("\n\nComputing PCA dimension reduction...")
            trn_pca, val_pca = pca(trn_features, val_features, pca_opt, print_opt)
        else:
            print("")
            trn_pca, val_pca = trn_features, val_features

        # compute mAP (can display retrieval)
        if print_opt:
            print("\nComputing mAP...")
        AP, precisionsatk = image_retrieval_k(trn_pca, val_pca, trn_names, val_names, trn_images, val_images, 10, view, 15, print_opt, diff)
        mAP = np.mean(AP)
        if print_opt:
            print("\n\nmAP = {}".format(mAP))
        
        return trn_features, trn_names, val_features, val_names, AP, precisionsatk, mAP

In [None]:
def compute_pixel(names, query_dataset, query_classes, train_dataset, train_classes):
    trn_names = np.array([names[idx] for idx in train_classes])
    val_names = np.array([names[idx] for idx in query_classes])

    val_images = []
    for i in query_dataset:
        img = i[0]
        save_img = img/2 + 0.5
        save_img = np.transpose(save_img.numpy(), (1, 2, 0))
        save_img = cv2.normalize(save_img, None, alpha = 0, beta = 255, norm_type = cv2.NORM_MINMAX, dtype = cv2.CV_32F)
        save_img = save_img.astype(np.uint8)
        val_images.append(save_img)
    val_images = np.array(val_images)

    trn_images = []
    for i in train_dataset:
        img = i[0]
        save_img = img/2 + 0.5
        save_img = np.transpose(save_img.numpy(), (1, 2, 0))
        save_img = cv2.normalize(save_img, None, alpha = 0, beta = 255, norm_type = cv2.NORM_MINMAX, dtype = cv2.CV_32F)
        save_img = save_img.astype(np.uint8)
        trn_images.append(save_img)
    trn_images = np.array(trn_images)
    
    val_pixels = []
    for i in val_images:
        img = cv2.cvtColor(i, cv2.COLOR_RGB2GRAY)
        pix = img.reshape((50176))
        val_pixels.append(pix)

    trn_pixels = []
    for i in trn_images:
        img = cv2.cvtColor(i, cv2.COLOR_RGB2GRAY)
        pix = img.reshape((50176))
        trn_pixels.append(pix) 

    val_pixels = np.array(val_pixels)
    trn_pixels = np.array(trn_pixels)

    return trn_pixels, val_pixels, trn_names, val_names, trn_images, val_images

## Finding Learning rate

In [None]:
def find_lr(m, dataloader):
    criterion = losses.CircleLoss(m=0.4, gamma=80)
    optimizer = optim.AdamW(m.parameters(), lr=1e-7, weight_decay=1e-2)
    lr_finder = LRFinder(m, optimizer, criterion, device="cuda")
    lr_finder.range_test(dataloader, end_lr=10, num_iter=100)
    lr_finder.plot()
    lr_finder.reset()

## Visualization functions:

In [None]:
# visualization for a given query image:
def visualize_query(dataset, difficulty, query, query_name, train_data, train_names, k):
    # Finding the euclidean distance from the query image and sorting them into index
    query = query.reshape((1, -1))
    D = euclidean_distances(train_data, query).squeeze()
    index = np.argsort(D)

    # grab only the k closest points
    data = []
    name = []
    for idx in index[:k]:
        data.append(train_data[idx])
        name.append(train_names[idx])
  
    joined_data = np.concatenate((query, np.array(data)))
    # create tsne
    tsne = TSNE(random_state=42)
    joined_tsne = tsne.fit_transform(joined_data)

    X_tsne = joined_tsne[1:]
    query_tsne = joined_tsne[:1]

    #sorting based on class label:
    sort_idx = np.argsort(name)
    sorted_names = []
    sorted_tSNE = []
    for i in sort_idx:
        sorted_names.append(name[i])
        sorted_tSNE.append(X_tsne[i])
    sorted_tSNE = np.array(sorted_tSNE)
    
    # convert names to integers for graphing function:
    names = []
    y = []
    counter = []
    count = -1
    start = 0
    for i in sorted_names:
        if i not in names:
            counter.append(start)
            start = 0
            names.append(i)
            count += 1
        start += 1
        y.append(count)
    counter.append(start)

    # setting colours:
    colors = ["#476A2A", "#7851B8", "#BD3430", "#4A2D4E", "#875525", "#A83683", "#4E655E", "#853541", "#3A3120", "#535D8E", "blue"]
    if len(names) > len(colors)+1:
        # setting colours:
        colors = []
        for i in range(len(y)+1):
            r = random.random()
            b = random.random()
            g = random.random()
            colors.append((r, g, b))


  # Plot output:
    plt.figure(figsize=(10, 10))
    if X_tsne[:, 0].min() < 0:
        x = 1.5
    else:
        x=-1.5
    if X_tsne[:, 1].min() < 0:
        xx = 1.5
    else:
        xx = -1.5
    plt.xlim(sorted_tSNE[:, 0].min()*x, sorted_tSNE[:, 0].max()*1.5)
    plt.ylim(sorted_tSNE[:, 1].min()*xx, sorted_tSNE[:, 1].max()*1.5)

    for i in range(len(X_tsne)):
        plt.text(sorted_tSNE[i, 0], sorted_tSNE[i, 1], str(y[i]), color = colors[y[i]], fontdict={'weight': 'bold', 'size': 9})

    plt.text(query_tsne[0, 0], query_tsne[0, 1], ".", color=(1,0,0), fontdict={'weight': 'bold', 'size': 40})
    plt.title("t-SNE for {}-{} for k = {}".format(dataset, difficulty, k))
    plt.ylabel("t-SNE feature 0")
    plt.xlabel("t-SNE feature 1")
    plt.show()

    # print information relating to plot:
    print("Query Image: {} | Red Square\n".format(query_name))

    print(f'{"NAME":<20s} {"NUMBER":<10s} {"COUNT":<5s}')
    for idx, name, count in zip(range(len(names)), names, counter[1:]):
        print(f'{name:<20s} {str(idx):<10s} {str(count):<5s}')

In [None]:
def visualize_class(dataset, difficulty, class_name, filenames, features, model_name, save_folder=None):    
    # create tsne
    tsne = TSNE(random_state=42)
    X_tsne = tsne.fit_transform(features)
    
    if class_name[0] == "all":
        class_tSNE = X_tsne
        class_names = filenames
    else:
        # select only class related features
        class_tSNE = []
        class_names = []
        for t, name in zip(X_tsne, filenames):
            if name in class_name:
                class_tSNE.append(t)
                class_names.append(name)
        class_tSNE = np.array(class_tSNE)
    
    #sorting based on class label:
    sort_idx = np.argsort(class_names)
    sorted_names = []
    sorted_tSNE = []
    for i in sort_idx:
        sorted_names.append(class_names[i])
        sorted_tSNE.append(class_tSNE[i])
    sorted_tSNE = np.array(sorted_tSNE)
        
    # convert names to integers for graphing function:
    names = []
    y = []
    counter = []
    count = -1
    start = 0
    for i in sorted_names:
        if i not in names:
            counter.append(start)
            start = 0
            names.append(i)
            count += 1
        start += 1
        y.append(count)
    counter.append(start)

    # setting colours:
    colors = ["#476A2A", "#7851B8", "#BD3430", "#4A2D4E", "#875525", "#A83683", "#4E655E", "#853541", "#3A3120", "#535D8E", "blue"]
    if len(names) > len(colors)+1:
        # setting colours:
        colors = []
        for i in range(len(y)+1):
            r = random.random()
            b = random.random()
            g = random.random()
            colors.append((r, g, b))


  # Plot output:
    plt.figure(figsize=(10, 10))
    if X_tsne[:, 0].min() < 0:
        x = 1.5
    else:
        x=-1.5
    if X_tsne[:, 1].min() < 0:
        xx = 1.5
    else:
        xx = -1.5
    plt.xlim(sorted_tSNE[:, 0].min()*x, sorted_tSNE[:, 0].max()*1.5)
    plt.ylim(sorted_tSNE[:, 1].min()*xx, sorted_tSNE[:, 1].max()*1.5)

    for i in range(len(class_tSNE)):
        plt.text(sorted_tSNE[i, 0], sorted_tSNE[i, 1], str(y[i]), color = colors[y[i]], fontdict={'weight': 'bold', 'size': 9})

    plt.title("t-SNE w. {} for {}-{}".format(model_name, dataset, difficulty))

    plt.ylabel("t-SNE feature 0")
    plt.xlabel("t-SNE feature 1")
        
    if save_folder != None:
        # save figure to set folder
        plt.savefig(save_folder)
    else:
        plt.show()
        # print information relating to plot:
        print(f'{"NAME":<20s} {"NUMBER":<10s} {"COUNT":<5s}')
        for idx, name, count in zip(range(len(names)), names, counter[1:]):
            print(f'{name:<20s} {str(idx):<10s} {str(count):<5s}')

## Retreival:


In [None]:
# load in a model
m = load_model('resnet')

In [None]:
# load in pre fine tuned models
m = load_ft(m, 'resnet', 'oxford')

In [None]:
# find optimal learning rate
find_lr(m, ox_all_loader)

In [None]:
# fine tune model
iteration_loss, epoch_loss, maps = fine_tune(ox_all_loader, m, epochs=50, lr=0.00005)

In [None]:
# plot loss and map over epoch/iteration when finetuning
plot_loss(iteration_loss, epoch_loss, maps)

In [None]:
# Compute map for specific dataset
model = m.eval()
dataset = 'oxford'
difficulty = 'easy'
view = 0
pca_opt = 0
print_opt = True
diff = True

trn_features, trn_names, val_features, val_names, AP, precisionsatk, mAP = compute_map(model, dataset, difficulty, pca_opt, view, print_opt, diff)

In [None]:
## Pixel:
names = ox_names
query_dataset = ox_query_dataset
query_classes = ox_query_classes
train_dataset = ox_easy_dataset
train_classes = ox_easy_classes
view = 1

trn_pixels, val_pixels, trn_names, val_names, trn_images, val_images = compute_pixel(names, query_dataset, query_classes, train_dataset, train_classes)

AP, precisionsatk = image_retrieval_k(trn_pixels, val_pixels, trn_names, val_names, trn_images, val_images, 10, view, 15, True, False)

## Visualization:

In [None]:
# visualize query -> hard queries = 8, 43, 50
# displays query image, its average precision (AP) and where it fits with all the training points (red square is query image)
# choose query idx
idx = 0

# display query + AP
plt.imshow(val_images[idx])
plt.title("Query Image | {} | AP={}".format(val_names[idx], np.round(AP[idx], 4)))
plt.axis("off")
plt.show()

# visualize query w.r.t training data
visualize_query(dataset, difficulty, val_features[idx], val_names[idx], trn_features, trn_names, k=-1)

In [None]:
# to visualize all data: ["all"], else ["name1", "name2", etc]
visualize_class(dataset, difficulty, ["all"], trn_names, trn_features, "ResNet-ft")