# Distributed Training for Stable Diffusion

This notebook demonstrates how to train a Stable Diffusion model using PyTorch Lightning and Ray Train. 

<div class="alert alert-block alert-info">

<b>Here is the roadmap for this notebook:</b>

<ul>
    <li>Part 1: Load the preprocessed data into a Ray Dataset</li>
    <li>Part 2: Define a stable diffusion model</li>
    <li>Part 3: Define a PyTorch Lightning training loop</li>
    <li>Part 4: Migrate the training loop to Ray Train</li>
    <li>Part 5: Create and fit a Ray Train TorchTrainer</li>
</ul>

</div>

## Imports

In [None]:
import os

import pytorch_lightning as pl
import numpy as np
import torch
import torch.nn.functional as F
from diffusers import DDPMScheduler, UNet2DConditionModel
from pytorch_lightning.utilities.types import OptimizerLRScheduler
from transformers import PretrainedConfig, get_linear_schedule_with_warmup

import ray.train
from ray.train.lightning import (
    RayDDPStrategy,
    RayLightningEnvironment,
    RayTrainReportCallback,
)
from ray.train.torch import TorchTrainer, get_device

<img src="https://anyscale-materials.s3.us-west-2.amazonaws.com/stable-diffusion/training_architecture_v3.jpeg" width="700px">

The preceding architecture diagram illustrates the training pipeline for Stable Diffusion. 

It is primarily composed of three main stages:
1. **Streaming data from the preprocessing stage**
2. **Training the model**
3. **Storing the model checkpoints**


## 1. Load the preprocessed data into a Ray Dataset

Let's start by specifying the datasets we want to use. We'll use `parquet` data that was generated using the same preprocessing pipeline.

In [None]:
columns = ["image_latents_256", "caption_latents"]

train_data_uri = (
    "s3://anyscale-public-materials/ray-summit/stable-diffusion/data/preprocessed/256/"
)
train_ds = ray.data.read_parquet(train_data_uri, columns=columns, shuffle="files")
train_ds

<div class="alert alert-block alert-warning">

<b>NOTE:</b> We make use of column pruning by setting `columns=columns` in `read_parquet` to only load the columns we need. Column pruning is a good practice to follow when working with large datasets to reduce memory usage.

</div>

Given pyarrow and in turn parquet does not support saving float16, we need to add a step to convert the float32 columns to float16. 

Halving the precision of the data helps us reduce the memory usage and speed up the training process.

