In [1]:
import torch
from torch import nn
import torchvision.transforms
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip, RandomGrayscale, ToTensor, ToPILImage

from PIL import Image
from skimage.util import random_noise
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from collections import OrderedDict
from math import ceil
import time
import random
import glob


# Model

In [2]:
from Models import CustomHasher, SmallHasher, ResnetHasher

### Visualization Tools

In [3]:
# Visualize Data
class ShowTensor():
    def __call__(self, tensor):
        npimg = tensor.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')
show_tensor = ShowTensor()
class ShowTriplet():
    def __call__(self, triplet, axes=True):
        npimgs = [ tensor.numpy() for tensor in triplet ] 
        fig, axs = plt.subplots(1,3, constrained_layout=True, figsize=(12,12))
        axs[0].imshow(np.transpose(npimgs[0], (1, 2, 0)), interpolation='nearest')
        axs[1].imshow(np.transpose(npimgs[1], (1, 2, 0)), interpolation='nearest')
        axs[2].imshow(np.transpose(npimgs[2], (1, 2, 0)), interpolation='nearest')
        axs[0].set_title("Anchor (A)")
        axs[1].set_title("Positive (P)")
        axs[2].set_title("Negative (N)")
        if not axes:
            axs[0].set_axis_off()
            axs[1].set_axis_off()
            axs[2].set_axis_off()
        plt.show()
show_triplet = ShowTriplet()

