# Multi-Machine Training (MMT) SPMD DDP Example

This notebook demonstrates how to integrate the Lightning SDK's MMT API with Monarch's SPMD DDP training to enable distributed training across multiple nodes.

## Overview

- **MMT (Multi-Machine Training)**: Lightning SDK's API for distributed computing
- **SPMD (Single Program, Multiple Data)**: Parallel computing approach
- **DDP (Distributed Data Parallel)**: PyTorch's distributed training method

This example shows how to scale PyTorch training across multiple machines using these technologies together.

## 1. Import Dependencies

First, let's import all the necessary libraries for multi-node distributed training.

In [None]:
import asyncio
import os

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from lightning_sdk import Machine, MMT, Studio
from monarch.actor import Actor, current_rank, endpoint, proc_mesh
from torch.nn.parallel import DistributedDataParallel as DDP

## 2. Define the Model

Let's define a simple toy model that we'll use for distributed training. This model has configurable sizes to test different scales.

In [None]:
class ToyModel(nn.Module):
    def __init__(self, input_size=128, hidden_size=512, output_size=64):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(input_size, hidden_size)
        self.relu1 = nn.ReLU()
        self.net2 = nn.Linear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.net3 = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        x = self.relu1(self.net1(x))
        x = self.dropout(x)
        x = self.relu2(self.net2(x))
        x = self.dropout(x)
        return self.net3(x)

## 3. Multi-Node DDP Actor

This is the core component that handles distributed training across multiple nodes. The actor:
- Manages distributed process groups
- Coordinates between different nodes and GPUs
- Handles the actual training loop

In [None]:
class MultiNodeDDPActor(Actor):
    """
    Multi-node DDP Actor that can run across multiple machines using MMT.
    Adapted from the single-machine DDPActor to work in a multi-node environment.
    """

    def __init__(self, gpus_per_node=8):
        self.rank = current_rank().rank
        self.gpus_per_node = gpus_per_node

        # Get distributed environment variables set by MMT
        self.world_size = int(os.environ.get("WORLD_SIZE", gpus_per_node))
        self.node_rank = int(os.environ.get("NODE_RANK", 0))
        self.master_addr = os.environ.get("MASTER_ADDR", "localhost")
        self.master_port = os.environ.get("MASTER_PORT", "12355")

        # Calculate global rank
        self.global_rank = self.node_rank * self.gpus_per_node + self.rank

    def _rprint(self, msg):
        print(
            f"Node {self.node_rank}, Local Rank {self.rank}, Global Rank {self.global_rank}: {msg}"
        )

    @endpoint
    async def setup(self):
        self._rprint("Initializing torch distributed for multi-node training")
        self._rprint(
            f"World size: {self.world_size}, Master: {self.master_addr}:{self.master_port}"
        )

        # Initialize the process group for multi-node training
        dist.init_process_group(
            backend="nccl",  # Use NCCL for multi-GPU/multi-node
            rank=self.global_rank,
            world_size=self.world_size,
        )

        # Set the device for this process
        torch.cuda.set_device(self.rank)

        self._rprint("Finished initializing torch distributed")

    @endpoint
    async def cleanup(self):
        self._rprint("Cleaning up torch distributed")
        dist.destroy_process_group()

    @endpoint
    async def demo_basic(
        self,
        num_epochs=10,
        batch_size=64,
        input_size=128,
        hidden_size=512,
        output_size=64,
    ):
        self._rprint("Running multi-node DDP example")

        # Create model and move it to the appropriate GPU
        device = torch.device(f"cuda:{self.rank}")
        model = ToyModel(
            input_size=input_size, hidden_size=hidden_size, output_size=output_size
        ).to(device)

        # Wrap model with DDP
        ddp_model = DDP(model, device_ids=[self.rank])

        # Print model size information (only from rank 0)
        if self.global_rank == 0:
            total_params = sum(p.numel() for p in model.parameters())
            total_size_gb = total_params * 4 / (1024**3)
            self._rprint(
                f"Model has {total_params:,} parameters ({total_size_gb:.2f} GB)"
            )

        loss_fn = nn.MSELoss()
        optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

        # Training loop
        for epoch in range(num_epochs):
            optimizer.zero_grad()

            # Generate random input data
            inputs = torch.randn(batch_size, input_size).to(device)
            labels = torch.randn(batch_size, output_size).to(device)

            # Forward pass
            outputs = ddp_model(inputs)
            loss = loss_fn(outputs, labels)

            # Backward pass
            loss.backward()
            optimizer.step()

            # Print progress from all ranks every 5 epochs
            if epoch % 5 == 0:
                self._rprint(f"Epoch {epoch}, Loss: {loss.item():.4f}")

        self._rprint("Finished multi-node DDP training")

## 4. Training Orchestration

This function orchestrates the training on each node by:
1. Setting up the process mesh
2. Spawning the DDP actor
3. Running the training
4. Cleaning up resources

