In [1]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

In [2]:
imgs_array = np.load('imgs_array.npy')
y = np.load('y.npy')
names = np.load('names.npy')

In [3]:
from sklearn.model_selection import train_test_split 
X_train, X_test, y_train, y_test, names_train, names_test = train_test_split(imgs_array, y, names, test_size=0.2, random_state=42)

In [11]:
from sklearn.metrics.pairwise import euclidean_distances

def image_retrieval_k(train_data, test_data, train_names, test_names, hide=0):
    avg_precisions = []
    avg_recalls = []
    all_avg_precisions = []
    all_avg_recalls = []
    
    for idx, query in enumerate(test_data):
        
        all_precisions = []
        all_recalls = []
        precisions = []
        recalls = []

        query = query.reshape((1, -1))
        D = euclidean_distances(train_data, query).squeeze()
        index = np.argsort(D)
        
        last_correct_image_idx = 0
        for i in range(len(index)):
            if train_names[index[i]] == test_names[idx]:
                last_correct_image_idx = i
                
        for kk in range(1, last_correct_image_idx):
            TP = 0
            FP = 0
            FN = 0
            
            correct_count = 0
            for ind in index:
                if train_names[ind] == test_names[idx]:
                    correct_count += 1
            sized_index = index[:kk]
            
            tmp = [query.reshape((32,32))]
            for ind in sized_index:
                tmp.append(train_data[ind].reshape((32,32)))
                if train_names[ind] == test_names[idx]:
                    TP += 1
                else:
                    FP += 1
            FN = correct_count - TP

            output = np.array(tmp)*255
            output = output.transpose(1, 0, 2)
            output = output.reshape((32, -1))
            im_query = Image.fromarray(output)
            
            # If the last k image is a correct image we add precision to the list
            if train_names[sized_index[-1]] == test_names[idx]:
                precisions.append(TP/(TP+FP))
                recalls.append(TP/(TP+FN))

            # Adding all precisions and recalls to a seperate list
            all_precisions.append(TP/(TP+FP))
            all_recalls.append(TP/(TP+FN))
        
            if hide == 1:
                print("Precision: {} \t Recall: {}".format(all_precisions[-1], all_recalls[-1]))
                display(im_query) 
        
        # If no correct images found we append a precision of 0 for AP of 0
        if precisions == []:
            precisions.append(0)
            recalls.append(0)
            
        avg_precisions.append(np.average(precisions))
        avg_recalls.append(np.average(recalls))
        
        if hide == 2 or hide == 1:
            print("\nAverage Precision for query {}: ".format(idx), avg_precisions[-1])
        
        all_avg_precisions.append(np.average(all_precisions))
        all_avg_recalls.append(np.average(all_recalls))
        
    return avg_precisions, avg_recalls, all_avg_precisions, all_avg_recalls

Input for image_retrieval_k:
- X_train -> Training data (labelled X_train here)
- X_test -> Testing data (used for the queries, labelled as X_test here)
- names_train -> Object in the image names of each image in the training set
- names_test -> " " for the testing set
- last number is the veiwing option:
    - 0 -> everything is hidden (just returns averages)
    - 1 -> prints the AP for each query
    - 2 -> prints the AP for each query, and prints each query image set with each overall precision, recall over all images

Output:
- avg_precisions -> A list of all AP results
- avg_recalls -> A list of all average recalls relating to the AP results
- all_avg_precisions -> A list of all precision averages, including non correct images
- all_avg_recalls -> A list of all the recall averages

mAP can be solved by finding the average of the AP list (avg_precisions):
- Which should be = 0.577

In [12]:
avg_precisions, avg_recalls, all_avg_precisions, all_avg_recalls = image_retrieval_k(X_train, X_test, names_train, names_test, 2)
mAP = np.average(avg_precisions)
print("\nmAP = {}".format(mAP))


Average Precision for query 0:  0.13619681943983533

Average Precision for query 1:  0.7825413000011822

Average Precision for query 2:  1.0

Average Precision for query 3:  0.11066929987301398

Average Precision for query 4:  0.8563186640436217

Average Precision for query 5:  0.3501491148449118

Average Precision for query 6:  0.020383640196746412

Average Precision for query 7:  0.24747722772308606

Average Precision for query 8:  0.3417442591349827

Average Precision for query 9:  0.9601754385964912

Average Precision for query 10:  1.0

Average Precision for query 11:  0.9068864468864469

Average Precision for query 12:  0.4720538435823559

Average Precision for query 13:  0.8859259259259259

Average Precision for query 14:  0.9819004524886878

Average Precision for query 15:  0.436107118843049

Average Precision for query 16:  0.6748730101646999

Average Precision for query 17:  0.9897435897435898

Average Precision for query 18:  0.9809090909090908

Average Precision for query 


Average Precision for query 160:  0.015301721318105732

Average Precision for query 161:  1.0

Average Precision for query 162:  0.2121130556453953

Average Precision for query 163:  0.6323428042506032

Average Precision for query 164:  0.2704788143685297

Average Precision for query 165:  0.3785121291684094

Average Precision for query 166:  0.21217073241502954

Average Precision for query 167:  0.43576987373065773

Average Precision for query 168:  0.6483850512245688

Average Precision for query 169:  1.0

Average Precision for query 170:  0.6353970972954036

Average Precision for query 171:  0.022863187486363942

Average Precision for query 172:  0.12539978132961294

Average Precision for query 173:  1.0

Average Precision for query 174:  0.9498697574355468

Average Precision for query 175:  0.9147463768115943

Average Precision for query 176:  0.8583296476485333

Average Precision for query 177:  0.8303190164954871

Average Precision for query 178:  0.024358566793780238

Average P