<a href="https://colab.research.google.com/github/abursuc/dldiy-practicals/blob/master/siamese_triplet_mnist_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Siamese networks

## Colab preparation

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from os import path


import numpy as np
import random

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

import torch
from torch.optim import lr_scheduler
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms


# 1. Setup and initializations
We'll go through learning feature embeddings using different loss functions on MNIST dataset. This is just for visualization purposes, thus we'll be using 2-dimensional embeddings which isn't the best choice in practice.

For every experiment the same embedding network is used (`32 conv 5x5 -> ReLU -> MaxPool 2x2 -> 64 conv 5x5 -> ReLU -> MaxPool 2x2 -> Fully Connected 256 -> ReLU -> Fully Connected 256 -> ReLU -> Fully Connected 2`) with the same hyperparameters.

In [None]:
class ExperimentParams():
    def __init__(self):
        self.num_classes = 10
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.batch_size = 256
        self.lr = 1e-2
        self.num_epochs = 10
        self.num_workers = 4
        self.data_dir = '/home/docker_user/'
        

args = ExperimentParams()

## 1.1 Prepare dataset
We'll be working on MNIST dataset

In [None]:

mean, std = 0.1307, 0.3081

train_dataset = MNIST(f'{args.data_dir}/data/MNIST', train=True, download=True,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((mean,), (std,))
                             ]))
test_dataset = MNIST(f'{args.data_dir}/data/MNIST', train=False, download=True,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((mean,), (std,))
                            ]))

## 1.2 Common setup

In [None]:

mnist_classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
              '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
              '#bcbd22', '#17becf']

def plot_embeddings(embeddings, targets, title='',xlim=None, ylim=None):
    plt.figure(figsize=(10,10))
    for i in range(10):
        inds = np.where(targets==i)[0]
        plt.scatter(embeddings[inds,0], embeddings[inds,1], alpha=0.5, color=colors[i])
    if xlim:
        plt.xlim(xlim[0], xlim[1])
    if ylim:
        plt.ylim(ylim[0], ylim[1])
    plt.legend(mnist_classes)
    plt.title(title)

def extract_embeddings(dataloader, model, args):
    with torch.no_grad():
        model.eval()
        embeddings = np.zeros((len(dataloader.dataset), 2))
        labels = np.zeros(len(dataloader.dataset))
        k = 0
        for images, target in dataloader:
            images = images.to(args.device)
            embeddings[k:k+len(images)] = model.get_embedding(images).data.cpu().numpy()
            labels[k:k+len(images)] = target.numpy()
            k += len(images)
    return embeddings, labels


def get_raw_images(dataloader,mean=0.1307, std=0.3081):

    raw_images = np.zeros((len(dataloader.dataset), 1, 28, 28))
    k = 0
    for input, target in dataloader:
        raw_images[k:k+len(input)] = (input*std + mean).data.cpu().numpy()
        k += len(input)

    return raw_images


def show(img, title=None):
    # img is a torch.Tensor     
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    plt.axis('off')
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

# 2. Baseline: Classification with softmax
We'll train the model for classification and use outputs of penultimate layer as embeddings. 

We will define our base embedding architecture which will serve as common backbone for our experiments

## 2.1 Architecture

### Exercise

Complete the missing blocks in the definition of the following `EmbeddingNet` architecture: (`32 conv 5x5 -> ReLU -> MaxPool 2x2 -> 64 conv 5x5 -> ReLU -> MaxPool 2x2 -> Fully Connected 256 -> ReLU -> Fully Connected 256 -> ReLU -> Fully Connected 2`)

In [None]:
class EmbeddingNet(nn.Module):
    def __init__(self):
        super(EmbeddingNet, self).__init__()
#         self.conv1 = nn.Conv2d(1, ...)
#         self.conv2 = ...
#         self.fc1 = ...
#         self.fc2 = ...
#         self.fc3 = ...

    def forward(self, x, debug=False):
        x1 = F.max_pool2d(F.relu(self.conv1(x)), kernel_size=2, stride=2)
#       output = ...                  
        if debug == True:
            print(f'input: {x.size()}')
            print(f'x1: {x1.size()}')
                      
        return output

    def get_embedding(self, x):
        return self.forward(x)
    


If you want to better check the sizes of the hidden states and do debugging, you can add a `debug` variable in the `forward` function just like above

In [None]:
input = torch.zeros(1, 1, 28, 28)
net = EmbeddingNet()
net(input,debug=True)

