# Hero Notebook: TorchTitan Multi-Node Training with Monarch & Lightning SDK

This notebook demonstrates how to run TorchTitan training using Monarch for distributed multi-node training on Lightning AI infrastructure.

<div align="center">
  <img src="./assets/NB_Monarch_Lightning.svg" alt="Monarch Lightning Architecture" width="800"/>
</div>

<!-- Image size settings:
  - Adjust 'width' attribute to control the diagram size (e.g., width="600", width="1000", or width="100%")
  - You can also use 'height' attribute instead (e.g., height="400")
  - Remove width/height attributes to display at original size
-->

## Table of Contents

This notebook provides a comprehensive guide to running distributed multi-node training using **Monarch** (Meta's distributed actor framework) with **TorchTitan** (PyTorch's large-scale LLM training library) on **Lightning AI** infrastructure. You'll learn how to set up, execute, debug, and manage distributed training workflows across multiple GPU nodes. 

While Part I & II are the core of this Notebook for setup and training; Part III is for users who are interested in Monarch's advanced features such as interactive distributed debugging, environment variable management, and code synchronization for workspaces between local node and remote nodes.

### What You'll Learn

**Part I: Environment Setup** *(Essential Prerequisites)*
- Install TorchTitan - Set up PyTorch and TorchTitan for LLM training
- Download Llama-3.1-8B Model Assets - Get model tokenizers from Hugging Face
- Install Monarch - Install Meta's distributed actor framework
- Setup Weights & Biases - Configure experiment tracking
- Update Lightning SDK - Get the latest Lightning SDK features
- Verify Installations - Confirm all dependencies are ready

**Part II: Multi-Node Training** *(Core Training Workflow)*
- Import Lightning SDK Components - Import required classes for multi-machine training
- Configure Training Job Parameters - Set up nodes, GPUs, and network settings
- Launch Multi-Node Training Job - Start distributed infrastructure on Lightning AI
- Set Up Process Mesh - Initialize Monarch's distributed computing mesh
- Define TorchTitan Trainer Actor - Create distributed training actor
- Run TorchTitan Training - Execute Llama 3-8B training across nodes

**Part III: Advanced Features** *(Distributed Development & Debugging)*

1. **Environment Variable Management**
   - Spawn Environment Variable Actor - Manage env vars across nodes
   - Get/Set Environment Variables - Inspect and modify remote environments
   - List Environment Variables - Query env vars by prefix

2. **Workspace Synchronization** *(Hot-Reload Code & Configs)*
   - Introduction to sync_workspace - Understanding workspace sync
   - Content checker Actor for files - Define an Actor to check content
   - Create Local Configuration - Set up training configs
   - Sync to Remote Nodes - Propagate changes to workers
   - Verify Synchronization - Confirm files are synced

3. **Interactive Debugging with Breakpoints**
   - Debugging Overview - Using pdb with distributed actors
   - Define Debug Trainer - Create actor with breakpoints
   - Spawn and Debug - Run interactive debugging session
   - Debugger Commands - Learn monarch debug CLI commands

**Part IV: Cleanup**
- Stop Process Mesh - Gracefully shutdown distributed resources

---

### Key Concepts

- **Monarch Actor**: Distributed computation unit that runs on remote nodes
- **Process Mesh (ProcMesh)**: Network of processes across multiple nodes for distributed computing
- **Endpoint**: Method decorator that makes actor methods callable remotely
- **Workspace Sync**: Synchronize local code/config changes to remote worker nodes without restart
- **Lightning MMT**: Multi-Machine Training orchestration on Lightning AI

### Prerequisites
- Lightning AI account with access to GPU machines (L40S recommended)
- Hugging Face account with Llama model access
- Basic understanding of distributed training concepts

---

# Part I: Environment Setup

Before running the notebook cells, ensure all dependencies are properly installed by following the steps below.

## Install TorchTitan

Clone the TorchTitan repository and install from PyPI:

```bash
# Clone the repository (needed for config files, scripts, and test assets)
git clone https://github.com/pytorch/torchtitan.git
cd torchtitan

# Install TorchTitan from PyPI
pip install torchtitan
```

Note: TorchTitan requires PyTorch. If you need a specific CUDA version, install PyTorch first:

```bash
pip install torch --index-url https://download.pytorch.org/whl/cu126
pip install torchtitan
```

## Download Llama-3-8B Model Assets

