In [3]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

Matplotlib is building the font cache; this may take a moment.


In [4]:
imgs_array = np.load('imgs_array.npy')
y = np.load('y.npy')
names = np.load('names.npy')

In [5]:
from sklearn.model_selection import train_test_split 
X_train, X_test, y_train, y_test, names_train, names_test = train_test_split(imgs_array, y, names, test_size=0.2, random_state=42)

In [53]:
from sklearn.metrics.pairwise import euclidean_distances

def image_retrieval_k(train_data, test_data, train_names, test_names, k=20, hide=0):
    avg_precisions = []
    avg_recalls = []
    all_avg_precisions = []
    all_avg_recalls = []
    precisionsatk = []
    
    for idx, query in enumerate(test_data):
        
        all_precisions = []
        all_recalls = []
        precisions = []
        recalls = []

        query = query.reshape((1, -1))
        D = euclidean_distances(train_data, query).squeeze()
        index = np.argsort(D)
        
        last_correct_image_idx = 0
        for i in range(len(index)):
            if train_names[index[i]] == test_names[idx]:
                last_correct_image_idx = i
        
        if k > last_correct_image_idx:
            last_correct_image_idx = k+1
        
        for kk in range(1, last_correct_image_idx+1):
            TP = 0
            FP = 0
            FN = 0
            
            correct_count = 0
            for ind in index:
                if train_names[ind] == test_names[idx]:
                    correct_count += 1
            sized_index = index[:kk]
            
            tmp = [query.reshape((32,32))]
            for ind in sized_index:
                tmp.append(train_data[ind].reshape((32,32)))
                if train_names[ind] == test_names[idx]:
                    TP += 1
                else:
                    FP += 1
            FN = correct_count - TP

            output = np.array(tmp)*255
            output = output.transpose(1, 0, 2)
            output = output.reshape((32, -1))
            im_query = Image.fromarray(output)
            
            # If the last k image is a correct image we add precision to the list
            if train_names[sized_index[-1]] == test_names[idx]:
                precisions.append(TP/(TP+FP))
                recalls.append(TP/(TP+FN))

            # Adding all precisions and recalls to a seperate list
            all_precisions.append(TP/(TP+FP))
            all_recalls.append(TP/(TP+FN))
        
            if hide == 1:
                print("Precision@k: {} \t Recall: {}".format(all_precisions[-1], all_recalls[-1]))
                display(im_query) 
        
        # If no correct images found we append a precision of 0 for AP of 0
        if precisions == []:
            precisions.append(0)
            recalls.append(0)
         
        avg_precisions.append(np.average(precisions))
        avg_recalls.append(np.average(recalls))
        precisionsatk.append(all_precisions[k])
        
        if hide == 2 or hide == 1:
                print("\nAverage Precision for query {}: ".format(idx), avg_precisions[-1])
                print("Precision@k for query {}: ".format(idx), all_precisions[k])
        
        all_avg_precisions.append(np.average(all_precisions))
        all_avg_recalls.append(np.average(all_recalls))
        
    return avg_precisions, avg_recalls, all_avg_precisions, all_avg_recalls, precisionsatk

Input for image_retrieval_k:
- X_train -> Training data (labelled X_train here)
- X_test -> Testing data (used for the queries, labelled as X_test here)
- names_train -> Object in the image names of each image in the training set
- names_test -> " " for the testing set
- k -> returns the precisions at that k value
- last number is the veiwing option:
    - 0 -> everything is hidden (just returns averages and precisionsatk)
    - 1 -> prints the AP and precision@k for each query
    - 2 -> prints the AP and precision@k for each query, and prints each query image set with each overall precision, recall over all images

Output:
- avg_precisions -> A list of all AP results
- avg_recalls -> A list of all average recalls relating to the AP results
- all_avg_precisions -> A list of all precision averages, including non correct images
- all_avg_recalls -> A list of all the recall averages
- precisionatk -> The precisions at the set k value in the input for each query

mAP can be solved by finding the average of the AP list (avg_precisions):
- Which should be = 0.575

