# Introduction to Ray Train + PyTorch
© 2025, Anyscale. All Rights Reserved

This notebook will walk you through the basics of distributed training with Ray Train and PyTorch.

<div class="alert alert-block alert-info">

<b> Here is the roadmap for this notebook </b>

<ol>
  <li>When to use Ray Train</li>
  <li>Single GPU Training with PyTorch</li>
  <li>Distributed Training with Ray Train and PyTorch</li>
  <li>Ray Train in Production</li>
</ol>
</div>

__Install Dependencies__

In [None]:
%%bash
uv pip install -r python_depset.lock --no-cache-dir --no-deps --system

**Imports**

In [None]:
import csv
import datetime
import os   
import tempfile

from pathlib import Path

import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import pandas as pd
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torchvision.models import resnet18
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Normalize, Compose

import ray
from ray.train import ScalingConfig, RunConfig
from ray.train.torch import TorchTrainer

## 1. When to use Ray Train


Use Ray Train when you face one of the following challenges:

|Challenge|Detail|Solution|
|---|---|---|
|**Need to speed up or scale up training**| Training jobs might take a long time to complete, or require a lot of compute | Ray Train provides a **distributed training** framework that allows engineers to scale training jobs to multiple GPUs |
|**Minimize overhead of setting up clusters**| Engineers need to manage the underlying infrastructure | Ray Train **provisions the underlying infrastructure** via Ray's cluster autoscaler. |
|**Achieve observability**| Engineers need to connect to different nodes and GPUs to find the root cause of failures, fetch logs, traces, etc | Ray Train **provides observability** via Ray's dashboard, metrics, and traces that allow engineers to monitor the training job |
|**Ensure reliable training**| Training jobs can fail due to hardware failures, network issues, or other unexpected events | Ray Train **ensures fault tolerance** via checkpointing, automatic retries, and the ability to resume training from the last checkpoint |
|**Avoid significant code rewrite**| Engineers might need to fully rewrite their training loop to support distributed training | Ray Train has **built-in integrations** with the PyTorch ecosystem (Torch, Lightning, Huggingface), Tree-based methods (XGB, LGBM), and more to minimize the amount of code changes needed |


## 2. Single GPU Training with PyTorch

### 2.1. Overview

We will start by fitting a `ResNet18` model to an `MNIST` dataset. Conceptually we will follow the below recipe presented below.

|<img src="https://anyscale-public-materials.s3.us-west-2.amazonaws.com/ray-ai-libraries/diagrams/single_gpu_pytorch_v3.png" width="70%" loading="lazy">|
|:--|
|An overview of the single GPU training process. At a high level, here is how training loop in PyTorch looks like. The key stages include loading the dataset; run the training on mini-batches on a single GPU; saving the model checkpoint to the persistent storage.|