Download the Llama-3.1-8B tokenizer from Hugging Face. You'll need a Hugging Face token with access to the Llama models:

```bash
python scripts/download_hf_assets.py \
    --repo_id meta-llama/Llama-3.1-8B \
    --assets tokenizer \
    --hf_token=YOUR_HUGGINGFACE_TOKEN_KEY
```

Replace `YOUR_HUGGINGFACE_TOKEN_KEY` with your actual Hugging Face token.

## Install Monarch

Install Monarch (torchmonarch) version 0.2.0 from PyPI:

```bash
pip install torchmonarch==0.3.0
```

For more information, visit: https://github.com/meta-pytorch/monarch

## Setup Weights & Biases

Check if wandb is installed. If not, install it and login:

```bash
pip install wandb
wandb login
```

Follow the prompts to authenticate with your wandb account.

## Update the Lightning SDK

The latest version of lightning SDK offers IP sharing between the client host and remote nodes. This features is being used in this Notebook.

```bash
pip install -U lightning_sdk
```

## Verify Installations

After completing the installation steps above, verify that TorchTitan and Monarch are properly installed:

```python
# Verify TorchTitan installation
import torchtitan
print("TorchTitan is installed successfully")

# Verify Monarch installation
import monarch
print("Monarch is installed successfully")

# Verify PyTorch and CUDA
import torch
print(f"PyTorch version: {torch.__version__}")
```

If all imports succeed, you're ready to proceed with the training workflow below.

---

# Part II: Multi-Node Training with Monarch and Lightning

Now that the environment is set up, we can proceed with configuring and launching the distributed training job.

## Configure Environment and Imports

Set up environment variables and import necessary components for distributed training.

In [None]:
import os
# Need to set before importing monarch
os.environ["MONARCH_FILE_LOG"] = "debug"
os.environ["HYPERACTOR_MESH_ENABLE_LOG_FORWARDING"] = "true"
os.environ["HYPERACTOR_MESH_ENABLE_FILE_CAPTURE"] = "true"
os.environ["HYPERACTOR_MESH_TAIL_LOG_LINES"] = "100"

import socket
import subprocess
import sys
import time

from utils import get_host_ip_addr, bootstrap_addr
from monarch.actor import Actor, enable_transport, endpoint
from monarch._src.actor.bootstrap import attach_to_workers


class Hello(Actor):
    @endpoint
    def hello(self) -> str:
        print("HELLO!")
        return "echo"


# Configuration
NUM_NODES = 2
NUM_GPUS = 8
port = 26600

# Enable client transport
host_ip_addr = get_host_ip_addr(addr_type="public")
enable_transport(f"tcp://{host_ip_addr}:{port}@tcp://0.0.0.0:{port}")
print(f"Client transport enabled at {host_ip_addr}:{port}")

## Launch Multi-Node Training Job

Launch the MMT job using Lightning SDK. This starts worker processes on remote nodes that run the Monarch bootstrap.

In [None]:
from mmt_utils import launch_mmt_job

MMT_JOB_NAME = f"Monarch-v0.2.0-MMT-{NUM_NODES}-nodes"

job, studio = launch_mmt_job(
    num_nodes=NUM_NODES,
    mmt_job_name=MMT_JOB_NAME,
    port=port,
    num_gpus=NUM_GPUS,
)

print(f"Job launched. You can monitor it using: job.status")
print(f"To stop the job: job.stop()")
print(f"To clean up: studio.stop()")

In [None]:
# Check job status
job.status

## Set Up Process Mesh from Workers

Get worker IP addresses from the job and create a process mesh by attaching to the workers.

In [None]:
# Get worker IP addresses from the job
ip_addresses_list_public = [machine.public_ip for machine in job.machines]
print(f"Worker IPs: {ip_addresses_list_public}")

# Create worker addresses
worker_addrs = [f"tcp://{ip}:{port}@tcp://0.0.0.0:{port}" for ip in ip_addresses_list_public]
print(f"Worker addresses: {worker_addrs}")

In [None]:
# Attach to workers and create process mesh
host_mesh = attach_to_workers(
    name="host_mesh", ca="trust_all_connections", workers=worker_addrs
)

proc_mesh = host_mesh.spawn_procs(per_host={"gpus": NUM_GPUS})
await proc_mesh.logging_option(stream_to_client=True, aggregate_window_sec=3)

