# Distributed Training with Ray Actors

## Overview

This notebook demonstrates distributed training using Ray Actors and PyTorch Distributed. You'll build a worker group pattern step-by-step, learning each concept as you go.


<div class="alert alert-block alert-info">

<b> Here is the roadmap for this notebook </b>

<ol>
  <li>Architecture overview</li>
  <li>Part 1: Setup and imports</li>
  <li>Part 2: Building workers and orchestration</li>
  <li>Part 3: Putting it all together</li>
  <li>Part 4: Connection to Ray Train</li>
</ol>
</div>

**What We'll Build:** Multiple Ray Actors that coordinate to perform distributed collective operations (like broadcasting tensors), simulating a distributed training setup.

**Key Learning Goals:**
- Understand how Ray Actors enable stateful distributed computation
- Learn the worker group pattern for distributed training
- See how Ray integrates with PyTorch Distributed
- Master coordination patterns for multi-actor workflows


## Architecture Overview: What We're Building

Before we start coding, let's understand what we're building and why.


### The Challenge: Distributed Training

In distributed training, we need:
1. **Multiple workers** that can train in parallel
2. **Stateful processes** that maintain model parameters and optimizer state
3. **Communication** between workers to synchronize gradients
4. **Coordination** to ensure all workers stay in sync


### The Solution: Worker Group Pattern

We'll build a **worker group** by orchestrating Ray Actors that communicate using PyTorch Distributed to perform collective operations necessary for distributed training.

Here's a high-level architecture diagram:

<img src="https://anyscale-materials.s3.us-west-2.amazonaws.com/ray-core/ray-core-ray-summit-2025-distributed-training-with-actors.png" width="800">

### Communication Layers

We'll use **two communication systems** that work together:

**Layer 1: Ray RPC (Driver ↔ Workers)**
- Driver creates actors
- Driver calls methods on actors (setup, execute, cleanup)
- Used for orchestration and control

**Layer 2: PyTorch Distributed (Worker ↔ Worker)**
- Workers communicate directly with each other
- Used for collective operations (broadcast, all-reduce, etc.)
- High-performance communication (NCCL for GPU, Gloo for CPU)


### Execution Flow

Here's what happens when we run distributed training:

1. Setup Phase
   1. Driver creates N worker actors
   1. Each worker gets unique rank (0 to N-1)
   1. Driver gets master address from rank 0
   1. All workers initialize PyTorch distributed

2. Training Phase
   1. Workers perform collective operations
   1. Example: Broadcast model parameters
   1. Example: All-reduce gradients
   1. All workers stay synchronized

3. Cleanup Phase
   1. Workers destroy process groups
   1. Actors are terminated

### Why This Pattern Matters

This pattern is **fundamental to Ray Train** and distributed training in general:
- **Scalable**: Works from 2 workers to hundreds
- **Flexible**: Supports CPU and GPU training
- **Efficient**: High-performance communication with NCCL/Gloo
- **Production-ready**: Used by Ray Train for real workloads

Now let's start building!


## Part 1: Setup and Imports

Let's start by importing all necessary libraries and initializing Ray:

In [None]:
import os
import socket
from collections import defaultdict
from typing import List

import ray
import torch
import torch.distributed as dist
from datetime import timedelta

# Import utility functions for the worker actor
from scripts.utils import (
    setup_torch_process_group_impl,
    cleanup_impl,
)

# Initialize Ray
ray.init(ignore_reinit_error=True)

## Part 2: Building Workers and Orchestration Together

Let's get started by building the actor and orchestration code to create our distributed training workers. We'll build this step-by-step, adding one capability at a time. For each step, you'll see both the actor code (what runs on each worker) and the orchestration code (how we coordinate all workers from the driver).


### Step 2.1: Actor Skeleton with Core Methods

First, let's look at the overall structure of our worker actor. Don't worry about understanding every detail yet - we'll implement each method step by step. This skeleton shows you the four core methods our actor will have:

