<a href="https://colab.research.google.com/github/mhtefe/deepLearning/blob/master/imageRetrieval/7_4_ImageRetrieval_Satellite_Resnet50_distance.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Copy data files to local machine

In [None]:
!mkdir WorkData
!cp /content/drive/My\ Drive/MachineLearning/datas/NWPU-RESISC45.rar WorkData/ 

In [None]:
cd WorkData/

In [None]:
!mkdir NWPU

In [None]:
!unrar x "NWPU-RESISC45.rar" "NWPU/"

In [None]:
ls

In [None]:
import os

import numpy as np

import torch
import torch.nn as nn
import torchvision

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models

from PIL import Image

nwpuRoot = '/content/WorkData/NWPU/NWPU-RESISC45/'

def getNWPUData():
    list_file = []
    list_cate = []
    allPaths = list(os.walk(nwpuRoot))[1:]
    for element in allPaths:
        path, _, files = element
        for file in files:
            category = path.split('\\')[-1] # check this line on linux, looks pretty dummy
            imgToLoad = os.path.join(path, file)
            
            list_file.append(imgToLoad)
            list_cate.append(category)
    return list_file, list_cate

'''
lst, ctg = getNWPUData()
print(lst)
'''

In [None]:
class Categories_NWPU(Dataset):
    def __init__(self, transform=None):
        self.input_images, self.input_categories = getNWPUData()       
        self.transform = transform
    
    def __len__(self):
        return len(self.input_images)
    
    def __getitem__(self, idx):        
        image =  Image.open(self.input_images[idx]).convert('RGB') 
        
        if self.transform:
            image = self.transform(image)
        
        return image, self.input_categories[idx], self.input_images[idx]
    
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
transform = transforms.Compose([ transforms.Resize((224,224)),transforms.ToTensor(), normalize ])

categ_dataset = Categories_NWPU(transform)

In [None]:
model = models.resnet50(pretrained=True, progress=True)
model.fc = nn.Identity()

for param in model.parameters():
    param.requires_grad = False   
    
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

model = model.to(device)

In [None]:
batch_size = 32
img_loader = torch.utils.data.DataLoader(categ_dataset, batch_size=batch_size, shuffle=False)  # <1>
print(len(img_loader.dataset))

In [None]:
all_feats = []
total = 0
totalData = len(img_loader.dataset)
model.eval()

for inputs, labels, path in img_loader:
  inputs = inputs.to(device)
  preds = model(inputs)

  aa = preds.cpu().numpy()
  all_feats.append(aa)
  
  total = total + batch_size
  print( 100*(total//totalData))

all_stack_matrix = np.vstack(all_feats)

In [None]:
cd ..

In [None]:
!cp /content/drive/My\ Drive/MachineLearning/datas/UCMerced_LandUse.zip WorkData/ 

In [None]:
cd WorkData/

In [None]:
!unzip UCMerced_LandUse.zip > UCMerced

In [None]:
from scipy import spatial
import matplotlib.pyplot as plt
import matplotlib.image as ima

In [None]:
#filename = 'airplane/airplane00.tif'
#filename = 'storagetanks/storagetanks00.tif' # this one is interesting, gives round shapes as output
#filename = 'beach/beach00.tif' #this one looks cool
#filename = 'baseballdiamond/baseballdiamond03.tif' # this can be a lil bit challenging
filename = 'overpass/overpass03.tif' # more challenging
#filename = 'freeway/freeway03.tif'
#filename = 'agricultural/agricultural08.tif'
#filename = 'harbor/harbor08.tif'
#filename = 'mediumresidential/mediumresidential08.tif'
#filename = 'river/river01.tif'

img_path = '/content/WorkData/UCMerced_LandUse/Images/' + filename

img = Image.open(img_path).convert('RGB')
pilImgT = transform(img)

test_features_np = model(pilImgT.unsqueeze(0).to(device)).cpu().numpy()
distances =  []
for i in all_stack_matrix:
  res = spatial.distance.cosine(i, test_features_np)
  distances.append(res)

distances = np.array(distances)
indices = distances.argsort()[:5]

fig, axes = plt.subplots(nrows=1, ncols=6, figsize=(25,25) )
axes[0].patch.set_edgecolor('green')  
axes[0].patch.set_linewidth('10')  
axes[0].imshow(ima.imread(img_path), interpolation='lanczos')

plt.rcParams["axes.edgecolor"] = 'black'
plotnumber = 1
for index in indices:
      _, _, c = categ_dataset[index]
      axes[plotnumber].patch.set_edgecolor('blue')  
      axes[plotnumber].patch.set_linewidth('10') 
      axes[plotnumber].imshow(ima.imread(c), interpolation='lanczos')            
      plotnumber+=1

plt.show()