In [4]:
# Visualize Model
SmallHasher().visualize()

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 64, 64]             168
         MaxPool2d-2            [-1, 6, 32, 32]               0
              ReLU-3            [-1, 6, 32, 32]               0
       BatchNorm2d-4            [-1, 6, 32, 32]              12
       HasherBlock-5            [-1, 6, 32, 32]               0
            Conv2d-6           [-1, 12, 32, 32]             660
         MaxPool2d-7           [-1, 12, 16, 16]               0
              ReLU-8           [-1, 12, 16, 16]               0
       BatchNorm2d-9           [-1, 12, 16, 16]              24
      HasherBlock-10           [-1, 12, 16, 16]               0
           Conv2d-11           [-1, 24, 16, 16]           2,616
        MaxPool2d-12             [-1, 24, 8, 8]               0
             ReLU-13             [-1, 24, 8, 8]               0
      BatchNorm2d-14             [-1, 2

### Image Augmentation/Transforms

In [5]:
class RandomRotate:
    # Rotate by one of the given angles.
    def __init__(self, angles):
        self.angles = angles
    def __call__(self, x):
        angle = random.choice(self.angles)
        return TF.rotate(x, angle)

class RandomFillCrop:
    def __init__(self, chance):
        self.chance = chance
        self.rcs = [ RandomCrop((64, 64), padding=4, fill=0), 
                     RandomCrop((64, 64), padding=4, fill=255) ] 
    def __call__(self, img):
        return random.choice(self.rcs)(img) if random.random() < self.chance else img

class Noise:
    def __init__(self):
        pass
    def __call__(self, img):
        noise_type = random.random()
        if noise_type < .3:
            noise_type = 's&p'
        elif noise_type < .6:
            noise_type = 'gaussian'
        elif noise_type < .9:
            noise_type = 'speckle'
        else:
            return img
        
        return Image.fromarray(random_noise(np.asarray(img), mode=noise_type))

# Used in model
class ToRGBTensor:
    def __init__(self):
        self.tt = ToTensor()
    def __call__(self, img):
        return self.tt(img.convert('RGB'))

In [6]:

transforms = Compose([
    RandomFillCrop(.8),
    Noise(),
    RandomHorizontalFlip(.5),
    RandomRotate(range(0, 360, 30)),
    RandomGrayscale(0.3)
]) 

# Load Training Data

In [7]:
class TripletDataset(torch.utils.data.Dataset):    
    def __init__(self, directory, transforms, batch_size=64):
        self.transforms = transforms
        self.batch_size = batch_size
        
        self.file_list = glob.glob(f'{directory}/*.png')
        self.triplets_original = len(self.file_list) // 2
        
        self.tt = ToRGBTensor()
        
        print(f'Found {len(self.file_list)} images.')

    def __len__(self):
        return self.triplets_original

    # (A, P, N)
    def __getitem__(self, index):
        anchor = self.file_list[index]
        negative = self.file_list[self.triplets_original + index]
        
        A = Image.open(anchor)
        P = self.transforms(A)
        N = Image.open(negative)
        
        return (self.tt(A), self.tt(P), self.tt(N))
    
    def generate_batches(self):
        triplets_retrieved = 0
        while True:
            tr_after_yield = triplets_retrieved + self.batch_size
            
            triplet_tensors = [ self.__getitem__(x) for x in range(triplets_retrieved, min(tr_after_yield, len(self))) ]
            As=[];Ps=[];Ns=[]
            for triplet in triplet_tensors:
                As.append(triplet[0])
                Ps.append(triplet[1])
                Ns.append(triplet[2])
            combined = As + Ps + Ns
            
            tensor_stack = torch.stack(combined)
            yield tensor_stack
            triplets_retrieved = tr_after_yield
            
            if triplets_retrieved >= self.triplets_original:
                return
    def num_batches(self):
        return ceil(self.triplets_original / self.batch_size)

In [8]:
dataset = TripletDataset('TrainDataset', transforms)

Found 537379 images.


In [9]:
validationset = TripletDataset('ValidDataset', transforms)

Found 179127 images.


In [10]:
triplet = dataset[0]
show_triplet(triplet)
triplet = validationset[0]
show_triplet(triplet)

TypeError: Cannot handle this data type: (1, 1, 4), <f8

## Hyperparameters

## Triplet Loss:
$$L(\displaystyle A,\displaystyle P,\displaystyle N) = max\Big( d(f(\displaystyle A), f(\displaystyle P)) − d(f(\displaystyle A), f(\displaystyle N)) + margin, 0 \Big)$$

Where:
* ${\displaystyle A}$ is an anchor input, 
* ${\displaystyle P}$ is a positive input of the same class as ${\displaystyle A}$, 
* ${\displaystyle N}$ is a negative input of a different class from ${\displaystyle A}$, 
* ${\displaystyle \alpha }$  is a margin between positive and negative pairs, 
* ${\displaystyle \operatorname{f}}$ is an embedding on a metric space, and 
* ${\displaystyle \operatorname{d}}$ is a distance function on that space (In this case L1 or L2 norm, Manhattan and Euclidean distance respectively).

In [11]:
# Model
model = CustomHasher()

# Loss
norm_margin = .4
norm_type = 'Manhattan'
loss_fn = nn.TripletMarginLoss(margin=norm_margin, p= 1 if norm_type=='Manhattan' else 2)

# Optimizer
opt = torch.optim.Adam(model.parameters(), lr=.01)

# Training Loop

In [12]:
print_batches = True

# Returns loss
def train_batch(model, batch, batchnum):
        batch_start_time = time.time()
        
        # Forward
        model.zero_grad()
        out_tensors = model.forward(batch)
        A, P, N = out_tensors.split(len(batch) // 3)
        
        # Backward
        loss = loss_fn(A, P, N)
        loss.backward()
        opt.step()

        batch_end_time = time.time()
        
        if print_batches:
            loss_num = loss.item()
            batch_time = round(batch_end_time-batch_start_time, 2)
            num_batches = dataset.num_batches()
            est_epoch_time_seconds = batch_time*num_batches
            est_epoch_time_minutes = round(est_epoch_time_seconds / 60, 2)
            print(f'Completed train batch: {str(batch_num).rjust(4, " ")} of {num_batches} | \
Loss: {"{:6.4f}".format(loss_num))} | \
Time: {str(batch_time).rjust(6, " ")} (est. {est_epoch_time_minutes} min for epoch)')
        return loss_num

# Returns loss
def valid_batch(model, batch, batch_num):
        batch_start_time = time.time()

        # Forward
        model.zero_grad()
        out_tensors = model.forward(vbatch)
        A, P, N = out_tensors.split(len(vbatch) // 3)
        
        # Backward
        loss = loss_fn(A, P, N)
        batch_end_time = time.time()

        loss_num = loss.item()
        
        if print_batches:
            batch_time = round(batch_end_time-batch_start_time, 2)
            num_batches = validationset.num_batches()
            est_epoch_time_seconds = batch_time*num_batches
            est_epoch_time_minutes = round(est_epoch_time_seconds / 60, 2)
            print(f'Completed valid batch: {str(batch_num).rjust(4, " ")} of {num_batches} | \
Loss: {"{:6.4f}".format(loss_num)} | \
Time: {str(batch_time).rjust(6, " ")} (est. {est_epoch_time_minutes} min for epoch)')
        return loss_num


SyntaxError: unexpected EOF while parsing (<fstring>, line 1)

In [13]:
epoch_avg_losses = []
# Train for 5 epochs
for epoch in range(1, 100):
    
    batch_num = 0
    batch_valid_losses = []
    
    # Training Batches
    epoch_start_time = time.time()
    for batch in dataset.generate_batches():
        batch_num += 1
        
        train_batch(model, batch, batch_num)

    # Validation batches
    batch_num = 0
    for vbatch in validationset.generate_batches():
        batch_num += 1

        loss_num = valid_batch(model, vbatch, batch_num)
        batch_valid_losses.append(loss_num)
    
    epoch_end_time = time.time()
    epoch_time_minutes = round((epoch_end_time - epoch_start_time) / 60, 2)
    epoch_average_train_loss = sum(batch_valid_losses) / len(batch_valid_losses)
    print(f"""\
╔══════════════════════════════════════════════════════════════════╗\n║ \
End of Epoch: {str(epoch).rjust(3, " ")} | \
Average Loss: {str(round(epoch_average_train_loss, 2)).rjust(10, " ")} | \
Time: {str(epoch_time_minutes).rjust(6, " ")} min. ║\n\
╚══════════════════════════════════════════════════════════════════╝""")
    
    # Early Stopping
    if (epoch > 5):
        if (epoch_average_train_loss > epoch_avg_losses[-1]):
            break
    epoch_avg_losses.append(epoch_average_train_loss)

TypeError: Cannot handle this data type: (1, 1, 4), <f8

# Export Model