In [None]:
@ray.remote
class DistributedWorker:
    """A minimal Ray actor for distributed operations.
    
    This actor manages the worker lifecycle but delegates actual work
    to functions passed via the execute() method.
    """

    def __init__(self, rank: int, world_size: int):
        """Initialize the distributed worker with rank and world_size."""
        self.rank = rank
        self.world_size = world_size
    
    def setup_torch_process_group(
        self,
        backend: str,
        master_addr: str,
        master_port: int,
        timeout_s: int = 1800,
    ):
        """Initialize torch distributed process group on this worker."""
        return setup_torch_process_group_impl(
            self.rank, self.world_size, backend, master_addr, master_port, timeout_s
        )
    
    def execute(self, func, *args, **kwargs):
        """Execute a function on this worker that leverages torch.distributed."""
        return func(*args, **kwargs)
    
    def cleanup(self):
        """Clean up the distributed process group."""
        return cleanup_impl(self.rank)

**Why This Design?** Notice how the actor is very simple - it just manages basic information (rank and world_size) and provides an `execute()` method to run functions. This keeps the actor small and flexible. All the actual work will be done by functions we pass to `execute()`.

Now let's write the code to create multiple workers:

In [None]:
def create_worker_group(num_workers: int, resources_per_worker: dict):
    return [
        DistributedWorker.options(**resources_per_worker).remote(
            rank=rank,
            world_size=num_workers,
        )
        for rank in range(num_workers)
    ]

Let's test creating some workers:

In [None]:
# Create 2 test workers, each using 1 GPU
test_workers = create_worker_group(
    num_workers=2,
    resources_per_worker={"num_gpus": 1},
)

**What just happened?** The `@ray.remote` decorator turns our regular Python class into a distributed actor that can run on any machine in our cluster. When we call `create_worker_group()`, Ray creates multiple workers in parallel, giving each one a unique rank (0, 1, 2, etc.).


### Step 2.2: GPU Setup and Coordination

