# Modular coding - Lightning & MLflow

MLflow has good support for PyTorch Lightning. Let's explore that a bit.

First we need the usual imports.

In [1]:
# Import needed modules
import torch

from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

from torch import nn

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import lightning as L

## Data handling - LightningDataModule

We again will define the LightningDataModule.

In [2]:
class MNISTDataModule(L.LightningDataModule):

    def __init__(self, data_dir="../data", batch_size=32):
        # In init-function you can set arguments like data paths
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage):
        # setup-function is used to specify the datasets
        if stage == "fit":
            self.train_dataset = datasets.MNIST(
                self.data_dir, train=True, download=True, transform=ToTensor()
            )
        if stage == "test":
            self.test_dataset = datasets.MNIST(
                self.data_dir, train=False, transform=ToTensor()
            )

    def train_dataloader(self):
        # train_dataloader specifies how to set up a training dataloader
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def test_dataloader(self):
        # test_dataloader specifies how to set up a test dataloader
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=False)

## Model writing - LightningModule

Let's again define the LightningModule.

In [3]:
class SimpleMLP(L.LightningModule):
    def __init__(self, hidden_size=20):
        # Init is done similar to nn.Module
        super().__init__()
        self.layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 10),
        )
        # We specify loss function in the module as well
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        # Forward is done similar to nn.Module
        return self.layers(x)

    def training_step(self, batch):
        # training_step-function specifies how data is fed into the model and how the loss is calculated
        data, target = batch
        outputs = self(data)

        # Calculate the loss
        loss = self.loss(outputs, target)

        # Count number of correct digits
        _, predicted = torch.max(outputs, 1)
        correct = (predicted == target).sum().item()

        batch_size = outputs.shape[0]

        # Log loss and number of correct predictions
        self.log("training_loss", loss, on_epoch=True, on_step=False)
        self.log(
            "training_accuracy", correct / batch_size, on_epoch=True, on_step=False
        )

        # training_step returns the loss
        return loss

    def test_step(self, batch):
        # test_step-function specifies how data is fed into the model and how the loss is calculated
        data, target = batch
        outputs = self(data)

        # Calculate the loss
        loss = self.loss(outputs, target)

        # Count number of correct digits
        _, predicted = torch.max(outputs, 1)
        correct = (predicted == target).sum().item()

        batch_size = outputs.shape[0]

        # Log loss and number of correct predictions
        self.log("test_loss", loss, on_epoch=True, on_step=False)
        self.log("test_accuracy", correct / batch_size, on_epoch=True, on_step=False)

        # training_step returns the loss
        return loss

    def configure_optimizers(self):
        # configure_optimizers-function specifies how the optimizer is created
        return torch.optim.AdamW(self.layers.parameters())

## Logging - MLflow autologger

This time let's use MLflow's [autologging feature](https://mlflow.org/docs/latest/ml/tracking/autolog#autolog-pytorch) that supports automatic logging from PyTorch Lightning:

In [4]:
import os
import mlflow

mlflow.set_tracking_uri("file:///tmp/mlflow/db")

experiment_name = "mnist-lightning"

mlflow.set_experiment(experiment_name)

mlflow.pytorch.autolog(checkpoint_save_best_only=False)

## Training - Trainer

We specify Trainer like before:

In [5]:
from lightning.pytorch.callbacks import TQDMProgressBar

In [6]:
model = SimpleMLP()
datamodule = MNISTDataModule()

trainer = L.Trainer(
    max_epochs=5,
    callbacks=[TQDMProgressBar(refresh_rate=100)]
)

/scratch/work/tuomiss1/conda_envs/ml-reproducibility/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/work/tuomiss1/conda_envs/ml-reproducibility ...
INFO:pytorch_lightning.utilities.rank_zero:ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


This time we however wrap the `trainer.fit`-call under `mlflow.start_run`:

In [7]:
with mlflow.start_run() as run:
    trainer.fit(model, datamodule=datamodule)

/scratch/work/tuomiss1/conda_envs/ml-reproducibility/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/work/tuomiss1/conda_envs/ml-reproducibility ...

  | Name   | Type             | Params | Mode 
----------------------------------------------------
0 | layers | Sequential       | 15.9 K | train
1 | loss   | CrossEntropyLoss | 0      | train
----------------------------------------------------
15.9 K    Trainable params
0         Non-trainable params
15.9 K    Total params
0.064     Total estimated model params size (MB)
6         Modules in train mode
0         Modules in eval mode


Epoch 4: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1875/1875 [00:11<00:00, 160.95it/s, v_num=4]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1875/1875 [00:11<00:00, 160.37it/s, v_num=4]




## Examining the output

We can see that MLflow automatically recorded plenty of parameters:

In [8]:
mlflow.search_runs(experiment_names=[experiment_name])

Unnamed: 0,run_id,experiment_id,status,artifact_uri,start_time,end_time,metrics.training_accuracy,metrics.training_loss,params.differentiable,params.capturable,...,params.maximize,params.fused,params.weight_decay,tags.mlflow.latest_checkpoint_artifact,tags.mlflow.user,tags.mlflow.source.name,tags.Mode,tags.mlflow.runName,tags.mlflow.source.type,tags.mlflow.autologging
0,813832b599c34bf38414eaa025b6845d,747633508105639431,FINISHED,file:///tmp/mlflow/db/747633508105639431/81383...,2025-11-10 22:03:42.797000+00:00,2025-11-10 22:04:49.251000+00:00,0.950533,0.168893,False,False,...,False,,0.01,checkpoints/epoch_4/checkpoint.pth,tuomiss1,/scratch/work/tuomiss1/conda_envs/ml-reproduci...,training,dashing-hen-904,LOCAL,
1,a76fc81ba9f441c09c58c0e2d3b4a056,747633508105639431,FINISHED,file:///tmp/mlflow/db/747633508105639431/a76fc...,2025-11-10 21:58:47.133000+00:00,2025-11-10 21:59:50.084000+00:00,0.952867,0.164744,False,False,...,False,,0.01,checkpoints/epoch_4/checkpoint.pth,tuomiss1,/scratch/work/tuomiss1/conda_envs/ml-reproduci...,training,fortunate-loon-636,LOCAL,
2,cb6348373b004ecc98032a51a769b8a2,747633508105639431,FAILED,file:///tmp/mlflow/db/747633508105639431/cb634...,2025-11-10 21:58:07.339000+00:00,2025-11-10 21:58:07.346000+00:00,,,,,...,,,,,tuomiss1,/scratch/work/tuomiss1/conda_envs/ml-reproduci...,,chill-donkey-189,LOCAL,
3,5d9ac19da3eb438194fac049234036af,747633508105639431,FINISHED,file:///tmp/mlflow/db/747633508105639431/5d9ac...,2025-11-10 21:49:05.178000+00:00,2025-11-10 21:49:12.203000+00:00,,,,,...,,,,,tuomiss1,/scratch/work/tuomiss1/conda_envs/ml-reproduci...,,peaceful-fowl-174,LOCAL,
4,3b2bedf65917449e8da4c3f124f0304b,747633508105639431,FINISHED,file:///tmp/mlflow/db/747633508105639431/3b2be...,2025-11-10 21:36:56.192000+00:00,2025-11-10 21:38:02.178000+00:00,0.95225,0.167361,False,False,...,False,,0.01,checkpoints/epoch_4/checkpoint.pth,tuomiss1,/scratch/work/tuomiss1/conda_envs/ml-reproduci...,training,traveling-grouse-320,LOCAL,pytorch
