# MNIST Lightning training - Amazon SageMaker Training Job


Train a MNIST classification model created using Pytorch Lightning and logging the training, validation, and test metrics using Amazon SageMaker Experiments.


---
This notebook has been designed to work in Amazon SageMaker Studio with `Python 3 (PyTorch 1.12 Python 3.8 CPU Optimized)`.

---

In this notebook there are two examples:
- training on a single GPU on a `ml.g4dn.xlarge` instance
- training on a 4 GPU on a `ml.g4dn.12xlarge` instance

#### Install and update libraries
This notebook requires a version of SageMaker Python SDK greater than 2.123.0 to be able to use the [recently released new capabilities of SageMaker Experiments](https://aws.amazon.com/about-aws/whats-new/2022/12/amazon-sagemaker-experiments-ml-experiment-management-diverse-ides/).

In [None]:
# %%capture
# %pip install -U "sagemaker >= 2.123.0"

In [None]:
import sagemaker
from sagemaker.debugger import DebuggerHookConfig, Rule, TensorBoardOutputConfig, rule_configs
from sagemaker.experiments import Run
from sagemaker.pytorch import PyTorch
from sagemaker.utils import name_from_base
from torchvision.datasets import MNIST

Definitions and objects necessary for running the Training Job.

In [None]:
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()

prefix = "pytorch-demo-mnist"
bucket = sagemaker_session.default_bucket()

## Upload datasets

All training jobs in this sample use the same dataset. Instead of downloading it from the public repository for every training job, we'll upload a copy to S3 and then use [FastFile mode](https://docs.aws.amazon.com/sagemaker/latest/dg/model-access-training-data.html) to serve the data to the training job.

In [None]:
MNIST.mirrors = ["https://sagemaker-sample-files.s3.amazonaws.com/datasets/image/MNIST/"]

MNIST("data", download=True)

In [None]:
data_s3_uri = sagemaker_session.upload_data(
    "data",
    bucket=bucket,
    key_prefix=f"{prefix}/MNIST-data",
)

We'll then pass this dictionary to the `fit()` function to tell SageMaker were to find the training and testing data. In this case, they two data channel will point at the same S3 prefix, but in general they file organization might be different.

In [None]:
fit_inputs = {"train": data_s3_uri, "test": data_s3_uri}

## Configure profiling rules
During the training we'll use [SageMaker Debugger](https://docs.aws.amazon.com/sagemaker/latest/dg/train-debugger.html) to monitor the training process, and log tensor and scalar outputs. We also include some automatic auditing of the training process, leveraging few of the most common built-in debugger rules.  
With SageMaker debugger, we can configure the hook that capture the data from the Training Job definition. We will also indicate the S3 prefix where to store the tensorflow-compatible outputs.


In [None]:
hook_config = DebuggerHookConfig(
    hook_parameters={
        "train.save_interval": "100",
        "eval.save_interval": "10",
    }
)

tensorboard_output_config = TensorBoardOutputConfig(
    s3_output_path=f"s3://{bucket}/{prefix}/tensorboard"
)

rules = [
    Rule.sagemaker(rule_configs.vanishing_gradient()),
    Rule.sagemaker(rule_configs.overfit()),
    Rule.sagemaker(rule_configs.overtraining()),
    Rule.sagemaker(rule_configs.poor_weight_initialization()),
]

## Training Job with single GPU

The code we'll use for the training job is in the `code` folder.
- `mnist_pl.py` is the entry point script, it will be executed by the training job.
- `requirements.txt` we included the libraries not already present in the training image.
- `models.py` and `data_modules.py` container the Lightning modules for the model and data. they can easily be replaced with more complex models of your choice. To make integration with the SageMaker training easier, the scripts are designed to allow setting of the hyperparameters from the CLI, following the [recommended best practices](https://pytorch-lightning.readthedocs.io/en/stable/common/hyperparameters.html). 
- `sm_utils_functions.py` is a collection of convenience functions for integrating the training script with SageMaker Training.
- `sm_logger.py` provides a [Lightning Logger](https://pytorch-lightning.readthedocs.io/en/latest/extensions/logging.html#make-a-custom-logger) to integrate the training process with [SageMaker Experiments](https://docs.aws.amazon.com/sagemaker/latest/dg/experiments.html) with minimal code changes.
- `sm_debug_callback.py` provides a Lightning Callback to capture tensors and metrics using SageMaker Debugger.


We define the experiment name and the name of the experimental run, then define a PyTorch Estimator, and then start the training job.

In [None]:
with Run(
    experiment_name="pytorch-demo-mnist",
    run_name=name_from_base("1x1-gpu-ff"),
) as run:
    estimator = PyTorch(
        entry_point="mnist_pl.py",
        base_job_name="lightning-mnist-1x1",
        role=role,
        source_dir="code",
        instance_count=1,
        instance_type="ml.g4dn.xlarge",
        py_version="py38",
        framework_version="1.12.1",
        output_path=f"s3://{bucket}/{prefix}",
        sagemaker_session=sagemaker_session,
        hyperparameters={"batch_size": 256, "epochs": 20},
        rules=rules,
        debugger_hook_config=hook_config,
        tensorboard_output_config=tensorboard_output_config,
        input_mode="FastFile",
        # keep_alive_period_in_seconds=20 * 60,  # Warm pool, useful if re-running the same job
    )

    estimator.fit(inputs=fit_inputs, wait=True)

You can follow the training job progress

## Training Job with multiple GPUs on the same instance

In [None]:
with Run(
    experiment_name="pytorch-demo-mnist",
    run_name=name_from_base("1x4-gpu"),
    sagemaker_session=sagemaker_session,
) as run:
    estimator = PyTorch(
        entry_point="mnist_pl.py",
        base_job_name="lightning-mnist-1x4",
        role=role,
        source_dir="code",
        instance_count=1,
        instance_type="ml.g4dn.12xlarge",
        py_version="py38",
        framework_version="1.12.1",
        output_path=f"s3://{bucket}/{prefix}",
        sagemaker_session=sagemaker_session,
        hyperparameters={"batch_size": 128, "epochs": 20},
        debugger_hook_config=hook_config,
        tensorboard_output_config=tensorboard_output_config,
        input_mode="FastFile",
        # keep_alive_period_in_seconds=20 * 60,  # Warm pool, useful if re-running the same job
    )

    estimator.fit(wait=False)

## Review the tracked metrics

After each training job is complete, the training, validation, and testing should be recorded as `Run`  within SageMaker experiment `pytorch-demo-mnist`. The run should also include a confusion matrix in the _chart_ tab  
 ![screen shot of confusion matrix](images/conf_mat.png)