**The Problem:** When using NCCL (NVIDIA's collective communication library) for GPU training, each worker needs visibility to **all GPUs on its node**, not just its own GPU. By default, Ray isolates each actor to see only its assigned GPU.

**Why This Matters:**
- NCCL uses peer-to-peer GPU communication for efficiency
- Workers on the same node can use fast NVLink/PCIe instead of going through the network
- Without visibility to other GPUs, NCCL falls back to slower communication paths

**The Solution:** We gather GPU information from all workers, group them by node, and set `CUDA_VISIBLE_DEVICES` to include all GPUs on each node.

Now let's write the orchestration code that solves this problem. This function will:

In [None]:
def share_cuda_visible_devices(workers: List):
    """Share CUDA_VISIBLE_DEVICES across workers on the same node."""

    # Step 1: Collect metadata from all workers using execute()
    metadata_list = ray.get([
        worker.execute.remote(get_worker_metadata)
        for worker in workers
    ])
    
    # Step 2: Group workers by node
    node_to_workers = defaultdict(list)
    for worker_idx, (node_id, gpu_ids) in enumerate(metadata_list):
        node_to_workers[node_id].append(worker_idx)

    node_to_gpu_ids = defaultdict(set)
    for worker_idx, (node_id, gpu_ids) in enumerate(metadata_list):
        for gpu_id in gpu_ids:
            node_to_gpu_ids[node_id].add(str(gpu_id))
    
    # Step 3: Set CUDA_VISIBLE_DEVICES on each worker using execute()
    set_refs = []
    for node_id, worker_indices in node_to_workers.items():
        gpu_ids_str = ",".join(sorted(node_to_gpu_ids[node_id]))
        
        for worker_idx in worker_indices:
            set_ref = workers[worker_idx].execute.remote(
                set_worker_cuda_devices,
                rank=worker_idx,
                gpu_ids_str=gpu_ids_str
            )
            set_refs.append(set_ref)
    
    # Wait for all workers to complete setting CUDA_VISIBLE_DEVICES
    ray.get(set_refs)

The orchestration function above uses two helper functions. Here they are - notice these are regular Python functions, not part of the actor:

In [None]:
def get_worker_metadata():
    """Get metadata about this worker (node_id and GPU IDs)."""
    node_id = ray.get_runtime_context().get_node_id()
    gpu_ids = ray.get_gpu_ids()
    return node_id, gpu_ids

def set_worker_cuda_devices(rank: int, gpu_ids_str: str):
    """Set CUDA_VISIBLE_DEVICES for this worker."""
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids_str
    print(f"[Rank {rank}] Set CUDA_VISIBLE_DEVICES={gpu_ids_str}")
    return True

Let's explore what these helper functions return. Start by getting a reference to worker 0:

In [None]:
test_worker = test_workers[0]
test_worker

This is an actor handle - a reference to the remote worker.

In [None]:
# Call execute.remote() to run the function on the worker
metadata_ref = test_worker.execute.remote(get_worker_metadata)
metadata_ref

This is an ObjectRef - a future/promise that will contain the result.

In [None]:
# Get the actual result
ray.get(metadata_ref)

Returns: (node_id, [gpu_id]). By default, each worker only sees its assigned GPU!

Now check worker 1:

In [None]:
ray.get(test_workers[1].execute.remote(get_worker_metadata))

Same node_id? Then both workers are on the same machine.

Gather from all workers at once:

In [None]:
metadata_refs = [
    worker.execute.remote(get_worker_metadata)
    for worker in test_workers
]
metadata_refs

A list of ObjectRefs - one per worker.

In [None]:
metadata_list = ray.get(metadata_refs)
metadata_list

This is what the orchestration function processes!

Now run the full GPU sharing orchestration:

In [None]:
share_cuda_visible_devices(test_workers)

Watch the workers print their expanded GPU visibility.


### Step 2.3: PyTorch Distributed Initialization

**What We're Doing:** Now we need to create a communication channel between workers so they can perform collective operations (broadcast, all-reduce, etc.). PyTorch Distributed provides this through "process groups."

**The Challenge:** PyTorch Distributed workers need to find each other on the network. This requires:
1. One worker (rank 0) to act as the "rendezvous point" and share its address
2. All workers to connect to this address and form a process group
3. Each worker to know its unique rank and the total world size

Now let's implement the function that `setup_torch_process_group()` delegates to:

In [None]:
def setup_torch_process_group_impl(
    rank: int,
    world_size: int,
    backend: str,
    master_addr: str,
    master_port: int,
    timeout_s: int = 1800,
):
    """Implementation of PyTorch process group initialization."""
    # Set environment variables for torch distributed
    os.environ["MASTER_ADDR"] = master_addr
    os.environ["MASTER_PORT"] = str(master_port)
    os.environ["RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    
    # For NCCL backend, set async error handling
    if backend == "nccl":
        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
        
        # Set CUDA device for this worker
        if torch.cuda.is_available():
            gpu_ids = ray.get_gpu_ids()
            if gpu_ids:
                torch.cuda.set_device(gpu_ids[0])
    
    # Initialize the process group
    dist.init_process_group(
        backend=backend,
        init_method="env://",
        rank=rank,
        world_size=world_size,
        timeout=timedelta(seconds=timeout_s),
    )
    
    print(f"[Rank {rank}] Process group initialized successfully!")
    return True

We also need a helper function to get the network address for the rendezvous point:

In [None]:
def get_address_and_port():
    """Get the IP address and an available port."""
    ip_address = ray.util.get_node_ip_address()
    
    # Find an available port
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("", 0))
        s.listen(1)
        port = s.getsockname()[1]
    
    return ip_address, port

Inspect what rank 0 will use as the rendezvous point:

In [None]:
rank_0 = test_workers[0]
rank_0

In [None]:
address_ref = rank_0.execute.remote(get_address_and_port)
address_ref

In [None]:
ray.get(address_ref)

Returns: (IP_address, port_number) — all workers will connect here.

Now here's the orchestration function:

In [None]:
def setup_torch_distributed(workers: List, backend: str = "gloo", timeout_s: int = 1800):
    """Set up torch distributed process group across all workers."""
    # Step 1: Get master address and port from rank 0 worker using execute()
    master_addr, master_port = ray.get(workers[0].execute.remote(get_address_and_port))
    
    # Step 2: Initialize process group on all workers in parallel
    setup_refs = [
         worker.setup_torch_process_group.remote(
            backend=backend,
            master_addr=master_addr,
            master_port=master_port,
            timeout_s=timeout_s,
        )
        for worker in workers
    ]
    
    # Step 3: Wait for all to complete (synchronization barrier)
    ray.get(setup_refs)

PyTorch Distributed uses environment variables (`MASTER_ADDR`, `MASTER_PORT`, `RANK`, `WORLD_SIZE`) to coordinate.
- **"gloo"** for CPU
- **"nccl"** for GPU (faster!)

Run the initialization:

In [None]:
setup_torch_distributed(test_workers, backend="nccl")

Each worker prints when it joins the process group. Now they can communicate!

Let's verify the process group is initialized on one worker:

In [None]:
# Check if distributed is initialized on worker 0
check_ref = test_workers[0].execute.remote(lambda: dist.is_initialized())
check_ref

In [None]:
ray.get(check_ref)

Should return `True` - the process group is ready!


### Step 2.4: Collective Operations

Now workers can perform **collective operations** - all workers participate simultaneously.

**Broadcast:** One worker (source) sends data to all others at once. Used in distributed training to share model weights!

**Orchestration Side:** Trigger torch.distributed.broadcast using execute():

In [None]:
def run_distributed_broadcast(workers: List, device: str, src_rank: int = 0):
    """Run a broadcast operation across all workers."""
    world_size = len(workers)
    
    # Execute broadcast on all workers in parallel using execute()
    broadcast_refs = [
        worker.execute.remote(
            broadcast_tensor,
            rank=rank,
            world_size=world_size,
            device=device,
            src_rank=src_rank
        )
        for rank, worker in enumerate(workers)
    ]
    
    # Wait for completion
    tensors = ray.get(broadcast_refs)
    
    # Display results
    for rank, tensor in enumerate(tensors):
        print(f"  Rank {rank}: {tensor.tolist()}")
    
    # Verify all tensors match
    all_same = all(torch.equal(tensors[0], t) for t in tensors)
    assert all_same, "Broadcast failed: tensors do not match!"

**Helper Function:** Define broadcast as a standalone function:

In [None]:
def broadcast_tensor(rank: int, world_size: int, device: str, src_rank: int = 0):
    """Participate in a broadcast operation from src_rank to all workers."""
    if not dist.is_initialized():
        raise RuntimeError("Process group not initialized!")
    
    # Create tensor
    if rank == src_rank:
        tensor = torch.tensor([100.0, 200.0, 300.0, 400.0, 500.0], device=device)
        print(f"[Rank {rank}] Broadcasting tensor: {tensor.tolist()}")
    else:
        tensor = torch.zeros(5, device=device)
        print(f"[Rank {rank}] Before broadcast: {tensor.tolist()}")
    
    # Perform the broadcast
    dist.broadcast(tensor, src=src_rank)
    
    print(f"[Rank {rank}] After broadcast: {tensor.tolist()}")
    return tensor.cpu()

Run the broadcast operation:

In [None]:
run_distributed_broadcast(test_workers, device="cuda", src_rank=0)

Watch the output:
- Rank 0: creates `[100, 200, 300, 400, 500]`
- Rank 1: starts with zeros
- After broadcast: both have the same values!

This is how model weights get shared in distributed training.


### Step 2.5: Cleanup

Now let's implement the cleanup function that our actor's `cleanup()` method delegates to:

In [None]:
def cleanup_impl(rank: int) -> bool:
    """Implementation of process group cleanup."""
    if dist.is_initialized():
        print(f"[Rank {rank}] Destroying process group")
        dist.destroy_process_group()
    return True

Now let's define the cleanup orchestration function:

In [None]:
def cleanup_workers(workers: List):
    """Clean up worker actors."""
    cleanup_refs = [worker.cleanup.remote() for worker in workers]
    ray.get(cleanup_refs)

Clean up the test workers:

In [None]:
cleanup_workers(test_workers)

Each worker destroys its PyTorch process group. Always clean up!

Great! Now we have the complete actor and all orchestration functions, built side-by-side.

**Why This Design?**
- **Minimal Actor**: Only 4 methods manage lifecycle and state
- **Flexible**: Any function can be executed via `execute()` without modifying the actor
- **Testable**: Helper functions can be tested independently
- **Reusable**: Functions like `broadcast_tensor` work across different projects


## Part 3: Complete Example

Here's the complete workflow to run distributed training:

In [None]:
# Configuration
num_workers = 4
use_gpu = False  # Set to True if you have GPUs

# Determine backend, resources, and device
if use_gpu and torch.cuda.is_available():
    backend = "nccl"
    resources_per_worker = {"num_gpus": 1}
    device = "cuda"
else:
    backend = "gloo"
    resources_per_worker = {"num_cpus": 1}
    device = "cpu"

# Create workers
workers = create_worker_group(
    num_workers=num_workers,
    resources_per_worker=resources_per_worker,
)

# Setup GPU visibility (if using GPUs)
if use_gpu and torch.cuda.is_available():
    share_cuda_visible_devices(workers)

# Initialize PyTorch Distributed
setup_torch_distributed(workers, backend=backend)

# Run distributed operation
run_distributed_broadcast(workers, device=device, src_rank=0)

# Cleanup
cleanup_workers(workers)

## Part 4: Hands-On Exercise

Now it's your turn! Add an all-reduce operation as a standalone function.

**Task:** Implement `allreduce_tensor()` function where each worker contributes its rank value, and all workers receive the sum.

In [None]:
# TODO: Implement this function
# Hint: Use dist.all_reduce(tensor, op=dist.ReduceOp.SUM)

def allreduce_tensor(rank: int, world_size: int, device: str):
    """Perform an all-reduce (sum) operation."""
    # Create tensor with this worker's rank
    tensor = torch.tensor([float(rank)], device=device)
    print(f"[Rank {rank}] Before all-reduce: {tensor.item()}")
    
    # TODO: Perform all-reduce with SUM operation
    # Your code here
    
    print(f"[Rank {rank}] After all-reduce: {tensor.item()}")
    return tensor.cpu()

In [None]:
# Write your solution here

<details>

<summary>Click to see the solution</summary>

```python
def allreduce_tensor(rank: int, world_size: int, device: str):
    """Perform an all-reduce (sum) operation."""
    tensor = torch.tensor([float(rank)], device=device)
    print(f"[Rank {rank}] Before all-reduce: {tensor.item()}")
    
    # All-reduce sums values from all workers
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
    
    print(f"[Rank {rank}] After all-reduce: {tensor.item()}")
    return tensor.cpu()

Test it using execute():
world_size = len(workers)
allreduce_refs = [
    worker.execute.remote(allreduce_tensor, rank, world_size, device)
    for rank, worker in enumerate(workers)
]
tensors = ray.get(allreduce_refs)
# For 4 workers with ranks 0,1,2,3: sum = 0+1+2+3 = 6
# All workers should receive 6
```


</details>

## Part 5: Connection to Ray Train

This worker group pattern is the foundation of **Ray Train**!

Ray Train provides a high-level API that handles all this complexity:

In [None]:
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig

# Ray Train abstracts away all the worker management
# trainer = TorchTrainer(
#     train_loop_per_worker=your_training_function,
#     scaling_config=ScalingConfig(
#         num_workers=4,
#         use_gpu=True
#     )
# )
# result = trainer.fit()

Under the hood, Ray Train:
- Creates worker actors (like we did)
- Sets up PyTorch distributed (like we did)
- Handles checkpointing, fault tolerance, and more
- Integrates with popular ML frameworks