print(f"Process mesh created with {NUM_NODES} nodes and {NUM_GPUS} GPUs per node")

## Quick Test: Hello World Actor

Test the process mesh with a simple Hello actor to verify connectivity.

In [None]:
actor_mesh = proc_mesh.spawn("hello", Hello)
actor_mesh.hello.call().get()
time.sleep(5)
actor_mesh.hello.call().get()
print("Hello actor test completed successfully!")

# Example Hero - Run TorchTitan using Monarch for Llama 3 - 8B

## Generate Job Name Helper

Define a utility function to generate a unique job name based on the username, number of hosts, and GPUs per host.

In [None]:
import getpass
def get_job_name(num_hosts: int, num_gpus_per_host: int):
    return f"monarch-{getpass.getuser()}-hosts{num_hosts}-gpus{num_gpus_per_host}"
print(get_job_name(num_hosts=NUM_NODES, num_gpus_per_host=NUM_GPUS))

## Define TorchTitan Trainer Actor

Create the `TitanTrainerWrapper` class, a Monarch Actor that wraps TorchTitan's training functionality.

In [None]:
import os
import sys
import logging
from monarch.actor import ProcMesh, Actor, endpoint, current_rank
import socket
from torchtitan.tools.logging import init_logger, logger
from torchtitan.train import Trainer
from typing import Optional
import torch
from torchtitan.config import JobConfig


class TitanTrainerWrapper(Actor):
    def __init__(self, job_config: JobConfig):
        self.rank = current_rank().rank
        self.job_config = job_config

    def _rprint(self, msg):
        """Helper method to print with rank information."""
        print(f"{self.rank=} {msg}")

    @endpoint
    def init(self):
        logging.getLogger().addHandler(logging.StreamHandler(sys.stderr))
        print(f"Initializing actor: {self.rank} {current_rank()=} {socket.gethostname()=}")


    @endpoint
    def train(self):
        logger.info("Starting training")
        config = self.job_config
        trainer: Optional[Trainer] = None

        try:
            trainer = Trainer(config)
            trainer.train()

            if config.checkpoint.create_seed_checkpoint:
                assert (
                    int(os.environ["WORLD_SIZE"]) == 1
                ), "Must create seed checkpoint using a single device, to disable sharding."
                assert (
                    config.checkpoint.enable
                ), "Must enable checkpointing when creating a seed checkpoint."
                trainer.checkpointer.save(curr_step=0, )
                logger.info("Created seed checkpoint")
            else:
                trainer.train()
        finally:
            if trainer:
                trainer.close()

            if torch.distributed.is_initialized():
                torch.distributed.destroy_process_group()
                logger.info("Process group destroyed.")
        print("Done training")

## Define Async Main Training Function

Set up the main asynchronous function that orchestrates the distributed training.

In [None]:
from torchtitan.config import ConfigManager, JobConfig
from monarch.tools.network import AddrType
from monarch.utils import setup_env_for_distributed

async def async_main(job_config: JobConfig):
    torch.use_deterministic_algorithms(True)
    job_name = get_job_name(NUM_NODES, NUM_GPUS)

    """
    # if use_ipaddr is not passed, then default is IPv6 for MASTER_ADDR
    """
    await setup_env_for_distributed(proc_mesh, use_ipaddr=AddrType.IPv4)

    await proc_mesh.logging_option(stream_to_client=True, aggregate_window_sec=3)

    print(job_config)
    print(f"Spawning meshes on {job_name}")

    trainer_actor = proc_mesh.spawn("trainer_actor", TitanTrainerWrapper, job_config)

    await trainer_actor.init.call()
    await trainer_actor.train.call()

## Initialize Logger and Run Training

Configure the TorchTitan logger and run the training.

In [None]:
init_logger()
config_manager = ConfigManager()

job_name = get_job_name(NUM_NODES, NUM_GPUS)

manual_args = [
        "--job.config_file",
        os.path.expanduser("/teamspace/studios/this_studio/torchtitan/torchtitan/models/llama3/train_configs/llama3_8b.toml"),
        "--model.tokenizer-path",
        "/teamspace/studios/this_studio/torchtitan/assets/hf/Llama-3.1-8B",
        "--training.steps",
        "25",
        "--training.dataset_path",
        "/teamspace/studios/this_studio/torchtitan/tests/assets/c4_test",
        "--job.dump_folder",
        "/teamspace/studios/this_studio/torchtitan/outputs/" + job_name,
        "--training.seq_len",
        "1024",
    ]
