In [1]:
%cd /mnt/models/mreso/monarch/examples/

/mnt/models/mreso/monarch/examples


In [None]:


# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# @noautodeps
# pyre-ignore-all-errors
import logging
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from monarch.tools import commands
from monarch.actor import Actor, current_rank, endpoint
from monarch.utils import setup_env_for_distributed
from torch.nn.parallel import DistributedDataParallel as DDP
# from slurm.utils_with_init import (
from slurm.utils import (
    get_appdef, 
    get_server_info, 
    create_proc_mesh,
)

os.environ["RUST_BACKTRACE"] = "full"
os.environ["RUST_LOG"] = "debug"


logging.basicConfig(
    level=logging.DEBUG,
    format="%(name)s %(asctime)s %(levelname)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    force=True,
)


logger: logging.Logger = logging.getLogger(__name__)

class BarrierActor(Actor):
    """This Actor wraps the basic functionality from Torch's DDP example.

    Conveniently, all of the methods we need are already laid out for us,
    so we can just wrap them in the usual Actor endpoint semantic with some
    light modifications.

    Adapted from: https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html#basic-use-case
    """

    def __init__(self):
        # os.environ["NCCL_DEBUG"] = "INFO"
        self.rank = current_rank().rank
        #self.local_rank = self.rank % 8  # Local GPU ID (0-7)
        self.local_rank = int(self.rank % 8)  # Local GPU ID

    def _rprint(self, msg):
        """Helper method to print with rank information."""
        print(f"{self.rank=} {msg}")

    @endpoint
    async def setup(self):
        """Initialize the PyTorch distributed process group."""
        self._rprint("Initializing torch distributed")
        self._rprint(f"{self.local_rank=}")
        # Set GPU device BEFORE dist.init_process_group
        torch.cuda.set_device(self.local_rank)
        self._rprint(f"Set GPU device to {self.local_rank}")


        WORLD_SIZE = int(os.environ["WORLD_SIZE"])
        MASTER_ADDR = os.environ.get("MASTER_ADDR", "localhost")
        MASTER_PORT = os.environ.get("MASTER_PORT", "12355")
        RANK = int(os.environ.get("RANK", "0"))
        LOCAL_RANK = int(os.environ.get("LOCAL_RANK", "0"))
        print(f"MASTER_ADDR: {MASTER_ADDR}, MASTER_PORT: {MASTER_PORT}, RANK: {RANK}")
        # initialize the process group
        dist.init_process_group(
                    backend="nccl",
                    init_method=f"tcp://{MASTER_ADDR}:{MASTER_PORT}",
                    world_size=WORLD_SIZE,
                    rank=RANK,
                    device_id=LOCAL_RANK,
                )
        self._rprint("Finished initializing torch distributed")

    @endpoint
    async def cleanup(self):
        """Clean up the PyTorch distributed process group."""
        self._rprint("Cleaning up torch distributed")
        dist.destroy_process_group()

    @endpoint
    async def demo_basic(self):
        """Run a basic DDP training example."""
        # self._rprint(f"{os.environ['NCCL_DEBUG']=}")
        torch.cuda.set_device(self.local_rank)
        self._rprint("Running basic DDP example")
        self._rprint(f"{torch.cuda.device_count()=}")
        self._rprint(f"{torch.cuda.current_device()=}")
        self._rprint(f"{torch.cuda.get_device_name(0)=}")
        self._rprint(f"{torch.cuda.is_initialized()=}")
        t = current_rank().rank * torch.ones(1).cuda()
        torch.distributed.all_reduce(t)
        self._rprint(f"{t=}")
        self._rprint("Finished running basic DDP example")


async def main():
    num_hosts = 2
    appdef = await get_appdef(num_hosts)

    appdef.roles[0].resource.gpu = 8

    server_info = await get_server_info(appdef)

    try:
        print("CREATE PROC MESH")
        proc_mesh = await create_proc_mesh(num_hosts, appdef, server_info)
        
        await proc_mesh.logging_option(
            stream_to_client=True,
            aggregate_window_sec=None,
        )

        print("SPAWN ACTORS")
        barrier_actor = proc_mesh.spawn("barrier_actor", BarrierActor)
        print("SETUP ENV")
        await setup_env_for_distributed(proc_mesh)
        print("SETUP CALL")
        await barrier_actor.setup.call()
        print("BASIC DEMO CALL")
        await barrier_actor.demo_basic.call()
        print("CLEAUP CALL")
        await barrier_actor.cleanup.call()

        print("DDP example completed successfully!")

    finally:
        commands.kill(f"slurm:///{server_info.name}")

if __name__ == "__main__":
    await main()

SyntaxError: expected argument value expression (1532023262.py, line 79)