# SageMaker HPO with MLflow

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/sagemaker-mlflow|sagemaker_hpo_mlflow.ipynb)

Train a PyTorch model using HPO in SageMaker and track with MLflow using nested runs

## Setup environment

Install necessary libraries

In [None]:
!pip install torchvision mlflow==2.13.2 sagemaker-mlflow==0.1.0

Import necessary libraries

In [None]:
import os

from torchvision import transforms
from torchvision.datasets import MNIST

import mlflow
import sagemaker
from sagemaker import get_execution_role
from sagemaker.pytorch import PyTorch
from sagemaker.tuner import (
    CategoricalParameter,
    ContinuousParameter,
    HyperparameterTuner,
    IntegerParameter,
)

Declare some variables used later

In [None]:
# Define session, role, and region so we can
# perform any SageMaker tasks we need
sagemaker_session = sagemaker.Session()
role = get_execution_role()
region = sagemaker_session.boto_region_name
bucket = sagemaker_session.default_bucket()

# S3 prefix for the training dataset to be uploaded to
prefix = "DEMO-pytorch-mnist"

# MLflow (replace these values with your own)
tracking_server_arn = "your tracking server arn"
experiment_name = "MNIST"

In [None]:
!mkdir -p training_code

## Get some training data

Download MNIST data

In [None]:
local_dir = "data"
MNIST.mirrors = [
    f"https://sagemaker-example-files-prod-{region}.s3.amazonaws.com/datasets/image/MNIST/"
]
MNIST(
    local_dir,
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    ),
)

Upload data to S3

In [None]:
train_input = sagemaker_session.upload_data(path="data", bucket=bucket, key_prefix=prefix)

### Write your training script

In [None]:
%%writefile training_code/mnist.py

import argparse
import json
import logging
import os
import sys

import mlflow
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
from torchinfo import summary
from torchvision import datasets, transforms

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))

# Set MLFlow specifics
parent_run_id = os.environ.get('MLFLOW_PARENT_RUN_ID', None)
mlflow_experiment_name = os.environ.get('MLFLOW_EXPERIMENT_NAME', None)

# Based on https://github.com/pytorch/examples/blob/master/mnist/main.py
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


def _get_train_data_loader(batch_size, training_dir, is_distributed, **kwargs):
    logger.info('Get train data loader')
    dataset = datasets.MNIST(
        training_dir,
        train=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    )
    train_sampler = (
        torch.utils.data.distributed.DistributedSampler(dataset)
        if is_distributed
        else None
    )
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=train_sampler is None,
        sampler=train_sampler,
        **kwargs,
    )


def _get_test_data_loader(test_batch_size, training_dir, **kwargs):
    logger.info('Get test data loader')
    return torch.utils.data.DataLoader(
        datasets.MNIST(
            training_dir,
            train=False,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        ),
        batch_size=test_batch_size,
        shuffle=True,
        **kwargs,
    )


def _average_gradients(model):
    # Gradient averaging.
    size = float(dist.get_world_size())
    for param in model.parameters():
        dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, group=0)
        param.grad.data /= size