config = config_manager.parse_args(manual_args)
await async_main(config)

**Congratulations! You just ran the interactive distributed training for Llama-3 model in a Notebook using Monarch actors and Lightning setup!**

This already gives the user lots of flexibilities such as changing the configurations and launching another training without iniatiating another job or set of nodes; or experiencing the logging aggregation using Monarch.

However, a curious user can dig more into advanced features of Monarch in Part III. Monarch offers features such as interactive distributed debugging while your training is running on mutliple nodes and ranks. Another feature is the `workspace_sync` where users can update packages, environments and files and sync them with remote nodes. Without Monarch, users may need to re-initiate their launches which usually takes lots of times. 


--- 

# Part III: Advanced Features (Distributed Development & Debugging)

## Environment Variable Management with Remote Actors

Spawn an actor that can interact with environment variables on remote nodes.

In [None]:
from monarch.actor import Actor, endpoint, current_rank
import os
import socket

class EnvVarActor(Actor):
    """Actor for managing environment variables on remote nodes."""

    def __init__(self):
        self.rank = current_rank().rank
        self.hostname = socket.gethostname()

    @endpoint
    def get_env(self, var_name: str) -> dict:
        """Get an environment variable value from the remote node."""
        value = os.environ.get(var_name)
        return {
            "rank": self.rank,
            "hostname": self.hostname,
            "var_name": var_name,
            "value": value
        }

    @endpoint
    def set_env(self, var_name: str, var_value: str) -> dict:
        """Set an environment variable on the remote node."""
        os.environ[var_name] = var_value
        return {
            "rank": self.rank,
            "hostname": self.hostname,
            "var_name": var_name,
            "value": var_value,
            "status": "set"
        }

    @endpoint
    def list_env_vars(self, prefix: str = "") -> dict:
        """List all environment variables matching a prefix."""
        matching_vars = {k: v for k, v in os.environ.items() if k.startswith(prefix)}
        return {
            "rank": self.rank,
            "hostname": self.hostname,
            "matching_vars": matching_vars,
            "count": len(matching_vars)
        }

In [None]:
# Spawn the environment variable actor across all nodes
env_actor = proc_mesh.spawn("env_actor", EnvVarActor)
print("EnvVarActor spawned across all nodes")

In [None]:
# Get an environment variable from all nodes
results = await env_actor.get_env.call("CUDA_VISIBLE_DEVICES")
print("\nCUDA_VISIBLE_DEVICES on all nodes:")
for result in results:
    if len(result) > 1:
        print(f"  Host {result[0].get('hosts', '?')} gpus {result[0].get('gpus', '?')}  Rank {result[1].get('rank', '?')} ({result[1].get('hostname', '?')}): {result[1].get('value', '?')}")
    else:
        print(f"  Rank {result.get('rank', '?')} ({result.get('hostname', '?')}): {result.get('value', '?')}")

In [None]:
# List all environment variables starting with "CUDA"
list_results = await env_actor.list_env_vars.call("CUDA")
print("\nCUDA-related environment variables on all nodes:")
for result in list_results:
    if len(result) > 1:
        print(f"\n  Rank {result[1]['rank']} ({result[1]['hostname']}) - {result[1]['count']} variables:")
        for var_name, var_value in result[1]['matching_vars'].items():
            print(f"    {var_name}={var_value}")

---

## Workspace Synchronization with `sync_workspace`

Sync local files to remote worker nodes without restarting the job.

In [None]:
class FileCheckerActor(Actor):
    """Actor to read and verify file contents on remote nodes."""

    def __init__(self):
        self.rank = current_rank().rank
        self.hostname = socket.gethostname()

    @endpoint
    def read_file(self, file_path: str) -> dict:
        """Read a file and return its contents."""
        try:
            with open(file_path, 'r') as f:
                content = f.read()
            return {
                "rank": self.rank,
                "hostname": self.hostname,
                "file_path": file_path,
                "content": content,
                "exists": True,
                "size": len(content)
            }
        except FileNotFoundError:
            return {
                "rank": self.rank,
                "hostname": self.hostname,
                "file_path": file_path,
                "exists": False,
                "error": "File not found"
            }
        except Exception as e:
            return {
                "rank": self.rank,
                "hostname": self.hostname,
                "file_path": file_path,
                "exists": False,
                "error": str(e)
            }

    @endpoint
    def file_exists(self, file_path: str) -> dict:
        """Check if a file exists on the remote node."""
        exists = os.path.exists(file_path)
        return {
            "rank": self.rank,
            "hostname": self.hostname,
            "file_path": file_path,
            "exists": exists
        }

