## Triplet loss
### Train a linear embedding space with triplet loss

> Adapt this code to train a linear embedding that takes as inputs 1-hot encoding of digits 0 to 9, and outputs a 2D-embedding. Then train this embedding with a triplet loss in order to shape the embedding space so that the digits appear ordered in the embedding space. Plot the embedding space with matplotlib.

[Slides](https://olki.loria.fr/cerisara/lexres/2024td1.html#/4/6)

In [None]:
import torch
import pytorch_lightning as pl

class Mod(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.W = torch.nn.Linear(1,5)

    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr = 1e-3)
        return opt

    def training_step(self, batch, batch_idx):
        anc, pos, neg = batch
        ea = self.W(anc)
        ep = self.W(pos)
        en = self.W(neg)
        dp = torch.nn.functional.triplet_margin_loss(ea,ep,en)
        self.log("train_loss", dp, on_step=False, on_epoch=True)
        return dp

In [None]:
class TripDS(torch.utils.data.Dataset):
    def __init__(self):
        super().__init__()

    def __len__(self):
        return 1000

    def __getitem__(self,i):
        if i%2==0:
            # pair: on sample une ancre from class 1
            xa = torch.randn(1)/10.-0.5
            xp = torch.randn(1)/10.-0.5
            xn = torch.randn(1)/10.+0.5
            return xa,xp,xn
        else:
            # impair: on sample une ancre from class 2
            xa = torch.randn(1)/10.+0.5
            xp = torch.randn(1)/10.+0.5
            xn = torch.randn(1)/10.-0.5
            return xa,xp,xn

In [None]:
traindata = TripDS()
trainloader = torch.utils.data.DataLoader(traindata, batch_size=1, shuffle=False)
mod = Mod()
logger = pl.loggers.TensorBoardLogger(save_dir="logs/", flush_secs=1)
trainer = pl.Trainer(limit_train_batches=1.0, max_epochs=1000, log_every_n_steps=1,logger=logger)
trainer.fit(model=mod, train_dataloaders=trainloader)