def train(args):
    is_distributed = len(args.hosts) > 1 and args.backend is not None
    logger.debug('Distributed training - {}'.format(is_distributed))
    use_cuda = args.num_gpus > 0
    logger.debug('Number of gpus available - {}'.format(args.num_gpus))
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    device = torch.device('cuda' if use_cuda else 'cpu')

    region = os.getenv('AWS_REGION')

    # if there's a parent_run_id run as nested MLflow_run
    nested = False

    if parent_run_id:
        nested = True

    if is_distributed:
        # Initialize the distributed environment.
        world_size = len(args.hosts)
        os.environ['WORLD_SIZE'] = str(world_size)
        host_rank = args.hosts.index(args.current_host)
        dist.init_process_group(
            backend=args.backend, rank=host_rank, world_size=world_size
        )
        logger.info(
            "Initialized the distributed environment: '{}' backend on {} nodes. ".format(
                args.backend, dist.get_world_size()
            )
            + "Current host rank is {}. Number of gpus: {}".format(
                dist.get_rank(), args.num_gpus
            )
        )

    # set the seed for generating random numbers
    torch.manual_seed(args.seed)
    if use_cuda:
        torch.cuda.manual_seed(args.seed)

    train_loader = _get_train_data_loader(
        args.batch_size, args.data_dir, is_distributed, **kwargs
    )
    test_loader = _get_test_data_loader(args.test_batch_size, args.data_dir, **kwargs)

    logger.debug(
        "Processes {}/{} ({:.0f}%) of train data".format(
            len(train_loader.sampler),
            len(train_loader.dataset),
            100.0 * len(train_loader.sampler) / len(train_loader.dataset),
        )
    )

    logger.debug(
        "Processes {}/{} ({:.0f}%) of test data".format(
            len(test_loader.sampler),
            len(test_loader.dataset),
            100.0 * len(test_loader.sampler) / len(test_loader.dataset),
        )
    )

    model = Net().to(device)
    if is_distributed and use_cuda:
        # multi-machine multi-gpu case
        model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        # single-machine multi-gpu case or single-machine or multi-machine cpu case
        model = torch.nn.DataParallel(model)

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    with mlflow.start_run(nested=nested):
        params = {
            k: o
            for k, o in vars(args).items()
        }
        sm_training_env = json.loads(os.environ['SM_TRAINING_ENV'])
        job_name = sm_training_env['job_name']
        job_uri = f'https://{region}.console.aws.amazon.com/sagemaker/home?region={region}#/jobs/{job_name}'
        mlflow.log_params(
            {**params, 'sagemaker_job_name': job_name, 'sagemaker_job_uri': job_uri}
        )

        # Log model summary.
        with open('model_summary.txt', 'w') as f:
            f.write(str(summary(model)))
        mlflow.log_artifact('model_summary.txt')

        for epoch in range(1, args.epochs + 1):
            model.train()
            for batch_idx, (data, target) in enumerate(train_loader, 1):
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(data)
                loss = F.nll_loss(output, target)
                loss.backward()
                if is_distributed and not use_cuda:
                    # average gradients manually for multi-machine cpu case only
                    _average_gradients(model)
                optimizer.step()
                if batch_idx % args.log_interval == 0:
                    logger.info(
                        'Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format(
                            epoch,
                            batch_idx * len(data),
                            len(train_loader.sampler),
                            100.0 * batch_idx / len(train_loader),
                            loss.item(),
                        )
                    )
                    mlflow.log_metric(
                        'loss',
                        loss.item(),
                        step=(batch_idx // args.log_interval),
                    )

            test(model, test_loader, device)
        save_model(model, args.model_dir)


def test(model, test_loader, device):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(
                output, target, size_average=False
            ).item()  # sum up batch loss
            pred = output.max(1, keepdim=True)[
                1
            ]  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    logger.info(
        'Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss,
            correct,
            len(test_loader.dataset),
            100.0 * correct / len(test_loader.dataset),
        )
    )
    mlflow.log_metrics(
        {
            'test_average_loss': test_loss,
            'test_accuracy': correct / len(test_loader.dataset),
        }
    )

def model_fn(model_dir):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = torch.nn.DataParallel(Net())
    with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
        model.load_state_dict(torch.load(f))
    return model.to(device)