In [None]:
def train_loop_torch(num_epochs: int = 2, batch_size: int = 128, local_path: str = "./checkpoints"):

    # Model, Loss, Optimizer
    criterion = CrossEntropyLoss()
    model = load_model_torch()
    optimizer = Adam(model.parameters(), lr=1e-5)

    # Load the data loader
    data_loader = build_data_loader_torch(batch_size=batch_size)

    # Training loop
    for epoch in range(num_epochs):
        for images, labels in data_loader:

            # Move the data to the GPU
            images, labels = images.to("cuda"), labels.to("cuda")

            # Forward pass
            outputs = model(images)

            # Compute the loss
            loss = criterion(outputs, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()

            # Update the weights
            optimizer.step()

        # Report the metrics
        metrics = report_metrics_torch(loss=loss, epoch=epoch)
        
        # Save the checkpoint and metrics
        Path(local_path).mkdir(parents=True, exist_ok=True)
        save_checkpoint_and_metrics_torch(metrics=metrics, model=model, local_path=local_path)

<div class="alert alert-block alert-info">

Quick notes:

<ul>
    <li><code>report_metrics_torch</code> and <code>save_checkpoint_and_metrics_torch</code> are defined below,</li>
    <li><code>local_path</code> is used for checkpointing. (default) Current working directory simply points to the notebook location (check <code>pwd</code> below).</li>
</ul>
</div>

In [None]:
!pwd

### 2.2. Build model and load it on the GPU

Build [Resnet18](https://pytorch.org/vision/main/models/resnet.html#resnet)

In [None]:
def build_resnet18():
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        in_channels=1, # grayscale MNIST images
        out_channels=64,
        kernel_size=(7, 7),
        stride=(2, 2),
        padding=(3, 3),
        bias=False,
    )
    return model

<div class="alert alert-block alert-info">

resnet18's <code>model.conv1</code> has <code>in_channels=3</code> by default. Here, we work with the <a href="https://pytorch.org/vision/main/generated/torchvision.datasets.MNIST.html#mnist" target="_blank">MNIST</a> grayscale images, thus <code>in_channels=1</code>.
</div>

Load the model on a single GPU

In [None]:
def load_model_torch() -> torch.nn.Module:
    model = build_resnet18()

    # move to the GPU device
    model.to("cuda")
    return model

### 2.3. Create Dataset and DataLoader

In [None]:
dataset = MNIST(root="./data", train=True, download=True)

In [None]:
!tree ./data

Let's display 9 example (image, target) pairs:

In [None]:
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3

for i in range(1, cols * rows + 1):
    sample_idx = np.random.randint(0, len(dataset.data))
    img, label = dataset[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(label)
    plt.axis("off")
    plt.imshow(img, cmap="gray")

Define a DataLoader to apply transformations and load data in batches

In [None]:
def build_data_loader_torch(batch_size: int) -> torch.utils.data.DataLoader:
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    dataset = MNIST(root="./data", train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    return train_loader

### 2.4. Create metrics and checkpointing

Compute and report the metrics using a simple print statement, and also save them to a CSV file.

In [None]:
def report_metrics_torch(loss: torch.Tensor, epoch: int) -> None:
    metrics = {"loss": loss.item(), "epoch": epoch}
    print(metrics)
    return metrics

Save the checkpoint in a previously defined local directory.

In [None]:
def save_checkpoint_and_metrics_torch(metrics: dict[str, float], model: torch.nn.Module, local_path: str) -> None:

    # Save the metrics
    with open(os.path.join(local_path, "metrics.csv"), "a") as f:
        writer = csv.writer(f)
        writer.writerow(metrics.values())

    # Save the model
    checkpoint_path = os.path.join(local_path, "model.pt")
    torch.save(model.state_dict(), checkpoint_path)

### 2.5. Run the training loop
Schedule the training loop on a single GPU

In [None]:
timestamp = datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d_%H-%M-%S")
local_path = f"/mnt/local_storage/single_gpu_mnist/torch_{timestamp}/"

<div class="alert alert-info">

<b>Note about Anyscale storage options</b>

In this example <code>local_path</code> points to the Anyscale's <a href="https://docs.anyscale.com/configuration/storage/#local-storage-for-a-node" target="_blank">local storage</a>. It's a convenient and quick access location for this basic example.

* Anyscale provides each node with its own volume and disk and doesn’t share them with other nodes.
* Local storage is very fast - Anyscale supports the Non-Volatile Memory Express (NVMe) interface.
* This is not a persisent storage, Anyscale deletes data in the local storage after instances are terminated. 

Read more about available <a href="https://docs.anyscale.com/configuration/storage" target="_blank">storage</a> options.
</div>

Start the training:

In [None]:
train_loop_torch(
    num_epochs=1,
    local_path=local_path
)

Let's inspect the produced checkpoints and metrics

In [None]:
!ls -l {local_path}

In [None]:
metrics = pd.read_csv(
    os.path.join(local_path, "metrics.csv"),
    header=None,
    names=["loss", "epoch"],
)

metrics

### 2.6. Use checkpointed model to generate predictions

Load model checkpoint to the device

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
loaded_model = build_resnet18()
loaded_model.load_state_dict(torch.load(os.path.join(local_path, "model.pt")))
loaded_model.to(device)
loaded_model.eval()

Generate predictions on randomly selected 9 images rom the MNIST dataset.

In [None]:
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3

for i in range(1, cols * rows + 1):
    sample_idx = np.random.randint(0, len(dataset.data))
    img, label = dataset[sample_idx]
    normalized_img = Normalize((0.5,), (0.5,))(ToTensor()(img))
    normalized_img = normalized_img.to(device)

    # use loaded model to generate preds
    with torch.no_grad():        
        prediction = loaded_model(normalized_img.unsqueeze(0)).argmax().cpu()

    figure.add_subplot(rows, cols, i)
    plt.title(f"label: {label}; pred: {int(prediction)}")
    plt.axis("off")
    plt.imshow(img, cmap="gray")

## 3. Distributed Data Parallel Training with Ray Train and PyTorch

Let's consider the case where we have a very large dataset of images that would take a long time to train on a single GPU. We would now like to scale this training job to run on multiple GPUs.

|<img src="https://anyscale-public-materials.s3.us-west-2.amazonaws.com/ray-ai-libraries/diagrams/multi_gpu_pytorch_v4.png" width="900px" loading="lazy">|
|:--|
|Schematic overview of DistributedDataParallel (DDP) training: (1) the model is replicated from the <code>GPU rank 0</code> to all other workers; (2) each worker receives a shard of the dataset and processes a mini-batch; (3) during the backward pass, gradients are averaged across GPUs; (4) checkpoint and metrics from rank 0 GPU are saved to the persistent storage.|

<div class="alert alert-block alert-info">
<b>Here is a migration roadmap: from PyTorch DDP to PyTorch with Ray Train</b>

<ol>
    <li>Configure scale and GPUs</li>
    <li>Migrate the model to Ray Train</li>
    <li>Migrate the dataset to Ray Train</li>
    <li>Build checkpoints and metrics reporting</li>
    <li>Configure persistent storage</li>
</ol>
</div>

### 3.1. Overview of the training loop in Ray Train

Let's see how this data-parallel training loop will look like with Ray Train and PyTorch.

In [None]:
def train_loop_ray_train(config: dict):  # pass in hyperparameters in config

    criterion = CrossEntropyLoss()

    # Use Ray Train to wrap the model with DistributedDataParallel
    model = load_model_ray_train()
    optimizer = Adam(model.parameters(), lr=1e-5)

    # Calculate the batch size for each worker
    global_batch_size = config["global_batch_size"]
    world_size = ray.train.get_context().get_world_size()
    batch_size = global_batch_size // world_size
    print(f"{world_size=}\n{batch_size=}")

    # Use Ray Train to wrap the data loader as a DistributedSampler
    data_loader = build_data_loader_ray_train(batch_size=batch_size)

    # Main training loop
    for epoch in range(config["num_epochs"]):

        # Ensure data is on the correct device
        data_loader.sampler.set_epoch(epoch)

        # images, labels are now sharded across the workers
        for images, labels in data_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()

            # gradients are now accumulated across the workers
            loss.backward()
            optimizer.step()

        # Use Ray Train to report metrics
        metrics = print_metrics_ray_train(loss, epoch)

        # Use Ray Train to save checkpoint and metrics
        save_checkpoint_and_metrics_ray_train(model, metrics)

<div class="alert alert-block alert-info">

<b>Main training loop</b>
<ul>
  <li><strong>global_batch_size</strong>: the total number of samples processed in a single training step of the entire training job.
    <ul>
      <li>It's estimated like this: <code>batch size * DDP workers * gradient accumulation steps</code>.</li>
    </ul>
  </li>
  <li>Notice that images and labels are no longer manually moved to device (<code>images.to("cuda")</code>). This is done by 
    <a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.prepare_data_loader.html#ray-train-torch-prepare-data-loader" target="_blank">
      prepare_data_loader()
    </a>.
  </li>
  <li>Config that will be passed here, is defined below. It will be passed to the Ray Train's <a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.TorchTrainer.html#ray-train-torch-torchtrainer" target="_blank">TorchTrainer</a>.</li>
  <li>
    <a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.v2.api.context.TrainContext.html#ray-train-v2-api-context-traincontext" target="_blank">
      TrainContext
    </a> lets users get useful information about the training i.e. node rank, world size, world rank, experiment name.
  </li>

  <li><code>load_model_ray_train</code> and <code>build_data_loader_ray_train</code> are implemented below.</li>
</ul>
</div>

In [None]:
train_loop_config = {
    "num_epochs": 1, 
    "global_batch_size": 128
}

### 3.2. Configure scale and GPUs
Outside of our training function, we create a `ScalingConfig`.

In [None]:
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)

<div class="alert alert-block alert-info">

<a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.ScalingConfig.html#ray-train-scalingconfig" target="_blank">ScalingConfig</a> configures:

<ul>
  <li><code>num_workers</code>: The number of distributed training worker processes.</li>
  <li><code>use_gpu</code>: Whether each worker should use a GPU (or CPU).</li>
</ul>

See docs on configuring <a href="https://docs.ray.io/en/latest/train/user-guides/using-gpus.html" target="_blank">scale and GPUs</a> for more details.
</div>

#### 3.2.1. Note on Ray Train key concepts

Ray Train is built around [four key concepts](https://docs.ray.io/en/latest/train/overview.html):
1. **Training function**: (implemented above `train_loop_ray_train`): A Python function that contains your model training logic.
1. **Worker**: A process that runs the training function.
1. **Scaling config**: specifices number of workers and compute resources (CPUs or GPUs, TPUs).
1. **Trainer**: A Python class (Ray Actor) that ties together the training function, workers, and scaling configuration to execute a distributed training job.

|<img src="https://docs.ray.io/en/latest/_images/overview.png" width="700px" loading="lazy">|
|:--|
|High-level architecture of how Ray Train|

### 3.3. Migrating the model to Ray Train

Use the [`prepare_model()`](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.prepare_model.html#ray-train-torch-prepare-model) utility function to:

* automatically move your model to the correct device,
* wrap the model in PyTorch's DDP or FSDP.

In [None]:
def load_model_ray_train() -> torch.nn.Module:
    model = build_resnet18()
    model = ray.train.torch.prepare_model(model) # Instead of model = model.to("cuda")
    return model

<div class="alert alert-block alert-info">
  <a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.prepare_model.html#ray-train-torch-prepare-model" target="_blank">
    prepare_model()
  </a> allows users to specify additional parameters:
  <ul>
    <li><code>parallel_strategy</code>: "ddp", "fsdp" – wrap models in <code>DistributedDataParallel</code> or <code>FullyShardedDataParallel</code></li>
    <li><code>parallel_strategy_kwargs</code>: pass additional arguments to "ddp" or "fsdp"</li>
  </ul>
  <p>
    With <a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.prepare_model.html#ray-train-torch-prepare-model" target="_blank">
      prepare_model()
    </a> you can use the same code regardless of number of workers or the device type being used (CPU, GPU).
  </p>
</div>

### 3.4. Migrating the dataset to Ray Train

Use the [`prepare_data_loader()`](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.prepare_data_loader.html#ray-train-torch-prepare-data-loader) utility function, to automatically:

* move the batches to the right device,
* copy data from host (CPU) memory to device (GPU) memory,
* pass PyTorch's [`DistributedSampler`](https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler) to the DataLoader, if using more than 1 worker. Each worker will load a subset of the original dataset that is exclusive to it.

[`prepare_data_loader()`](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.prepare_data_loader.html#ray-train-torch-prepare-data-loader) allows users to use the same code regardless of number of workers or the device type being used (CPU, GPU).

In [None]:
def build_data_loader_ray_train(batch_size: int) -> torch.utils.data.DataLoader:
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    train_data = MNIST(root="./data", train=True, download=True, transform=transform)

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)

    # Automatically pass a DistributedSampler instance as a DataLoader sampler
    train_loader = ray.train.torch.prepare_data_loader(train_loader)
    return train_loader

<div class="alert alert-block alert-warning">

<b>Ray Data integration</b>

This step isn't necessary if you are integrating your Ray Train workload with Ray Data. It's especially useful if preprocessing is CPU-heavly and user wants to run preprocessing and training of separate instances.
</div>

### 3.5. Reporting checkpoints and metrics

To monitor progress, we can continue to print/log metrics as before. This time we chose to log from all workers.

In [None]:
def print_metrics_ray_train(loss: torch.Tensor, epoch: int) -> None:
    metrics = {"loss": loss.item(), "epoch": epoch}
    world_rank = ray.train.get_context().get_world_rank() # report from all workers
    print(f"{metrics=} {world_rank=}")
    return metrics

<div class="alert alert-block alert-info">

If you want to log only from the rank 0 worker, use this code:

```python
def print_metrics_ray_train(loss: torch.Tensor, epoch: int) -> None:
    metrics = {"loss": loss.item(), "epoch": epoch}
    if ray.train.get_context().get_world_rank() == 0:  # report only from the rank 0 worker
        print(f"{metrics=} {world_rank=}")
    return metrics
```

</div>

We will report intermediate metrics and checkpoints using the [`ray.train.report`](https://docs.ray.io/en/latest/train/api/doc/ray.train.report.html#ray.train.report) utility function.

In [None]:
def save_checkpoint_and_metrics_ray_train(
    model: torch.nn.Module, metrics: dict[str, float]
) -> None:
    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
        torch.save(
            model.module.state_dict(),  # note the `.module` to unwrap the DistributedDataParallel
            os.path.join(temp_checkpoint_dir, "model.pt"),
        )

        ray.train.report(
            metrics,
            checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
        )

<div class="alert alert-block alert-info">
  <p><strong>Quick notes:</strong></p>
  <ul>
    <li>
      Use 
      <a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.report.html#ray.train.report" target="_blank">
        ray.train.report
      </a> to save the metrics and checkpoint.
    </li>
    <li>Only metrics from the rank 0 worker are reported.</li>
  </ul>
</div>

#### 2.6.1. Note on the checkpoint lifecycle

Here is the lifecycle of a checkpoint from being created using a local path to being uploaded to persistent storage.

<img src="https://docs.ray.io/en/latest/_images/checkpoint_lifecycle.png" width=800>


<div class="alert alert-block alert-info">
  <p><strong>Notes:</strong></p>
  <ul>
    <li>
      Given it is the same model across all workers, we can instead only build the checkpoint on the worker of rank 0.
      Note that we will still need to call 
      <a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.report.html#ray.train.report" target="_blank">
        ray.train.report
      </a> on all workers to ensure that the training loop is synchronized.
    </li>
    <li>Ray Train expects all workers to be able to write files to the same persistent storage location.</li>
    <li>Cloud storage is the recommended persistent storage location.</li>
  </ul>
</div>

In [None]:
def save_checkpoint_and_metrics_ray_train(
    model: torch.nn.Module, metrics: dict[str, float]
) -> None:
    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
        checkpoint = None

        # checkpoint only from rank 0 worker
        if ray.train.get_context().get_world_rank() == 0:
            torch.save(
                model.module.state_dict(), os.path.join(temp_checkpoint_dir, "model.pt")
            )
            checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)

        ray.train.report(
            metrics,
            checkpoint=checkpoint,
        )

Check our guide on [saving and loading checkpoints](https://docs.ray.io/en/latest/train/user-guides/checkpoints.html) for more details and best practices.

### 3.6. Configure remote storage

Create a `RunConfig` object to specify the path where results (including checkpoints and artifacts) will be saved.

In [None]:
storage_path = "/mnt/cluster_storage/training/"
run_config = RunConfig(storage_path=storage_path, name="distributed-mnist-resnet18")

### 3.7. Launching the distributed training job

Distributed data-parallel training, but now using Ray Train.

|<img src="https://anyscale-public-materials.s3.us-west-2.amazonaws.com/ray-ai-libraries/diagrams/multi_gpu_pytorch_annotated_v5.png" width="70%" loading="lazy">|
|:--|
||

We can now launch a distributed training job with a [`TorchTrainer`](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.TorchTrainer.html#ray.train.torch.TorchTrainer).

In [None]:
trainer = TorchTrainer(
    train_loop_ray_train,
    scaling_config=scaling_config,
    run_config=run_config,
    train_loop_config=train_loop_config,
)

Calling `trainer.fit()` will start the run and block until it completes.

We'll be able to observe relevant logs

|<img src="https://assets-training.s3.us-west-2.amazonaws.com/ray-intro/ray-train-intro-logs.png" width="80%" loading="lazy">|
|:--|
||

In [None]:
result = trainer.fit()

### 3.8. Access the training results

After training completes, a `Result` object is returned which contains information about the training run, including the metrics and checkpoints reported during training.

In [None]:
result

In [None]:
!ls /mnt/cluster_storage/training/distributed-mnist-resnet18/

We can check the metrics produced by the training job.

In [None]:
result.metrics_dataframe

### 3.9. Use checkpointed model to generate predictions

We can also take the latest checkpoint and load it to inspect the model.

In [None]:
ckpt = result.checkpoint
with ckpt.as_directory() as ckpt_dir:
    model_path = os.path.join(ckpt_dir, "model.pt")
    loaded_model_ray_train = build_resnet18()
    state_dict = torch.load(model_path, map_location=torch.device('cpu'), weights_only=True)
    loaded_model_ray_train.load_state_dict(state_dict)
    loaded_model_ray_train.to("cuda")
    loaded_model_ray_train.eval()

loaded_model_ray_train

<div class="alert alert-block alert-info">
  <p>
    To learn more about the training results, see this 
    <a href="https://docs.ray.io/en/latest/train/user-guides/results.html" target="_blank">
      docs
    </a> on inspecting the training results.
  </p>
</div>

Generate predictions on randomly selected 9 images rom the MNIST dataset.

In [None]:
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3

for i in range(1, cols * rows + 1):
    sample_idx = np.random.randint(0, len(dataset.data))
    img, label = dataset[sample_idx]
    normalized_img = Normalize((0.5,), (0.5,))(ToTensor()(img))
    normalized_img = normalized_img.to("cuda")

    # use loaded model to generate preds
    with torch.no_grad():        
        prediction = loaded_model_ray_train(normalized_img.unsqueeze(0)).argmax().cpu()

    figure.add_subplot(rows, cols, i)
    plt.title(f"label: {label}; pred: {int(prediction)}")
    plt.axis("off")
    plt.imshow(img, cmap="gray")

[36m(autoscaler +7m30s)[0m Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.


## 4. Ray Train in Production

Here are some use-cases of using Ray Train in production:
1. Canva uses Ray Train + Ray Data to cut down Stable Diffusion training costs by 3.7x. Read this [Anyscale blog post here](https://www.anyscale.com/blog/scalable-and-cost-efficient-stable-diffusion-pre-training-with-ray) and the [Canva  case study here](https://www.anyscale.com/resources/case-study/how-canva-built-a-modern-ai-platform-using-anyscale)
2. Anyscale uses Ray Train + Deepspeed to finetune language models. Read more [here](https://github.com/ray-project/ray/tree/master/doc/source/templates/04_finetuning_llms_with_deepspeed).


In [None]:
# Run this cell for file cleanup 
!rm -rf /mnt/cluster_storage/single_gpu_mnist