## mlflow.pytorch 101
#### Tested with Python 3.9.1 - 64 Bit
##### Ref: https://www.mlflow.org/docs/latest/python_api/mlflow.pytorch.html

In [2]:
import os

import pytorch_lightning as pl
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from pytorch_lightning.metrics.functional import accuracy

import mlflow.pytorch
from mlflow.tracking import MlflowClient

In [3]:
class MNISTModel(pl.LightningModule):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        acc = accuracy(loss, y)

        # Use the current of PyTorch logger
        self.log("train_loss", loss, on_epoch=True)
        self.log("acc", acc, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

def print_auto_logged_info(r):

    tags = {k: v for k, v in r.data.tags.items() if not k.startswith("mlflow.")}
    artifacts = [f.path for f in MlflowClient().list_artifacts(r.info.run_id, "model")]
    print("run_id: {}".format(r.info.run_id))
    print("artifacts: {}".format(artifacts))
    print("params: {}".format(r.data.params))
    print("metrics: {}".format(r.data.metrics))
    print("tags: {}".format(tags))

# Initialize our model
mnist_model = MNISTModel()

# Initialize DataLoader from MNIST Dataset
train_ds = MNIST(os.getcwd(), train=True,
    download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, batch_size=32)

# Initialize a trainer
trainer = pl.Trainer(max_epochs=20, progress_bar_refresh_rate=20)

# Auto log all MLflow entities
mlflow.pytorch.autolog()

# Train the model
with mlflow.start_run() as run:
    trainer.fit(mnist_model, train_loader)

# fetch the auto logged parameters and metrics
print_auto_logged_info(mlflow.get_run(run_id=run.info.run_id))

0it [00:00, ?it/s]Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to c:\Users\dmangonakis\Desktop\Yelp_Dataset\notebooks\MNIST\raw\train-images-idx3-ubyte.gz
9920512it [00:02, 5007346.98it/s]                             Extracting c:\Users\dmangonakis\Desktop\Yelp_Dataset\notebooks\MNIST\raw\train-images-idx3-ubyte.gz

0it [00:00, ?it/s][ADownloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to c:\Users\dmangonakis\Desktop\Yelp_Dataset\notebooks\MNIST\raw\train-labels-idx1-ubyte.gz

  0%|          | 0/28881 [00:00<?, ?it/s][A

0it [00:00, ?it/s][A[AExtracting c:\Users\dmangonakis\Desktop\Yelp_Dataset\notebooks\MNIST\raw\train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to c:\Users\dmangonakis\Desktop\Yelp_Dataset\notebooks\MNIST\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s][A[A

  1%|          | 16384/1648877 [00:00<00:14, 109960.31it/s][A[A

  6%|▌         | 