In [27]:
avg_precisions, avg_recalls, all_avg_precisions, all_avg_recalls, precisionsatk = image_retrieval_k(X_train, X_test, names_train, names_test, 20, 2)
mAP = np.average(avg_precisions)
print("\nmAP = {}".format(mAP))


Average Precision for query 0:  0.13619681943983533
Precision@k for query 0:  0.1

Average Precision for query 1:  0.7825413000011822
Precision@k for query 1:  0.6

Average Precision for query 2:  1.0
Precision@k for query 2:  0.7

Average Precision for query 3:  0.104396318923837
Precision@k for query 3:  0.15

Average Precision for query 4:  0.8563186640436217
Precision@k for query 4:  0.55

Average Precision for query 5:  0.3501491148449118
Precision@k for query 5:  0.25

Average Precision for query 6:  0.020383640196746412
Precision@k for query 6:  0.0

Average Precision for query 7:  0.24747722772308606
Precision@k for query 7:  0.15

Average Precision for query 8:  0.3417442591349827
Precision@k for query 8:  0.25

Average Precision for query 9:  0.9601754385964912
Precision@k for query 9:  0.75

Average Precision for query 10:  1.0
Precision@k for query 10:  0.75

Average Precision for query 11:  0.9068864468864469
Precision@k for query 11:  0.7

Average Precision for query 12:


Average Precision for query 101:  0.6971641463949158
Precision@k for query 101:  0.35

Average Precision for query 102:  0.2902016225334977
Precision@k for query 102:  0.2

Average Precision for query 103:  0.8281222962337513
Precision@k for query 103:  0.6

Average Precision for query 104:  0.6806628551129212
Precision@k for query 104:  0.45

Average Precision for query 105:  1.0
Precision@k for query 105:  0.85

Average Precision for query 106:  0.9316525579683476
Precision@k for query 106:  0.55

Average Precision for query 107:  0.07807364118011885
Precision@k for query 107:  0.05

Average Precision for query 108:  0.8861904761904762
Precision@k for query 108:  0.55

Average Precision for query 109:  0.7933527886556142
Precision@k for query 109:  0.65

Average Precision for query 110:  0.40635662982299314
Precision@k for query 110:  0.2

Average Precision for query 111:  0.011901745055385896
Precision@k for query 111:  0.0

Average Precision for query 112:  0.9263085453633054
Prec


Average Precision for query 199:  0.2140254892015508
Precision@k for query 199:  0.1

Average Precision for query 200:  0.11427349932191912
Precision@k for query 200:  0.05

Average Precision for query 201:  0.2162341715562543
Precision@k for query 201:  0.1

Average Precision for query 202:  0.5123406800707562
Precision@k for query 202:  0.35

Average Precision for query 203:  0.44414933560975467
Precision@k for query 203:  0.45

Average Precision for query 204:  0.4657310127685745
Precision@k for query 204:  0.35

Average Precision for query 205:  0.4777164316798618
Precision@k for query 205:  0.25

Average Precision for query 206:  0.11668099323511975
Precision@k for query 206:  0.15

Average Precision for query 207:  0.43588638404103597
Precision@k for query 207:  0.3

Average Precision for query 208:  0.021684634204689156
Precision@k for query 208:  0.0

Average Precision for query 209:  0.14604015865310085
Precision@k for query 209:  0.1

Average Precision for query 210:  0.6360

In [55]:
import pandas
data = {'Precision@k': precisionsatk, 'Average Precision': avg_precisions}
df = pandas.DataFrame(data=data)

In [56]:
pandas.set_option("display.max_rows", 281, "display.max_columns", 4)
df

Unnamed: 0,Precision@k,Average Precision
0,0.095238,0.136197
1,0.571429,0.782541
2,0.666667,1.0
3,0.142857,0.104396
4,0.52381,0.856319
5,0.238095,0.350149
6,0.0,0.020384
7,0.142857,0.247477
8,0.238095,0.341744
9,0.714286,0.960175


In [58]:
df.to_csv('MPEG7-metrics_k=20.csv')