In [2]:
#IMPORTING LIBRARIES

from IPython.display import clear_output
from PIL import Image 
import matplotlib.pylab as plt
import numpy as np
import os
import torch
from torch import dist
from torch import nn
from torch.nn import functional as f
from torch import optim
from torchvision.transforms import *
from tqdm import tqdm

from scraper import *

torch.autograd.set_detect_anomaly(True)

device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")

In [3]:
"""The get_images() function in scraper.py fetches 'N' examples from the 
web. This dataset is part of the one used by the original paper. It was obtained from: 
https://sites.google.com/site/imagesimilaritydata/
The text file obtained from the above website contains URLs that the images can be 
retrieved from.
Please set 'N' to your desired value, and uncomment the last line in order to
retrieve the images from the web."""

N = 5000

# get_images('./urls.txt', N, './images')

In [4]:
"""This function takes an image file, opens it as a PIL Image, resizes it to the 
   desired dimensions and finally converts it, first into a numpy array, 
   then a torch tensor."""

def tensor_from_image(path, dims):
    
    image = Image.open(path) 
    image = Resize(dims)(image)
    imagematrix = np.asarray(image) 
    
    if(len(imagematrix.shape) != 3):
        return None
    
    imagematrix = np.transpose(imagematrix, (2, 0, 1))
    imagetensor = torch.from_numpy(imagematrix)
    imagetensor = imagetensor.float()
    
    return imagetensor

In [5]:
"""This class defines the image dataset./images/ill be using for training.
   The dataset consists of triplets (query, positive, negative)
   where the 'positive' image is more similar to the 'query' image 
   than the negative image is.
   Each sample in this dataset is a dictionary containing these three tensors."""

class ImageDataset:    
        
    def __init__(self, source_dir, dims, range=(0, 500)):
        
        dirs = os.listdir(source_dir)
        dirs = dirs[range[0] : range[1]]
        samples = []
        
        for i, dir in enumerate(dirs):
            
            sample = {}
            sample['name'] = dir
            
            q_imagefile = os.path.join(source_dir, dir, 'q.png')
            p_imagefile = os.path.join(source_dir, dir, 'p.png')
            n_imagefile = os.path.join(source_dir, dir, 'n.png')

            sample['query'] = tensor_from_image(q_imagefile, dims)
            sample['pos'] = tensor_from_image(p_imagefile, dims)
            sample['neg'] = tensor_from_image(n_imagefile, dims)
            
            if(sample['query'] is None or sample['pos'] is None or sample['neg'] is None):
                continue
            
            samples.append(sample)
        
        self.samples = samples
        self.size = len(samples)
        self.dims = dims
    
    def __iter__(self):
        for sample in self.samples:
            for category in ['query', 'pos', 'neg']:
                N, H, W = sample[category].size()
                sample[category] = sample[category].view((1, N, H, W))
            yield sample
        
    def __getitem__(self, idx):
        return self.samples[idx]
    
    def display(self, idx):
        _, ax = plt.subplots(1, 3, figsize=(15, 15))
        images = []
        sample = self.samples[idx]
        for i, category in enumerate(['query', 'pos', 'neg']):
            imagetensor = sample[category]
            imagetensor = imagetensor.cpu()
            imagematrix = imagetensor.numpy()
            imagematrix = np.transpose(imagematrix, (1, 2, 0))
            imagematrix = np.int32(imagematrix)
            ax[i].imshow(imagematrix)
            
            
"""Please uncomment the lines below to test a sample execution of this class
(Select the concerned lines and press Ctrl + '/')"""
clear_output()

# EXAMPLE USAGE

# dataset = ImageDataset('./images', (224, 224))
# print("Total: {}".format(dataset.size))
# dataset.display(0)
# sample = dataset[0]
# q, p, n = sample['query'], sample['pos'], sample['neg']
# print(q.size())

In [6]:
"""This function defines the triplet ranking loss as used in the original paper.
Concretely, given three images (query, positive, negative), it provides a measure of
the effectiveness (or more precisely, lack thereof) of the model in
identifying that the 'positive' is more closely related to the 'query' 
than the 'negative'
Mathematically, it is defined as:
L (q, p, n) = max(0, g + dist(a, p) - dist(a, n))
'g' here is a margin, whose value can be set depending on how large of a distinction
is desired as ideal between the 'positive' and the 'negative' images

The code for this function was borrowed from a PyTorch Discussion Forum"""

