In [None]:
!pip install "ray[tune]" torch torchvision pytorch-lightning

In [2]:
import os
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from filelock import FileLock
from torchmetrics import Accuracy
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms

from ray.train.lightning import LightningTrainer, LightningConfigBuilder

In [3]:
# If you want to run full test, please set SMOKE_TEST to False
SMOKE_TEST = True

Our example builds on the MNIST example from the [blog post](https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09) we mentioned before. We adapted the original model and dataset definitions into `MNISTClassifier` and `MNISTDataModule`.

### Ligtning modules initiation

In [4]:
from model.mlp import MLP
from model.resnet import ResNet
import torchvision.datasets as datasets

In [13]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size, train_transform=None, test_transform=None, image_size=None, train_valid_split=None):
        super().__init__()
        self.image_size = image_size if image_size is not None else 28
        self.train_transform = train_transform if train_transform is not None else transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(30),
            transforms.Resize(self.image_size),
            transforms.CenterCrop(self.image_size),
            transforms.ToTensor()
        ])
        self.test_transform = test_transform if test_transform is not None else transforms.Compose([
            transforms.Resize(self.image_size),
            transforms.CenterCrop(self.image_size),
            transforms.ToTensor()
        ])

        self.batch_size = batch_size #config["batch_size"]
        self.train_valid_split = train_valid_split if train_valid_split is not None else 0.8

    def setup(self, stage=None):
        whole_train_dataset = datasets.FashionMNIST(root='data', train=True, transform=self.train_transform, download=True)
        train_size = int(self.train_valid_split * len(whole_train_dataset))
        valid_size = len(whole_train_dataset) - train_size
        self.mnist_train, self.mnist_val = random_split(whole_train_dataset, [train_size, valid_size])
        self.mnist_test = datasets.FashionMNIST(root='data', train=False, transform=self.test_transform, download=True)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=2)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=2)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=2)

In [14]:
class MNISTClassifier(pl.LightningModule):
    def __init__(self, config):
        super(MNISTClassifier, self).__init__()
        self.accuracy = Accuracy('multiclass', num_classes=10)
        self.mlp = MLP(**config["mlp_config"])
        self.resnet = ResNet(**config["resnet_config"])
        self.lr = config["lr"]

        self.validation_step_outputs = []

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def forward(self, x):
        out = self.resnet(x)
        out = self.mlp(out)
        return out

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)

        self.log("ptl/train_loss", loss)
        self.log("ptl/train_accuracy", accuracy)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)
        self.validation_step_outputs.append({"val_loss": loss, "val_accuracy": accuracy})

        return {"val_loss": loss, "val_accuracy": accuracy}

    def on_validation_epoch_end(self):
        avg_loss = torch.stack([x["val_loss"] for x in self.validation_step_outputs]).mean()
        avg_acc = torch.stack([x["val_accuracy"] for x in self.validation_step_outputs]).mean()
        self.log("ptl/val_loss", avg_loss)
        self.log("ptl/val_accuracy", avg_acc)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [15]:
import yaml

with open("model/configs/model.yaml", 'r') as stream:
      default_config=yaml.safe_load(stream)

default_config = default_config['model']

In [16]:
default_config

