Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.


# Train on Low-Priority AML Compute

_**A tutorial on training models using low-priority AML compute**_

## Contents
1. [Introduction](#Introduction)
1. [MNIST Dataset](#MNIST-Dataset)
1. [Setup](#Setup)
1. [Model Checkpointing](#Model-Checkpointing)
1. [Create a Low-Priority Cluster](#Create-a-Low-Priority-Cluster)
1. [Experiment](#Experiment)
1. [Handling Metrics](#Handling-Metrics)

## Introduction

This notebook discusses training models using [low-priority compute](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-manage-optimize-cost#low-pri-vm). Training on low-priority compute can cost much less than on dedicated compute.

When creating a compute cluster, its priority must be specified. The priority can either be _dedicated_ (default) or _low priority_. (See [this API doc](https://docs.microsoft.com/en-us/python/api/azureml-core/azureml.core.compute.amlcompute.amlcomputeprovisioningconfiguration?view=azure-ml-py) for details.)

A job using dedicated compute is granted uninterrupted access to a VM. In contrast, a job using low-priority compute is not. A job using low-priority compute is only granted temporary access to a VM. The hardware used by low-priority jobs may be ceded to / preempted by higher priority jobs, depending on region-wide availability of hardware and other factors.

When a job using low-priority compute is preempted, that job must stop running. The job waits for low-priority compute to become available again in the region. Once compute again becomes available, something notable happens: the run restarts "from scratch" on the new compute. In other words, previous state on the compute is lost, and the submitted Python script is re-invoked fresh on the new compute.

This means that, without special handling, every time a training run is preempted and restarted, the model will start training again from scratch. For instance, when training a DNN model, every time the run is preempted and restarted, the run will start training from epoch 0 again. For this reason, we need some special handling in the training script to save state between preemptions, so previous work can be reused after restarts.

## MNIST Dataset

All models will be trained on the [MNIST dataset](http://yann.lecun.com/exdb/mnist/), which consists of images of handwritten digits.

## Setup

Let's import some dependencies we'll need:

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

from azureml.core import Experiment, Workspace
from azureml.core.compute import AmlCompute, ComputeTarget
from azureml.core.compute_target import ComputeTargetException
from helper import launch_run

As part of the setup you have already created an Azure ML `Workspace`. Let's also create an `Experiment`.

In [None]:
# Initialize the workspace
ws = Workspace.from_config()

# Choose a name for experiment
experiment_name = "low-priority-compute-tutorial"

# Initialize the experiment
experiment = Experiment(ws, experiment_name)

output = {}
output["Subscription ID"] = ws.subscription_id
output["Workspace"] = ws.name
output["Resource Group"] = ws.resource_group
output["Location"] = ws.location
output["Experiment Name"] = experiment.name
pd.set_option("display.max_colwidth", -1)
outputDf = pd.DataFrame(data=output, index=[""])
outputDf.T

## Model Checkpointing

### How to Checkpoint

When training a DNN on low-priority compute, you can checkpoint the training state periodically.

In the [training script](training_script/training_script.py) used with this notebook, we checkpoint the state after each epoch:

```python
for epoch in range(...):
    ...
    train_epoch(...)

    save_checkpoint(model, optimizer, epoch)
    ...
```

`save_checkpoint` is implemented as:
```python
MODEL_CHECKPOINT_PATH = "model_checkpoints/checkpoint.pt"

def save_checkpoint(model, optimizer, epoch):
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        },
        MODEL_CHECKPOINT_PATH,
    )
    ...
```

This code saves a copy of of the relevant information to `MODEL_CHECKPOINT_PATH`. This info includes the latest model weights and optimizer state.


### Uploading the Checkpoint to the Cloud

Saving to `MODEL_CHECKPOINT_PATH` on the compute triggers an automatic upload of the model to blob storage. This is because when we launch a script run in [helper.py](helper.py) , we pass an `OutputFileDatasetConfig`:

```python
output_dataset_config = OutputFileDatasetConfig(
    name="model_checkpoints",
    destination=output_dataset_destination,
    source="model_checkpoints/",
)

...

src = ScriptRunConfig(
    ...
    arguments=[output_dataset_config, ...],
    ...
)
```

Under the hood, this ensures that the blob folder represented by `destination` is mounted to the local `model_checkpoints` directory on the compute. This means that, whenever a file is written to the `model_checkpoints` directory on the compute, it will automatically be uploaded to `destination` relative path in the default blob storage account. Also, any file available in the `destination` blob storage directory is available in the `model_checkpoints` directory on the compute. Note that if the `destination` parameter is empty, the default destination is `/dataset/{run_id}/model_checkpoints/` in the default blob storage account. See the documentation on [OutputFileDatasetConfig](https://docs.microsoft.com/en-us/python/api/azureml-core/azureml.data.output_dataset_config.outputfiledatasetconfig?view=azure-ml-py) for more info.


### Loading Checkpoints at the Start of a Run

Finally, at the start of our training script, we have:

```python
if checkpoint_file_exists:
    ...
    model, optimizer, starting_epoch = load_checkpoint()
    starting_epoch += 1
else:
    model = init_model()
    optimizer = init_optimizer(model)
    starting_epoch = 0
```

This means that, when the script starts, by default, it will check to see if a checkpoint already exists in storage. If a saved checkpoint exists, the script will load the saved state, and continue from there. Otherwise, the script will start training a new model from scratch.

## Create a Low-Priority Cluster

Define a function to create a compute cluster:

In [None]:
def create_compute_cluster(workspace, name, vm_priority="dedicated"):
    """Create a compute cluster."""
    try:
        compute_target = ComputeTarget(workspace=ws, name=name)
        print(f"Found existing cluster for {name} -- using it")
        return compute_target
    except ComputeTargetException:
        compute_config = AmlCompute.provisioning_configuration(
            vm_size="Standard_NC6", max_nodes=4, vm_priority=vm_priority
        )
        compute_target = ComputeTarget.create(ws, name, compute_config)
        print(f"Creating cluster {name}")
        compute_target.wait_for_completion(show_output=True)
        return compute_target

Invoke the function to create a low-priority compute cluster:

In [None]:
low_pri_compute = create_compute_cluster(
    workspace=ws, name="low-pri-compute-cluster", vm_priority="lowpriority"
)

## Experiment

### Introduction

Let's do an experiment simulating the preemption and restart that would happen on low-priority compute.

Preemption of low-priority compute isn't something that we can demonstrate directly. Azure decides to preempt compute based on region-wide availability of hardware and other factors. However, we can try to simulate preemption using dedicated compute.

### Setup

Create a high priority compute cluster:

In [None]:
compute = create_compute_cluster(workspace=ws, name="gpu-cluster")

### First Run
Launch a run for 2 epochs on dedicated compute.

In [None]:
run1 = launch_run(experiment=experiment, compute_target=compute, num_epochs=2)

run1

In [None]:
run1.wait_for_completion()

Checking metrics:

In [None]:
run1_metrics = run1.get_metrics()
run1_metrics["training_epoch"]

We can see that the run only trained epochs 0 and 1.

### Second Run

Let's start a second run that continues where the first run left off.

The first run should have saved a checkpoint in the `/dataset/{run1.id}/model_checkpoints/` directory in the default blob storage account. When we start the second run, let's override the checkpoint path in storage from its default of `/dataset/{run2.id}/model_checkpoints/` to `/dataset/{run1.id}/model_checkpoints/`. The second run should continue training the model from where the first run left off. Let's see if that's the case.

Kicking off the second run:

In [None]:
run2 = launch_run(
    experiment=experiment,
    compute_target=compute,
    num_epochs=4,
    output_dataset_storage_path=f"/dataset/{run1.id}/model_checkpoints/",
)

run2

In [None]:
run2.wait_for_completion()

Checking the metrics:

In [None]:
run2_metrics = run2.get_metrics()
run2_metrics["training_epoch"]

We see that, as expected, the second run trained 2 epochs, and it picked up where the first run left off.

### Validation Accuracy

We can also plot the model's accuracy on the validation set after each epoch:

In [None]:
plt.plot(
    run1_metrics["training_epoch"],
    run1_metrics["validation_accuracy"],
    marker="o",
    linestyle="",
    ms=12,
    label="Run 1",
)
plt.plot(
    run2_metrics["training_epoch"],
    run2_metrics["validation_accuracy"],
    marker="o",
    linestyle="",
    ms=12,
    label="Run 2",
)
plt.legend()
plt.title("Epoch vs Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Validaion Accuracy")
plt.show()

From the chart above, we can see that the model's accuracy on the validation set improved with each epoch.

## Handling Metrics

When a preempted job is resumed, metrics from previous runs are still present. This is something to keep in mind.

You probably also want to ensure all metrics are flushed (a.k.a. sent from local compute to the metrics servers) around the same time that the model checkpoint is saved. This reduces the likelihood of a race condition where the checkpoint is saved and the run is preempted, but recent metrics for the run haven't yet been flushed and hence are lost.

In the training script, you can see that we flush metrics right before saving the checkpoint:

```python
run.flush()
save_checkpoint(model, optimizer, epoch)
```