# Model tracking - MLflow

[MLflow](https://mlflow.org/docs/latest/) is a library for model tracking, packaging and sharing.

Let's use it to track our training.

First, we need to do the necessary imports:

In [1]:
# Import needed modules
import torch

from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

from torch import nn

from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np

Then we need to set up datasets and data loaders:

In [2]:
# Set up datasets and data loaders

data_dir = "../data"

batch_size = 32

train_dataset = datasets.MNIST(
    data_dir, train=True, download=True, transform=ToTensor()
)
test_dataset = datasets.MNIST(data_dir, train=False, transform=ToTensor())

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

Then we'll need to specify a model (in this case a toy multilayer perceptron model):

In [3]:
# Specify model

device = "cuda" if torch.cuda.is_available() else "cpu"

class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Flatten(), nn.Linear(28 * 28, 20), nn.ReLU(), nn.Linear(20, 10)
        )

    def forward(self, x):
        return self.layers(x)


model = SimpleMLP().to(device)
print(model)

SimpleMLP(
  (layers): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=784, out_features=20, bias=True)
    (2): ReLU()
    (3): Linear(in_features=20, out_features=10, bias=True)
  )
)


Then we specify the loss criterion and optimizer:

In [4]:
# Specify loss and optimizer

criterion = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(model.parameters())

### Implementing MLflow tracking

To make mlflow tracking work, we need to do the following things:

### Initial setup

