# MNIST - notebook training


Train a MNIST classification model created using Pytorch Lightning and logging the training, validation, and test metrics using Amazon SageMaker Experiments.


---
This notebook has been designed to work in Amazon SageMaker Studio with `Python 3 (PyTorch 1.12 Python 3.8 CPU Optimized)`, and tested to also work with `Python 3 (PyTorch 1.12 Python 3.8 GPU Optimized)`.

---

In [None]:
# # %%capture
# %pip install -U "pytorch-lightning==1.8.3"
# %pip install -U "sagemaker >= 2.123.0"

In [None]:
import os
from code.data_modules import MNISTDataModule
from code.models import MNISTModel
from code.sm_logger import SmLogger

import pytorch_lightning as pl
from pytorch_lightning.callbacks import TQDMProgressBar
from sagemaker.experiments import Run
from sagemaker.utils import name_from_base

In [None]:
# workaround to avoid a problem with the progress bar with validation step
class LitProgressBar(TQDMProgressBar):
    def init_validation_tqdm(self):
        bar = super().init_validation_tqdm()
        bar.disable = True
        return bar

## Model and Dataloader definition

The Model and DataModule are defined in the `mnist_pl.py` and `data_modules.py` scripts in the `code` folder.

In [None]:
model = MNISTModel()
dm = MNISTDataModule(
    train="data",
    test="data",
    batch_size=32,
    test_batch_size=500,
    validation_fraction=0.1,
    num_workers=int(os.cpu_count()),
)

## Training with SageMaker Experiments logging

To simplify the logging of the metrics and artifacts to [SageMaker Experiments](https://docs.aws.amazon.com/sagemaker/latest/dg/experiments.html) within the Lightning training look, there's a Lightning [Logger](https://pytorch-lightning.readthedocs.io/en/latest/extensions/logging.html#make-a-custom-logger) in `code/sm_experiments.py`.

In [None]:
experiment_name = "pytorch-demo-mnist"
run_name_base = "nb-training"

In [None]:
with Run(
    experiment_name=experiment_name,
    run_name=name_from_base(run_name_base),
) as run:
    sm_logger = SmLogger(run)
    trainer = pl.Trainer(
        max_epochs=20,
        default_root_dir="model",  # Save the model to a local folder
        callbacks=[LitProgressBar()],
        logger=sm_logger,
    )
    trainer.fit(model, dm)
    trainer.test(model, dm)

## Review the tracked metrics

Once the training is complete, the training, validation, and testing should be recorded as `Run`  within SageMaker experiment `pytorch-demo-mnist`. The run should also include a confusion matrix in the _chart_ tab   
![screen shot of confusion matrix](images/conf_mat.png)
