# Ray Worker + Torch Distributed

Ray worker keeps the state of the torch distributed process group. We provide a skeleton code to setup the processes to implement collective communication.

In [None]:
import os
import ray
import torch
import torch.distributed as dist

# Print all Ray logs.
os.environ['RAY_DEDUP_LOGS'] = '0'

# Start Ray (locally)
ray.init(ignore_reinit_error=True)

In [None]:
@ray.remote
class Worker:
    def __init__(self, rank, world_size):
        os.environ['RANK'] = str(rank)
        os.environ['WORLD_SIZE'] = str(world_size)
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = '12355'
        
        self.rank = rank
        self.world_size = world_size
        self.tensor = torch.tensor([float(self.rank + 1)])
        print(f"Rank {self.rank} initialized.")
    
    def setup_torch_distributed(self):
        dist.init_process_group(backend="gloo", rank=self.rank, world_size=self.world_size)
        return

    def compute(self):
        print(f"[Before all_reduce] Rank {self.rank} tensor = {self.tensor.item()}")
        dist.all_reduce(self.tensor, op=dist.ReduceOp.SUM)
        print(f"[After all_reduce] Rank {self.rank} tensor = {self.tensor.item()}")
        return self.tensor.item()

    def get_tensor(self):
        return self.tensor.item()

    def shutdown(self):
        dist.destroy_process_group()

# Launch 4 actors
world_size = 4
workers = [Worker.remote(rank=i, world_size=world_size) for i in range(world_size)]

# Trigger all_reduce
ray.get([w.compute.remote() for w in workers])

# Confirm the tensors
tensors = ray.get([w.get_tensor.remote() for w in workers])
print("Final tensors from all workers:", tensors)

In [None]:

# Cleanup
ray.get([w.shutdown.remote() for w in workers])