In [8]:
# This example introduces how to train a Pytorch Lightning Module using AIR LightningTrainer. 
# We will demonstrate how to train a basic neural network 
# on the MNIST dataset with distributed data parallelism.
!pip install "torchmetrics>=0.9" "pytorch_lightning>=1.6" "filelock" "ray"

Collecting ray
  Downloading ray-2.6.3-cp310-cp310-manylinux2014_x86_64.whl (56.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.9/56.9 MB[0m [31m27.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: ray
Successfully installed ray-2.6.3


In [4]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from filelock import FileLock
from torch.utils.data import DataLoader, random_split, Subset
from torchmetrics import Accuracy
from torchvision.datasets import MNIST
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning import trainer
from pytorch_lightning.loggers.csv_logs import CSVLogger

In [5]:
# Prepate DataSet and Module
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=100):
        super().__init__()
        self.data_dir = os.getcwd()
        self.batch_size = batch_size
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

    def setup(self, stage=None):
        with FileLock(f"{self.data_dir}.lock"):
            mnist = MNIST(
                self.data_dir, train=True, download=True, transform=self.transform
            )

            # split data into train and val sets
            self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        with FileLock(f"{self.data_dir}.lock"):
            self.mnist_test = MNIST(
                self.data_dir, train=False, download=True, transform=self.transform
            )
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)


datamodule = MNISTDataModule(batch_size=128)

In [6]:
# Next, define a simple multi-layer perception as the subclass of pl.LightningModule.
class MNISTClassifier(pl.LightningModule):
    def __init__(self, lr=1e-3, feature_dim=128):
        torch.manual_seed(421)
        super(MNISTClassifier, self).__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, 10),
            nn.ReLU(),
        )
        self.lr = lr
        self.accuracy = Accuracy(task="multiclass", num_classes=10)
        self.eval_loss = []
        self.eval_accuracy = []
        self.test_accuracy = []
        pl.seed_everything(888)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.linear_relu_stack(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.cross_entropy(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        loss, acc = self._shared_eval(val_batch)
        self.log("val_accuracy", acc)
        self.eval_loss.append(loss)
        self.eval_accuracy.append(acc)
        return {"val_loss": loss, "val_accuracy": acc}

    def test_step(self, test_batch, batch_idx):
        loss, acc = self._shared_eval(test_batch)
        self.test_accuracy.append(acc)
        self.log("test_accuracy", acc, sync_dist=True, on_epoch=True)
        return {"test_loss": loss, "test_accuracy": acc}

    def _shared_eval(self, batch):
        x, y = batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y)
        acc = self.accuracy(logits, y)
        return loss, acc

    def on_validation_epoch_end(self):
        avg_loss = torch.stack(self.eval_loss).mean()
        avg_acc = torch.stack(self.eval_accuracy).mean()
        self.log("val_loss", avg_loss, sync_dist=True)
        self.log("val_accuracy", avg_acc, sync_dist=True)
        self.eval_loss.clear()
        self.eval_accuracy.clear()
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [9]:
from pytorch_lightning.callbacks import ModelCheckpoint
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig
from ray.train.lightning import (
    LightningTrainer,
    LightningConfigBuilder,
    LightningCheckpoint,
)


def build_lightning_config_from_existing_code(use_gpu):
    # Create a config builder to encapsulate all required parameters.
    # Note that model instantiation and fitting will occur later in the LightingTrainer,
    # rather than in the config builder.
    config_builder = LightningConfigBuilder()

    # 1. define your model
    # model = MNISTClassifier(lr=1e-3, feature_dim=128)
    config_builder.module(cls=MNISTClassifier, lr=1e-3, feature_dim=128)

    # 2. define a ModelCheckpoint callback
    # checkpoint_callback = ModelCheckpoint(
    #     monitor="val_accuracy", mode="max", save_top_k=3
    # )
    config_builder.checkpointing(monitor="val_accuracy", mode="max", save_top_k=3)

    # 3. Define a Lightning trainer
    # trainer = pl.Trainer(
    #     max_epochs=10,
    #     accelerator="cpu",
    #     strategy="ddp",
    #     log_every_n_steps=100,
    #     logger=CSVLogger("logs"),
    #     callbacks=[checkpoint_callback],
    # )
    config_builder.trainer(
        max_epochs=10,
        accelerator="gpu" if use_gpu else "cpu",
        log_every_n_steps=100,
        logger=CSVLogger("logs"),
    )
    # You do not need to provide the checkpoint callback and strategy here,
    # since LightningTrainer configures them automatically.
    # You can also add any other callbacks into LightningConfigBuilder.trainer().

    # 4. Parameters for model fitting
    # trainer.fit(model, datamodule=datamodule)
    config_builder.fit_params(datamodule=datamodule)

    # Finally, compile all the configs into a dictionary for LightningTrainer
    lightning_config = config_builder.build()
    return lightning_config

In [12]:
# Ray cluster connection
import ray
RAY_URL="ray://ray-cluster-kuberay-head-svc.ray-cluster.svc.cluster.local:10001"
ray.init(address=RAY_URL)
ray.shutdown()

RuntimeError: Python minor versions differ between client and server: client is 3.10.10, server is 3.8.13

In [None]:
use_gpu = True # Set it to False if you want to run without GPUs
num_workers = 4


lightning_config = build_lightning_config_from_existing_code(use_gpu=use_gpu)

scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)

run_config = RunConfig(
    name="ptl-mnist-example",
    storage_path="/tmp/ray_results",
    checkpoint_config=CheckpointConfig(
        num_to_keep=3,
        checkpoint_score_attribute="val_accuracy",
        checkpoint_score_order="max",
    ),
)

trainer = LightningTrainer(
    lightning_config=lightning_config,
    scaling_config=scaling_config,
    run_config=run_config,
)
