# MNIST Lightning training - Amazon SageMaker Training Job


Train a MNIST classification model created using Pytorch Lightning and logging the training, validation, and test metrics using Amazon SageMaker Experiments.


---
This notebook has been designed to work in Amazon SageMaker Studio with `Python 3 (PyTorch 1.12 Python 3.8 CPU Optimized)`.

---

In this notebook there are two examples:
- training on a single GPU on a `ml.g4dn.xlarge` instance
- training on a 4 GPU on a `ml.g4dn.12xlarge` instance

In [None]:
import sagemaker
from sagemaker.experiments import Run
from sagemaker.pytorch import PyTorch
from sagemaker.utils import name_from_base

Definitions and objects necessary for running the Training Job.

In [None]:
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
prefix = "pytorch-demo-mnist"
bucket = sagemaker_session.default_bucket()

## Training Job with single GPU

The Model and DataModule are defined in the `mnist_pl.py` and `data_modules.py` scripts in the `code` folder.  
To simplify the logging of the metrics and artifacts to [SageMaker Experiments](https://docs.aws.amazon.com/sagemaker/latest/dg/experiments.html) within the Lightning training look, there's a Lightning [Logger](https://pytorch-lightning.readthedocs.io/en/latest/extensions/logging.html#make-a-custom-logger) in `code/sm_experiments.py`.

In [None]:
with Run(
    experiment_name="pytorch-demo-mnist",
    run_name=name_from_base("1x1-gpu"),
) as run:
    estimator = PyTorch(
        entry_point="mnist_pl.py",
        base_job_name="lightning-mnist",
        role=role,
        source_dir="code",
        instance_count=1,
        instance_type="ml.g4dn.xlarge",
        py_version="py38",
        framework_version="1.12.1",
        output_path=f"s3://{bucket}/{prefix}",
        sagemaker_session=sagemaker_session,
        # distribution={"pytorchddp": {"enabled": True}}, # works with or without
        debugger_hook_config=False,
        hyperparameters={"batch_size": 512, "epochs": 20},  # type: ignore
        # keep_alive_period_in_seconds=20 * 60, 
    )

    estimator.fit(wait=False)

## Training Job with multiple GPUs on the same instance

In [None]:
with Run(
    experiment_name="pytorch-demo-mnist",
    run_name=name_from_base("1x4-gpu"),
    sagemaker_session=sagemaker_session,
) as run:
    estimator = PyTorch(
        entry_point="mnist_pl.py",
        base_job_name="lightning-mnist",
        role=role,
        source_dir="code",
        instance_count=1,
        instance_type="ml.g4dn.12xlarge",
        py_version="py38",
        framework_version="1.12.1",
        output_path=f"s3://{bucket}/{prefix}",
        sagemaker_session=sagemaker_session,
        # distribution={"pytorchddp": {"enabled": True}}, # works with or without
        debugger_hook_config=False,
        hyperparameters={"batch_size": 512, "epochs": 20},  # type: ignore
        # keep_alive_period_in_seconds=20 * 60,
    )

    estimator.fit(wait=False)

## Review the tracked metrics

After each training job is complete, the training, validation, and testing should be recorded as `Run`  within SageMaker experiment `pytorch-demo-mnist`. The run should also include a confusion matrix in the _chart_ tab  
 ![screen shot of confusion matrix](images/conf_mat.png)