### Question
The dimension of the output is `batch-size x 2`. Why?

Now let's define a classification net that will add fully connected layer on top of `EmbeddingNet`

### Exercise

Fill in the missing spots in the `forward` pass:

In [None]:
class ClassificationNet(nn.Module):
    def __init__(self, embedding_net, num_classes):
        super(ClassificationNet, self).__init__()
        self.embedding_net = embedding_net
        self.prelu = nn.PReLU()
        self.fc = nn.Linear(2, num_classes)

    def forward(self, x, debug=False):
        # replace None with necessary entry 
        embedding = None
        output = self.fc(embedding)
        
        # if debug == True:
        #     print(f'input: {x.size()}')
        #     print(f'embedding: {embedding.size()}')
        #     print(f'output: {output.size()}')
            
        return output
    
    def get_embedding(self, x):
        # replace None with necessary entry 
        return None


## 2.2 Training

In [None]:
# Set up data loaders

kwargs = {'num_workers': args.num_workers, 'pin_memory': True} 
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, **kwargs)

embedding_net = EmbeddingNet()
model = ClassificationNet(embedding_net, num_classes=args.num_classes)
loss_fn = torch.nn.CrossEntropyLoss()
model.to(args.device)
loss_fn.to(args.device)

optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)

In [None]:
train_embeddings_baseline, train_labels_baseline = extract_embeddings(train_loader, model, args)
plot_embeddings(train_embeddings_baseline, train_labels_baseline, 'Train embeddings before training')

In [None]:
def train_classif_epoch(train_loader, model, loss_fn, optimizer, args, log_interval=50):
    model.train()
    losses = []
    total_loss, total_corrects, num_samples = 0, 0, 0
    corrects = 0    
    for batch_idx, (data, target) in enumerate(train_loader):
        num_samples += data.size(0)
        
        data, target = data.to(args.device), target.to(args.device)
        
        optimizer.zero_grad()
        outputs = model(data)

        loss = loss_fn(outputs, target)
        losses.append(loss.data.item())

        _,preds = torch.max(outputs.data,1)
        corrects += torch.sum(preds == target.data).cpu()

        loss.backward()
        optimizer.step()
        
        if batch_idx % log_interval == 0:
            print('Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f} \tAccuracy: {}'.format(
                batch_idx * len(data[0]), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), np.mean(losses), float(total_corrects)/num_samples))           
            
            total_loss += np.sum(losses)
            total_corrects += corrects
            losses, corrects = [], 0

    return total_loss/(batch_idx + 1), total_corrects/num_samples

def test_classif_epoch(test_loader, model, loss_fn, args, log_interval=50):
    with torch.no_grad():
        model.eval()
        losses, corrects = [], 0
        num_samples = 0
    
        for batch_idx, (data, target) in enumerate(test_loader):

            num_samples += data.size(0)
            data, target = data.to(args.device), target.to(args.device)

            outputs = model(data)

            loss = loss_fn(outputs, target)
            losses.append(loss.data.item())

            _,preds = torch.max(outputs.data,1)
            corrects += torch.sum(preds == target.data).cpu()

        return np.sum(losses)/(batch_idx + 1), corrects/num_samples

### Question
Why do we need `optimizer.zero_grad()`? What happens if we remove it?

In [None]:
start_epoch = 0

for epoch in range(0, start_epoch):
    scheduler.step()

for epoch in range(start_epoch, args.num_epochs):

    train_loss, train_accuracy = train_classif_epoch(train_loader, model, loss_fn, optimizer, args)

    message = 'Epoch: {}/{}. Train set: Average loss: {:.4f} Average accuracy: {:.4f}'.format(
        epoch + 1, args.num_epochs, train_loss, train_accuracy)
    
    val_loss, val_accuracy = test_classif_epoch(test_loader, model, loss_fn, args)
    
    message += '\nEpoch: {}/{}. Validation set: Average loss: {:.4f}  Average accuracy: {:.4f}'.format(epoch + 1, args.num_epochs,
                                                                             val_loss, val_accuracy)
    print(message)

    scheduler.step()


## 2.3 Visualizations


In [None]:
train_embeddings_baseline, train_labels_baseline = extract_embeddings(train_loader, model, args)
plot_embeddings(train_embeddings_baseline, train_labels_baseline, 'Train embeddings classification')
test_embeddings_baseline, test_labels_baseline = extract_embeddings(test_loader, model, args)
plot_embeddings(test_embeddings_baseline, test_labels_baseline, 'Test embeddings classification')