In [None]:
# Spawn the file checker actor
file_checker = proc_mesh.spawn("file_checker", FileCheckerActor)
print("FileCheckerActor spawned across all nodes")

In [None]:
# Create a local workspace directory for our custom config
local_workspace = "/teamspace/studios/this_studio/monarch_sync_example"
os.makedirs(local_workspace, exist_ok=True)

# Create a custom training configuration file
config_file_name = "custom_training_config.toml"
local_config_path = os.path.join(local_workspace, config_file_name)

# Write initial configuration
with open(local_config_path, 'w') as f:
    f.write("""# TorchTitan Custom Training Configuration
# This file demonstrates workspace synchronization

[training]
batch_size = 32
learning_rate = 0.001
max_steps = 100
warmup_steps = 10

[model]
model_type = "llama3_8b"
seq_len = 1024

[optimizer]
optimizer_type = "AdamW"
weight_decay = 0.01
""")

print(f"Created local config file: {local_config_path}")
with open(local_config_path, 'r') as f:
    print(f"\nInitial configuration:\n{f.read()}")

In [None]:
from monarch.tools.config.workspace import Workspace
from pathlib import Path

# Create a Workspace object pointing to our local directory
workspace = Workspace(dirs=[Path(local_workspace)])

print(f"Workspace configured: {workspace.dirs}")
print(f"\nSyncing workspace to remote nodes...")
# Perform initial sync
await proc_mesh.sync_workspace(workspace=workspace, conda=False, auto_reload=False)

print("Initial workspace sync completed!")

---

## Debugging with Breakpoints in Monarch

Monarch supports interactive debugging using Python's `pdb`. See the full documentation in `studio_3_interactive_debugging.ipynb`.

In [None]:
class TitanTrainerDebug(Actor):
    """TorchTitan Trainer Actor with debugging breakpoints."""

    def __init__(self, job_config: JobConfig):
        self.rank = current_rank().rank
        self.job_config = job_config
        self.trainer: Optional[Trainer] = None

    def _rprint(self, msg):
        """Helper method to print with rank information."""
        print(f"{self.rank=} {msg}")

    @endpoint
    def init(self):
        logging.getLogger().addHandler(logging.StreamHandler(sys.stderr))
        self._rprint(f"Initializing debug actor: {current_rank()=} {socket.gethostname()=}")

        # Breakpoint 1: After initialization
        breakpoint()  # Debug: Inspect actor initialization state

    @endpoint
    def setup_trainer(self):
        """Setup the trainer with a breakpoint to inspect configuration."""
        logger.info(f"Setting up trainer on rank {self.rank}")
        config = self.job_config

        # Breakpoint 2: Before trainer creation
        if self.rank == 0:
            breakpoint()  # Debug: Inspect job config before trainer creation

        self.trainer = Trainer(config)
        self._rprint("Trainer setup complete")

    @endpoint
    def train_step(self, num_steps: int = 5):
        """Run a few training steps with breakpoints."""
        if not self.trainer:
            raise RuntimeError("Trainer not initialized. Call setup_trainer first.")

        logger.info(f"Starting training for {num_steps} steps on rank {self.rank}")

        for step in range(num_steps):
            if step == 2 and self.rank == 0:
                breakpoint()  # Debug: Inspect mid-training state

            self._rprint(f"Processing step {step + 1}/{num_steps}")

        self._rprint(f"Completed {num_steps} training steps")

    @endpoint
    def cleanup(self):
        """Cleanup resources."""
        logger.info(f"Cleaning up trainer on rank {self.rank}")

        if self.trainer:
            self.trainer.close()

        if torch.distributed.is_initialized():
            torch.distributed.destroy_process_group()
            logger.info("Process group destroyed.")

        self._rprint("Cleanup complete")

---

## Cleanup and Stop Process Mesh

In [None]:
# Stop the hello actor first
actor_mesh.stop().get()

In [None]:
# Shutdown the host mesh
host_mesh.shutdown().get()
print("Process mesh stopped successfully!")