In this initial setup we specify `MLFLOW_TRACKING_URI`. It can be a folder, an SQLite database file or even a [MLflow tracking server](https://mlflow.org/docs/latest/self-hosting/architecture/tracking-server/). In this example we'll use the file storage.

After that we set a name for our [experiment](https://mlflow.org/docs/latest/ml/tracking/#runs). In MLflow each experiment is a grouping of runs and models. Run is e.g. a single training session while model is the model trained by that session.

In [5]:
import os
import mlflow

mlflow.set_tracking_uri("file:///tmp/mlflow/db")

experiment_name = "mnist-example"

mlflow.set_experiment(experiment_name)

2025/11/10 23:01:39 INFO mlflow.tracking.fluent: Experiment with name 'mnist-example' does not exist. Creating a new experiment.


<Experiment: artifact_location='file:///tmp/mlflow/db/102422313242706622', creation_time=1762808499456, experiment_id='102422313242706622', last_update_time=1762808499456, lifecycle_stage='active', name='mnist-example', tags={}>

### Finding a model signature

Here we figure out our model's [signature](https://mlflow.org/docs/latest/ml/model/signatures/). Signature is the form of our models inputs and outputs and our models parameters.

We also keep the `input_example` so that we can use that when we eventually store the model checkpoints.

In [6]:
from mlflow.models import infer_signature

input_example, _ = next(iter(train_loader))

with torch.no_grad():
    # Send example input to device
    input_example = input_example.to(device)
    # Get model output
    output_example = model(input_example)
    # Convert example input and output to numpy arrays
    input_example = input_example.cpu().numpy()
    output_example = output_example.cpu().numpy()

# Infer signature automatically
signature = infer_signature(input_example, output_example)

### Setting tracking during training

To get tracking information, we need to encapsulate the experiment with [mlflow.start_run](https://mlflow.org/docs/latest/api_reference/python_api/mlflow.html#mlflow.start_run)-block. This starts and stops the tracking when we want.


In the actual training loop our training parameters can be tracked with [mlflow.log_params](https://mlflow.org/docs/latest/ml/tracking/#start-logging).

Arbitrary artifacts can also be stored. Here we store a summary of our model's structure with [mlflow.log_artifact](https://mlflow.org/docs/latest/api_reference/python_api/mlflow.html?highlight=log_artifact#mlflow.log_artifact).

During the training we record our model's checkpoints with [mlflow.pytorch.log_model](https://mlflow.org/docs/latest/api_reference/python_api/mlflow.pytorch.html#mlflow.pytorch.log_model) and use [mlflow.log_metric](https://mlflow.org/docs/latest/api_reference/python_api/mlflow.html#mlflow.log_metric) to record custom metrics.

In this example we store the model after each epoch, but in practice this storing could be done only if the model is performing better than previous models.

In [7]:
from torchinfo import summary

# Train the model
with mlflow.start_run():
    
    num_batches = len(train_loader)
    num_items = len(train_loader.dataset)
    epochs = 5

    params = {
        "epochs": epochs,
        "batch_size": batch_size,
        "loss_function": criterion.__class__.__name__,
        "optimizer": optimizer.__class__.__name__,
    }
    
    # Log training parameters.
    mlflow.log_params(params)

    # Log model summary.
    with open("model_summary.txt", "w") as f:
        f.write(str(summary(model)))
    mlflow.log_artifact("model_summary.txt")

    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        total_correct = 0
        for data, target in tqdm(train_loader, total=num_batches):
            # Copy data and targets to GPU
            data = data.to(device)
            target = target.to(device)
    
            # Do a forward pass
            outputs = model(data)
    
            # Calculate the loss
            loss = criterion(outputs, target)
            total_loss += loss.item()
    
            # Count number of correct digits
            _, predicted = torch.max(outputs, 1)
            total_correct += (predicted == target).sum().item()
    
            # Backpropagation
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

    
        train_loss = total_loss / num_batches
        accuracy = total_correct / num_items
        print(f"Average loss: {train_loss:7f}, accuracy: {accuracy:.2%}")

        mlflow.log_metric("loss", train_loss, step=epoch)
        mlflow.log_metric("accuracy", accuracy, step=epoch)

        # Log model checkpoint
        model_info = mlflow.pytorch.log_model(
            model,
            name=f"SimpleMLP",
            step=epoch,
            signature=signature,
            input_example=input_example,
            registered_model_name=f"SimpleMLP",
        )

100%|██████████| 1875/1875 [00:10<00:00, 184.03it/s]


Average loss: 0.440874, accuracy: 88.14%


Successfully registered model 'SimpleMLP'.
Created version '1' of model 'SimpleMLP'.
100%|██████████| 1875/1875 [00:08<00:00, 212.04it/s]


Average loss: 0.245349, accuracy: 92.96%


Registered model 'SimpleMLP' already exists. Creating a new version of this model...
Created version '2' of model 'SimpleMLP'.
100%|██████████| 1875/1875 [00:08<00:00, 213.40it/s]


Average loss: 0.204088, accuracy: 94.21%


Registered model 'SimpleMLP' already exists. Creating a new version of this model...
Created version '3' of model 'SimpleMLP'.
100%|██████████| 1875/1875 [00:08<00:00, 213.43it/s]


Average loss: 0.179658, accuracy: 94.89%


Registered model 'SimpleMLP' already exists. Creating a new version of this model...
Created version '4' of model 'SimpleMLP'.
100%|██████████| 1875/1875 [00:08<00:00, 213.87it/s]


Average loss: 0.163285, accuracy: 95.33%


Registered model 'SimpleMLP' already exists. Creating a new version of this model...
Created version '5' of model 'SimpleMLP'.


## Checking experiments

The data describing our runs is stored in the `runs.sqlite`. We can query that file with multiple functions.

[mlflow.search experiments](https://mlflow.org/docs/latest/ml/search/search-experiments/) can be used to search for experiments:

In [8]:
mlflow.search_experiments()

[<Experiment: artifact_location='file:///tmp/mlflow/db/102422313242706622', creation_time=1762808499456, experiment_id='102422313242706622', last_update_time=1762808499456, lifecycle_stage='active', name='mnist-example', tags={}>,
 <Experiment: artifact_location='file:///tmp/mlflow/db/0', creation_time=1762808499414, experiment_id='0', last_update_time=1762808499414, lifecycle_stage='active', name='Default', tags={}>]

[mlflow.search_runs](https://mlflow.org/docs/latest/ml/search/search-runs/) can be used to find runs:

In [9]:
mlflow.search_runs(
    experiment_names=[experiment_name],
    order_by=["metrics.accuracy DESC"],
)

Unnamed: 0,run_id,experiment_id,status,artifact_uri,start_time,end_time,metrics.accuracy,metrics.loss,params.batch_size,params.epochs,params.loss_function,params.optimizer,tags.mlflow.source.name,tags.mlflow.user,tags.mlflow.runName,tags.mlflow.source.type
0,87cf4b79e9f44d13ae816ad0991f8ef6,102422313242706622,FINISHED,file:///tmp/mlflow/db/102422313242706622/87cf4...,2025-11-10 21:01:47.930000+00:00,2025-11-10 21:03:44.154000+00:00,0.9533,0.163285,32,5,CrossEntropyLoss,AdamW,/scratch/work/tuomiss1/conda_envs/ml-reproduci...,tuomiss1,ambitious-worm-311,LOCAL


[mlflow.search_logged_models](https://mlflow.org/docs/latest/ml/search/search-models/) can be used to find stored models:

In [10]:
mlflow.search_logged_models()

Unnamed: 0,artifact_location,creation_timestamp,experiment_id,last_updated_timestamp,metrics,model_id,model_type,name,params,source_run_id,status,status_message,tags
0,file:///tmp/mlflow/db/102422313242706622/model...,1762808617086,102422313242706622,1762808624127,"[<Metric: dataset_digest=None, dataset_name=No...",m-395ccf47336b4570b565d77394532446,,SimpleMLP-4,"{'loss_function': 'CrossEntropyLoss', 'epochs'...",87cf4b79e9f44d13ae816ad0991f8ef6,READY,,{'mlflow.source.name': '/scratch/work/tuomiss1...
1,file:///tmp/mlflow/db/102422313242706622/model...,1762808600313,102422313242706622,1762808608287,"[<Metric: dataset_digest=None, dataset_name=No...",m-5b3530b5084d40c691f85821fde4ed8f,,SimpleMLP-3,"{'loss_function': 'CrossEntropyLoss', 'epochs'...",87cf4b79e9f44d13ae816ad0991f8ef6,READY,,{'mlflow.source.name': '/scratch/work/tuomiss1...
2,file:///tmp/mlflow/db/102422313242706622/model...,1762808583149,102422313242706622,1762808591499,"[<Metric: dataset_digest=None, dataset_name=No...",m-dc40f029db8b4717b67398000889daee,,SimpleMLP-2,"{'loss_function': 'CrossEntropyLoss', 'epochs'...",87cf4b79e9f44d13ae816ad0991f8ef6,READY,,{'mlflow.source.name': '/scratch/work/tuomiss1...
3,file:///tmp/mlflow/db/102422313242706622/model...,1762808565945,102422313242706622,1762808574336,"[<Metric: dataset_digest=None, dataset_name=No...",m-49ce3ca923ab47ffbbcfaa73d8861354,,SimpleMLP-1,"{'loss_function': 'CrossEntropyLoss', 'epochs'...",87cf4b79e9f44d13ae816ad0991f8ef6,READY,,{'mlflow.source.name': '/scratch/work/tuomiss1...
4,file:///tmp/mlflow/db/102422313242706622/model...,1762808528192,102422313242706622,1762808557053,"[<Metric: dataset_digest=None, dataset_name=No...",m-ae175ea280f24c4390ed3c209eaf7810,,SimpleMLP-0,"{'loss_function': 'CrossEntropyLoss', 'epochs'...",87cf4b79e9f44d13ae816ad0991f8ef6,READY,,{'mlflow.source.name': '/scratch/work/tuomiss1...


All of these functions support various filters and search functionalities.

For example, we can try to find the model that has the best accuracy:

In [11]:
mlflow.search_logged_models(order_by=[
        {"field_name": "metrics.accuracy", "ascending": False}  # Highest accuracy first
    ])

Unnamed: 0,artifact_location,creation_timestamp,experiment_id,last_updated_timestamp,metrics,model_id,model_type,name,params,source_run_id,status,status_message,tags
0,file:///tmp/mlflow/db/102422313242706622/model...,1762808617086,102422313242706622,1762808624127,"[<Metric: dataset_digest=None, dataset_name=No...",m-395ccf47336b4570b565d77394532446,,SimpleMLP-4,"{'loss_function': 'CrossEntropyLoss', 'epochs'...",87cf4b79e9f44d13ae816ad0991f8ef6,READY,,{'mlflow.source.name': '/scratch/work/tuomiss1...
1,file:///tmp/mlflow/db/102422313242706622/model...,1762808600313,102422313242706622,1762808608287,"[<Metric: dataset_digest=None, dataset_name=No...",m-5b3530b5084d40c691f85821fde4ed8f,,SimpleMLP-3,"{'loss_function': 'CrossEntropyLoss', 'epochs'...",87cf4b79e9f44d13ae816ad0991f8ef6,READY,,{'mlflow.source.name': '/scratch/work/tuomiss1...
2,file:///tmp/mlflow/db/102422313242706622/model...,1762808583149,102422313242706622,1762808591499,"[<Metric: dataset_digest=None, dataset_name=No...",m-dc40f029db8b4717b67398000889daee,,SimpleMLP-2,"{'loss_function': 'CrossEntropyLoss', 'epochs'...",87cf4b79e9f44d13ae816ad0991f8ef6,READY,,{'mlflow.source.name': '/scratch/work/tuomiss1...
3,file:///tmp/mlflow/db/102422313242706622/model...,1762808565945,102422313242706622,1762808574336,"[<Metric: dataset_digest=None, dataset_name=No...",m-49ce3ca923ab47ffbbcfaa73d8861354,,SimpleMLP-1,"{'loss_function': 'CrossEntropyLoss', 'epochs'...",87cf4b79e9f44d13ae816ad0991f8ef6,READY,,{'mlflow.source.name': '/scratch/work/tuomiss1...
4,file:///tmp/mlflow/db/102422313242706622/model...,1762808528192,102422313242706622,1762808557053,"[<Metric: dataset_digest=None, dataset_name=No...",m-ae175ea280f24c4390ed3c209eaf7810,,SimpleMLP-0,"{'loss_function': 'CrossEntropyLoss', 'epochs'...",87cf4b79e9f44d13ae816ad0991f8ef6,READY,,{'mlflow.source.name': '/scratch/work/tuomiss1...


Unsurprisingly, the best model is the one with most training.

## Viewing model files

Models can be viewed under the `mlruns` folder. Each saved model contains not only model checkpoints, but also model metadata and inferred Python environment specifications needed to run this model. 

## Launching MLflow UI

MLflow has an UI interface that can be launched from the command line.

In most cases the MLflow UI port is usually not open to the wider world, so you need to do something like
```sh
ssh jupyternode -J loginnode -L 5000:localhost:5000
```
to create a ssh proxy tunnel to the server.


In Jupyterlab you can start a new Bash terminal and run
```bash
export MLFLOW_TRACKING_URI=sqlite:///runs.sqlite
mlflow ui
```
to open the MLFlow UI.

Alternatively, you can launch MLflow UI in your own machine if you just copy the event logs to your computer or if you mount the file system to your computer using e.g. sshfs.

CSC and LUMI also provice MLflow as an application in their clusters.

## MLflow Tracking server

MLflow has a [Tracking server](https://mlflow.org/docs/latest/self-hosting/architecture/tracking-server/) that can be used to track inputs from multiple jobs or multi-node setups. It can also be used to view the experiments.

It can be launched with
```bash
export MLFLOW_TRACKING_URI=sqlite:///runs.sqlite
mlflow server
```