While the embeddings look separable (which is what we trained them for), they don't have good metric properties. They might not be the best choice as a descriptor for new classes.

# 3. Siamese network
Now we'll train a siamese network that takes a pair of images and trains the embeddings so that the distance between them is minimized if their from the same class or greater than some margin value if they represent different classes.
We'll minimize a contrastive loss function*:
$$L_{contrastive}(x_0, x_1, y) = \frac{1}{2} y \lVert f(x_0)-f(x_1)\rVert_2^2 + \frac{1}{2}(1-y)\{max(0, m-\lVert f(x_0)-f(x_1)\rVert_2)\}^2$$

*Raia Hadsell, Sumit Chopra, Yann LeCun, [Dimensionality reduction by learning an invariant mapping](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf), CVPR 2006*

## 3.1 Architecture
We will first define the siamese architecture on top of our `EmbeddingNet`

### Exercise

Fill in the forward part of `SiameseNet`

In [None]:
class SiameseNet(nn.Module):
    def __init__(self, embedding_net):
        super(SiameseNet, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, x1, x2):
        # fill in the missing 2 lines :)
        
        return output1, output2

    def get_embedding(self, x):
        return self.embedding_net(x)

## 3.2 Data loader
We will also need to adapt our data loader to fetch pairs of images 

In [None]:
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler
from PIL import Image

class SiameseMNIST(Dataset):
    """
    train mode: For each sample creates randomly a positive or a negative pair
    test mode: Creates fixed pairs for testing
    """

    def __init__(self, mnist_dataset):
        self.mnist_dataset = mnist_dataset

        self.train = self.mnist_dataset.train
        self.transform = self.mnist_dataset.transform

        if self.train:
            self.train_labels = self.mnist_dataset.train_labels
            self.train_data = self.mnist_dataset.train_data
            self.labels_set = set(self.train_labels.numpy())
            self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
                                     for label in self.labels_set}
        else:
            # generate fixed pairs for testing
            self.test_labels = self.mnist_dataset.test_labels
            self.test_data = self.mnist_dataset.test_data
            self.labels_set = set(self.test_labels.numpy())
            '''
            create a dictionary with an entry key for each label and the value an array storing
            the indices of the images having the respective label
            '''
            self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
                                     for label in self.labels_set}

            random_state = np.random.RandomState(42)
            # itereate through test_data and randomly select samples with the same label
            positive_pairs = [[i,
                               random_state.choice(self.label_to_indices[self.test_labels[i].item()]),
                               1]
                              for i in range(0, len(self.test_data), 2)]

            # itereate through test_data, create a list of all labels different from current one and then
            # randomly select samples with having one of these labels
            negative_pairs = [[i,
                               random_state.choice(self.label_to_indices[
                                                       np.random.choice(
                                                           list(self.labels_set - set([self.test_labels[i].item()]))
                                                       )
                                                   ]),
                               0]
                              for i in range(1, len(self.test_data), 2)]
            # format: [index1, index2, label(0/1)]
            self.test_pairs = positive_pairs + negative_pairs

    def __getitem__(self, index):
        
        # at train time pairs of samples are fetched randomly on the fly
        if self.train:
            # select random label,i.e. similar (1) or non-similar (0) images
            target = np.random.randint(0, 2)
            img1, label1 = self.train_data[index], self.train_labels[index].item()
            if target == 1:
                # select an image with the same label as img1
                siamese_index = index
                while siamese_index == index:
                    siamese_index = np.random.choice(self.label_to_indices[label1])
            else:
                # eliminate label1 from the set of possible labels to select
                siamese_label = np.random.choice(list(self.labels_set - set([label1])))
                # randomly select an image having a label from this subset
                siamese_index = np.random.choice(self.label_to_indices[siamese_label])
            img2 = self.train_data[siamese_index]
        else:
            img1 = self.test_data[self.test_pairs[index][0]]
            img2 = self.test_data[self.test_pairs[index][1]]
            target = self.test_pairs[index][2]

        img1 = Image.fromarray(img1.numpy(), mode='L')
        img2 = Image.fromarray(img2.numpy(), mode='L')
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        return (img1, img2), target

    def __len__(self):
        return len(self.mnist_dataset)

## 3.3 Loss function

