In [39]:
import os
import pickle
import numpy as np
from PIL import Image

In [40]:
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch

In [41]:
model = torchvision.models.resnet18(pretrained=True)
model = nn.Sequential(*(list(model.children())[:-1]))

In [55]:
def compute_features(img):
    image = Image.open(img).convert('RGB')
    transform = transforms.Compose([
        #transforms.CenterCrop(200),
        transforms.Resize(256),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    image = transform(image).unsqueeze(0)

    features = model(image).flatten().detach().numpy()
    return model(image)

In [56]:
def compute_feature_dataset(dataset_dir):    
    features = []
    if os.path.exists(dataset_dir):
        img_files = [os.path.join(dataset_dir, fname) for fname in os.listdir(dataset_dir) if fname.endswith(".jpg")]
        print("Total images: {}".format(len(img_files)))        
        count = 0
        for img_file in sorted(img_files):
            feature = []
            feature = compute_features(img_file)
            features.append(feature)
            if count % 100 == 0 or count == len(img_files) - 1:
                print("Processed image {} ".format(count))
            count += 1
        with open('features.pkl', 'wb') as f:
            pickle.dump(features, f)
    else:
        print("Directory does not exist")
    return features

In [57]:
dir = '{}/BBDD/'.format(os.getcwd())
print(dir)


/Users/siddhantbhambri/Downloads/BBDD/


In [58]:
def load_feature_dataset(feature_dataset_file):
    features = []
    try:
        with open(feature_dataset_file, 'rb') as f:
            features = pickle.load(f)
    except:
        print("No such file")
    if len(features) == 0:
        features = compute_feature_dataset(dataset_dir)
    return np.asarray(features)

In [59]:
def euclidean_distance(x, y):     
    return np.linalg.norm(np.array(x.detach()) - np.array(y.detach()))

def k_nearest_search(query, features, metric = "euclidean_distance", k = 10):   

    if(k > len(features)): 
        return "K is larger than proper length"

    reverse = False
    dist_to_img = []
    distance_gt = 0
    for idx, feature in enumerate(features):
        # calculate distance
        distance = 0
    
        if metric == "euclidean_distance":
            distance = euclidean_distance(feature, query)        
        
        # save the distance of ground truth
         if idx == ground_truth:
            distance_gt = distance
            
        # sort out the best results by appropriate order
        if(len(dist_to_img) < k): 
            dist_to_img.append([distance, idx])
            dist_to_img = sorted(dist_to_img, reverse = reverse)
        else:
            if (not reverse and distance < dist_to_img[-1][0]) or (reverse and distance > dist_to_img[-1][0]):
                dist_to_img[-1] = [distance, idx]
                dist_to_img = sorted(dist_to_img, reverse = reverse)
    
    dist_to_img = sorted(dist_to_img, reverse = reverse)
    ## find the rank of gt picture
     rank = 1
    for i, dti in enumerate(dist_to_img):
        if dti[-1] == ground_truth:
            rank = i + 1
            break
             
    return [dist_to_img]

In [60]:
dataset_dir = "{}/BBDD".format(os.getcwd()) ## dataset path
queryset_dir = "{}/qst1_w1".format(os.getcwd()) ## query set path

In [63]:
k = 10

distance_metric = "euclidean_distance"
feature_dataset_file = 'features.pkl'
results = []
corresps = []
features = load_feature_dataset(feature_dataset_file)

ground truth correspondences
with open('qsd1_w1/gt_corresps.pkl', 'rb') as f:
     corresps = pickle.load(f)

In [64]:
query_set = [fname for fname in sorted(os.listdir(queryset_dir)) if fname.endswith(".jpg")]
for idx, query in enumerate(query_set):
    print("Query {}: {}, ground-truth: {}".format(idx, query, str(corresps[idx][0]).zfill(5)))

    query_feature = compute_features(os.path.join(queryset_dir, query))
    
    [k_nearest] = k_nearest_search(query_feature, features, distance_metric, k)
    
    print("{}-most similar images:".format(k))
    result = []
    for i, image in enumerate(k_nearest):
        print("{}. {}.jpg, score = {}".format(i + 1, str(image[-1]).zfill(5), image[0]))
        result.append(image[-1])
        if i == k - 1:
            results.append(result)   
    print("==================================")
print(results)    
with open('result.pkl', 'wb') as f:
    pickle.dump(results, f)              

10-most similar images:
1. 00276.jpg, score = 1.0286312103271484
2. 00188.jpg, score = 1.0509461164474487
3. 00150.jpg, score = 1.0557793378829956
4. 00161.jpg, score = 1.057380199432373
5. 00022.jpg, score = 1.0591340065002441
6. 00034.jpg, score = 1.0703048706054688
7. 00242.jpg, score = 1.070947289466858
8. 00137.jpg, score = 1.0804861783981323
9. 00240.jpg, score = 1.0812982320785522
10. 00121.jpg, score = 1.0841634273529053
10-most similar images:
1. 00272.jpg, score = 1.0469697713851929
2. 00165.jpg, score = 1.070396065711975
3. 00128.jpg, score = 1.1113601922988892
4. 00240.jpg, score = 1.1176832914352417
5. 00222.jpg, score = 1.1186643838882446
6. 00283.jpg, score = 1.1253557205200195
7. 00020.jpg, score = 1.1309294700622559
8. 00015.jpg, score = 1.1357393264770508
9. 00090.jpg, score = 1.1366055011749268
10. 00258.jpg, score = 1.1385775804519653
10-most similar images:
1. 00240.jpg, score = 1.0971295833587646
2. 00242.jpg, score = 1.1092299222946167
3. 00022.jpg, score = 1.116

10-most similar images:
1. 00130.jpg, score = 1.0568861961364746
2. 00240.jpg, score = 1.0620003938674927
3. 00035.jpg, score = 1.0800104141235352
4. 00140.jpg, score = 1.088382363319397
5. 00150.jpg, score = 1.0921690464019775
6. 00137.jpg, score = 1.0994752645492554
7. 00258.jpg, score = 1.1035208702087402
8. 00015.jpg, score = 1.1061745882034302
9. 00021.jpg, score = 1.110646367073059
10. 00072.jpg, score = 1.1122254133224487
10-most similar images:
1. 00251.jpg, score = 1.0412582159042358
2. 00195.jpg, score = 1.134540319442749
3. 00242.jpg, score = 1.1377849578857422
4. 00240.jpg, score = 1.1447430849075317
5. 00043.jpg, score = 1.1529570817947388
6. 00188.jpg, score = 1.153725266456604
7. 00165.jpg, score = 1.1599212884902954
8. 00022.jpg, score = 1.1613439321517944
9. 00200.jpg, score = 1.1616042852401733
10. 00015.jpg, score = 1.1624407768249512
10-most similar images:
1. 00188.jpg, score = 1.0904144048690796
2. 00240.jpg, score = 1.1218490600585938
3. 00147.jpg, score = 1.1252

In [52]:
corresps = [[120], [170], [277], [227], [251], [274], [285], [258], [117], [203], [192], [22], [113], [101], [174], [155], [270], [47], [286], [215], [262], [245], [257], [182], [262], [38], [238], [67], [86], [133]]
count = 0
#r = [[188, 240, 147, 283, 249], [140, 22, 283, 128, 21], [277, 125, 24, 283, 165], [227, 165, 61, 24, 229], [251, 240, 188, 15, 135], [242, 176, 188, 226, 240], [240, 15, 165, 20, 128], [258, 204, 70, 106, 22], [117, 121, 240, 283, 188], [203, 240, 150, 15, 31], [239, 192, 240, 147, 154], [22, 249, 106, 242, 240], [113, 22, 140, 240, 70], [91, 240, 22, 43, 242], [140, 70, 137, 258, 106], [155, 70, 188, 22, 33], [240, 226, 242, 188, 195], [240, 47, 258, 195, 18], [286, 283, 240, 15, 106], [188, 283, 242, 215, 204], [262, 240, 249, 188, 193], [245, 165, 61, 108, 44], [240, 188, 137, 257, 15], [182, 165, 240, 195, 210], [262, 204, 66, 242, 240], [240, 242, 195, 22, 106], [240, 21, 188, 15, 242], [188, 67, 165, 240, 283], [86, 240, 165, 24, 39], [240, 176, 204, 154, 188]]
#r = [[188], [140], [277], [227], [251], [242], [240], [258], [117], [203], [239], [22], [113], [91], [140], [155], [240], [240], [286], [188], [262], [245], [240], [182], [262], [240], [240], [188], [86], [240]]
#r = [[188], [140], [277], [227], [251], [242], [240], [258], [117], [203], [239], [22], [113], [91], [140], [155], [240], [240], [286], [188], [262], [245], [240], [182], [262], [240], [240], [188], [86], [240]]
#r = [[276], [272], [240], [22], [157], [23], [188], [240], [188], [165], [239], [188], [283], [258], [22], [155], [240], [225], [130], [251], [188], [203], [91], [35], [43], [262], [227], [283], [200], [240]]
#r = [[276, 188, 150, 161, 22, 34, 242, 137, 240, 121], [272, 165, 128, 240, 222, 283, 20, 15, 90, 258], [240, 242, 22, 188, 17, 20, 70, 185, 106, 66], [22, 249, 106, 242, 240, 33, 70, 21, 62, 43], [157, 240, 39, 242, 283, 195, 165, 15, 16, 22], [23, 22, 188, 140, 147, 240, 244, 233, 242, 32], [188, 240, 242, 3, 117, 35, 261, 28, 106, 250], [240, 20, 188, 176, 34, 15, 22, 66, 283, 33], [188, 283, 242, 215, 204, 240, 195, 239, 249, 132], [165, 283, 241, 240, 24, 199, 66, 31, 125, 39], [239, 192, 240, 147, 154, 43, 140, 137, 255, 176], [188, 240, 242, 69, 252, 84, 262, 62, 31, 168], [283, 165, 281, 40, 240, 204, 140, 116, 15, 168], [258, 204, 70, 106, 22, 240, 168, 188, 140, 89], [22, 43, 242, 18, 240, 140, 40, 20, 31, 226], [155, 70, 188, 22, 33, 234, 240, 106, 60, 21], [240, 249, 70, 106, 74, 140, 16, 82, 242, 107], [225, 121, 239, 240, 283, 105, 43, 249, 188, 40], [130, 240, 35, 140, 150, 137, 258, 15, 21, 72], [251, 195, 242, 240, 43, 188, 165, 22, 200, 15], [188, 240, 147, 283, 249, 22, 15, 204, 157, 135], [203, 240, 150, 15, 31, 22, 35, 40, 33, 117], [91, 137, 140, 106, 249, 240, 212, 43, 22, 239], [35, 22, 137, 186, 204, 15, 240, 32, 200, 31], [43, 240, 233, 22, 106, 91, 278, 70, 204, 52], [262, 204, 66, 242, 240, 89, 205, 35, 43, 159], [227, 165, 61, 24, 229, 240, 108, 112, 263, 99], [283, 240, 188, 242, 31, 108, 226, 24, 112, 39], [200, 22, 240, 106, 242, 15, 226, 188, 20, 249], [240, 22, 188, 31, 106, 195, 35, 185, 150, 20]]
r = [[140, 188, 240, 15, 106, 242, 276, 43, 244, 31], [240, 14, 22, 34, 15, 48, 226, 200, 16, 150], [165, 188, 242, 195, 240, 278, 99, 15, 117, 137], [14, 150, 221, 226, 106, 195, 82, 242, 240, 16], [283, 22, 84, 15, 121, 242, 226, 233, 240, 40], [195, 240, 176, 95, 107, 249, 221, 200, 188, 193], [61, 221, 24, 112, 176, 240, 99, 247, 266, 200], [195, 15, 240, 188, 16, 95, 165, 71, 82, 242], [15, 200, 188, 20, 14, 249, 283, 122, 242, 48], [278, 150, 22, 18, 74, 239, 226, 200, 180, 117], [188, 240, 195, 283, 3, 135, 121, 117, 106, 67], [34, 242, 221, 240, 165, 200, 176, 62, 150, 241], [240, 149, 165, 176, 99, 31, 195, 121, 283, 242], [204, 240, 278, 188, 200, 34, 33, 37, 176, 249], [165, 240, 112, 220, 148, 61, 195, 71, 284, 128], [22, 150, 33, 188, 119, 204, 242, 278, 163, 84], [200, 226, 240, 242, 106, 15, 35, 20, 62, 22], [226, 106, 240, 242, 188, 145, 239, 31, 125, 278], [15, 242, 238, 240, 150, 188, 168, 145, 123, 31], [283, 240, 188, 242, 137, 176, 204, 15, 223, 36], [195, 240, 106, 283, 0, 31, 34, 137, 163, 221], [249, 188, 240, 22, 15, 150, 242, 74, 283, 226], [176, 240, 274, 188, 195, 210, 226, 16, 112, 220], [188, 106, 22, 193, 226, 14, 140, 221, 31, 200], [165, 240, 15, 200, 16, 61, 176, 142, 263, 39], [240, 188, 22, 204, 117, 137, 89, 121, 35, 106], [22, 188, 226, 200, 65, 140, 150, 106, 276, 249], [195, 176, 274, 106, 226, 65, 200, 213, 240, 246], [22, 150, 62, 195, 252, 202, 21, 226, 200, 35], [240, 33, 249, 137, 193, 43, 70, 17, 84, 22]]
for i, (a,b) in enumerate(zip(corresps, r)):
    if a[0] in b:
        print("query {}: OK".format(i))
        count += 1
    else:
        print("query {}: Not OK".format(i))
print("ratio = {}/30". format(count))
    

query 0: Not OK
query 1: Not OK
query 2: Not OK
query 3: Not OK
query 4: Not OK
query 5: Not OK
query 6: Not OK
query 7: Not OK
query 8: Not OK
query 9: Not OK
query 10: Not OK
query 11: Not OK
query 12: Not OK
query 13: Not OK
query 14: Not OK
query 15: Not OK
query 16: Not OK
query 17: Not OK
query 18: Not OK
query 19: Not OK
query 20: Not OK
query 21: Not OK
query 22: Not OK
query 23: Not OK
query 24: Not OK
query 25: Not OK
query 26: Not OK
query 27: Not OK
query 28: Not OK
query 29: Not OK
ratio = 0/30


In [None]:
# Check the result file from the code below #



In [25]:
r

[[188, 240, 147, 283, 249, 22, 15, 204, 157, 135],
 [140, 22, 283, 128, 21, 204, 242, 170, 137, 43],
 [277, 125, 24, 283, 165, 226, 281, 162, 89, 242],
 [227, 165, 61, 24, 229, 240, 108, 112, 263, 99],
 [251, 240, 188, 15, 135, 242, 43, 18, 20, 35],
 [242, 176, 188, 226, 240, 274, 202, 195, 35, 79],
 [240, 15, 165, 20, 128, 22, 89, 242, 18, 24],
 [258, 204, 70, 106, 22, 240, 168, 188, 140, 89],
 [117, 121, 240, 283, 188, 40, 35, 20, 84, 32],
 [203, 240, 150, 15, 31, 22, 35, 40, 33, 117],
 [239, 192, 240, 147, 154, 43, 140, 137, 255, 176],
 [22, 249, 106, 242, 240, 33, 70, 21, 62, 43],
 [113, 22, 140, 240, 70, 188, 51, 258, 69, 159],
 [91, 240, 22, 43, 242, 159, 188, 87, 70, 210],
 [140, 70, 137, 258, 106, 21, 240, 234, 259, 35],
 [155, 70, 188, 22, 33, 234, 240, 106, 60, 21],
 [240, 226, 242, 188, 195, 39, 204, 15, 90, 258],
 [240, 47, 258, 195, 18, 48, 140, 161, 128, 96],
 [286, 283, 240, 15, 106, 40, 16, 137, 20, 281],
 [188, 283, 242, 215, 204, 240, 195, 239, 249, 132],
 [262, 240, 

In [54]:
with open('results.pkl', 'wb') as f:
    pickle.dump(r, f)

In [1]:
import pickle

In [2]:
with open('results.pkl', 'rb') as f:
    res = pickle.load(f)