def triple_loss(q, p, n, g=0.2) : 
    dist_function = nn.PairwiseDistance(p=2)
    distance = dist_function(q, p) - dist_function(q, n) + g 
    loss = torch.mean(torch.max(distance, torch.zeros_like(distance))) 
    return loss

In [7]:
"""This function evaluates the model on a given dataset. The model is 
considered to have done its job on a particular sample correctly if
dist(query, positive) is less than dist(query, negative), 
subsequently implying that 'positive' is more similar to 'query'
than 'negative' is.

The distance computations are done after the image tensors have been passed 
through the model"""

def evaluate(model, samples):
    
    total = 0
    correct = 0

    for sample in tqdm(samples):

        q, p, n = sample['query'], sample['pos'], sample['neg']   
        N, H, W = q.size()
        q = q.view((1, N, H, W))
        p = p.view((1, N, H, W))
        n = n.view((1, N, H, W))
        
        q, p, n = q.to(device), p.to(device), n.to(device)
        
        q = model(q)
        p = model(p)
        n = model(n)

        dist = nn.PairwiseDistance(p=2)
        if(dist(q, p) < dist(q, n)):
            correct += 1
        total += 1
    
    print("{} samples in total. Correctly identified: {}, accuracy: {:.3f}".format(total, correct, correct/total))

    

In [8]:

def train(model, train_set, test_set, optimizer, loss_function=triple_loss, 
          n_epochs=1):
    
    for i in range(n_epochs):
        total_loss = 0
        for sample in tqdm(train_set):
            q = sample['query']
            p = sample['pos']
            n = sample['neg']
            q, p, n = q.to(device), p.to(device), n.to(device)
            
            N, H, W = q.size()
            q = q.view((1, N, H, W))
            p = p.view((1, N, H, W))
            n = n.view((1, N, H, W))
            
            q = model(q)
            p = model(p)
            n = model(n)
            
            loss = triple_loss(q, p, n)
            loss.backward(retain_graph=False)
            optimizer.step()
            
            total_loss += loss
                
        print("Epoch {} completed. Loss: {}".format(i+1, total_loss / len(train_set)))
        print("Evaluating training set:")
        evaluate(model, train_set)
        print("Evaluating test set:")
        evaluate(model, test_set)
        

In [9]:
"""This class defines the model we will be training and using for our ranking
task. The model consists of three smaller convolutional networks, whose outputs are
to be concatenated to produce a final feature vector. Details of the network can
be found in the following paper:
http://wangjiangb.github.io/pdfs/deep_ranking.pdf
"""


class network(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        """The code for AlexNet was borrowed from the PyTorch Documentation"""
        self.alexnet = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.AdaptiveAvgPool2d((6, 6)),
            nn.Dropout(),
            nn.Flatten(),
            nn.Linear(9216, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU()
        )
        
        self.shallow1 = nn.Sequential(
            nn.Conv2d(3, 3, kernel_size=3, stride=4, padding=2),
            nn.Conv2d(3, 96, kernel_size=8, padding=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=4),
            nn.Conv2d(96, 96, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),            
        )
        
        self.shallow2 = nn.Sequential(
            nn.Conv2d(3, 3, kernel_size=3, stride=8, padding=2),
            nn.Conv2d(3, 96, kernel_size=8, padding=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=4),
            nn.Conv2d(96, 96, kernel_size=7, padding=7),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=4),
            nn.Flatten(),            
        )
        
        self.embedding = nn.Linear(7168, 4096)

    def forward(self, x):
        x1 = self.alexnet(x)
        x1 = nn.functional.normalize(x1, p=2, dim=1)
        
        
#         x2 = nn.functional.interpolate(x, size=57)
        x2 = self.shallow1(x)
        
#         x3 = nn.functional.interpolate(x, size=29)
        x3 = self.shallow2(x)
        
        x4 = torch.cat((x2, x3), 1)
        x4 = nn.functional.normalize(x4, p=2, dim=1)
        
        x5 = torch.cat((x1, x4), 1)
        
        x5 = self.embedding(x5)
        x5 = nn.functional.normalize(x5, p=2, dim=1)
        
#         x5 = x5.view((x5.size(1)))
        return x5
        
#         return x
        