# Train a Pytorch Lightning Image Classifier

This example introduces how to train a Pytorch Lightning Module using AIR LightningTrainer. We will demonstrate how to train a basic neural network on the MNIST dataset with distributed data parallelism.

Source: https://docs.ray.io/en/latest/train/examples/lightning/lightning_mnist_example.html

In [1]:
# Requirements:
# ray, torchmetrics, pytorch_lightning

In [2]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from filelock import FileLock # for handling concurrent writes
from torch.utils.data import DataLoader, random_split, Subset
from torchmetrics import Accuracy
from torchvision.datasets import MNIST
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers.csv_logs import CSVLogger


## Prepare Dataset and Module

In [3]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=100):
        super().__init__()
        self.data_dir = os.getcwd()
        self.batch_size = batch_size
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

    def setup(self, stage=None):
        with FileLock(f"{self.data_dir}.lock"):
            mnist = MNIST(
                self.data_dir, train=True, download=True, transform=self.transform
            )

            # split data into train and val sets
            self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        with FileLock(f"{self.data_dir}.lock"):
            self.mnist_test = MNIST(
                self.data_dir, train=False, download=True, transform=self.transform
            )
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)


datamodule = MNISTDataModule(batch_size=128)

Next, define a simple multi-layer perception as the subclass of `pl.LightningModule`.

In [4]:
class MNISTClassifier(pl.LightningModule):
    def __init__(self, lr=1e-3, feature_dim=128):
        torch.manual_seed(421)
        super(MNISTClassifier, self).__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, 10),
            nn.ReLU(),
        )
        self.lr = lr
        self.accuracy = Accuracy(task="multiclass", num_classes=10)
        self.eval_loss = []
        self.eval_accuracy = []
        self.test_accuracy = []
        pl.seed_everything(888)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.linear_relu_stack(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.cross_entropy(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        loss, acc = self._shared_eval(val_batch)
        self.log("val_accuracy", acc)
        self.eval_loss.append(loss)
        self.eval_accuracy.append(acc)
        return {"val_loss": loss, "val_accuracy": acc}

    def test_step(self, test_batch, batch_idx):
        loss, acc = self._shared_eval(test_batch)
        self.test_accuracy.append(acc)
        self.log("test_accuracy", acc, sync_dist=True, on_epoch=True)
        return {"test_loss": loss, "test_accuracy": acc}

    def _shared_eval(self, batch):
        x, y = batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y)
        acc = self.accuracy(logits, y)
        return loss, acc

    def on_validation_epoch_end(self):
        avg_loss = torch.stack(self.eval_loss).mean()
        avg_acc = torch.stack(self.eval_accuracy).mean()
        self.log("val_loss", avg_loss, sync_dist=True)
        self.log("val_accuracy", avg_acc, sync_dist=True)
        self.eval_loss.clear()
        self.eval_accuracy.clear()
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

## Define the Cofigurations for AIR LightningTrainer

The `LightningConfigBuilder` class stores all the parameters involved in training a PyTorch Lightning module. It takes the same parameter lists as those in PyTorch Lightning.

- The `.module()` method takes a subclass of pl.LightningModule and its initialization parameters. `LightningTrainer` will instantiate a model instance internally in the workers’ training loop.

- The `.trainer()` method takes the initialization parameters of `pl.Trainer`. You can specify training configurations, loggers, and callbacks here.

- The `.fit_params()` method stores all the parameters that will be passed into `pl.Trainer.fit()`, including train/val dataloaders, datamodules, and checkpoint paths.

- The `.checkpointing()` method saves the configurations for a `RayModelCheckpoint` callback. This callback reports the latest metrics to the AIR session along with a newly saved checkpoint.

- The `.build()` method generates a dictionary that contains all the configurations in the builder. This dictionary will be passed to `LightningTrainer` later.

In [5]:
from pytorch_lightning.callbacks import ModelCheckpoint
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig
from ray.train.lightning import (
    LightningTrainer,
    LightningConfigBuilder,
    LightningCheckpoint,
)


def build_lightning_config_from_existing_code(use_gpu):
    # Create a config builder to encapsulate all required parameters.
    # Note that model instantiation and fitting will occur later in the LightingTrainer,
    # rather than in the config builder.
    config_builder = LightningConfigBuilder()

    # 1. define your model
    # model = MNISTClassifier(lr=1e-3, feature_dim=128)
    config_builder.module(cls=MNISTClassifier, lr=1e-3, feature_dim=128)

    # 2. define a ModelCheckpoint callback
    # checkpoint_callback = ModelCheckpoint(
    #     monitor="val_accuracy", mode="max", save_top_k=3
    # )
    config_builder.checkpointing(monitor="val_accuracy", mode="max", save_top_k=3)

    # 3. Define a Lightning trainer
    # trainer = pl.Trainer(
    #     max_epochs=10,
    #     accelerator="cpu",
    #     strategy="ddp",
    #     log_every_n_steps=100,
    #     logger=CSVLogger("logs"),
    #     callbacks=[checkpoint_callback],
    # )
    config_builder.trainer(
        max_epochs=10,
        accelerator="gpu" if use_gpu else "cpu",
        log_every_n_steps=100,
        logger=CSVLogger("logs"),
    )
    # You do not need to provide the checkpoint callback and strategy here,
    # since LightningTrainer configures them automatically.
    # You can also add any other callbacks into LightningConfigBuilder.trainer().

    # 4. Parameters for model fitting
    # trainer.fit(model, datamodule=datamodule)
    config_builder.fit_params(datamodule=datamodule)

    # Finally, compile all the configs into a dictionary for LightningTrainer
    lightning_config = config_builder.build()
    return lightning_config

In [6]:
use_gpu = True # Set it to False if you want to run without GPUs
num_workers = 1


In [7]:
lightning_config = build_lightning_config_from_existing_code(use_gpu=use_gpu)

scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)

