<a href="https://colab.research.google.com/github/mhtefe/deepLearning/blob/master/7_2_Resnet50ImageRetrieval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!mkdir WorkData
!cp /content/drive/My\ Drive/MachineLearning/datas/101_ObjectCategories.tar.gz WorkData/ 

In [None]:
cd WorkData

In [None]:
ls

In [None]:
!tar -xf 101_ObjectCategories.tar.gz

In [None]:
import os

import numpy as np

import torch
import torch.nn as nn
import torchvision

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models

from PIL import Image

caltech101Root = '/content/WorkData/101_ObjectCategories/'

def get101CategoriesData():
    list_file = []
    list_cate = []
    allPaths = list(os.walk(caltech101Root))[1:]
    for element in allPaths:
        path, _, files = element
        for file in files:
            category = path.split('\\')[-1] # check this line on linux
            imgToLoad = os.path.join(path, file)
            
            list_file.append(imgToLoad)
            list_cate.append(category)
    return list_file, list_cate
    
#%%    

In [None]:
class Categories_101(Dataset):
    def __init__(self, transform=None):
        self.input_images, self.input_categories = get101CategoriesData()       
        self.transform = transform
    
    def __len__(self):
        return len(self.input_images)
    
    def __getitem__(self, idx):        
        image =  Image.open(self.input_images[idx]).convert('RGB') 
        
        if self.transform:
            image = self.transform(image)
        
        return image, self.input_categories[idx], self.input_images[idx]
    
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
transform = transforms.Compose([ transforms.Resize((224,224)),transforms.ToTensor(), normalize ])

categ_dataset = Categories_101(transform)

In [None]:
model = models.resnet50(pretrained=True, progress=True)
model.fc = nn.Identity()

for param in model.parameters():
    param.requires_grad = False   
    
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

model = model.to(device)

In [None]:
batch_size = 32
img_loader = torch.utils.data.DataLoader(categ_dataset, batch_size=batch_size, shuffle=False)  # <1>
print(len(img_loader.dataset))

In [None]:
all_feats = []
total = 0
model.eval()

for inputs, labels, path in img_loader:
  inputs = inputs.to(device)
  preds = model(inputs)

  #all_preds.append( preds.cpu().detach().numpy() )
  aa = preds.cpu().numpy()
  all_feats.append(aa)
  
  total = total + batch_size
  print( 100*(total/8677))

all_stack_matrix = np.vstack(all_feats)

In [None]:
ls

In [None]:
!cp /content/drive/My\ Drive/MachineLearning/datas/256_ObjectCategories.zip WorkData/ 

In [None]:
cd ..

In [None]:
cd WorkData/

In [None]:
!unzip 256_ObjectCategories.zip > 256_ObjectCategories

In [None]:
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

n_neighbors = 10
neighbors = NearestNeighbors(n_neighbors=n_neighbors,
                             algorithm='ball_tree',
                             metric='euclidean')
neighbors.fit(all_stack_matrix)

In [None]:
#filename = '230.trilobite-101'
#filename = '012.binoculars'
#filename = '007.bat'
filename = '001.ak47'
#filename = '029.cannon'
#filename = '038.chimp'

filename = filename + '/' + filename[0:3] + '_0001.jpg'
img_path = '/content/WorkData/256_objectcategories/256_ObjectCategories/' + filename

img = Image.open(img_path).convert('RGB')
pilImgT = transform(img)

test_features = model(pilImgT.unsqueeze(0).to(device))
test_features_np = test_features.cpu().numpy()

_, indices = neighbors.kneighbors(test_features_np)

# this part of code is taken from "somewhere else", I'll share the reference
def similar_images(indices):
    plt.figure(figsize=(15,15), facecolor='white')
    plotnumber = 1    
    for index in indices:
        if plotnumber<=len(indices) :
            ax = plt.subplot(2,5,plotnumber)
            _, _, c = categ_dataset[index]
            plt.imshow(mpimg.imread(c), interpolation='lanczos')            
            plotnumber+=1
    plt.tight_layout()

print(indices.shape)

plt.imshow(mpimg.imread(img_path), interpolation='lanczos')
plt.xlabel(img_path.split('.')[0] + '_Original Image',fontsize=20)
plt.show()
print('********* Predictions ***********')
similar_images(indices[0])