In [None]:
async def run_multi_node_training(gpus_per_node=8):
    """
    Run multi-node training using the local process mesh.
    This function is called on each node by the MMT job.
    """
    # Get the number of GPUs available on this node
    num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else gpus_per_node

    print(f"Starting training on node with {num_gpus} GPUs")

    # Create process mesh for this node
    local_proc_mesh = await proc_mesh(
        gpus=num_gpus,
        env={
            "MASTER_ADDR": os.environ.get("MASTER_ADDR", "localhost"),
            "MASTER_PORT": os.environ.get("MASTER_PORT", "12355"),
            "WORLD_SIZE": os.environ.get("WORLD_SIZE", str(num_gpus)),
            "NODE_RANK": os.environ.get("NODE_RANK", "0"),
        },
    )

    # Spawn our actor mesh on top of the process mesh
    ddp_actor = await local_proc_mesh.spawn("ddp_actor", MultiNodeDDPActor, num_gpus)

    # Setup torch Distributed
    await ddp_actor.setup.call()

    # Run the training with larger model for multi-node setup
    await ddp_actor.demo_basic.call(
        num_epochs=100,
        batch_size=256,
        input_size=512,
        hidden_size=1024 * 16,  # Large model that benefits from multi-node
        output_size=256,
    )

    # Cleanup
    await ddp_actor.cleanup.call()

## 5. Job Launching with Lightning SDK

These functions handle launching and monitoring the multi-machine training job using Lightning SDK's MMT API.

In [None]:
def launch_mmt_job(num_nodes=3, teamspace="my-teamspace", user="my-user"):
    """
    Launch a multi-machine training job using Lightning SDK's MMT API.
    """
    # Initialize a Studio
    studio = Studio(name="multi-node-ddp-training", teamspace=teamspace, user=user)
    studio.start()

    # Create the training script content that will be executed on each node
    training_script = """
        import asyncio
        import sys
        import os

        # Add the current directory to Python path if needed
        sys.path.append(os.getcwd())

        from mmt_spmd_ddp import run_multi_node_training

        if __name__ == "__main__":
            asyncio.run(run_multi_node_training(gpus_per_node=8))
    """

    # Write the training script to a temporary file
    with open("multi_node_train.py", "w") as f:
        f.write(training_script)

    print(f"Launching MMT job with {num_nodes} nodes...")

    # Run a Multi-machine job
    job = MMT.run(
        command="python multi_node_train.py",
        name="multi-node-ddp-training",
        machine=Machine.T4,  # Use GPU machines for training
        studio=studio,
        num_machines=num_nodes,
        env={
            "CUDA_VISIBLE_DEVICES": "0,1,2,3,4,5,6,7",  # Make all GPUs visible
        },
    )

    print(f"Job started with ID: {job.name}")
    print(f"Job status: {job.status}")

    # Monitor job status
    return job, studio

In [None]:
def monitor_job(job, studio):
    """
    Monitor the job status and provide updates.
    """
    import time

    print("Monitoring job status...")
    while job.status in ["Running", "Pending"]:
        print(f"Job status: {job.status}")
        time.sleep(30)  # Check every 30 seconds

    print(f"Final job status: {job.status}")

    # Clean up
    if job.status == "Completed":
        print("Training completed successfully!")
    else:
        print(f"Training finished with status: {job.status}")

    # Shut down the studio
    studio.stop()

## 6. Example Usage

Here's how to launch a multi-node training job. Make sure to update the configuration parameters below with your actual values.

In [None]:
# Configuration
NUM_NODES = 3
TEAMSPACE = "general"  # Replace with your teamspace
USER = "alisol"  # Replace with your username

# Launch the job
job, studio = launch_mmt_job(
    num_nodes=NUM_NODES, teamspace=TEAMSPACE, user=USER
)

print(f"Job launched. You can monitor it using: job.status")
print(f"To stop the job: job.stop()")
print(f"To clean up: studio.stop()")

## 7. Job Monitoring

You can use this cell to monitor your job progress:

In [None]:
# Check job status
# print(f"Current job status: {job.status}")

# Uncomment to monitor the job automatically
# monitor_job(job, studio)

## 8. Local Testing

For testing purposes, you can also run the training locally (single node) by executing the training function directly:

In [None]:
# Local testing - uncomment to run single-node training
# import os
#
# # Set environment variables for local testing
# os.environ["MASTER_ADDR"] = "localhost"
# os.environ["MASTER_PORT"] = "12355"
# os.environ["WORLD_SIZE"] = "1"
# os.environ["NODE_RANK"] = "0"
#
# # Run local training
# await run_multi_node_training(gpus_per_node=1)

## Key Concepts

### Multi-Machine Training (MMT)
- Distributed computing across multiple physical machines
- Managed by Lightning SDK for resource allocation and coordination

### SPMD (Single Program, Multiple Data)
- Same program runs on all nodes/processes
- Each process works on different data
- Coordination through message passing

### Distributed Data Parallel (DDP)
- PyTorch's method for distributed training
- Model replicated across devices
- Gradients synchronized across all replicas

### Actor Model
- Monarch's abstraction for distributed computation
- Encapsulates state and behavior
- Communicates through message passing

This notebook demonstrates how these technologies work together to enable scalable machine learning training across multiple machines.