# Distributed Training Optimizations for Stable Diffusion

This notebook demonstrates certain optimizations that can be applied to the training process to improve performance and reduce costs.

<div class="alert alert-block alert-info">

<b>Here is the roadmap for this notebook:</b>

<ul>
    <li>Part 1: Setup</li>
    <li>Part 2: Using Fully Sharded Data Parallel (FSDP)</li>
    <li>Part 3: Online (end-to-end) preprocessing and training</li>
</ul>

</div>

## Imports

In [None]:
import os

import lightning.pytorch as pl
import numpy as np
import torch
import torch.nn.functional as F
from diffusers import DDPMScheduler, UNet2DConditionModel
from transformers import PretrainedConfig, get_linear_schedule_with_warmup
from lightning.pytorch.utilities.types import OptimizerLRScheduler

import ray.train
from torch.distributed.fsdp import BackwardPrefetch
from ray.train.lightning import RayLightningEnvironment, RayTrainReportCallback, RayFSDPStrategy
from ray.train.torch import TorchTrainer, get_device

## 1. Setup

Let's begin with the same code as in the basic pretraining notebook.

We first load the dataset and convert the precision to float16.

In [None]:
def convert_precision(batch: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    for k, v in batch.items():
        batch[k] = v.astype(np.float16)
    return batch


columns = ["image_latents_256", "caption_latents"]

train_data_uri = (
    "s3://anyscale-public-materials/ray-summit/stable-diffusion/data/preprocessed/256/"
)
train_ds = ray.data.read_parquet(train_data_uri, columns=columns, shuffle="files")
train_ds = train_ds.map_batches(convert_precision, batch_size=None)

ray_datasets = {"train": train_ds}

We then define the model.

In [None]:
class StableDiffusion(pl.LightningModule):

    def __init__(
        self,
        lr: float,
        resolution: int,
        weight_decay: float,
        num_warmup_steps: int,
        model_name: str,
    ) -> None:
        self.lr = lr
        self.resolution = resolution
        self.weight_decay = weight_decay
        self.num_warmup_steps = num_warmup_steps
        super().__init__()
        self.save_hyperparameters()
        # Initialize U-Net.
        model_config = PretrainedConfig.get_config_dict(model_name, subfolder="unet")[0]
        self.unet = UNet2DConditionModel(**model_config)
        # Define the training noise scheduler.
        self.noise_scheduler = DDPMScheduler.from_pretrained(
            model_name, subfolder="scheduler"
        )
        # Setup loss function.
        self.loss_fn = F.mse_loss
        self.current_training_steps = 0

    def on_fit_start(self) -> None:
        """Move cumprod tensor to GPU in advance to avoid data movement on each step."""
        self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(
            get_device()
        )

    def forward(
        self, batch: dict[str, torch.Tensor]
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass of the model."""
        # Extract inputs.
        latents = batch["image_latents_256"]
        conditioning = batch["caption_latents"]
        # Sample the diffusion timesteps.
        timesteps = self._sample_timesteps(latents)
        # Add noise to the inputs (forward diffusion).
        noise = torch.randn_like(latents)
        noised_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
        # Forward through the model.
        outputs = self.unet(noised_latents, timesteps, conditioning)["sample"]
        return outputs, noise

    def training_step(
        self, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        """Training step of the model."""
        outputs, targets = self.forward(batch)
        loss = self.loss_fn(outputs, targets)
        self.log(
            "train/loss_mse", loss.item(), prog_bar=False, on_step=True, sync_dist=False
        )
        self.current_training_steps += 1
        return loss

    def configure_optimizers(self) -> OptimizerLRScheduler:
        """Configure the optimizer and learning rate scheduler."""
        optimizer = torch.optim.AdamW(
            self.trainer.model.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay,
        )
        # Set a large training step here to keep lr constant after warm-up.
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.num_warmup_steps,
            num_training_steps=100000000000,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1,
            },
        }

    def _sample_timesteps(self, latents: torch.Tensor) -> torch.Tensor:
        return torch.randint(
            0, len(self.noise_scheduler), (latents.shape[0],), device=latents.device
        )

## 2. Using Fully Sharded Data Parallel (FSDP)

Ray Train also supports Fully Sharded Data Parallel (FSDP) for distributed training.

FSDP is a new training paradigm that is designed to improve the performance of large-scale training by reducing the memory footprint of the model by sharding the model parameters across different GPUs.

Here is a diagram to help illustrate how FSDP works.

<img src="https://user-images.githubusercontent.com/26745457/236892936-d4b91751-4689-421e-ac5f-edfd2eeeb635.png" width=800>

### FSDP configuration:

#### Sharding strategy:

There are three different modes of the FSDP sharding strategy:

1. `NO_SHARD`: Parameters, gradients, and optimizer states are not sharded. Similar to DDP.
2. `SHARD_GRAD_OP`: Gradients and optimizer states are sharded during computation, and additionally, parameters are sharded outside computation. Similar to ZeRO stage-2.
3. `FULL_SHARD`: Parameters, gradients, and optimizer states are sharded. It has minimal GRAM usage among the 3 options. Similar to ZeRO stage-3.

#### Auto-wrap policy:

Model layers are often wrapped with FSDP in a layered fashion. This means that only the layers in a single FSDP instance are required to aggregate all parameters to a single device during forwarding or backward calculations.

Depending on the model architecture, we might need to specify a custom auto-wrap policy.

For example, we can use the `transformer_auto_wrap_policy` to automatically wrap each Transformer Block into a single FSDP instance.

#### Overlap communication with computation:

You can specify to overlap the upcoming all-gather while executing the current forward/backward pass. It can improve throughput but may slightly increase peak memory usage. Set `backward_prefetch` and `forward_prefetch` to overlap communication with computation.




Let's update our training loop to use FSDP.

In [None]:
def train_loop_per_worker_fsdp(config):
    train_ds = ray.train.get_dataset_shard("train")
    train_dataloader = train_ds.iter_torch_batches(
        batch_size=config["batch_size_per_worker"],
        drop_last=True,
    )

    torch.set_float32_matmul_precision("high")
    model = StableDiffusion(
        lr=config["lr"],
        resolution=config["resolution"],
        weight_decay=config["weight_decay"],
        num_warmup_steps=config["num_warmup_steps"],
        model_name=config["model_name"],
    )

    trainer = pl.Trainer(
        max_steps=config["max_steps"],
        max_epochs=config["max_epochs"],
        accelerator="gpu",
        devices="auto",
        precision="bf16-mixed",
        strategy=RayFSDPStrategy( # Use RayFSDPStrategy instead of RayDDPStrategy
            sharding_strategy="SHARD_GRAD_OP", # Run FSDP with SHARD_GRAD_OP sharding strategy
            backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Overlap communication with computation in backward pass
        ),
        plugins=[RayLightningEnvironment()],
        callbacks=[RayTrainReportCallback()],
        enable_checkpointing=False,
    )

    trainer.fit(model, train_dataloaders=train_dataloader)


Let's run the training loop with FSDP.

In [None]:
storage_path = "/mnt/cluster_storage/"
experiment_name = "stable-diffusion-pretraining-fsdp"

train_loop_config = {
    "batch_size_per_worker": 8,
    "prefetch_batches": 2,
    "every_n_train_steps": 10, # Report metrics and checkpoints every 10 steps
    "lr": 0.0001,
    "num_warmup_steps": 10_000,
    "weight_decay": 0.01,
    "max_steps": 550_000,
    "max_epochs": 1,
    "resolution": 256,
    "model_name": "stabilityai/stable-diffusion-2-base",
}

run_config = ray.train.RunConfig(name=experiment_name, storage_path=storage_path)

scaling_config = ray.train.ScalingConfig(
    num_workers=2,
    use_gpu=True,
)

trainer = TorchTrainer(
    train_loop_per_worker_fsdp,
    train_loop_config=train_loop_config,
    scaling_config=scaling_config,
    run_config=run_config,
    datasets=ray_datasets,
)

result = trainer.fit()

Let's load the model from the checkpoint and inspect it.

In [None]:
with result.checkpoint.as_directory() as checkpoint_dir:
    ckpt_path = os.path.join(checkpoint_dir, "checkpoint.ckpt")
    model = StableDiffusion.load_from_checkpoint(ckpt_path, map_location="cpu")
    print(model)

## 3. Online (end-to-end) preprocessing and training

Looking ahead at more challenging Stable Diffusion training pipelines, we will need to handle data in a more sophisticated way.



<img>

<img src="https://anyscale-public-materials.s3.us-west-2.amazonaws.com/ray-summit/stable-diffusion/diagrams/sdxl_random_crop_limitation.png">

<img src="https://anyscale-public-materials.s3.us-west-2.amazonaws.com/ray-summit/stable-diffusion/diagrams/moving_preprocessors_to_gpu.png">

<img src="https://anyscale-public-materials.s3.us-west-2.amazonaws.com/ray-summit/stable-diffusion/diagrams/online_preprocessing_as_a_solution.png">

### Resources for online preprocessing and training

Check out the following resources for more details:

- [Reducing the Cost of Pre-training Stable Diffusion by 3.7x with Anyscale](https://www.anyscale.com/blog/scalable-and-cost-efficient-stable-diffusion-pre-training-with-ray)
- [Pretraining Stable Diffusion (V2) workspace template](https://console.anyscale.com/v2/template-preview/stable-diffusion-pretraining)
- [Processing 2 Billion Images for Stable Diffusion Model Training - Definitive Guides with Ray Series](https://www.anyscale.com/blog/processing-2-billion-images-for-stable-diffusion-model-training-definitive-guides-with-ray-series)
- [We Pre-Trained Stable Diffusion Models on 2 billion Images and Didn't Break the Bank - Definitive Guides with Ray Series](https://www.anyscale.com/blog/we-pre-trained-stable-diffusion-models-on-2-billion-images-and-didnt-break-the-bank-definitive-guides-with-ray-series)

