<a href="https://colab.research.google.com/github/mlop-ai/mlop/blob/main/examples/lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h1 align="center" style="font-family: Inter, sans-serif; font-style: normal; font-weight: 700; font-size: 72px">m:lop</h1>


In [None]:
%pip install -Uq "mlop[full]" lightning torchmetrics onnx
# %pip install "mlop[full] @ git+https://github.com/mlop-ai/mlop.git"
# import sys; import os; sys.path.insert(0, os.path.dirname(os.path.abspath(os.path.dirname("__file__"))))
import mlop
from mlop.compat import MLOPLogger

mlop.login()

In [None]:
import numpy as np
import torch
from torch.nn import functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import lightning.pytorch as pl
import torchmetrics

pl.seed_everything(1337)

## Set up the experiment run

In [None]:
class LitMLP(pl.LightningModule):
    def __init__(self, in_dims, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-4):
        super().__init__()

        self.layer_1 = nn.Linear(np.prod(in_dims), n_layer_1)
        self.layer_2 = nn.Linear(n_layer_1, n_layer_2)
        self.layer_3 = nn.Linear(n_layer_2, n_classes)

        self.save_hyperparameters()
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=n_classes)
        self.valid_acc = torchmetrics.Accuracy(task="multiclass", num_classes=n_classes)
        self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=n_classes)

    def forward(self, x):
        batch_size, *dims = x.size()

        x = x.view(batch_size, -1)
        x = F.relu(self.layer_1(x))
        x = F.relu(self.layer_2(x))
        x = self.layer_3(x)
        x = F.log_softmax(x, dim=1)

        return x

    def loss(self, xs, ys):
        logits = self(xs)
        loss = F.nll_loss(logits, ys)
        return logits, loss

    def training_step(self, batch, batch_idx):
        xs, ys = batch
        logits, loss = self.loss(xs, ys)
        preds = torch.argmax(logits, 1)

        self.log("train/loss", loss, on_epoch=True)
        self.train_acc(preds, ys)
        self.log("train/acc", self.train_acc, on_epoch=True)

        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams["lr"])

    def test_step(self, batch, batch_idx):
        xs, ys = batch
        logits, loss = self.loss(xs, ys)
        preds = torch.argmax(logits, 1)

        self.test_acc(preds, ys)
        self.log("test/loss_epoch", loss, on_step=False, on_epoch=True)
        self.log("test/acc_epoch", self.test_acc, on_step=False, on_epoch=True)

    def on_test_epoch_end(self):  # args are defined as part of pl API
        dummy_input = torch.zeros(self.hparams["in_dims"], device=self.device)
        model_filename = "model_final.onnx"
        self.to_onnx(model_filename, dummy_input, export_params=True)

        # TODO: remove legacy compat
        artifact = mlop.Artifact(name="model.ckpt", type="model")
        artifact.add_file(model_filename)
        self.logger.experiment.log({"model": artifact})

    def on_validation_epoch_start(self):
        self.validation_step_outputs = []

    def validation_step(self, batch, batch_idx):
        xs, ys = batch
        logits, loss = self.loss(xs, ys)
        preds = torch.argmax(logits, 1)
        self.valid_acc(preds, ys)

        self.log("valid/loss_epoch", loss)  # default on val/test is on_epoch only
        self.log("valid/acc_epoch", self.valid_acc)
        self.validation_step_outputs.append(logits)

        return logits

    def on_validation_epoch_end(self):
        validation_step_outputs = self.validation_step_outputs
        dummy_input = torch.zeros(self.hparams["in_dims"], device=self.device)
        model_filename = f"model_{str(self.global_step).zfill(5)}.onnx"
        torch.onnx.export(self, dummy_input, model_filename, opset_version=11)

        # TODO: remove legacy compat
        artifact = mlop.Artifact(name="model.ckpt", type="model")
        artifact.add_file(model_filename)
        self.logger.experiment.log({"model": artifact})
        
        flattened_logits = torch.flatten(torch.cat(validation_step_outputs))
        self.logger.experiment.log(
            {
                "valid/logits": mlop.Histogram(flattened_logits.to("cpu")),
                "global_step": self.global_step,
            }
        )


class ImagePredictionLogger(pl.Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.val_imgs, self.val_labels = val_samples
        self.val_imgs = self.val_imgs[:num_samples]
        self.val_labels = self.val_labels[:num_samples]

    def on_validation_epoch_end(self, trainer, pl_module):
        val_imgs = self.val_imgs.to(device=pl_module.device)
        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, 1)
        trainer.logger.experiment.log(
            {
                "examples": [
                    mlop.Image(x, caption=f"Pred:{pred}, Label:{y}")
                    for x, pred, y in zip(val_imgs, preds, self.val_labels)
                ],
                "global_step": trainer.global_step,
            }
        )


class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir=".lightning/", batch_size=128):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            mnist = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(
                self.data_dir, train=False, transform=self.transform
            )

    def train_dataloader(self):
        mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
        return mnist_train

    def val_dataloader(self):
        mnist_val = DataLoader(self.mnist_val, batch_size=10 * self.batch_size)
        return mnist_val

    def test_dataloader(self):
        mnist_test = DataLoader(self.mnist_test, batch_size=10 * self.batch_size)
        return mnist_test


mnist = MNISTDataModule()
mnist.prepare_data()
mnist.setup()
samples = next(iter(mnist.val_dataloader()))

## Start training with **mlop**

In [None]:
mlop_logger = MLOPLogger(name=".lightning")
trainer = pl.Trainer(
    logger=mlop_logger,
    log_every_n_steps=50,
    max_epochs=5,
    deterministic=True,
    callbacks=[ImagePredictionLogger(samples)],
)

model = LitMLP(in_dims=(1, 28, 28))
try:
    trainer.fit(model, mnist)
    trainer.test(datamodule=mnist, ckpt_path=None)
finally:
    mlop.finish()