## Bonus: Hyperparameter tuning of distributed training with Ray Tune and Ray Train

This is a bonus notebook that shows how to perform hyperparameter tuning of distributed training with Ray Tune and Ray Train.

<img src="https://docs.ray.io/en/latest/_images/train-tuner.svg" width=600>

## Imports

In [None]:
import tempfile
import os
from typing import Any

import torch
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.models import resnet18
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

import ray
from ray import tune, train
from ray.train.torch import TorchTrainer

Here, we will use the example of training a ResNet18 model on the MNIST dataset.

In [None]:
def train_loop_ray_train(config: dict):  # pass in hyperparameters in config
    criterion = CrossEntropyLoss()

    model = resnet18()
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    model = train.torch.prepare_model(model) # Wrap the model in DistributedDataParallel

    global_batch_size = config["global_batch_size"]
    batch_size = global_batch_size // ray.train.get_context().get_world_size()
    optimizer = Adam(model.parameters(), lr=config["lr"])
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    train_data = MNIST(root="./data", train=True, download=True, transform=transform)
    data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
    data_loader = train.torch.prepare_data_loader(data_loader) # Wrap the data loader in a DistributedSampler

    for epoch in range(config["num_epochs"]):
        # Ensure data is on the correct device
        data_loader.sampler.set_epoch(epoch)

        for (
            images,
            labels,
        ) in data_loader:  # images, labels are now sharded across the workers
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()  # Gradients are accumulated across the workers
            optimizer.step()

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            torch.save(
                model.module.state_dict(), os.path.join(temp_checkpoint_dir, "model.pt")
            )
            # Report the loss to Ray Tune
            ray.train.report(
                {"loss": loss.item()},
                checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
            )

We now pass the training loop into the `train.torch.TorchTrainer` to perform distributed training.

In [None]:
trainer = TorchTrainer(
    train_loop_ray_train,
    train_loop_config={"num_epochs": 2, "global_batch_size": 128},
    run_config=train.RunConfig(
        storage_path="/mnt/cluster_storage/dist_train_tune_example/",
        name="tune_example",
    ),
    scaling_config=train.ScalingConfig(
        num_workers=2,
        use_gpu=True,
    ),
)

Turns out a Ray Train trainer is itself a Ray Tune trainable, so we can pass it directly into the `tune.Tuner` as we have done before.

In [None]:
tuner = tune.Tuner(
    trainer,
    param_space={
        "train_loop_config": {
            "num_epochs": 1,
            "global_batch_size": 128,
            "lr": tune.loguniform(1e-4, 1e-1),
        }
    },
    tune_config=tune.TuneConfig(
        mode="min",
        metric="loss",
        num_samples=2,
    ),
)

results = tuner.fit()

best_result = results.get_best_result()
best_result.config

### Clean up

In [None]:
!rm -rf /mnt/cluster_storage/dist_train_tune_example/