In [None]:
def convert_precision(batch: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    for k, v in batch.items():
        batch[k] = v.astype(np.float16)
    return batch

train_ds = train_ds.map_batches(convert_precision, batch_size=None)

We form a dictionary of the datasets to eventually pass to the trainer.

In [None]:
ray_datasets = {"train": train_ds}

<div class="alert alert-block alert-warning">

<b>NOTE:</b> We did not create a validation dataset in the preprocessing step. Validation can consume valuable GPU hours and resources that could be better utilized for training, especially on high-performance GPUs like the A100. Thoughtful scheduling of validation can help optimize resource usage.

</div>


## 2. Define a stable diffusion model

This "standard" LightningModule does not explicitly refer to Ray or Ray Train, which makes migrating workloads easier.

In [None]:
class StableDiffusion(pl.LightningModule):
    def __init__(
        self,
        lr: float,
        resolution: int,
        weight_decay: float,
        num_warmup_steps: int,
        model_name: str,
    ) -> None:
        self.lr = lr
        self.resolution = resolution
        self.weight_decay = weight_decay
        self.num_warmup_steps = num_warmup_steps
        super().__init__()
        self.save_hyperparameters()
        # Initialize U-Net.
        model_config = PretrainedConfig.get_config_dict(model_name, subfolder="unet")[0]
        self.unet = UNet2DConditionModel(**model_config)
        # Define the training noise scheduler.
        self.noise_scheduler = DDPMScheduler.from_pretrained(
            model_name, subfolder="scheduler"
        )
        # Setup loss function.
        self.loss_fn = F.mse_loss
        self.current_training_steps = 0

    def on_fit_start(self) -> None:
        """Move cumprod tensor to GPU in advance to avoid data movement on each step."""
        self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(
            get_device()
        )

    def forward(
        self, batch: dict[str, torch.Tensor]
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass of the model."""
        # Extract inputs.
        latents = batch["image_latents_256"]
        conditioning = batch["caption_latents"]
        # Sample the diffusion timesteps.
        timesteps = self._sample_timesteps(latents)
        # Add noise to the inputs (forward diffusion).
        noise = torch.randn_like(latents)
        noised_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
        # Forward through the model.
        outputs = self.unet(noised_latents, timesteps, conditioning)["sample"]
        return outputs, noise

    def training_step(
        self, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        """Training step of the model."""
        outputs, targets = self.forward(batch)
        loss = self.loss_fn(outputs, targets)
        self.log(
            "train/loss_mse", loss.item(), prog_bar=False, on_step=True, sync_dist=False
        )
        self.current_training_steps += 1
        return loss

    def configure_optimizers(self) -> OptimizerLRScheduler:
        """Configure the optimizer and learning rate scheduler."""
        optimizer = torch.optim.AdamW(
            self.trainer.model.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay,
        )
        # Set a large training step here to keep lr constant after warm-up.
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.num_warmup_steps,
            num_training_steps=100000000000,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1,
            },
        }

    def _sample_timesteps(self, latents: torch.Tensor) -> torch.Tensor:
        return torch.randint(
            0, len(self.noise_scheduler), (latents.shape[0],), device=latents.device
        )

## 3. Define a PyTorch Lightning training loop

Here is a training loop that is specific to PyTorch Lightning.

It performs the following steps:
1. **Model Initialization:**
   - Instantiate the diffusion model.
2. **Trainer Setup:**
   - Instantiate the Lightning Trainer with a `DDPStrategy` to perform data parallel training.
3. **Training Execution:**
   - Run the trainer using the `fit` method.

In [None]:
def lightning_training_loop(
    train_loader: torch.utils.data.DataLoader,
    storage_path: str,
    model_name: str = "stabilityai/stable-diffusion-2-base",
    resolution: int = 256,
    lr: float = 1e-4,
    max_epochs: int = 1,
    num_warmup_steps: int = 10_000,
    weight_decay: float = 1e-2,
) -> None:
    # 1. Initialize the model
    torch.set_float32_matmul_precision("high")
    model = StableDiffusion(
        model_name=model_name,
        resolution=resolution,
        lr=lr,
        num_warmup_steps=num_warmup_steps,
        weight_decay=weight_decay,
    )

    # 2. Initialize the Lightning Trainer
    trainer = pl.Trainer(
        accelerator="gpu",
        devices="auto",
        precision="bf16-mixed",
        max_epochs=max_epochs,
        default_root_dir=storage_path
    )

    # 3. Run the trainer
    trainer.fit(model=model, train_dataloaders=train_loader)


Here is how we would run the lightning training loop on a single GPU.

```python
pl_compatible_data_loader = train_ds.limit(128).iter_torch_batches(batch_size=8)
storage_path = "/mnt/local_storage/lightning/stable-diffusion-pretraining/"
lightning_training_loop(train_loader=pl_compatible_data_loader, storage_path=storage_path)
```

We can run this on a worker with a GPU using Ray Core, just to check it

In [None]:
@ray.remote(num_gpus=1)
def demo_train():
    pl_compatible_data_loader = train_ds.limit(128).iter_torch_batches(batch_size=8)
    return lightning_training_loop(train_loader=pl_compatible_data_loader, storage_path='/mnt/cluster_storage/demo_single_node_train')

In [None]:
ray.get(demo_train.remote())

Let's inspect the storage path to see what files were created.

In [None]:
!ls /mnt/cluster_storage/demo_single_node_train --recursive

# 4. Migrate the training loop to Ray Train

Let's start by migrating the training loop to Ray Train to achieve distributed data parallel training.

### Distributed Data Parallel Training
Here is a diagram showing the standard distributed data parallel training loop.

<img src="https://anyscale-public-materials.s3.us-west-2.amazonaws.com/ray-ai-libraries/diagrams/multi_gpu_pytorch_v4.png" width=800>

Note how the model state is initially synchronized across all the GPUs before the training loop begins.

Then after each backward pass, the gradients are synchronized across all the GPUs. 

### Ray Train Migration

Here are the changes we need to make to the training loop to migrate it to Ray Train.

In [None]:
def train_loop_per_worker(
    config: dict, # Update the function signature to comply with Ray Train
):  
    # Prepare data loaders using Ray
    train_ds = ray.train.get_dataset_shard("train")
    train_dataloader = train_ds.iter_torch_batches(
        batch_size=config["batch_size_per_worker"],
        drop_last=True,
        prefetch_batches=config["prefetch_batches"],
    )

    # Same model initialization as vanilla lightning
    torch.set_float32_matmul_precision("high")
    model = StableDiffusion(
        lr=config["lr"],
        resolution=config["resolution"],
        weight_decay=config["weight_decay"],
        num_warmup_steps=config["num_warmup_steps"],
        model_name=config["model_name"],
    )

    # Same trainer setup as vanilla lightning except we add Ray Train specific arguments
    trainer = pl.Trainer(
        max_steps=config["max_steps"],
        max_epochs=config["max_epochs"],
        accelerator="gpu",
        precision="bf16-mixed",
        devices="auto",  # Set devices to "auto" to use all available GPUs
        strategy=RayDDPStrategy(),  # Use RayDDPStrategy for distributed data parallel training
        plugins=[
            RayLightningEnvironment()
        ],  # Use RayLightningEnvironment to run the Lightning Trainer
        callbacks=[
            RayTrainReportCallback()
        ],  # Use RayTrainReportCallback to report metrics and checkpoints
        enable_checkpointing=False,  # Disable lightning checkpointing
    )

    # 4. Same as vanilla lightning
    trainer.fit(model, train_dataloaders=train_dataloader)

Here is the same diagram as before but with the Ray Train specific components highlighted.

<img src="https://anyscale-public-materials.s3.us-west-2.amazonaws.com/ray-ai-libraries/diagrams/multi_gpu_lightning_annotated_v5.png" width=800>

We made use of:
- `ray.train.get_dataset_shard("train")` to get the training dataset shard.
- `RayDDPStrategy` to perform distributed data parallel training.
- `RayLightningEnvironment` to run the Lightning Trainer.
- `RayTrainReportCallback` to report metrics and checkpoints.

## 5. Create and fit a Ray Train TorchTrainer

Let's first specify the scaling configuration to tell Ray Train to use 2 GPU training workers.

In [None]:
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)

We then specify the run configuration to tell Ray Train where to store the checkpoints and metrics

In [None]:
storage_path = "/mnt/cluster_storage/"
experiment_name = "stable-diffusion-pretraining"

run_config = ray.train.RunConfig(name=experiment_name, storage_path=storage_path)

Now we can create our Ray Train `TorchTrainer`

In [None]:
train_loop_config = {
    "batch_size_per_worker": 8,
    "prefetch_batches": 2,
    "lr": 0.0001,
    "num_warmup_steps": 10_000,
    "weight_decay": 0.01,
    "max_steps": 550_000,
    "max_epochs": 1,
    "resolution": 256,
    "model_name": "stabilityai/stable-diffusion-2-base",
}

trainer = TorchTrainer(
    train_loop_per_worker,
    train_loop_config=train_loop_config,
    scaling_config=scaling_config,
    run_config=run_config,
    datasets=ray_datasets,
)

Here is a high-level architecture of how Ray Train works:

<img src="https://docs.ray.io/en/latest/_images/overview.png" width=600>

Here are some key points:
- The scaling config specifies the number of training workers.
- A trainer actor process is launched that oversees the training workers.

We call `.fit()` to start the training job.

In [None]:
result = trainer.fit()
result

## Clean up 

Let's clean up the storage path to remove the checkpoints and artifacts we created during this notebook.

In [None]:
!rm -rf /mnt/cluster_storage/stable-diffusion-pretraining