run_config = RunConfig(
    name="ptl-mnist-example",
    storage_path="/tmp/ray_results",
    checkpoint_config=CheckpointConfig(
        num_to_keep=3,
        checkpoint_score_attribute="val_accuracy",
        checkpoint_score_order="max",
    ),
)

trainer = LightningTrainer(
    lightning_config=lightning_config,
    scaling_config=scaling_config,
    run_config=run_config,
)

In [8]:
result = trainer.fit()
print("Validation Accuracy: ", result.metrics["val_accuracy"])
result

0,1
Current time:,2023-09-06 23:32:04
Running for:,00:00:37.63
Memory:,12.2/31.2 GiB

Trial name,status,loc,iter,total time (s),train_loss,val_accuracy,val_loss
LightningTrainer_d3314_00000,TERMINATED,192.168.1.147:700772,10,31.5043,0.0504907,0.970508,-15.8305


[2m[36m(LightningTrainer pid=700772)[0m Starting distributed worker processes: ['700925 (192.168.1.147)']
[2m[36m(RayTrainWorker pid=700925)[0m Setting up process group for: env:// [rank=0, world_size=1]
[2m[36m(RayTrainWorker pid=700925)[0m [rank: 0] Global seed set to 888
[2m[36m(RayTrainWorker pid=700925)[0m GPU available: True (cuda), used: True
[2m[36m(RayTrainWorker pid=700925)[0m TPU available: False, using: 0 TPU cores
[2m[36m(RayTrainWorker pid=700925)[0m IPU available: False, using: 0 IPUs
[2m[36m(RayTrainWorker pid=700925)[0m HPU available: False, using: 0 HPUs
[2m[36m(RayTrainWorker pid=700925)[0m [rank: 0] Global seed set to 888
[2m[36m(RayTrainWorker pid=700925)[0m You are using a CUDA device ('NVIDIA GeForce RTX 3070') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/genera

Sanity Checking: 0it [00:00, ?it/s])[0m 
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]


[2m[36m(RayTrainWorker pid=700925)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[2m[36m(RayTrainWorker pid=700925)[0m 
[2m[36m(RayTrainWorker pid=700925)[0m   | Name              | Type               | Params
[2m[36m(RayTrainWorker pid=700925)[0m ---------------------------------------------------------
[2m[36m(RayTrainWorker pid=700925)[0m 0 | linear_relu_stack | Sequential         | 101 K 
[2m[36m(RayTrainWorker pid=700925)[0m 1 | accuracy          | MulticlassAccuracy | 0     
[2m[36m(RayTrainWorker pid=700925)[0m ---------------------------------------------------------
[2m[36m(RayTrainWorker pid=700925)[0m 101 K     Trainable params
[2m[36m(RayTrainWorker pid=700925)[0m 0         Non-trainable params
[2m[36m(RayTrainWorker pid=700925)[0m 101 K     Total params
[2m[36m(RayTrainWorker pid=700925)[0m 0.407     Total estimated model params size (MB)


Epoch 0:   0%|          | 0/430 [00:00<?, ?it/s]                           




Epoch 0:   2%|▏         | 8/430 [00:00<00:10, 38.79it/s, v_num=0]
Epoch 0:   7%|▋         | 31/430 [00:00<00:04, 96.46it/s, v_num=0]
Epoch 0:  12%|█▏        | 53/430 [00:00<00:03, 124.82it/s, v_num=0]
Epoch 0:  12%|█▏        | 53/430 [00:00<00:03, 124.77it/s, v_num=0]
Epoch 0:  18%|█▊        | 76/430 [00:00<00:02, 146.56it/s, v_num=0]
Epoch 0:  18%|█▊        | 77/430 [00:00<00:02, 145.61it/s, v_num=0]
Epoch 0:  23%|██▎       | 100/430 [00:00<00:02, 159.32it/s, v_num=0]
Epoch 0:  28%|██▊       | 121/430 [00:00<00:01, 165.23it/s, v_num=0]
Epoch 0:  28%|██▊       | 122/430 [00:00<00:01, 165.79it/s, v_num=0]
Epoch 0:  33%|███▎      | 141/430 [00:00<00:01, 168.28it/s, v_num=0]
Epoch 0:  38%|███▊      | 165/430 [00:00<00:01, 175.45it/s, v_num=0]
Epoch 0:  44%|████▎     | 188/430 [00:01<00:01, 180.86it/s, v_num=0]
Epoch 0:  44%|████▍     | 189/430 [00:01<00:01, 181.27it/s, v_num=0]
Epoch 0:  50%|████▉     | 213/430 [00:01<00:01, 186.76it/s, v_num=0]
Epoch 0:  50%|████▉     | 214/430 [00:01<00

[2m[36m(RayTrainWorker pid=700925)[0m `Trainer.fit` stopped: `max_epochs=10` reached.
2023-09-06 23:32:04,156	INFO tune.py:1148 -- Total run time: 37.65 seconds (37.63 seconds for the tuning loop).


Validation Accuracy:  0.970507800579071


Result(
  metrics={'_report_on': 'train_epoch_end', 'train_loss': 0.05049066245555878, 'val_accuracy': 0.970507800579071, 'val_loss': -15.830514907836914, 'epoch': 9, 'step': 4300, 'should_checkpoint': True, 'done': True, 'trial_id': 'd3314_00000', 'experiment_tag': '0'},
  path='/tmp/ray_results/ptl-mnist-example/LightningTrainer_d3314_00000_0_2023-09-06_23-31-26',
  checkpoint=LightningCheckpoint(local_path=/tmp/ray_results/ptl-mnist-example/LightningTrainer_d3314_00000_0_2023-09-06_23-31-26/checkpoint_000009)
)

## Evaluate your model on test dataset

In [9]:
checkpoint: LightningCheckpoint = result.checkpoint
best_model: pl.LightningModule = checkpoint.get_model(MNISTClassifier)

Global seed set to 888


### Single-node Testing

In [10]:
# Download and setup MNIST datamodule on the head node
datamodule.setup()
test_dataloader = datamodule.test_dataloader()

trainer = pl.Trainer()
result = trainer.test(best_model, dataloaders=test_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3070') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: /home/dino/Documents/Machine-Learning-Collection/ML/Ray-examples/ray-train/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

## Multi-node Testing

In [11]:
import ray
import pytorch_lightning as pl

from pytorch_lightning.plugins.environments import (
    LightningEnvironment,
)
from ray.air.util.torch_dist import (
    TorchDistributedWorker,
    init_torch_dist_process_group,
    shutdown_torch_dist_process_group,
)


class RayEnvironment(LightningEnvironment):
    """Setup Lightning DDP training environment for Ray cluster."""

    def world_size(self) -> int:
        return int(os.environ["WORLD_SIZE"])

    def global_rank(self) -> int:
        return int(os.environ["RANK"])

    def local_rank(self) -> int:
        return int(os.environ["LOCAL_RANK"])

    def set_world_size(self, size: int) -> None:
        # Disable it since `world_size()` directly returns data from AIR session.
        pass

    def set_global_rank(self, rank: int) -> None:
        # Disable it since `global_rank()` directly returns data from AIR session.
        pass

    def teardown(self):
        pass


@ray.remote
class TestWorker(TorchDistributedWorker):
    def run(self):
        trainer = pl.Trainer(
            num_nodes=num_workers,
            accelerator="gpu",
            strategy="ddp",
            plugins=[RayEnvironment()],
        )
        return trainer.test(best_model, dataloaders=test_dataloader)


# Create 4 remote Ray Actors, each with 1 GPU
workers = [TestWorker.options(num_gpus=1).remote() for _ in range(num_workers)]

# Initialize the Torch distributed group among the 4 actors.
# This will set up the required environment variables including 
# RANK, LOCAL_RANK, WORLD_SIZE, MASTER_ADDRESS, ...
init_torch_dist_process_group(workers=workers, backend="nccl")

# Execute the testing run in parallel
results = ray.get([worker.run.remote() for worker in workers])

# Shutdown the process group
shutdown_torch_dist_process_group(workers=workers)

ModuleNotFoundError: No module named 'pytorch_lightning.plugins.environments.lightning_environment'