{'regularization_ratio': 0.5,
 'lr': 0.001,
 'resnet_config': {'first_conv': {'in_channels': 3,
   'out_channels': 64,
   'kernel_size': 5,
   'stride': 2,
   'padding': 2},
  'block_list': [{'in_channels': 64,
    'out_channels': 64,
    'kernel_size': 3,
    'stride': 1,
    'padding': 'same'},
   {'in_channels': 64,
    'out_channels': 128,
    'kernel_size': 3,
    'stride': 2,
    'padding': 1},
   {'in_channels': 128,
    'out_channels': 128,
    'kernel_size': 3,
    'stride': 1,
    'padding': 'same'},
   {'in_channels': 128,
    'out_channels': 256,
    'kernel_size': 3,
    'stride': 2,
    'padding': 1},
   {'in_channels': 256,
    'out_channels': 256,
    'kernel_size': 3,
    'stride': 1,
    'padding': 'same'}],
  'pool_size': 2},
 'mlp_config': {'block_list': [{'in_size': 1024,
    'out_size': 512,
    'activation_fun': 'relu',
    'batch_norm': True,
    'dropout': 0.0},
   {'in_size': 512,
    'out_size': 256,
    'activation_fun': 'none',
    'batch_norm': False,
    

## Tuning the model parameters

The parameters above should give you a good accuracy of over 90% already. However, we might improve on this simply by changing some of the hyperparameters. For instance, maybe we get an even higher accuracy if we used a smaller learning rate and larger middle layer size.

Instead of manually loop through all the parameter combinitions, let's use Tune to systematically try out parameter combinations and find the best performing set.

First, we need some additional imports:

In [17]:
from pytorch_lightning.loggers import TensorBoardLogger
from ray import air, tune
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining

### Configuring the search space

Now we configure the parameter search space using {class}`LightningConfigBuilder <ray.train.lightning.LightningConfigBuilder>`. We would like to choose between three different layer and batch sizes. The learning rate should be sampled uniformly between `0.0001` and `0.1`. The `tune.loguniform()` function is syntactic sugar to make sampling between these different orders of magnitude easier, specifically we are able to also sample small values.

:::{note}
In `LightningTrainer`, the frequency of metric reporting is the same as the frequency of checkpointing. For example, if you set `builder.checkpointing(..., every_n_epochs=2)`, then for every 2 epochs, all the latest metrics will be reported to the Ray Tune session along with the latest checkpoint. Please make sure the target metrics(e.g. metrics specified in `TuneConfig`, schedulers, and searchers) are logged before saving a checkpoint.

:::


:::{note}
Use `LightningConfigBuilder.checkpointing()` to specify the monitor metric and checkpoint frequency for the Lightning ModelCheckpoint callback. To properly save AIR checkpoints, you must also provide an AIR {class}`CheckpointConfig <ray.air.config.CheckpointConfig>`. Otherwise, LightningTrainer will create a default CheckpointConfig, which saves all the reported checkpoints by default.

:::

In [27]:
# The maximum training epochs
num_epochs = 5

# Number of sampls from parameter space
num_samples = 10

accelerator = "gpu"

config = {
    'lr': tune.choice([0.001, 0.01]),
    #'batch_size': tune.choice([64, 128, 256]),
    'resnet_config':
     {
      'first_conv':
        {
          'in_channels': 1,
          'out_channels': 32,
          'kernel_size': 3,
          'stride': 2,
          'padding': 1
        },
      'block_list': [
        {
          'in_channels': 32,
          'out_channels': 16,
          'kernel_size': 3,
          'stride': 2,
          'padding': 1
        },
        {
          'in_channels': 16,
          'out_channels': 8,
          'kernel_size': 3,
          'stride': 1,
          'padding': 'same'
        }
      ],
      'pool_size': 2
    },
    'mlp_config':
     {
        'block_list': [
        {
          'in_size': 72,
          'out_size': 64,
          'activation_fun': 'relu',
          'batch_norm': True,
          'dropout': 0.1
        },
        {
          'in_size': 64,
          'out_size': 10,
          'activation_fun': 'logsoftmax',
          'batch_norm': False,
          'dropout': 0.0
        }
      ]
    },
    'img_key': 0,
    'class_key': 1
  }


If you have more resources available, you can modify the above parameters accordingly. e.g. more epochs, more parameter samples.

In [28]:
if SMOKE_TEST:
    num_epochs = 3
    num_samples = 3
    accelerator = "cpu"

In [29]:
dm = MNISTDataModule(64)
logger = TensorBoardLogger(save_dir=os.getcwd(), name="tune-mnist", version=".")

lightning_config = (
    LightningConfigBuilder()
    .module(cls=MNISTClassifier, config=config)
    .trainer(max_epochs=num_epochs, accelerator=accelerator, logger=logger)
    .fit_params(datamodule=dm)
    .checkpointing(monitor="ptl/val_accuracy", save_top_k=2, mode="max")
    .build()
)

# Make sure to also define an AIR CheckpointConfig here
# to properly save checkpoints in AIR format.
run_config = RunConfig(
    checkpoint_config=CheckpointConfig(
        num_to_keep=2,
        checkpoint_score_attribute="ptl/val_accuracy",
        checkpoint_score_order="max",
    ),
)

### Selecting a scheduler

In this example, we use an [Asynchronous Hyperband](https://blog.ml.cmu.edu/2018/12/12/massively-parallel-hyperparameter-optimization/)
scheduler. This scheduler decides at each iteration which trials are likely to perform
badly, and stops these trials. This way we don't waste any resources on bad hyperparameter
configurations.

In [30]:
scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)

### Training with GPUs

We can specify the number of resources, including GPUs, that Tune should request for each trial.

`LightningTrainer` takes care of environment setup for Distributed Data Parallel training, the model and data will automatically get distributed across GPUs. You only need to set the number of GPUs per worker in `ScalingConfig` and also set `accelerator="gpu"` in LightningTrainerConfigBuilder.

In [31]:
scaling_config = ScalingConfig(
    num_workers=2, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1}
)

In [32]:
if SMOKE_TEST:
    scaling_config = ScalingConfig(
        num_workers=1, use_gpu=False, resources_per_worker={"CPU": 1}
    )

In [33]:
# Define a base LightningTrainer without hyper-parameters for Tuner
lightning_trainer = LightningTrainer(
    scaling_config=scaling_config,
    run_config=run_config,
)

### Putting it together

Lastly, we need to create a `Tuner()` object and start Ray Tune with `tuner.fit()`.

The full code looks like this:

In [None]:
def tune_mnist_asha(num_samples=10):
    scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)

    tuner = tune.Tuner(
        lightning_trainer,
        param_space={"lightning_config": lightning_config},
        tune_config=tune.TuneConfig(
            metric="ptl/val_accuracy",
            mode="max",
            num_samples=num_samples,
            scheduler=scheduler,
        ),
        run_config=air.RunConfig(
            name="tune_mnist_asha",
        ),
    )
    results = tuner.fit()
    best_result = results.get_best_result(metric="ptl/val_accuracy", mode="max")


tune_mnist_asha(num_samples=num_samples)

## Using Population Based Training to find the best parameters

The `ASHAScheduler` terminates those trials early that show bad performance.
Sometimes, this stops trials that would get better after more training steps,
and which might eventually even show better performance than other configurations.

Another popular method for hyperparameter tuning, called
[Population Based Training](https://deepmind.com/blog/article/population-based-training-neural-networks),
instead perturbs hyperparameters during the training run. Tune implements PBT, and
we only need to make some slight adjustments to our code.

In [None]:
def tune_mnist_pbt(num_samples=10):
    # The range of hyperparameter perturbation.
    mutations_config = (
        LightningConfigBuilder()
        .module(
            config={
                "lr": tune.loguniform(1e-4, 1e-1),
            }
        )
        .build()
    )

    # Create a PBT scheduler
    scheduler = PopulationBasedTraining(
        perturbation_interval=1,
        time_attr="training_iteration",
        hyperparam_mutations={"lightning_config": mutations_config},
    )

    tuner = tune.Tuner(
        lightning_trainer,
        param_space={"lightning_config": lightning_config},
        tune_config=tune.TuneConfig(
            metric="ptl/val_accuracy",
            mode="max",
            num_samples=num_samples,
            scheduler=scheduler,
        ),
        run_config=air.RunConfig(
            name="tune_mnist_pbt",
        ),
    )
    results = tuner.fit()
    best_result = results.get_best_result(metric="ptl/val_accuracy", mode="max")
    best_result

In [None]:
tune_mnist_pbt(num_samples=num_samples)

An example output of a run could look like this:

```bash
:emphasize-lines: 12

 +------------------------------+------------+-------+----------------+----------------+---------------------+-----------+--------------------+----------------------+
 | Trial name                   | status     | loc   |   layer_1_size |   layer_2_size |                  lr |      loss |   ptl/val_accuracy |   training_iteration |
 |------------------------------+------------+-------+----------------+----------------+---------------------+-----------+--------------------+----------------------|
 | LightningTrainer_85489_00000 | TERMINATED |       |            64  |            64  | 0.0030@perturbed... | 0.108734  |        0.984954    |                   5  |
 | LightningTrainer_85489_00001 | TERMINATED |       |            32  |            256 | 0.0010@perturbed... | 0.093577  |        0.983411    |                   5  |
 | LightningTrainer_85489_00002 | TERMINATED |       |            128 |            64  | 0.0233@perturbed... | 0.0922348 |        0.983989    |                   5  |
 | LightningTrainer_85489_00003 | TERMINATED |       |            64  |            128 | 0.0002@perturbed... | 0.124648  |        0.98206	  |                   5  |
 | LightningTrainer_85489_00004 | TERMINATED |       |            128 |            256 | 0.0021              | 0.101717  |        0.993248    |                   5  |
 | LightningTrainer_85489_00005 | TERMINATED |       |            32  |            128 | 0.0003@perturbed... | 0.121467  |        0.984182    |                   5  |
 | LightningTrainer_85489_00006 | TERMINATED |       |            128 |            64  | 0.0020@perturbed... | 0.053446  |        0.984375    |                   5  |
 | LightningTrainer_85489_00007 | TERMINATED |       |            64  |            64  | 0.0063@perturbed... | 0.129804  |        0.98669	  |                   5  |
 | LightningTrainer_85489_00008 | TERMINATED |       |            128 |            256 | 0.0436@perturbed... | 0.363236  |        0.982253    |                   5  |
 | LightningTrainer_85489_00009 | TERMINATED |       |            128 |            256 | 0.001               | 0.150946  |        0.985147    |                   5  |
 +------------------------------+------------+-------+----------------+----------------+---------------------+-----------+--------------------+----------------------+
```

As you can see, each sample ran the full number of 5 iterations.
All trials ended with quite good parameter combinations and showed relatively good performances (above `0.98`).
In some runs, the parameters have been perturbed. And the best configuration even reached a mean validation accuracy of `0.993248`!

In summary, AIR LightningTrainer is easy to extend to use with Tune. It only required adding a few lines of code to integrate with Ray Tuner to get great performing parameter configurations.

## More PyTorch Lightning Examples

- {ref}`Use LightningTrainer for Image Classification <lightning_mnist_example>`.
- {ref}`Use LightningTrainer with Ray Data and Batch Predictor <lightning_advanced_example>`
- {ref}`Fine-tune a Large Language Model with LightningTrainer and FSDP <dolly_lightning_fsdp_finetuning>`
- {doc}`/tune/examples/includes/mlflow_ptl_example`: Example for using [MLflow](https://github.com/mlflow/mlflow/)
  and [Pytorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) with Ray Tune.
- {doc}`/tune/examples/includes/mnist_ptl_mini`:
  A minimal example of using [Pytorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning)
  to train a MNIST model. This example utilizes the Ray Tune-provided
  {ref}`PyTorch Lightning callbacks <tune-integration-pytorch-lightning>`.
