# IMPORTS

In [1]:
import os
import glob

import numpy as np

from tqdm.auto import tqdm

import skimage
from skimage import io

import matplotlib.pyplot as plt

from annoy import AnnoyIndex # [1]

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models

import albumentations as A # [2]
from albumentations.pytorch.transforms import ToTensorV2

# CLASS DEFINITIONS

In [2]:
class Helper():
    def __init__(self):
        pass
    def sample_plot(x, figsize = (25, 25), num = 16):
        '''Plots the sample for the dataset
        x: -> Data
        figsize: -> Tuple, telling the size of the figure
        num: -> Number of samples that are to be plotted'''
        fig = plt.figure(figsize = figsize)
        x = x.detach().permute(0, 2, 3, 1).cpu().numpy()
        for i in range(num):
            plt.subplot(5, 5, i + 1, xticks = [], yticks = [])
            plt.imshow(x[i])
    def read_img(addr):
        '''Takes URI as an input and returns an image
        addr: -> A URI string that is the address for the input image'''
        img = io.imread(addr)
        return img

In [3]:
class get_data(Dataset):
    '''Used along with dataloader to load the dataset'''
    def __init__(self, files_list):
        super(get_data, self).__init__()
        self.files = files_list
        self.tfms = A.Compose([
            A.Resize(224, 224),
            A.Normalize(mean = 0.0, std = 1.0, max_pixel_value = 255.0),
            ToTensorV2()
        ])
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        x = self.tfms(image = io.imread(self.files[idx]))['image'].to(DEVICE)
        return x

In [4]:
class Network(nn.Module):
    '''This class defines the architecture of the network'''
    def __init__(self):
        super(Network, self).__init__()
        self.model = models.resnet50(pretrained = True)
        weight = self.model.conv1.weight.clone()
        self.model.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False) # Altering the model to accomodate 4-channel input
        with torch.no_grad():
            self.model.conv1.weight[:, :3] = weight
            self.model.conv1.weight[:, 3] = self.model.conv1.weight[:, 0]
        self.model.avgpool = nn.Identity()
        self.model.fc = nn.Identity()
        for param in self.model.parameters():
            param.requires_grad = False
    def forward(self, x):
        x = self.model(x)
        return x

In [5]:
class solution(nn.Module):
    def __init__(self, inp, model, ann_idx, files_list, num):
        super(solution, self).__init__()
        self.num = num
        self.files_list = files_list
        tfms = A.Compose([
            A.Resize(224, 224),
            A.Normalize(mean = 0.0, std = 1.0, max_pixel_value = 255.0),
            ToTensorV2()
        ])
        self.q = inp
        emb = model(torch.unsqueeze(tfms(image = inp)['image'].cuda(), axis = 0)).detach().cpu().numpy().flatten()
        ann_idx.add_item(len_db + 1, emb)
        self.similar = ann_idx.get_nns_by_item(len_db + 1, self.num)
    def plot_sim(self):
        '''Plots similar images to the input'''
        plt.figure(figsize = (10, 10))
        plt.subplot()
        plt.xticks([])
        plt.yticks([])
        plt.imshow(self.q)
        title = plt.title('Query Image')
        plt.setp(title, color = 'green')
        plt.figure(figsize = (25, 25))
        for i in range(self.num - 1):
            plt.subplot(5, 5, i + 1, xticks = [], yticks = [])
            plt.imshow(io.imread(self.files_list[self.similar[i + 1]]))
            title = plt.title('Similar Garment')
            plt.setp(title, color = 'green')

# DEVICE SETUP

In [6]:
if torch.cuda.is_available():
    DEVICE = "cuda:0"
    torch.cuda.empty_cache()
else:
    DEVICE = "cpu"

# DATASET LOADING AND VISUALIZATION

In [7]:
files_list = glob.glob('../input/flixstock/bottoms_resized_png/*.png')
len_db = len(files_list)
num = 10

DB = get_data(files_list)
db = DataLoader(DB, shuffle = False, batch_size = len(os.listdir('../input/flixstock/bottoms_resized_png')))

x = next(iter(db))

Helper.sample_plot(x)

# MODEL DEFINITION AND FEATURE MAP GENERATION

In [8]:
model = Network().to(DEVICE)

feature_vec = []

for i in tqdm(range(len_db)):
    with torch.no_grad():
        fv = model(torch.unsqueeze(x[i], axis = 0)).detach().cpu().numpy().flatten()
    feature_vec.append(fv)

# INDEX CREATION

In [9]:
feature_len = len(feature_vec[0])
ann_idx = AnnoyIndex(feature_len, 'angular')

for i in tqdm(range(len_db)):
    ann_idx.add_item(i, feature_vec[i])

ann_idx.build(10)

# OUTPUT

In [14]:
q_img = Helper.read_img(addr = files_list[150])
sol = solution(q_img, model, ann_idx, files_list, num + 1)
sol.plot_sim()

# REFERENCES

[1] Erik Bernhardsson. (2018). Annoy: Approximate Nearest Neighbors in C++/Python.

[2] A. Buslaev, V., & A.~A. Kalinin (2018). Albumentations: fast and flexible image augmentations. ArXiv e-prints.