# Pattern recognition CW2

In [1]:
from scipy.io import loadmat
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neighbors import NearestNeighbors
import time

### Import file with 6 main components:

In [2]:
""" 
camId : which camera was used to get the shot (1 or 2)
filelist: names of the images (with format x_label_camId_index.png)
labels: class of the image (which person's image is it?)
query_idx: indexes of test set
gallery_idx: indexes of test set used for kNN 
train_idx: indexes of training and validation set
"""   
train_idxs = loadmat('cuhk03_new_protocol_config_labeled.mat')

### Import file with the feature vectors of all images:

In [3]:
import json
with open('PR_data/feature_data.json', 'r') as f:
    features = json.load(f)
features_np = np.array(features) #list of features converted to an array 
features_np.shape

(14096, 2048)

In [4]:
train_idxs['labels'].shape

(14096, 1)

In [5]:
train_idxs['gallery_idx'].shape

(5328, 1)

In [6]:
train_idxs['train_idx']

array([[    1],
       [    2],
       [    3],
       ...,
       [14094],
       [14095],
       [14096]], dtype=uint16)

In [7]:
7368+5328

12696

In [8]:
train_idxs['query_idx'].shape

(1400, 1)

In [9]:
12696+1400

14096

## kNN Classification for query set on gallery set

In [11]:
def extract_info_from_filename(image_index): 
    name = str(train_idxs['filelist'][image_index][0][0])
    [dc, label, camId, index] = name.split('_')
    return label, camId

def delete_images(query_image, query_index, gallery_images, gallery_labels):
    gallery_names = train_idxs['filelist'][train_idxs['gallery_idx'].flatten()]
    query_label, query_camId = extract_info_from_filename(query_index)
    new_gallery_images = []
    new_gallery_labels = []
    for img, name in zip(gallery_images, gallery_names):
        [dc, label_g, camId_g, index] = str(name[0][0]).split('_')
        if [int(label_g), int(camId_g)] != [int(query_label), int(query_camId)]:
            new_gallery_images.append(img)
            new_gallery_labels.append(int(label_g))
    return new_gallery_images, new_gallery_labels

def score_rank(query_label, rank_k, new_gallery_labels):
    #print(rank_k)
    rank_k_labels = np.array(new_gallery_labels)[rank_k] #extract the labels which are of rank k
    return (query_label[0] in rank_k_labels) #return true if query label in rank k, else false

In [12]:
train_idxs['gallery_idx'].flatten()

array([   21,    23,    24, ..., 14062, 14064, 14065], dtype=uint16)

In [13]:
def NNClassification_deletions(query_image, query_index, gallery_images, gallery_labels, k):
    neigh = NearestNeighbors(n_neighbors=k, n_jobs=-1)
    new_gallery_images, new_gallery_labels = delete_images(query_image, query_index, gallery_images, gallery_labels)
    neigh.fit(new_gallery_images)   #new_gallery_labels)
    distances, indices = neigh.kneighbors(query_image, k)
    return distances, indices, new_gallery_labels #neigh.score(query_image, query_label), neigh.predict(query_image), 

In [14]:
gallery_labels = train_idxs['labels'][train_idxs['gallery_idx'].flatten()]
query_labels = train_idxs['labels'][train_idxs['query_idx'].flatten()]

query_images = features_np[train_idxs['query_idx'].flatten()]
gallery_images = features_np[train_idxs['gallery_idx'].flatten()]

def get_accuracy_for_k_ranks(k):
    list_of_truths = []
    start = time.time() #time tracking - start time of process
    for index, query_index in enumerate(train_idxs['query_idx'].flatten().tolist()):
        if(index < 101):
            if (index % 10 == 0):
                print("Index: ", index, " & time taken: ", time.time() - start)

            #perform NN Classification and extract top k ranks 
            dist, rank_k_indices, new_gallery_labels = NNClassification_deletions(query_images[index].reshape(1, 2048), query_index, gallery_images, gallery_labels, k)
            #calculate the score (true/false) for top k images and query image recognition
            score = score_rank(query_labels[index], rank_k_indices, new_gallery_labels)
            #create a list of scores to get overall accuracy later on
            list_of_truths.append(score)
    
    #over-all accuracy  
    acc = sum(list_of_truths)/len(list_of_truths)

In [15]:
train_idxs['gallery_idx'].flatten()

array([   21,    23,    24, ..., 14062, 14064, 14065], dtype=uint16)

In [16]:
train_idxs['filelist'][204]