def save_model(model, model_dir):
    logger.info('Saving the model.')
    path = os.path.join(model_dir, 'model.pth')
    # recommended way from http://pytorch.org/docs/master/notes/serialization.html
    torch.save(model.cpu().state_dict(), path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Data and model checkpoints directories
    parser.add_argument(
        '--batch-size',
        type=int,
        default=64,
        metavar='N',
        help='input batch size for training (default: 64)',
    )
    parser.add_argument(
        '--test-batch-size',
        type=int,
        default=1000,
        metavar='N',
        help='input batch size for testing (default: 1000)',
    )
    parser.add_argument(
        '--epochs',
        type=int,
        default=10,
        metavar='N',
        help='number of epochs to train (default: 10)',
    )
    parser.add_argument(
        '--lr',
        type=float,
        default=0.01,
        metavar='LR',
        help='learning rate (default: 0.01)',
    )
    parser.add_argument(
        '--momentum',
        type=float,
        default=0.5,
        metavar='M',
        help='SGD momentum (default: 0.5)',
    )
    parser.add_argument(
        '--seed', type=int, default=1, metavar='S', help='random seed (default: 1)'
    )
    parser.add_argument(
        '--log-interval',
        type=int,
        default=100,
        metavar='N',
        help='how many batches to wait before logging training status',
    )
    parser.add_argument(
        '--backend',
        type=str,
        default=None,
        help='backend for distributed training (tcp, gloo on cpu and gloo, nccl on gpu)',
    )

    # Container environment
    parser.add_argument(
        '--hosts', type=list, default=json.loads(os.environ['SM_HOSTS'])
    )
    parser.add_argument(
        '--current-host', type=str, default=os.environ['SM_CURRENT_HOST']
    )
    parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
    parser.add_argument(
        '--data-dir', type=str, default=os.environ['SM_CHANNEL_TRAINING']
    )
    parser.add_argument('--num-gpus', type=int, default=os.environ['SM_NUM_GPUS'])

    args = parser.parse_args()

    mlflow.set_experiment(mlflow_experiment_name)
    if parent_run_id:
        with mlflow.start_run(run_id=parent_run_id):
            train(args)
    else:
        train(args)

Since we're using MLflow in our training script, let's make sure the container installs `mlflow` along with our MLflow AWS plugin before running our training script. We can do this by creating a `requirements.txt` file and putting it in the same directory as our training script.

In [None]:
%%writefile training_code/requirements.txt
mlflow==2.13.2
torchinfo
sagemaker-mlflow==0.1.0

## SageMaker HPO and MLflow

In [None]:
hyperparameter_ranges = {
    "lr": ContinuousParameter(0.001, 0.1),
    "batch-size": CategoricalParameter([32, 64, 128, 256, 512]),
}

objective_metric_name = "average test loss"
objective_type = "Minimize"
metric_definitions = [{"Name": "average test loss", "Regex": "Test set: Average loss: ([0-9\\.]+)"}]

Create a MLflow experiment called `MNIST`. We'll give this SageMaker HPO job a run name, `HPODemo`. Each training attempt will be its own child run under `HPODemo`.

In [None]:
mlflow.set_tracking_uri(tracking_server_arn)
experiment = mlflow.set_experiment(experiment_name)

with mlflow.start_run(run_name=sagemaker.utils.name_from_base("HPODemo")) as run:
    estimator = PyTorch(
        entry_point="mnist.py",
        source_dir="training_code",
        role=role,
        py_version="py39",
        framework_version="1.13",
        instance_count=1,
        instance_type="ml.c5.2xlarge",
        hyperparameters={"epochs": 5, "backend": "gloo"},
        environment={
            "MLFLOW_TRACKING_URI": tracking_server_arn,
            "MLFLOW_EXPERIMENT_NAME": experiment.name,
            "MLFLOW_PARENT_RUN_ID": run.info.run_id,
        },
    )

    tuner = HyperparameterTuner(
        estimator,
        objective_metric_name,
        hyperparameter_ranges,
        metric_definitions,
        max_jobs=9,
        max_parallel_jobs=3,
        objective_type=objective_type,
    )
    tuner.fit({"training": train_input})

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.


![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/sagemaker-mlflow|sagemaker_hpo_mlflow.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/sagemaker-mlflow|sagemaker_hpo_mlflow.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/sagemaker-mlflow|sagemaker_hpo_mlflow.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/sagemaker-mlflow|sagemaker_hpo_mlflow.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/sagemaker-mlflow|sagemaker_hpo_mlflow.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/sagemaker-mlflow|sagemaker_hpo_mlflow.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/sagemaker-mlflow|sagemaker_hpo_mlflow.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/sagemaker-mlflow|sagemaker_hpo_mlflow.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/sagemaker-mlflow|sagemaker_hpo_mlflow.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/sagemaker-mlflow|sagemaker_hpo_mlflow.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/sagemaker-mlflow|sagemaker_hpo_mlflow.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/sagemaker-mlflow|sagemaker_hpo_mlflow.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/sagemaker-mlflow|sagemaker_hpo_mlflow.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/sagemaker-mlflow|sagemaker_hpo_mlflow.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/sagemaker-mlflow|sagemaker_hpo_mlflow.ipynb)