In [None]:
import torch
from torch import nn
import torchvision.transforms
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip, RandomGrayscale, ToTensor, ToPILImage

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from collections import OrderedDict
from math import ceil
import time
import random
import glob

# Model

In [None]:
from Models import ResnetHasher, CustomHasher

### Visualization Tools

In [None]:
# Visualize Data
class ShowTensor():
    def __call__(self, tensor):
        npimg = tensor.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')
show_tensor = ShowTensor()
class ShowTriplet():
    def __call__(self, triplet, axes=True):
        npimgs = [ tensor.numpy() for tensor in triplet ] 
        fig, axs = plt.subplots(1,3, constrained_layout=True, figsize=(12,12))
        axs[0].imshow(np.transpose(npimgs[0], (1, 2, 0)), interpolation='nearest')
        axs[1].imshow(np.transpose(npimgs[1], (1, 2, 0)), interpolation='nearest')
        axs[2].imshow(np.transpose(npimgs[2], (1, 2, 0)), interpolation='nearest')
        axs[0].set_title("Anchor (A)")
        axs[1].set_title("Positive (P)")
        axs[2].set_title("Negative (N)")
        if not axes:
            axs[0].set_axis_off()
            axs[1].set_axis_off()
            axs[2].set_axis_off()
show_triplet = ShowTriplet()

In [None]:
# Visualize Model
from torchsummary import summary
def visualizeModel(model):
    summary(model, (3, 64, 64))
from Models import HasherBlock

#print([x for x in ResnetHasher().resnet.children()])
#ResnetHasher().visualize()

In [None]:
CustomHasher().visualize()

### Image Augmentation/Transforms

In [None]:
class RandomR4:
    """Rotate by one of the given angles."""
    def __init__(self, angles):
        self.angles = angles
    def __call__(self, x):
        angle = random.choice(self.angles)
        return TF.rotate(x, angle)
class ToRGBTensor:
    def __init__(self):
        self.tt = ToTensor()
    def __call__(self, img):
        return self.tt(img.convert('RGB'))

In [None]:
transform = True
transforms = Compose([
    RandomCrop((64, 64), padding=4),
    RandomHorizontalFlip(.5),
    RandomR4((0, 90, 180, 270)),
    RandomGrayscale(0.3)
]) if transform else Compose([])

# Load Training Data

In [None]:
class TripletDataset(torch.utils.data.Dataset):    
    def __init__(self, directory, transforms, batch_size=64):
        self.transforms = transforms
        self.batch_size = batch_size
        
        self.file_list = glob.glob(f'{directory}/*.png')
        self.triplets_original = len(self.file_list) // 2
        
        self.tt = ToRGBTensor()
        
        print(f'Found {len(self.file_list)} images.')

    def __len__(self):
        return self.triplets_original

    # (A, P, N)
    def __getitem__(self, index):
        anchor = self.file_list[index]
        negative = self.file_list[self.triplets_original + index]
        
        A = Image.open(anchor)
        P = self.transforms(A)
        N = Image.open(negative)
        
        return (self.tt(A), self.tt(P), self.tt(N))
    
    def generate_batches(self):
        triplets_retrieved = 0
        while True:
            tr_after_yield = triplets_retrieved + self.batch_size
            
            triplet_tensors = [ self.__getitem__(x) for x in range(triplets_retrieved, min(tr_after_yield, len(self))) ]
            As=[];Ps=[];Ns=[]
            for triplet in triplet_tensors:
                As.append(triplet[0])
                Ps.append(triplet[1])
                Ns.append(triplet[2])
            combined = As + Ps + Ns
            
            tensor_stack = torch.stack(combined)
            yield tensor_stack
            triplets_retrieved = tr_after_yield
            
            if triplets_retrieved >= self.triplets_original:
                return
    def len_batches(self):
        return ceil(self.triplets_original / self.batch_size)

In [None]:
dataset = TripletDataset('Dataset', transforms)

In [None]:
triplet = dataset[0]
show_triplet(triplet)

## Hyperparameters

## Triplet Loss:
$$L(\displaystyle A,\displaystyle P,\displaystyle N) = max\Big( d(f(\displaystyle A), f(\displaystyle P)) − d(f(\displaystyle A), f(\displaystyle N)) + margin, 0 \Big)$$

Where:
* ${\displaystyle A}$ is an anchor input, 
* ${\displaystyle P}$ is a positive input of the same class as ${\displaystyle A}$, 
* ${\displaystyle N}$ is a negative input of a different class from ${\displaystyle A}$, 
* ${\displaystyle \alpha }$  is a margin between positive and negative pairs, 
* ${\displaystyle \operatorname{f}}$ is an embedding on a metric space, and 
* ${\displaystyle \operatorname{d}}$ is a distance function on that space (In this case L1 or L2 norm, Manhattan and Euclidean distance respectively).

In [None]:
# Model
model = CustomHasher()

# Loss
norm_margin = .2
norm_type = 'Manhattan'
loss_fn = nn.TripletMarginLoss(margin=norm_margin, p= 1 if norm_type=='Manhattan' else 2)

# Optimizer
opt = torch.optim.Adam(model.parameters(), lr=.02)

# Training Loop

In [None]:
epoch_avg_losses = []

# Train for 5 epochs
for epoch in range(1, 100):
    
    batch_num = 0
    batches = dataset.generate_batches()
    epoch_losses = []
    
    epoch_start_time = time.time()
    for batch in batches:
        # Batch is all As, then all Ps, then all Ns. There are always an equal number of each.
        # Start batch timer
        batch_start_time = time.time()
        
        # Forward
        model.zero_grad()
        out_tensors = model.forward(batch)
        A, P, N = out_tensors.split(len(batch) // 3)
        
        # Backward
        loss = loss_fn(A, P, N)
        loss.backward()
        opt.step()
        
        batch_end_time = time.time()
        
        
        batch_num += 1
        loss_num = loss.item()
        epoch_losses.append(loss_num)
        batch_time = round(batch_end_time-batch_start_time, 2)
        len_batches = dataset.len_batches()
        est_epoch_time_seconds = batch_time*len_batches
        est_epoch_time_minutes = round(est_epoch_time_seconds / 60, 2)
        print(f'Completed batch: {str(batch_num).rjust(4, " ")} of {len_batches} | Loss: {str(round(loss_num, 2)).rjust(10, " ")} | Time: {str(batch_time).rjust(6, " ")} (est. {est_epoch_time_minutes} min for epoch)')
        
    
    epoch_end_time = time.time()
    epoch_time_minutes = round((epoch_end_time - epoch_start_time) / 60, 2)
    epoch_average_loss = sum(epoch_losses) / len(epoch_losses)
    print(f"""\
╔═════════════════════════════════════════════════════════════╗\n║ \
End of Epoch: {str(epoch).rjust(3, " ")} | \
Average Loss: {str(round(epoch_average_loss, 2)).rjust(10, " ")} | \
Time: {str(epoch_time_minutes).rjust(6, " ")} ║\n\
╚═════════════════════════════════════════════════════════════╝""")
    
    # Early Stopping
    if (epoch > 5):
        if (epoch_average_loss > epoch_average_loss[-1:]):
            break
    epoch_average_losses.append(epoch_average_loss)