array([array(['1_022_2_06.png'], dtype='<U14')], dtype=object)

In [17]:
query_images

array([[0.98361254, 0.33195475, 0.26549727, ..., 0.02350602, 0.08491972,
        0.00514889],
       [1.30696166, 0.22219124, 0.62375975, ..., 0.02954952, 0.37348923,
        0.05303315],
       [0.29322624, 0.55588913, 0.06961903, ..., 0.15892459, 0.21527511,
        0.09551907],
       ...,
       [0.7593497 , 0.24537075, 2.19513273, ..., 0.47061139, 1.35971546,
        0.46198806],
       [0.48667279, 0.49618778, 0.30824399, ..., 0.04450084, 1.28799629,
        0.37683338],
       [0.05719076, 0.60635394, 0.58300596, ..., 0.40994745, 1.37061644,
        0.55975795]])

In [28]:
gallery_labels = train_idxs['labels'][train_idxs['gallery_idx'].flatten()]
query_labels = train_idxs['labels'][train_idxs['query_idx'].flatten()]

query_images = features_np[train_idxs['query_idx'].flatten()]
gallery_images = features_np[train_idxs['gallery_idx'].flatten()]

N=100


def remove_indices(n_indices, query_index):
    query_label, query_camId = extract_info_from_filename(query_index)
    n_names = train_idxs['filelist'][n_indices]
    #k_images = features_np[k_indices]
    #print(k_names.shape, k_names[0], k_images[0])
    final_n_indices = []
    final_n_labels = []
    for index, name in zip(n_indices, n_names):
        [dc, label_g, camId_g, index_g] = str(name[0][0]).split('_')
        #print(str(name[0][0]).split('_'))
        #print(index)
        if [int(label_g), int(camId_g)] != [int(query_label), int(query_camId)]:
            final_n_indices.append(index)
            final_n_labels.append(int(label_g))
    return final_n_indices, final_n_labels

def get_accuracy(k, N):
    list_of_truths = []
    start = time.time() #time tracking - start time of process
    neigh = NearestNeighbors(n_neighbors=N, n_jobs=-1)
    neigh.fit(gallery_images)   #new_gallery_labels)
    for index, query_index in enumerate(train_idxs['query_idx'].flatten().tolist()):
        if (index % 10 == 0):
            print("Index: ", index, " & time taken: ", time.time() - start)
        N_distances, N_indices = neigh.kneighbors(query_images[index].reshape(1, 2048), N)
        #print(query_labels[index], query_index, N_indices)
        topN_gallery_images = gallery_images[N_indices[0]]
        
        topN_gallery_indices = [np.where(np.all(features_np == image, axis=1))[0][0] for image in topN_gallery_images]
        #print('topN indices', topN_gallery_indices)
        #print('n_indices', N_indices[0])
        reduced_topN_indices, reduced_topN_labels = remove_indices(topN_gallery_indices, query_index)
        if query_labels[index][0] in reduced_topN_labels[:k]:
            list_of_truths.append(True)
        else:
            list_of_truths.append(False)
        if index == 100:
            break
    #over-all accuracy  
    acc = sum(list_of_truths)/len(list_of_truths) 
    print(acc)

In [29]:
%prun get_accuracy(100, 100)

Index:  0  & time taken:  0.6375007629394531
Index:  10  & time taken:  36.28590679168701
Index:  20  & time taken:  69.53656077384949
Index:  30  & time taken:  104.77642774581909
Index:  40  & time taken:  136.43121576309204
Index:  50  & time taken:  170.34395289421082
Index:  60  & time taken:  202.93267583847046
Index:  70  & time taken:  234.7717638015747
Index:  80  & time taken:  266.59189677238464
Index:  90  & time taken:  299.0750877857208
Index:  100  & time taken:  331.22363471984863
0.8910891089108911
 

In [None]:
get_accuracy_for_k_ranks(5)

In [41]:
%prun get_accuracy_for_k_ranks(10)

Index:  0  & time taken:  0.0
Index:  10  & time taken:  23.64891242980957
Index:  20  & time taken:  47.139809370040894
Index:  30  & time taken:  71.5410213470459
Index:  40  & time taken:  95.63419580459595
Index:  50  & time taken:  119.49739098548889
Index:  60  & time taken:  142.7541127204895
Index:  70  & time taken:  165.4745020866394
Index:  80  & time taken:  188.19400596618652
Index:  90  & time taken:  211.17892265319824
Index:  100  & time taken:  233.9523696899414
 