In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from collections import OrderedDict

## Hyperparameters

In [3]:
loss_margin = .2
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True)

Using cache found in /home/apaz/.cache/torch/hub/pytorch_vision_v0.6.0


In [21]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        
        class OntoF16(nn.Module):
            def __init__(self):
                super(OntoF16, self).__init__()
                
                self.f16_map = 2**12 - 1
            def forward(self, x):
                return x * self.f16_map
        
        self.resnet = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True)
        self.enc_layers = OrderedDict([ 
            ('resnet', self.resnet),
                                       
            ('full_1', nn.Linear(1000, 750)), 
            ('relu_1', nn.ReLU()),
            ('norm_1', nn.BatchNorm1d(750)),
            
            ('full_2', nn.Linear(750, 500)),
            ('relu_2', nn.ReLU()),
            ('norm_2', nn.BatchNorm1d(500)),
            
            ('full_3', nn.Linear(500, 250)),
            ('relu_3', nn.ReLU()),
            ('norm_3', nn.BatchNorm1d(250)),
            
            ('output', nn.Linear(250, 8)),
            ('sigmoid', nn.Sigmoid()),
            ('ontoF16', OntoF16())
        ])
        self.encoder = nn.Sequential(self.enc_layers)
    
    def forward(self, x):
        return self.encoder(x)

In [22]:
from torchsummary import summary
def visualize(model):
    summary(model, (3, 64, 64))

encoder = Encoder()
visualize(encoder)

Using cache found in /home/apaz/.cache/torch/hub/pytorch_vision_v0.6.0


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           9,408
            Conv2d-2           [-1, 64, 32, 32]           9,408
       BatchNorm2d-3           [-1, 64, 32, 32]             128
       BatchNorm2d-4           [-1, 64, 32, 32]             128
              ReLU-5           [-1, 64, 32, 32]               0
              ReLU-6           [-1, 64, 32, 32]               0
         MaxPool2d-7           [-1, 64, 16, 16]               0
         MaxPool2d-8           [-1, 64, 16, 16]               0
            Conv2d-9           [-1, 64, 16, 16]          36,864
           Conv2d-10           [-1, 64, 16, 16]          36,864
      BatchNorm2d-11           [-1, 64, 16, 16]             128
      BatchNorm2d-12           [-1, 64, 16, 16]             128
             ReLU-13           [-1, 64, 16, 16]               0
             ReLU-14           [-1, 64,

### Loss
![title](https://wikimedia.org/api/rest_v1/media/math/render/svg/933c19129ec9060b0e7ea6f54f715c4c92010399)

In [4]:
def triplet_loss(a, p, n) : 
    d = nn.PairwiseDistance(p=2)
    distance = d(a, p) - d(a, n) + loss_margin 
    loss = torch.mean(torch.max(distance, torch.zeros_like(distance))) 
    return loss

# Load Training Data

# Training Loop