In [1]:
import torch
import pytorch_lightning as pl
import math
import torch.nn.functional as F
import matplotlib.pyplot as plt

# Pytorch dataset that generates one triple
class NTripDS(torch.utils.data.Dataset):
    def __init__(self):
        super().__init__()

    def __len__(self):
        return 1000

    def __getitem__(self,i):
        va = torch.randint(0, 10, (1,))
        vb = torch.randint(0, 10, (1,))
        vc = torch.randint(0, 10, (1,))
        xa = F.one_hot(va, num_classes=10).float()
        b = F.one_hot(vb, num_classes=10).float()
        c = F.one_hot(vc, num_classes=10).float()
        if math.fabs(vb-va) <= math.fabs(vc-va): xp,xn = b,c
        else: xp,xn = c,b
        return xa,xp,xn


class Mod(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.W = torch.nn.Linear(10,2)

    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr = 1e-3)
        return opt

    def forward(self, x):
        e = self.W(x)
        return e

    def training_step(self, batch, batch_idx):
        anc, pos, neg = batch
        ea = self.W(anc)
        ep = self.W(pos)
        en = self.W(neg)
        dp = torch.nn.functional.triplet_margin_loss(ea,ep,en)
        self.log("train_loss", dp, on_step=False, on_epoch=True)
        return dp

def plot(mod):
    x = F.one_hot(torch.arange(0,10)).float()
    y = mod(x)
    ny = y.detach().numpy()
    x = ny[:,0]
    y = ny[:,1]
    for i in range(len(x)): plt.annotate(f"{i}", (x[i],y[i]))
    plt.scatter(x, y)
    plt.show()

traindata = NTripDS()
trainloader = torch.utils.data.DataLoader(traindata, batch_size=1, shuffle=False)
mod = Mod()
logger = pl.loggers.TensorBoardLogger(save_dir="logs/", flush_secs=1)
trainer = pl.Trainer(limit_train_batches=1.0, max_epochs=3, log_every_n_steps=1, logger=logger)
trainer.fit(model=mod, train_dataloaders=trainloader)

#Â to view the training curves:
# tensorboard --logdir=lightning_logs/

plot(mod)

ModuleNotFoundError: No module named 'pytorch_lightning'