### Ayudantia IIC3697 
# Triplet Loss

En esta parte de la ayudantía veremos como se puede entrenar un modelo utilizando triplet loss. Utilizaremos el set de datos MNIST para esto, en donde dos imágenes deberán estar más cercanas entre si, si es que pertenecen a la misma clase.

Pueden ver una guía de como realizar este ejercicio de forma más eficiente (como les comenté en la ayudantía) en la siguiente [página](https://omoindrot.github.io/triplet-loss). Tomenlo como guía por que el código está en *Tensorflow*.

In [0]:
!pip install tb-nightly tqdm

Collecting tb-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/0f/37/7ddf6bffbc6df5cbb6f2d7d39691b852b238d00bc32f134f546db6e820ee/tb_nightly-1.14.0a20190514-py3-none-any.whl (3.1MB)
[K     |████████████████████████████████| 3.1MB 2.8MB/s 
Installing collected packages: tb-nightly
Successfully installed tb-nightly-1.14.0a20190514


In [0]:
from itertools import product
import random

import torch
from torch.utils.data import DataLoader, Subset
from torch import nn
import torch.optim as optim
from torch.nn.init import kaiming_normal_
from torch.nn.functional import triplet_margin_loss
from torch.utils.tensorboard import SummaryWriter

import torchvision
import torchvision.datasets as datasets
import torchvision.models as resnet
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda

from tqdm import tqdm_notebook as tqdm

import tensorboardcolab
!mkdir -p logs/tensorboard

Using TensorFlow backend.


### Configuración Básica

In [0]:
batch_size_train = 16
batch_size_test = 16

n_epochs = 2
margin = 1

use_imagenet_weights = True

run_name = 'Ayudantia'

Seteamos el device para usar GPU en caso de que sea posible.

In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Dataset y Data Loader

Cargamos el dataset de MNIST. Normalizaremos las imágenes y las transformaremos de blanco y negro a RGB, además de dejarlas como tensores de *PyTorch*.

In [0]:
!mkdir -p data

transform = Compose([
                ToTensor(),
                Normalize((0.1307,), (0.3081,)),
                Lambda(lambda x: x.repeat(3, 1, 1) ) # De 1 canal a 3 (blanco y negro a RGB).
             ])


mnist_trainset = datasets.MNIST(root='./data', train=True, 
                                download=True, transform=transform)
mnist_trainset = Subset(mnist_trainset, range(1024)) # Reducir cantidad de ejemplos usados.
train_loader = DataLoader(mnist_trainset, batch_size=batch_size_train, shuffle=True)

mnist_valset = datasets.MNIST(root='./data', train=False, 
                              download=True, transform=transform)
mnist_valset = Subset(mnist_valset, range(1024)) # Reducir cantidad de ejemplos usados.
test_loader = DataLoader(mnist_valset, batch_size=batch_size_test, shuffle=True)

  0%|          | 0/9912422 [00:00<?, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


9920512it [00:00, 24840020.87it/s]                            


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz


32768it [00:00, 425158.78it/s]
0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


1654784it [00:00, 6903327.05it/s]                           
8192it [00:00, 140549.60it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!


### Modelo

Cargamos el modelo, en caso de requerir pesos preentrenados utilizamos los de *Imagenet*.

In [0]:
model = resnet.resnet18(pretrained=use_imagenet_weights)

# Reemplazamos la última capa para no realizar la clasificación
model.fc = nn.Identity()

model.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/checkpoints/resnet18-5c106cde.pth
100%|██████████| 46827520/46827520 [00:00<00:00, 69582032.22it/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Co

### Optimizador

Creamos el optimizador, utilizaremos un simple gradiente estocástico.

In [0]:
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

### Tensorboard

A partir de la versión 1.1 de *PyTorch*, existe compatibilidad con esta herramienta. Setearemos una corrída en *Tensorboard* para poder visualizar distintas métricas durante entrenamiento.

In [0]:
import secrets

run_name = run_name if run_name else secrets.token_hex(16) 
writer = SummaryWriter(log_dir=f'logs/tensorboard/{run_name}')

### Entrenamiento

Definimos varias funciones auxiliares para entrenar nuestro modelo y evaluarlo según corresponde. De esta manera nuestro código se mantiene ordenado.

In [0]:
def is_valid_triple(anchor, positive, negative):
  # Definimos un triplet válido sin que nos importe que el anchor sea distinto del positivo
  # Esto lo manejamos en el loop 
  anchor = int(anchor)
  positive = int(positive)
  negative = int(negative)
  return anchor == positive and anchor != negative


def iter_dataloader(data_loader, device):
  total_iters = len(data_loader)
  
  for iteration, (images, targets) in tqdm(enumerate(data_loader, 1), total=total_iters):
    images = images.to(device)
    targets = targets.to(device)
    yield iteration, (images, targets) 


def get_valid_triplets(embeddings, targets, max_triplet_number=2**10):
  valid_triplets = []
  emb_target = zip(embeddings, targets)
  for anchor, positive, negative in product(enumerate(emb_target), repeat=3): # Producto cartesiano x3, muy ineficiente
    anchor_batch_idx, (anchor_emb, anchor_target) = anchor
    positive_batch_idx, (positive_emb, positive_target) = positive
    negative_batch_idx, (negative_emb, negative_target) = negative

    # Los triplets no pueden tener un mismo anchor y positivo
    if anchor_batch_idx == positive_batch_idx:
      continue
    
    if is_valid_triple(anchor_target, positive_target, negative_target):
      valid_triplets.append((anchor_emb, positive_emb, negative_emb))
      
  random.shuffle(valid_triplets)
  valid_triplets = valid_triplets[:max_triplet_number]
  
  batch_anchor, batch_positive, batch_negative = [
      torch.stack(stack) for stack in zip(*valid_triplets)]
    
  return batch_anchor, batch_positive, batch_negative
  
  
def forward_pass(images_or_embeddings, targets, margin):
  if images_or_embeddings.ndimension() == 4:
    images = images_or_embeddings
    embeddings = model(images)
  else:
    embeddings = images_or_embeddings

  anchor, positive, negative = get_valid_triplets(embeddings, targets)

  loss = triplet_margin_loss(anchor, positive, negative)

  preds = (torch.norm(anchor - negative, dim=1) 
          - torch.norm(anchor - positive, dim=1)) > margin
  
  return loss, preds


def backward_pass(loss, optimizer):
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  
def log_train_scalar_tb(writer, mode, epoch_loss, epoch_acc, global_step):
  writer.add_scalar(f'{mode} Epoch Loss', epoch_loss, global_step=global_step)
  writer.add_scalar(f'{mode} Epoch Accuracy', epoch_acc, global_step=global_step)
  
  
def log_training(writer, mode, epoch_loss, epoch_acc, epoch, global_step):
  log_train_scalar_tb(writer, mode, epoch_loss, epoch_acc, global_step)

  print(f'[{mode}] Epoch {epoch} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
  
  
def log_embeddings(writer, model, data_loader, global_step=0, tag='Test_Image_Embeddings'):
  model.eval()

  with torch.set_grad_enabled(False):
    all_embeddings = []
    all_labels = []
    all_images = []
    for iteration, (images, labels) in iter_dataloader(data_loader, device):
      embeddings = model(images)
      
      all_images.append(images)
      all_labels.append(labels)
      all_embeddings.append(embeddings)

    all_images = torch.cat(all_images)
    all_labels = torch.cat(all_labels)
    all_embeddings = torch.cat(all_embeddings)
    
    all_labels = [str(int(t)) for t in all_labels]

    writer.add_embedding(all_embeddings, metadata=all_labels, 
                         label_img=all_images, global_step=global_step, tag=tag) 

    
def do_train(epoch, model, data_loader, optimizer, device, margin):
  running_loss = 0.0
  running_corrects = 0
  iters_per_epoch = len(data_loader)
  mode = 'TRAIN'

  model.train()
  
  triplet_count = 0
  for iteration, (images, targets) in iter_dataloader(data_loader, device):
    loss, preds = forward_pass(images, targets, margin)
    backward_pass(loss, optimizer)

    n_valid_triplets = preds.size(0)    
    triplet_count += n_valid_triplets
    
    running_loss += loss.item() * n_valid_triplets
    running_corrects += torch.sum(preds == torch.ones_like(preds))

  epoch_loss = running_loss / triplet_count
  epoch_acc = running_corrects.double() / triplet_count
  
  global_step = epoch * iters_per_epoch
  log_training(writer, mode, epoch_loss, epoch_acc, epoch, global_step)
  
  
def do_test(epoch, model, data_loader, device, margin):
  running_loss = 0.0
  running_corrects = 0
  iters_per_epoch = len(data_loader)
  mode = 'VAL'

  model.eval()
  
  with torch.set_grad_enabled(False):
    triplet_count = 0
    for iteration, (images, targets) in iter_dataloader(data_loader, device):
      embeddings = model(images)
      loss, preds = forward_pass(embeddings, targets, margin)

      n_valid_triplets = preds.size(0)    
      triplet_count += n_valid_triplets
    
      running_loss += loss.item() * n_valid_triplets
      running_corrects += torch.sum(preds == torch.ones_like(preds))

    epoch_loss = running_loss / triplet_count
    epoch_acc = running_corrects.double() / triplet_count
    
    global_step = epoch * iters_per_epoch
    log_training(writer, mode, epoch_loss, epoch_acc, epoch, global_step)

In [0]:
do_test(0, model, test_loader, device, margin)
for epoch in range(1, n_epochs + 1):
  do_train(epoch, model, train_loader, optimizer, device, margin)
  do_test(epoch, model, test_loader, device, margin)

HBox(children=(IntProgress(value=0, max=64), HTML(value='')))


[VAL] Epoch 0 Loss: 0.5972 Acc: 0.7222


HBox(children=(IntProgress(value=0, max=64), HTML(value='')))


[TRAIN] Epoch 1 Loss: 1.1619 Acc: 0.7019


HBox(children=(IntProgress(value=0, max=64), HTML(value='')))


[VAL] Epoch 1 Loss: 0.5608 Acc: 0.8528


HBox(children=(IntProgress(value=0, max=64), HTML(value='')))


[TRAIN] Epoch 2 Loss: 0.3501 Acc: 0.8893


HBox(children=(IntProgress(value=0, max=64), HTML(value='')))


[VAL] Epoch 2 Loss: 0.2041 Acc: 0.9422


### Embeddings en Tensorboard

In [0]:
log_embeddings(writer, model, test_loader, global_step=n_epochs)

HBox(children=(IntProgress(value=0, max=64), HTML(value='')))




In [0]:
writer.close()

### Visualización de Tensorboard en Colab

In [0]:
%load_ext tensorboard

In [0]:
%tensorboard --logdir logs/tensorboard 