# Ray Train + Ray Data
© 2025, Anyscale. All Rights Reserved

This notebook will walk you through integrating Ray Train with Ray Data.

<div class="alert alert-block alert-info">

<b> Here is the roadmap for this notebook </b>

<ol>
  <li>When to integrate Ray Train with Ray Data</li>
  <li>Architecture</li>
  <li>Integrating Ray Train with Ray Data</li>
  <li>Training with Ray Train and Ray Data</li>
</ol>
</div>

**Imports**

In [None]:
import os   
import tempfile

from PIL import Image
import numpy as np
import pandas as pd
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torchvision.models import resnet18
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Normalize, Compose

import ray
from ray.train import ScalingConfig, RunConfig
from ray.train.torch import TorchTrainer

**Utilities**

In [None]:
def build_resnet18():
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        in_channels=1,  # grayscale MNIST images
        out_channels=64,
        kernel_size=(7, 7),
        stride=(2, 2),
        padding=(3, 3),
        bias=False,
    )
    return model


def load_model_ray_train() -> torch.nn.Module:
    model = build_resnet18()
    model = ray.train.torch.prepare_model(model)  # Instead of model = model.to("cuda")
    return model


def print_metrics_ray_train(loss: torch.Tensor, epoch: int) -> None:
    metrics = {"loss": loss.item(), "epoch": epoch}
    world_rank = ray.train.get_context().get_world_rank()  # report from all workers
    print(f"{metrics=} {world_rank=}")
    return metrics

def save_checkpoint_and_metrics_ray_train(
    model: torch.nn.Module, metrics: dict[str, float]
) -> None:
    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
        torch.save(
            model.module.state_dict(),  # note the `.module` to unwrap the DistributedDataParallel
            os.path.join(temp_checkpoint_dir, "model.pt"),
        )

        ray.train.report(
            metrics,
            checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
        )

## 1. When to integrate Ray Train with Ray Data
Use both Ray Train and Ray Data when you face one of the following challenges:
| Challenge | Detail | Solution |
| --- | --- | --- |
| Need to perform online or just-in-time data processing | The training pipeline requires processing data on the fly, such as data augmentation, normalization, or other transformations that may differ for each training epoch. | Ray Train's integration with Ray Data makes it easy to implement just-in-time data processing. |
| Need to improve hardware utilization | Training and data processing need to be scaled independently to keep GPUs fully utilized, especially when preprocessing is CPU-intensive. | Ray Data can distribute data processing across multiple CPU nodes, while Ray Train runs the training loop on GPUs. |
| Need a consistent interface for loading data | The training process may need to load data from various sources, such as Parquet, CSV, or lakehouses. | Ray Data provides a consistent interface for loading, shuffling, sharding, and batching data for training loops. |

## 2. Architecture

Here is a diagram showing the a sample Ray Data and Ray Train integration

<img src="https://anyscale-materials.s3.us-west-2.amazonaws.com/ray-train-deep-dive/ray_train_v2_architecture.png" width="900" loading="lazy">

## 3. Integrating Ray Train with Ray Data

Here is how our training loop will look like using **Ray Data** instead of the **PyTorch DataLoader**:

In [None]:
def train_loop_ray_train_ray_data(config: dict):
    # Same initialization as before
    criterion = CrossEntropyLoss()
    model = load_model_ray_train()
    optimizer = Adam(model.parameters(), lr=1e-3)
    
    # This time we use Ray Train's integration with Ray Data to load the data
    global_batch_size = config["global_batch_size"]
    batch_size = global_batch_size // ray.train.get_context().get_world_size()
    data_loader = build_data_loader_ray_train_ray_data(batch_size=batch_size) 
    

    for epoch in range(config["num_epochs"]):
        # No longer need to ensure data is on the correct device
        # data_loader.sampler.set_epoch(epoch)

        # Note our batches are now dictionaries instead of tuples
        for batch in data_loader: 
            outputs = model(batch["image"])
            loss = criterion(outputs, batch["label"])
            optimizer.zero_grad()
            loss.backward() 
            optimizer.step()


        metrics = print_metrics_ray_train(loss, epoch)
        save_checkpoint_and_metrics_ray_train(model, metrics)


Here is the updated `build_data_loader_ray_train_ray_data` function that uses Ray Data to load the data:

In [None]:
def build_data_loader_ray_train_ray_data(
    batch_size: int, prefetch_batches: int = 2
):
    dataset_iterator = ray.train.get_dataset_shard("train")
    data_loader = dataset_iterator.iter_torch_batches(
        batch_size=batch_size, prefetch_batches=prefetch_batches
    )
    return data_loader

<div class="alert alert-block alert-info">

**Note** Use the `iter_torch_batches` function to build a torch compatible data loader.

</div>

### 2.1 Preparing the data
Let's store the training data in a format that Ray Data can easily read. 

Let's use the Parquet format, which is a columnar storage format that is efficient for reading and writing data.

In [None]:
torch_dataset = MNIST(root="./data", train=True, download=True)
df = pd.DataFrame({"image": torch_dataset.data.tolist(), "label": torch_dataset.targets})
df.to_parquet("/mnt/cluster_storage/mnist.parquet")

Next, construct a Ray Data Dataset from the Parquet source.

In [None]:
train_ds = ray.data.read_parquet("/mnt/cluster_storage/mnist.parquet")

Perform the same preprocessing steps that pytorch data loader does.

In [None]:
def transform_images(row: dict):
    # Define the torchvision transform.
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    image_arr = np.array(row["image"], dtype=np.uint8)
    row["image"] = transform(Image.fromarray(image_arr))
    return row

<div class="alert alert-block alert-info">

**Note** Unlike the PyTorch DataLoader, the preprocessing can now occur on any node in the cluster.

The data will be passed to training workers via the ray object store (a distributed in-memory object store).

<div>

In [None]:
train_ds = train_ds.map(transform_images)

## 4. Training with Ray Train and Ray Data

Pass the constructed `train_ds` to the `TorchTrainer` via the `datasets` parameter.

In [None]:
datasets = {"train": train_ds}
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)
storage_path = "/mnt/cluster_storage/training/"
trainer = TorchTrainer(
    train_loop_ray_train_ray_data,
    train_loop_config={"num_epochs": 1, "global_batch_size": 512},
    scaling_config=scaling_config,
    run_config=RunConfig(storage_path=storage_path, name="dist-mnist-res18-ray-data"),
    datasets=datasets,
)


Calling `trainer.fit()` will now use Ray Data to load and shard the data.

In [None]:
trainer.fit()