$$L_{contrastive}(x_0, x_1, y) = \frac{1}{2} y \lVert f(x_0)-f(x_1)\rVert_2^2 + \frac{1}{2}(1-y)\{max(0, m-\lVert f(x_0)-f(x_1)\rVert_2)\}^2$$

### Exercise

Fill in the missing parts of the `contrastive loss`

In [None]:
class ContrastiveLoss(nn.Module):
    """
    Contrastive loss
    Takes embeddings of two samples and a target label == 1 if samples are from the same class and label == 0 otherwise
    """

    def __init__(self, margin):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.eps = 1e-9

    def forward(self, output1, output2, target, size_average=True):
        # compute squared distances between output2 and output1  
        squared_distances = None
        # add the second term from them loss. You can use ReLU for compressing the max formula
        losses = 0.5 * (target.float() * squared_distances +
                         None )
        
        return losses.mean() if size_average else losses.sum()
        


## 3.4 Training

In [None]:
# Set up data loaders
siamese_train_dataset = SiameseMNIST(train_dataset) # Returns pairs of images and target same/different
siamese_test_dataset = SiameseMNIST(test_dataset)

args.batch_size = 128
kwargs = {'num_workers': args.num_workers, 'pin_memory': True}
siamese_train_loader = torch.utils.data.DataLoader(siamese_train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
siamese_test_loader = torch.utils.data.DataLoader(siamese_test_dataset, batch_size=args.batch_size, shuffle=False, **kwargs)

margin = 1.
embedding_net = EmbeddingNet()
model = SiameseNet(embedding_net)
loss_fn = ContrastiveLoss(margin)
model.to(args.device)
loss_fn.to(args.device)

args.lr = 1e-3
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)

In [None]:
def train_siamese_epoch(train_loader, model, loss_fn, optimizer, args, log_interval=100):
    model.train()
    losses = []
    total_loss, num_samples =  0, 0
  
    for batch_idx, (data, target) in enumerate(train_loader):
        num_samples += data[0].size(0)
        
        data = tuple(d.to(args.device) for d in data)
        target = target.to(args.device)
          
        optimizer.zero_grad()
        
        outputs = model(data[0], data[1])
        # alternatively: outputs = model(*data)
        
        loss = loss_fn(outputs[0], outputs[1], target)
        # alternatively: loss = loss_fn(*outputs, target)
        
        losses.append(loss.data.item())

        loss.backward()
        optimizer.step()
        
        if batch_idx % log_interval == 0:
            print('Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f} '.format(
                batch_idx * len(data[0]), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), np.mean(losses)))           
            
            total_loss += np.sum(losses)
            losses = []
            
    return total_loss/(batch_idx + 1)

def test_siamese_epoch(test_loader, model, loss_fn, args, log_interval=50):
    with torch.no_grad():
        model.eval()
        losses = []
        num_samples = 0
    
        for batch_idx, (data, target) in enumerate(test_loader):

            num_samples += data[0].size(0)
            data = tuple(d.to(args.device) for d in data)
            target = target.to(args.device)
            outputs = model(data[0], data[1])

            loss = loss_fn(outputs[0], outputs[1], target)
            losses.append(loss.data.item())
    
        return np.sum(losses)/(batch_idx + 1)

In [None]:
start_epoch = 0

# needed for annealing learning rate in case of resuming of training
for epoch in range(0, start_epoch):
    scheduler.step()

# main training loop
for epoch in range(start_epoch, args.num_epochs):

    # train stage
    train_loss = train_siamese_epoch(siamese_train_loader, model, loss_fn, optimizer, args)
    message = 'Epoch: {}/{}. Train set: Average loss: {:.4f}'.format(
        epoch + 1, args.num_epochs, train_loss)
    
    # testing/validation stage    
    test_loss = test_siamese_epoch(siamese_test_loader, model, loss_fn, args)
    
    message += '\nEpoch: {}/{}. Validation set: Average loss: {:.4f}'.format(epoch + 1, args.num_epochs,
                                                                             test_loss)
    print(message)

    scheduler.step()


## 3.5 Visualizations

In [None]:
train_embeddings_cl, train_labels_cl = extract_embeddings(train_loader, model, args)
plot_embeddings(train_embeddings_cl, train_labels_cl, title='Train embeddings (constrastive loss)')
test_embeddings_cl, test_labels_cl = extract_embeddings(test_loader, model, args)
plot_embeddings(test_embeddings_cl, test_labels_cl, title='Test embeddings (contrastive loss)')

