# Task 3: Implementing All Reduce with Ray 

In this task, you will be using the point-to-point communication APIs in [ray.util.collective](https://docs.ray.io/en/latest/ray-more-libs/ray-collective.html) to implement the AllReduce collective communication operation. To be specific, we have provided you a template of a `Worker` class. This class is a Ray Actor, and different actor processes will be maintaining the Actor state and will handle method execution for the respective instances. You need to complete the functions of this class so that the Actor processes can perform AllReduce communication among them. 

You need to implement:
1. Simple P2P communication: `do_send`, `do_recv` and `do_send_recv`
2. Ray AllReduce: A simple allreduce implementation that uses Ray's built-in
3. BDE AllReduce: `bde_all_reduce`. This function should implement the BDE (bidirectional exchanges) version of AllReduce. The reduce operation will perform addition over all the messages of the processes
3. MST AllReduce: `mst_all_reduce` . This function should implement the MST (minimum-spanning tree) version of AllReduce. The reduce operation will perform addition over all the messages of the processes. You need to implement this MST AllReduce using Reduce and Broadcast operations as the building blocks

For MST and BDE AllReduce, you can only use the P2P communcation functions (`do_send`, etc) and any other helper methods that you write. 

We have provided you with profiling functions so that you can see the difference between these implementations.

In [1]:
import ray, torch, math, os, time, itertools
import numpy as np
import ray.util.collective as col
from ray.util.collective import types
import types as t

os.environ["PYTHONWARNINGS"]="ignore::DeprecationWarning"
world_size = 8   # change this to a smaller number if you need to debug
group_name = "dsc_204a"
backend = "gloo"
ray.init() 

2024-03-04 23:45:56,307	INFO util.py:154 -- Outdated packages:
  ipywidgets==7.8.1 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2024-03-04 23:45:58,491	INFO worker.py:1540 -- Connecting to existing Ray cluster at address: 10.47.192.23:6380...
2024-03-04 23:45:58,533	INFO worker.py:1715 -- Connected to Ray cluster. View the dashboard at [1m[32mhttp://10.47.192.23:8265 [39m[22m


0,1
Python version:,3.8.18
Ray version:,2.9.3
Dashboard:,http://10.47.192.23:8265


In [2]:
@ray.remote
class Worker:
    def __init__(self, world_size, rank, group_name, backend=backend):
        # please initialize the collective group using the following function:
        # ray.util.collective.collective.init_collective_group
        # (https://docs.ray.io/en/latest/ray-more-libs/ray-collective.html#ray.util.collective.collective.init_collective_group)
        #
        # In addition, please add any additional class attributes you find helpful
        
        # YOUR CODE HERE
        col.collective.init_collective_group(world_size, rank, backend, group_name)
        self.rank = rank
        self.world_size = world_size
        self.comm_log = []

    def get_msg(self):
        if hasattr(self, 'msg'):
            return self.msg
        else:
            return None
    
    def get_buf(self):
        if hasattr(self, 'buf'):
            return self.buf
        else:
            return None

    def set_msg(self, msg):
        self.msg = msg
        return True
    
    def set_buf(self, shape, dtype):
        self.buf = torch.zeros(shape, dtype=dtype)
        return True

    def get_comm_log(self):
        return self.comm_log
    
    def empty_log(self):
        self.comm_log = []

    # This is a wrapper function for Ray P2P send (https://docs.ray.io/en/latest/ray-more-libs/ray-collective.html#point-to-point-communication)
    # This function calls Ray P2P send with self.msg as the message to be sent
    # target_rank: the rank of the destination process
    def do_send(self, target_rank):
        # YOUR CODE HERE
        msg = self.get_msg()
        col.send(msg, target_rank, group_name)
        self.comm_log.append(["send", self.rank, target_rank])
        return self.msg

    # This is a wrapper function for Ray P2P recv (https://docs.ray.io/en/latest/ray-more-libs/ray-collective.html#point-to-point-communication)
    # This function calls Ray P2P recv with self.buf as the buffer for incoming messages
    # src_rank: the rank of the sender process
    def do_recv(self, src_rank):
        # YOUR CODE HERE
        col.recv(self.buf, src_rank, group_name)
        self.comm_log.append(["recv", self.rank, src_rank])
        return self.buf
    
    # This function sends the self.msg from the sender to the receiver
    # Please implement this function using self.do_send and self.do_recv as the building blocks
    # src_rank: the rank of the sender's process
    # target_rank: the rank of the destination process
    def do_send_recv(self, src_rank, target_rank):
        print(f"*** sending from rank {src_rank} to rank {target_rank}***")
        # YOUR CODE HERE
        if self.rank != target_rank:
            self.do_send(target_rank)
            return self.msg
        
        elif self.rank != src_rank:
            self.do_recv(src_rank)
            return self.buf
        
    # This function performs AllReduce using ray.util.collective.allreduce
    # (https://docs.ray.io/en/latest/ray-more-libs/ray-collective.html#point-to-point-communication)
    # The reduce stage adds up all the messages from the processors
    def ray_all_reduce(self, op=types.ReduceOp.SUM):
        # YOUR CODE HERE
        col.allreduce(self.msg, op=op, group_name=group_name)
        return self.msg

#     # This function implements the BDE (bidirectional exchanges) version of AllReduce
#     # The reduce operation will perform addition over all the messages of the processes
    def bde_all_reduce(self, op=types.ReduceOp.SUM):
        # YOUR CODE HERE
        def recursive_bde_all_reduce(left, right):
            if left == right:
                return

            size = right - left + 1
            mid = (left + right) // 2

            if self.rank <= mid:
                partner = self.rank + size // 2
            else:
                partner = self.rank - size // 2

            if self.rank <= mid:
                self.do_send(partner)
                recv_msg = self.do_recv(partner)
                self.msg += recv_msg
            else:
                recv_msg = self.do_recv(partner)
                self.do_send(partner)
                self.msg += recv_msg

            if self.rank <= mid:
                recursive_bde_all_reduce(left, mid)
            else:
                recursive_bde_all_reduce(mid + 1, right)

        recursive_bde_all_reduce(0, self.world_size - 1)
        return self.msg

    # This function implements the MST (minimum-spanning tree) version of AllReduce
    # The reduce operation will perform addition over all the messages of the processes
    # You need to implement this MST AllReduce using reduce and broadcast as the building blocks
    def mst_all_reduce(self, op=types.ReduceOp.SUM):
        # YOUR CODE HERE
#         raise NotImplementedError()
        def broadcast(root, left, right):
            if left == right:
                return
            
            mid = (left+right)//2
            if root <= mid:
                dest = right
            else:
                dest = left
            
            if self.rank == root:
                self.do_send(dest)
            elif self.rank == dest:
                self.set_msg(self.do_recv(root))
                           
            if self.rank <= mid and root <= mid:
                broadcast(root, left, mid)
            elif self.rank <= mid and root > mid:
                broadcast(dest, left, mid)
            elif self.rank > mid and root <= mid:
                broadcast(dest, mid+1, right)
            elif self.rank > mid and root > mid:
                broadcast(root, mid+1, right)
            
        def MSTreduce(root, left, right):
            if left == right:
                return
            
            mid = (left+right)//2
            if root <= mid:
                srce = right
            else:
                srce = left
                
            if self.rank <= mid and root <= mid:
                MSTreduce(root, left, mid)
            elif self.rank <= mid and root > mid:
                MSTreduce(srce, left, mid)
            elif self.rank > mid and root <= mid:
                MSTreduce(srce, mid+1, right)
            elif self.rank > mid and root > mid:
                MSTreduce(root, mid+1, right)
                
            if self.rank == srce:
                self.do_send(root)
            elif self.rank == root:
                self.msg += self.do_recv(srce)
        
        MSTreduce(0, 0, self.world_size - 1)
        broadcast(0, 0, self.world_size - 1)
        
        return self.msg

    # YOUR CODE HERE


In [3]:
# Please initialize the workers here and
# declare the collective group using the following function:
# ray.util.collective.collective.create_collective_group 
# (https://docs.ray.io/en/latest/ray-more-libs/ray-collective.html#ray.util.collective.collective.GroupManager.create_collective_group)
# The values for world_size, backend, and group_name are provided at the top of this file
ray.shutdown()
ray.init()

workers = []

for rank in range(world_size):
    workers.append(Worker.remote(world_size=world_size, rank=rank, backend=backend, group_name=group_name))
    
col.collective.create_collective_group(actors = workers, 
                                       world_size=world_size,
                                       ranks=range(len(workers)),
                                       backend = backend,
                                       group_name = group_name)

#
# After this initialization stage, there should be a list of Ray
# object refs for the Workers you initialized
# You will need this list to perform the following collective communication tasks
# YOUR CODE HERE
# raise NotImplementedError()

2024-03-04 23:45:58,669	INFO worker.py:1540 -- Connecting to existing Ray cluster at address: 10.47.192.23:6380...
2024-03-04 23:45:58,676	INFO worker.py:1715 -- Connected to Ray cluster. View the dashboard at [1m[32mhttp://10.47.192.23:8265 [39m[22m


[33m(raylet)[0m It looks like you're creating a detached actor in an anonymous namespace. In order to access this actor in the future, you will need to explicitly connect to this namespace with ray.init(namespace="8234491d-4aa1-4e6e-8ddc-7bffdb4c0058", ...)


[36m(pid=29645, ip=10.47.192.24)[0m NCCL seems unavailable. Please install Cupy following the guide at: https://docs.cupy.dev/en/stable/install.html.


# P2P Communication

In [4]:
def profiling_p2p(workers, size, num_trials, dtype=torch.float32):
    print(f"***** Start profiling p2p *****")
    GB = 1024**3
    MB = 1024**2
    src = 0
    target = 1
    workers = [workers[src], workers[target]]

    msg = torch.ones(1, int(size), dtype=dtype)
    msg_ref = ray.put(msg)

    for w in workers:
        w.set_msg.remote(msg_ref)
        w.set_buf.remote(msg.shape, msg.dtype)
    
    tic = time.time()
    for i in range(num_trials):
        results = ray.get([workers[0].do_send_recv.remote(src, target),
                           workers[1].do_send_recv.remote(src, target)])
    toc = time.time()
    
    time_cost_per_trial = (toc - tic) / num_trials
    msg_size = size * torch.finfo(dtype).bits  # msg size in bits
    comm_size = msg_size
    bandwidth = comm_size / time_cost_per_trial
    print(f"SendRecv: {[src, target]}\tSize: {comm_size / 8 / MB:.5f} MB\t"
          f"Avg time per trial: {time_cost_per_trial:.5f}s\tBandwidth: {bandwidth / 8 / MB:.2f} MB/s")

    print(f"***** Completed profiling p2p *****")
    return bandwidth, time_cost_per_trial, msg_size, comm_size

In [5]:
def test_p2p(workers):
    print(f"***** Start testing p2p *****")
    src = 0
    target = 1
    workers = [workers[src], workers[target]]
    msg_len = 20
    msg = torch.ones(1, int(msg_len))
    msg_ref = ray.put(msg)

    for w in workers:
        w.empty_log.remote()
        w.set_msg.remote(msg_ref)
        w.set_buf.remote(msg.shape, msg.dtype)
    
    results = ray.get([workers[0].do_send_recv.remote(src, target),
                           workers[1].do_send_recv.remote(src, target)])

    assert(torch.eq(results[0], msg).sum() == msg_len)
    assert(torch.eq(results[1], msg).sum() == msg_len)

    print(f"***** p2p test passed *****")

In [6]:
# size: the number of values we send over the connection; this is a way to 
#       control the communication volume
# num_trials: the number of communication trials to be run during profiling so
#             that an average number can be computed
profiling_p2p(workers, size=1<<23, num_trials=10)

***** Start profiling p2p *****


[36m(Worker pid=29646, ip=10.47.192.24)[0m Use get_job_id() instead
[36m(Worker pid=29646, ip=10.47.192.24)[0m   self._job_id = ray.get_runtime_context().job_id
[36m(pid=17860, ip=10.46.128.21)[0m NCCL seems unavailable. Please install Cupy following the guide at: https://docs.cupy.dev/en/stable/install.html.


[33m(raylet)[0m It looks like you're creating a detached actor in an anonymous namespace. In order to access this actor in the future, you will need to explicitly connect to this namespace with ray.init(namespace="8234491d-4aa1-4e6e-8ddc-7bffdb4c0058", ...)
[36m(Worker pid=29644, ip=10.47.192.24)[0m *** sending from rank 0 to rank 1***
[33m(raylet)[0m It looks like you're creating a detached actor in an anonymous namespace. In order to access this actor in the future, you will need to explicitly connect to this namespace with ray.init(namespace="8234491d-4aa1-4e6e-8ddc-7bffdb4c0058", ...)
SendRecv: [0, 1]	Size: 32.00000 MB	Avg time per trial: 0.73848s	Bandwidth: 43.33 MB/s
***** Completed profiling p2p *****


(363496756.71134037, 0.7384810209274292, 268435456, 268435456)

In [7]:
# test p2p
test_p2p(workers)

***** Start testing p2p *****
***** p2p test passed *****
[36m(Worker pid=17860, ip=10.46.128.21)[0m *** sending from rank 0 to rank 1***

# Ray AllReduce
Profiling Ray's `allreduce` implementation can give us a good reference.

In [8]:
def profile_ray_all_reduce(workers, size, num_trials, dtype=torch.float32):
    print(f"***** Start profiling ray AllReduce *****")
    GB = 1024**3
    MB = 1024**2
    src = 0
    target = 1

    msg = torch.ones(1, int(size), dtype=torch.float32)
    msg_ref = ray.put(msg)

    for w in workers:
        w.set_msg.remote(msg_ref)
    
    tic = time.time()
    for i in range(num_trials):
        results = ray.get([w.ray_all_reduce.remote() for w in workers])
    toc = time.time()


    time_cost_per_trial = (toc - tic) / num_trials
    msg_size = size * torch.finfo(dtype).bits  # msg size in bits
    comm_size = 2 * msg_size * (len(workers) - 1) / (len(workers))
    bandwidth = comm_size / time_cost_per_trial
    print(f"SendRecv: {[src, target]}\tSize: {comm_size / 8 / MB:.5f} MB\t"
          f"Avg time per trial: {time_cost_per_trial:.5f}s\tBandwidth: {bandwidth / 8 / MB:.2f} MB/s")

    print(f"***** Completed profiling ray AllReduce *****")
    return bandwidth, time_cost_per_trial, msg_size, comm_size

def test_ray_all_reduce(workers): 
    print(f"***** Start testing ray_all_reduce *****")
    msg_len = 20
    msg = torch.ones(1, int(msg_len))
    msg_ref = ray.put(msg)

    for w in workers:
        w.set_msg.remote(msg_ref)
    
    results = ray.get([w.ray_all_reduce.remote() for w in workers])

    for r in results:
        assert(torch.eq(r, torch.tensor([len(workers)]*msg_len, dtype=msg.dtype)).sum() == msg_len)

    print(f"***** ray_all_reduce test passed *****")




In [9]:
# size: the number of values we send over the connection; this is a way to 
#       control the communication volume
# num_trials: the number of communication trials to be run during profiling so
#             that an average number can be computed
profile_ray_all_reduce(workers, size=1<<23, num_trials=10)

***** Start profiling ray AllReduce *****
[36m(Worker pid=29644, ip=10.47.192.24)[0m *** sending from rank 0 to rank 1***
SendRecv: [0, 1]	Size: 56.00000 MB	Avg time per trial: 1.06062s	Bandwidth: 52.80 MB/s
***** Completed profiling ray AllReduce *****


(442914514.0747709, 1.060615611076355, 268435456, 469762048.0)

In [10]:
test_ray_all_reduce(workers)

***** Start testing ray_all_reduce *****
***** ray_all_reduce test passed *****


# BDE AllReduce

In [11]:
def profile_bde_all_reduce(workers, size, num_trials, dtype=torch.float32):
    print(f"***** Start profiling bde AllReduce *****")
    GB = 1024**3
    MB = 1024**2
    src = 0
    target = 1

    msg = torch.ones(1, int(size), dtype=torch.float32)
    msg_ref = ray.put(msg)

    for w in workers:
        w.empty_log.remote()
        w.set_msg.remote(msg_ref)
        w.set_buf.remote(msg.shape, msg.dtype)
    
    tic = time.time()
    for i in range(num_trials):
        results = ray.get([w.bde_all_reduce.remote() for w in workers])
    toc = time.time()


    time_cost_per_trial = (toc - tic) / num_trials
    msg_size = size * torch.finfo(dtype).bits  # msg size in bits
    comm_size = 2 * msg_size * (len(workers) - 1) / (len(workers))
    bandwidth = comm_size / time_cost_per_trial
    print(f"SendRecv: {[src, target]}\tSize: {comm_size / 8 / MB:.5f} MB\t"
          f"Avg time per trial: {time_cost_per_trial:.5f}s\tBandwidth: {bandwidth / 8 / MB:.2f} MB/s")

    print(f"***** Completed profiling bde AllReduce *****")
    return bandwidth, time_cost_per_trial, msg_size, comm_size

In [12]:
def test_bde_all_reduce(workers): 
    print(f"***** Start testing bde_all_reduce *****")
    msg_len = 5
    msg = torch.ones(1, int(msg_len))
    msg_ref = ray.put(msg)

    for w in workers:
        w.empty_log.remote()
        w.set_msg.remote(msg_ref)
        w.set_buf.remote(msg.shape, msg.dtype)
    
    # mst_all_reduce(workers)

    results = ray.get([w.bde_all_reduce.remote() for w in workers])

    for r in results:
        assert(torch.eq(r, torch.tensor([len(workers)]*msg_len, dtype=msg.dtype)).sum() == msg_len)

    print(f"***** bde_all_reduce test passed *****")

In [13]:
# size: the number of values we send over the connection; this is a way to 
#       control the communication volume
# num_trials: the number of communication trials to be run during profiling so
#             that an average number can be computed
profile_bde_all_reduce(workers, size=1<<23, num_trials=10)

***** Start profiling bde AllReduce *****
SendRecv: [0, 1]	Size: 56.00000 MB	Avg time per trial: 1.22262s	Bandwidth: 45.80 MB/s
***** Completed profiling bde AllReduce *****


(384225295.69054383, 1.2226213455200194, 268435456, 469762048.0)

In [14]:
test_bde_all_reduce(workers)

***** Start testing bde_all_reduce *****
***** bde_all_reduce test passed *****


# MST AllReduce

In [15]:
def profile_mst_all_reduce(workers, size, num_trials, dtype=torch.float32):
    print(f"***** Start profiling mst AllReduce *****")
    GB = 1024**3
    MB = 1024**2
    src = 0
    target = 1

    msg = torch.ones(int(size), 1, dtype=torch.float32)
    msg_ref = ray.put(msg)

    for w in workers:
        w.set_msg.remote(msg_ref)
        w.set_buf.remote(msg.shape, msg.dtype)
    
    tic = time.time()
    for i in range(num_trials):
        results = ray.get([w.mst_all_reduce.remote() for w in workers])
    toc = time.time()


    time_cost_per_trial = (toc - tic) / num_trials
    msg_size = size * torch.finfo(dtype).bits  # msg size in bits
    comm_size = 2 * msg_size * (len(workers) - 1) / (len(workers))
    bandwidth = comm_size / time_cost_per_trial
    print(f"SendRecv: {[src, target]}\tSize: {comm_size / 8 / MB:.5f} MB\t"
          f"Avg time per trial: {time_cost_per_trial:.5f}s\tBandwidth: {bandwidth / 8 / MB:.2f} MB/s")

    print(f"***** Completed profiling mst AllReduce *****")
    return bandwidth, time_cost_per_trial, msg_size, comm_size

In [16]:
def test_mst_all_reduce(workers): 
    print(f"***** Start testing mst_all_reduce *****")
    msg_len = 5
    msg = torch.ones(1, int(msg_len))
    msg_ref = ray.put(msg)

    for w in workers:
        w.set_msg.remote(msg_ref)
        w.set_buf.remote(msg.shape, msg.dtype)

    results = ray.get([w.mst_all_reduce.remote() for w in workers])

    for r in results:
        assert(torch.eq(r, torch.tensor([len(workers)]*msg_len, dtype=msg.dtype)).sum() == msg_len)

    print(f"***** mst_all_reduce test passed *****")

In [17]:
# size: the number of values we send over the connection; this is a way to 
#       control the communication volume
# num_trials: the number of communication trials to be run during profiling so
#             that an average number can be computed
profile_mst_all_reduce(workers, size=1<<23, num_trials=10)

***** Start profiling mst AllReduce *****


[33m(raylet, ip=10.47.192.24)[0m [2024-03-04 23:46:44,884 E 102 102] (raylet) node_manager.cc:3024: 1 Workers (tasks / actors) killed due to memory pressure (OOM), 0 Workers crashed due to other reasons at node (ID: 5644f0283c3734021b36913ef36a7035513f86bdf72c39b21c09188a, IP: 10.47.192.24) over the last time period. To see more information about the Workers killed on this node, use `ray logs raylet.out -ip 10.47.192.24`
[33m(raylet, ip=10.47.192.24)[0m 
[33m(raylet, ip=10.47.192.24)[0m Refer to the documentation on how to address the out of memory issue: https://docs.ray.io/en/latest/ray-core/scheduling/ray-oom-prevention.html. Consider provisioning more memory on this node or reducing task parallelism by requesting more CPUs per task. To adjust the kill threshold, set the environment variable `RAY_memory_usage_threshold` when starting Ray. To disable worker killing, set the environment variable `RAY_memory_monitor_refresh_ms` to zero.


SendRecv: [0, 1]	Size: 56.00000 MB	Avg time per trial: 1.30574s	Bandwidth: 42.89 MB/s
***** Completed profiling mst AllReduce *****


(359765999.47489977, 1.3057433128356934, 268435456, 469762048.0)

In [18]:
test_mst_all_reduce(workers)

***** Start testing mst_all_reduce *****
***** mst_all_reduce test passed *****


[33m(raylet, ip=10.46.128.21)[0m [2024-03-04 23:47:30,649 E 18308 18308] gcs_rpc_client.h:212: Failed to connect to GCS at address service-ray-cluster:6380 within 5 seconds.
[33m(raylet, ip=10.47.192.24)[0m [2024-03-04 23:47:44,886 E 102 102] (raylet) node_manager.cc:3024: 3 Workers (tasks / actors) killed due to memory pressure (OOM), 0 Workers crashed due to other reasons at node (ID: 5644f0283c3734021b36913ef36a7035513f86bdf72c39b21c09188a, IP: 10.47.192.24) over the last time period. To see more information about the Workers killed on this node, use `ray logs raylet.out -ip 10.47.192.24`
[33m(raylet, ip=10.47.192.24)[0m 
[33m(raylet, ip=10.47.192.24)[0m Refer to the documentation on how to address the out of memory issue: https://docs.ray.io/en/latest/ray-core/scheduling/ray-oom-prevention.html. Consider provisioning more memory on this node or reducing task parallelism by requesting more CPUs per task. To adjust the kill threshold, set the environment variable `RAY_memory_

# Profiling Results
Report the the profiler results for MST and BDE AllReduce below. Mention the size of the message and the number of trials  used as well (in case they were different from the defaults). You might not observe a significant difference in bandwidth for the two algorithms.

YOUR ANSWER HERE
# Results for BDE:

SendRecv: [0, 1]	Size: 56.00000 MB	Avg time per trial: 1.16643s	Bandwidth: 48.01 MB/s

(402736315.2167556, 1.166425848007202, 268435456, 469762048.0)

# Results for MST:

SendRecv: [0, 1]	Size: 56.00000 MB	Avg time per trial: 0.96710s	Bandwidth: 57.90 MB/s

(485741579.21503687, 0.9671028137207032, 268435456, 469762048.0)

# Conclusion:

The profiling results suggest that for message sizes of 56 MB, the MST AllReduce algorithm outperforms the BDE AllReduce algorithm in terms of both average time per trial and bandwidth. This indicates that the MST algorithm is more efficient for transferring large messages in this scenario. However, it's crucial to note that the optimal algorithm choice can vary depending on different conditions and requirements. Therefore, conducting experiments with various message sizes and network configurations is essential to determine the most suitable algorithm for specific use cases.