# Image Retrieval using Learnt Similarity ResNet-18

## Imports

In [1]:
%pip install numpy torch torchvision matplotlib tqdm

[0m

In [2]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid

import matplotlib.pyplot as plt

import tqdm

## Dataset

In [4]:
#unnormalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())

def denorm(x, channels=None, w=None ,h=None, resize=False):
    x = unnormalize(x)
    
    if resize:
        if channels is None or w is None or h is None:
            print('Number of channels, width and height must be provided for resize.')
        x = x.view(x.size(0), channels, w, h)
    
    return x

def show(img):
    npimg = img.cpu().numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)))

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(),                        
])

In [None]:
shapenet_dataset = datasets.ImageFolder("ShapeNetRendering/", transform=transform)

### Shuffling and Splitting Dataset

In [22]:
ratios = (0.79, 0.2, 0.01)

approximate_sizes = map(lambda x: round(len(shapenet_dataset) * x, 1), ratios)
print(f"Approximate dataset sizes: {tuple(approximate_sizes)}")

sizes = (107371, 27182, 1359)

[train, test, query] = random_split(shapenet_dataset, sizes)

print(f"Dataset Lengths:\n\tTrain: {len(train)},\n\tTest: {len(test)},\n\tTrain {len(query)}")

Approximate dataset sizes: (107370.5, 27182.4, 1359.1)
Dataset Lengths:
	Train: 107371,
	Test: 27182,
	Train 1359


## Model

### Residual Block

In [None]:
class ResidualBlock(nn.Module): 
    def __init__(self, channels, stride=1): 
        super().__init__() 
        
        self.in_channel, self.out_channel = channels

        self.left = nn.Sequential(
            nn.Conv2d(self.in_channel, self.out_channel, kernel_size=3,
                      stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(self.out_channel), 
            nn.ReLU(inplace=True), 
            nn.Conv2d(self.out_channel, self.out_channel, kernel_size=3,
                      stride=1, padding=1, bias=False), 
            nn.BatchNorm2d(self.out_channel)
        ) 
        
        if stride != 1 or self.in_channel != self.out_channel: 
            self.shortcut = nn.Sequential(
                nn.Conv2d(self.in_channel, self.out_channel, kernel_size=1,
                          stride=stride, padding=0, bias=False), 
                nn.BatchNorm2d(self.out_channel)
            )
        else:
            self.shortcut = nn.Sequential()
            
    def forward(self, x): 
        out = self.left(x) + self.shortcut(x)        
        return F.relu(out)

### Modified ResNet-18

In [None]:
class ModifiedResNet18(nn.Module):
    
    def __init__(self):
        super().__init__(self)
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64), 
            nn.ReLU()
        )
        
        self.conv2_x = nn.Sequential(
            self.add_residual_layer((64, 64), 2, stride=1),
            nn.MaxPool2d(3, stride=2, padding=1)
        )

        self.conv3_x = self.add_residual_layer((64, 128), 2, stride=2)
        self.conv4_x = self.add_residual_layer((128, 256), 2, stride=2)
        self.conv5_x = self.add_residual_layer((256, 512), 2, stride=2)
        
        self.avgpool = nn.AvgPool2d(7)
        self.fc = nn.Linear(512, 1)
        
    def add_residual_layer(self, channels, num_blocks, stride):
        channels_in, channels_out = channels
        
        layers = list()
        
        for i in range(num_blocks):
            if i == 0:
                layer_stride = stride
                layer_channels_in = channels_in
            else:
                layer_stride = 1
                layer_channels_in = channels_out

            layers.append(ResidualBlock(layer_channels_in,
                                        channels_out, 
                                        layer_stride))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        
        x = self.conv2_x(x)
        x = self.conv3_x(x)
        x = self.conv4_x(x)
        x = self.conv5_x(x)

        x = self.avgpool(x)

        x = x.view(x.size(0), -1)

        x = self.fc(x)
        
        return F.sigmoid(x)

## Search Algorithm

## Training

### Enable CUDA

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

### Hyperparameters

In [None]:
# Necessary Hyperparameters 
num_epochs = 20
learning_rate = 1e-3
batch_size = 128

# Additional Hyperparameters 
beta = 1
image_size = 137
weight_decay = 1e-5

### Loss Function

In [None]:
def vae_loss(recon_x, x, mu, logvar, beta):
	x_reshaped = x.view(-1, image_size ** 2)
	recon_x_reshaped = recon_x.view(-1, image_size ** 2)

	loss = F.mse_loss(recon_x_reshaped, x_reshaped, reduction="sum")
	divergence = torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / 2

	return loss - (beta * divergence)

### Initialisation

In [None]:
loader_train = DataLoader(train, batch_size=batch_size, shuffle=True)
loader_test = DataLoader(train, batch_size=batch_size, shuffle=False)

model = VAEModel(latent_dim)

optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate, 
                             weight_decay=weight_decay)

### Sample Data

In [None]:
samples, _ = next(iter(loader_test))

print(f"Dimensions of a batch: {samples.shape()}")

samples = samples.cpu()
samples = make_grid(samples, nrow=8, padding=2, normalize=False,
                    range=None, scale_each=False, pad_value=0)

plt.figure(figsize = (15,15))
plt.axis('off')

show(samples)

### Training Loop

In [None]:
model.train()

epoch_losses = list()

for epoch in range(num_epochs):
	epoch_loss = 0
	num_batches = 0

	with tqdm.tqdm(loader_train, unit="batch") as tepoch: 
		for batch_idx, (data, _) in enumerate(tepoch):   
			optimiser.zero_grad()
			data = data.to(device)

			reconstruction, mu, logvar = model(data)
			loss = vae_loss(reconstruction, data, mu, logvar, beta)

			loss.backward()
			optimiser.step()

			epoch_loss += loss.item()
			num_batches += 1

			if batch_idx % 20 == 0:
				tepoch.set_description(f"Epoch {epoch}")
				tepoch.set_postfix(loss=loss.item()/len(data))
	
	epoch_losses.append(epoch_loss / num_batches)

## Evaluation

### t-SNE Plot