In order to two compare vectors $x_1$ and $x_2$ we can use the `cosine similarity` 

$$\text{similarity}=\frac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert_2, \epsilon)}$$

An alternative is the Euclidean distance.

In order to save computation at query time we can pre-process our vectors and L2-normalize them. Now we can simply perform comparison by dot product

#### Exercise
Perform L2-normalization on the embeddings using `numpy` 

In [None]:
# L2-normalize embeddings
test_embeddings_norm = ....

### Question
Why do we normalize features?

### Exercise
Write now a function `most_sim` that computes all dot products between a query vector and the dataset, extracts the indices of the `topk` most similar vectors and put thme in a list of tuples (

In [None]:
def most_sim(x, emb, topk=6):
       return None

In [None]:
test_images_raw = get_raw_images(test_loader)

In [None]:
def launch_query(test_embeddings_norm, test_images_raw, query_id=None):
    query_id = random.randint(0, test_embeddings_norm.shape[0]) if query_id is None else query_id
    query_vector = test_embeddings_norm[query_id,:]

    print(f'query_id: {query_id} | query_embedding: {query_vector}')
    knns = most_sim(query_vector, test_embeddings_norm)
    knn_images = np.array([test_images_raw[x[0]] for x in knns ])
    
    title=['q: 1.0', f'1nn: {knns[1][1]:.3}', f'2nn: {knns[2][1]:.3}', 
           f'3nn: {knns[3][1]:.3}', f'4nn: {knns[4][1]:.3}', f'5nn: {knns[5][1]:.3}']
    show(torchvision.utils.make_grid(torch.from_numpy(knn_images)), title=title)
#     print(knns)
    

In [None]:
for i in range(5):
    launch_query(test_embeddings_norm, test_images_raw)

# Triplet network
We'll train a triplet network, that takes an anchor, positive (same class as anchor) and negative (different class than anchor) examples. The objective is to learn embeddings such that the anchor is closer to the positive example than it is to the negative example by some margin value.

![alt text](images/anchor_negative_positive.png "Source: FaceNet")
Source: [2] *Schroff, Florian, Dmitry Kalenichenko, and James Philbin. [Facenet: A unified embedding for face recognition and clustering.](https://arxiv.org/abs/1503.03832) CVPR 2015.*

**Triplet loss**:   $L_{triplet}(x_a, x_p, x_n) = max(0, m +  \lVert f(x_a)-f(x_p)\rVert_2^2 - \lVert f(x_a)-f(x_n)\rVert_2^2$\)

## 4.1 Architecture
We will first define the triplet architecture on top of our `EmbeddingNet`

### Exercise

Fill in the forward part of `TripleNet`

In [None]:
class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super(TripletNet, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, x1, x2, x3):
        # missing 3 lines here
        
        return output1, output2, output3

    def get_embedding(self, x):
        return self.embedding_net(x)

## 4.2 Data loader
We will also need to adapt our data loader to fetch triplets of images 

In [None]:
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler
from PIL import Image

class TripletMNIST(Dataset):
    """
    Train: For each sample (anchor) randomly chooses a positive and negative samples
    Test: Creates fixed triplets for testing
    """

    def __init__(self, mnist_dataset):
        self.mnist_dataset = mnist_dataset
        self.train = self.mnist_dataset.train
        self.transform = self.mnist_dataset.transform

        if self.train:
            self.train_labels = self.mnist_dataset.train_labels
            self.train_data = self.mnist_dataset.train_data
            self.labels_set = set(self.train_labels.numpy())
            self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
                                     for label in self.labels_set}

        else:
            self.test_labels = self.mnist_dataset.test_labels
            self.test_data = self.mnist_dataset.test_data
            # generate fixed triplets for testing
            self.labels_set = set(self.test_labels.numpy())
            self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
                                     for label in self.labels_set}

            random_state = np.random.RandomState(29)

            triplets = [[i,
                         random_state.choice(self.label_to_indices[self.test_labels[i].item()]),
                         random_state.choice(self.label_to_indices[
                                                 np.random.choice(
                                                     list(self.labels_set - set([self.test_labels[i].item()]))
                                                 )
                                             ])
                         ]
                        for i in range(len(self.test_data))]
            self.test_triplets = triplets

    def __getitem__(self, index):
        if self.train:
            img1, label1 = self.train_data[index], self.train_labels[index].item()
            positive_index = index
            while positive_index == index:
                positive_index = np.random.choice(self.label_to_indices[label1])
            negative_label = np.random.choice(list(self.labels_set - set([label1])))
            negative_index = np.random.choice(self.label_to_indices[negative_label])
            img2 = self.train_data[positive_index]
            img3 = self.train_data[negative_index]
        else:
            img1 = self.test_data[self.test_triplets[index][0]]
            img2 = self.test_data[self.test_triplets[index][1]]
            img3 = self.test_data[self.test_triplets[index][2]]

        img1 = Image.fromarray(img1.numpy(), mode='L')
        img2 = Image.fromarray(img2.numpy(), mode='L')
        img3 = Image.fromarray(img3.numpy(), mode='L')
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            img3 = self.transform(img3)
        return (img1, img2, img3), []

    def __len__(self):
        return len(self.mnist_dataset)

