In [1]:
import sys
import os

import torch
import torch.nn as nn
import torchvision.models as models
import pandas as pd
FOOD101_CLASSES = 101

def fix_names(state_dict):
    state_dict = {key.replace('module.', ''): value for (key, value) in state_dict.items()}
    return state_dict

model = models.mobilenet_v2(num_classes=FOOD101_CLASSES)  
checkpoint_path = 'mobilenet_v2_food101/pytorch_model.bin'

if os.path.isfile(checkpoint_path):
    print("=> loading checkpoint '{}'".format(checkpoint_path))
    

    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    weights = fix_names(checkpoint['state_dict'])
    model.load_state_dict(weights)

    print("=> loaded checkpoint '{}' (epoch {})"
            .format(checkpoint_path, checkpoint['epoch']))

  warn(


=> loading checkpoint 'mobilenet_v2_food101/pytorch_model.bin'
=> loaded checkpoint 'mobilenet_v2_food101/pytorch_model.bin' (epoch 27)


In [2]:
import torchvision
from torchvision import models, transforms, datasets

test_dataset = datasets.Food101(
        root='data/train',
        split = 'test'
    )
train_dataset = datasets.Food101(
        root='data/train',
        split = 'train',
    )

In [3]:
model.eval()

# Get the output of the last hidden layer
last_hidden_layer_output = None
def hook(module, input, output):
    global last_hidden_layer_output
    x = nn.functional.adaptive_avg_pool2d(output, (1, 1))
    last_hidden_layer_output = torch.flatten(x, 1)


In [4]:
data = pd.read_csv('final.csv')

In [5]:
data

Unnamed: 0,index,0,1,2,3,4,5,6,7,8,...,1272,1273,1274,1275,1276,1277,1278,1279,correct,prediction
0,0,0.839059,1.363838,0.080090,0.065024,1.491130,0.180133,0.081063,0.020513,2.038073,...,0.000000,0.044557,0.010182,0.000000,2.146513,0.123560,0.001441,0.000000,23,23
1,1,0.442812,2.405817,0.000000,0.009632,1.888990,0.004689,0.000000,0.613277,2.301057,...,0.027783,0.288800,0.008069,0.375614,1.273387,0.468740,0.003605,0.010072,23,23
2,2,0.079898,1.391043,0.599216,1.220002,1.637205,0.009429,0.172605,0.725320,1.864051,...,0.143566,0.400629,0.194802,0.386077,2.010916,0.095304,0.000000,0.000000,23,23
3,3,0.027375,0.955091,1.381925,1.178553,1.754310,0.023663,0.000000,0.073860,1.765674,...,0.153445,0.078129,0.132515,1.929875,2.099552,0.054047,0.077225,0.000000,23,23
4,4,0.448010,0.916095,0.461979,0.262524,0.973948,0.188042,0.000000,0.269618,1.456297,...,0.010148,0.113712,0.098903,0.281540,2.169709,1.261045,0.013835,0.000000,23,23
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
75745,75745,0.296093,0.000000,0.770814,0.009673,0.805547,0.448031,0.474958,0.115898,0.049163,...,0.238353,0.194286,0.295472,1.275607,0.796487,0.088551,0.181679,1.591767,69,69
75746,75746,0.364466,0.312251,1.551891,0.007010,0.212980,0.684131,0.425860,0.717610,0.000000,...,1.568368,0.105566,1.146569,2.035128,0.810989,0.067670,0.306812,0.724687,69,69
75747,75747,0.008435,0.000000,0.268577,0.181721,0.423351,0.003225,0.809493,0.037407,0.386182,...,0.574975,0.000000,0.232983,0.708428,1.095990,0.603932,0.353420,0.889988,69,69
75748,75748,0.086369,0.000000,0.230066,0.000000,0.566825,0.427237,0.693684,0.192603,0.312837,...,0.288039,0.027522,0.576601,0.872804,1.219102,0.087943,0.170363,0.946733,69,69


In [6]:
def get_image_features(img):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])
    img = transform(img)
    hook_handle = model.features.register_forward_hook(hook)
    with torch.no_grad():
        res = model(img.unsqueeze(0))
    # Detach the hook
    hook_handle.remove()
    return last_hidden_layer_output[0]

In [7]:
# Create annoy for nearest neighbour
from annoy import AnnoyIndex
import random

f = 1280  # Length of item vector that will be indexed

t = AnnoyIndex(f, 'angular')
for _,x in data.iterrows():
    vals = x.values[1:1281]
    t.add_item(int(x['index']), vals)

t.build(100) # 10 trees
t.save('test.ann')



True

In [10]:
u = AnnoyIndex(f, 'angular')
u.load('test.ann') # super fast, will just mmap the file

True

In [11]:
def get_image_nearest(img,ann,n=5):
    features = get_image_features(img)
    values = data.values[:,1:1281].shape
    res = ann.get_nns_by_vector(features,n=n,include_distances=True)

    return res

In [44]:
idx = 5000
index,distances = get_image_nearest(test_dataset[idx][0],u)

In [47]:
data[data.index.isin(index)]['correct']

15396    71
15481    71
15707    71
28038    65
28060    65
Name: correct, dtype: int64

In [46]:
test_dataset[idx][0]

[0.5099616050720215,
 0.5581638216972351,
 0.5859385132789612,
 0.5926702618598938,
 0.6115350127220154]