# Fault Tolerance in Ray Train + PyTorch
© 2025, Anyscale. All Rights Reserved

This notebook will walk you through ensuring fault tolerance in Ray Train + PyTorch.

<div class="alert alert-block alert-info">

<b> Here is the roadmap for this notebook </b>

<ol>
  <li>Overview of fault tolerance in Ray Train</li>
  <li>Automatic retries</li>
  <li>Manual restoration</li>
</ol>
</div>

**Imports**

In [None]:
import os   
import tempfile

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torchvision.models import resnet18
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Normalize, Compose

import ray
from ray.train import ScalingConfig, RunConfig
from ray.train.torch import TorchTrainer

**Utilities**

In [None]:
def build_resnet18():
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        in_channels=1,  # grayscale MNIST images
        out_channels=64,
        kernel_size=(7, 7),
        stride=(2, 2),
        padding=(3, 3),
        bias=False,
    )
    return model


def load_model_ray_train() -> torch.nn.Module:
    model = build_resnet18()
    model = ray.train.torch.prepare_model(model)  # Instead of model = model.to("cuda")
    return model


def build_data_loader_ray_train(batch_size: int) -> torch.utils.data.DataLoader:
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    train_data = MNIST(root="./data", train=True, download=True, transform=transform)

    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=batch_size, shuffle=True, drop_last=True
    )

    # Automatically pass a DistributedSampler instance as a DataLoader sampler
    train_loader = ray.train.torch.prepare_data_loader(train_loader)
    return train_loader

def print_metrics_ray_train(loss: torch.Tensor, epoch: int) -> None:
    metrics = {"loss": loss.item(), "epoch": epoch}
    world_rank = ray.train.get_context().get_world_rank() # report from all workers
    print(f"{metrics=} {world_rank=}")
    return metrics

## 1. Overview of fault tolerance in Ray Train

Ray Train provides two main mechanisms to handle failures:

- Automatic retries
- Manual restoration

Here is a diagram showing these two primary mechanisms:

<img src="https://anyscale-public-materials.s3.us-west-2.amazonaws.com/ray-summit/stable-diffusion/diagrams/fault_tolerant_cropped_v2.png" width=800>


## 2. Automatic retries

### 2.1 Modifying the Training Loop to Enable Checkpoint Loading

We need to make use of `get_checkpoint()` in the training loop to enable checkpoint loading for fault tolerance.

Here is how the modified training loop looks like.

In [None]:
def train_loop_ray_train_with_checkpoint_loading(config: dict):
    # Same initialization of loss, model, optimizer as before
    criterion = CrossEntropyLoss()
    model = load_model_ray_train()
    optimizer = Adam(model.parameters(), lr=1e-3)

    # Same initialization of the data loader as before
    global_batch_size = config["global_batch_size"]
    batch_size = global_batch_size // ray.train.get_context().get_world_size()
    data_loader = build_data_loader_ray_train(batch_size=batch_size)

    # Assume we start from epoch 0 unless we find a checkpoint
    start_epoch = 0

    # Load the latest checkpoint if it exists
    checkpoint = ray.train.get_checkpoint()
    if checkpoint:
        # Continue training from a previous checkpoint
        with checkpoint.as_directory() as ckpt_dir:
            model_state_dict = torch.load(
                os.path.join(ckpt_dir, "model.pt"),
            )
            # Load the model and optimizer state
            model.module.load_state_dict(model_state_dict)
            optimizer.load_state_dict(
                torch.load(os.path.join(ckpt_dir, "optimizer.pt"))
            )

            # Load the last epoch from the extra state
            start_epoch = (
                torch.load(os.path.join(ckpt_dir, "extra_state.pt"))["epoch"] + 1
            )

    # Same loop as before except it starts at a parameterized start_epoch
    for epoch in range(start_epoch, config["num_epochs"]):
        for images, labels in data_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        metrics = print_metrics_ray_train(loss,  epoch)
        # We now save the optimizer and epoch state in addition to the model
        save_checkpoint_and_metrics_ray_train_with_extra_state(
            model, metrics, optimizer, epoch
        )

We will also to update the checkpoint saving function to save the optimizer and epoch state:

In [None]:
def save_checkpoint_and_metrics_ray_train_with_extra_state(
    model: torch.nn.Module,
    metrics: dict[str, float],
    optimizer: torch.optim.Optimizer,
    epoch: int,
) -> None:
    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
        checkpoint = None
        if ray.train.get_context().get_world_rank() == 0:
            # === Make sure to save all state needed for resuming training ===
            torch.save(
                model.module.state_dict(),  # NOTE: Unwrap the model.
                os.path.join(temp_checkpoint_dir, "model.pt"),
            )
            torch.save(
                optimizer.state_dict(),
                os.path.join(temp_checkpoint_dir, "optimizer.pt"),
            )
            torch.save(
                {"epoch": epoch},
                os.path.join(temp_checkpoint_dir, "extra_state.pt"),
            )
            # ================================================================
            checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
    
        ray.train.report(  # use ray.train.report to save the metrics and checkpoint
            metrics,  # train.report will only save worker rank 0's metrics
            checkpoint=checkpoint,
            )    

### 2.2 Configuring Automatic Retries
Now that we have enabled checkpoint loading, we can configure a failure config which sets the maximum number of retries for a training job.

In [None]:
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)

experiment_name = "fault-tolerant-cifar-vit"
storage_path = "/mnt/cluster_storage/training/"
failure_config = ray.train.FailureConfig(max_failures=3) 
run_name = "distributed-mnist-resnet18-auto-retry"
run_config = RunConfig(
    storage_path=storage_path,
    name=run_name,
    failure_config=failure_config,
)

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_ray_train_with_checkpoint_loading,
    train_loop_config={"num_epochs": 1, "global_batch_size": 512},
    scaling_config=scaling_config,
    run_config=run_config,
)

Now we can proceed to run the training job as before. 

This time, if any worker fails, Ray Train will create a new attempt, restart the worker group and resume training from the last checkpoint up to the specified maximum number of failures.

In [None]:
trainer.fit()

## 3. Manual restoration


In case the retries are exhausted, we can perform a manual restoration by re-initializing the TorchTrainer with the same `run_config`

In [None]:
restored_trainer = TorchTrainer(
    train_loop_per_worker=train_loop_ray_train_with_checkpoint_loading,
    train_loop_config={"num_epochs": 1, "global_batch_size": 512},
    scaling_config=scaling_config,
    run_config=ray.train.RunConfig(
        name=run_name,
        storage_path=storage_path, 
    ),
)

Running the fit method will resume training from the last checkpoint.

Given we already have completed all epochs, we expect the training to terminate immediately.

In [None]:
result = restored_trainer.fit()
result