## 4.3 Loss function

### Exercise

Fill in the missing parts of the `triplet loss`:
 $L_{triplet}(x_a, x_p, x_n) = max(0, m +  \lVert f(x_a)-f(x_p)\rVert_2^2 - \lVert f(x_a)-f(x_n)\rVert_2^2$\)

In [None]:
class TripletLoss(nn.Module):
    """
    Triplet loss
    Takes embeddings of an anchor sample, a positive sample and a negative sample
    """

    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative, size_average=True):
        distance_positive = None # fill in code
        distance_negative = None  # fill in code
        # you can again use ReLU instead of max         
        losses = None # fill in code
        return losses.mean() if size_average else losses.sum()

## 4.4 Training

In [None]:
triplet_train_dataset = TripletMNIST(train_dataset) # Returns triplets of images
triplet_test_dataset = TripletMNIST(test_dataset)

args.batch_size = 128
kwargs = {'num_workers': args.num_workers, 'pin_memory': True} 
triplet_train_loader = torch.utils.data.DataLoader(triplet_train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
triplet_test_loader = torch.utils.data.DataLoader(triplet_test_dataset, batch_size=args.batch_size, shuffle=False, **kwargs)

margin = 1.
embedding_net = EmbeddingNet()
model = TripletNet(embedding_net)
loss_fn = TripletLoss(margin)
model.to(args.device)
loss_fn.to(args.device)
args.lr = 1e-3
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)
n_epochs = 5
log_interval = 100

### Exercise

Code your own train/test sequences similarly to the previous examples.
Watch out for some differences though.

In [None]:
def train_triplet_epoch(train_loader, model, loss_fn, optimizer, args, log_interval=100):
    model.train()
    losses = []
    total_loss, num_samples =  0, 0
  
    # fill in code here
    
    return total_loss/(batch_idx + 1)

def test_triplet_epoch(test_loader, model, loss_fn, args, log_interval=50):
    losses = []
    num_samples = 0
    # fill in code here

    return np.sum(losses)/(batch_idx + 1)


In [None]:
start_epoch = 0

# needed for annealing learning rate in case of resuming of training
for epoch in range(0, start_epoch):
    scheduler.step()

# main training loop
for epoch in range(start_epoch, args.num_epochs):
    scheduler.step()

    # train stage
    train_loss = train_triplet_epoch(triplet_train_loader, model, loss_fn, optimizer, args)
    message = 'Epoch: {}/{}. Train set: Average loss: {:.4f}'.format(
        epoch + 1, args.num_epochs, train_loss)
    
    # testing/validation stage    
    test_loss = test_triplet_epoch(triplet_test_loader, model, loss_fn, args)
    
    message += '\nEpoch: {}/{}. Validation set: Average loss: {:.4f}'.format(epoch + 1, args.num_epochs,
                                                                             test_loss)
    print(message)

## 4.5 Visualizations

In [None]:
train_embeddings_tl, train_labels_tl = extract_embeddings(train_loader, model, args)
plot_embeddings(train_embeddings_tl, train_labels_tl, title='Train triplet embeddings')
test_embeddings_tl, test_labels_tl = extract_embeddings(test_loader, model, args)
plot_embeddings(test_embeddings_tl, test_labels_tl, title='Val triplet embeddings')

In [None]:
# L2-normalize embeddings
test_embeddings_tl_norm = test_embeddings_tl / np.linalg.norm(test_embeddings_tl, axis=-1, keepdims=True)

In [None]:
test_images_raw = get_raw_images(test_loader)

In [None]:
for i in range(5):
    launch_query(test_embeddings_tl_norm, test_images_raw)