<a href="https://colab.research.google.com/github/Papa-Panda/Paper_reading/blob/main/Ring_allreduce.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# https://www.bilibili.com/video/BV1mm42137X8/?spm_id_from=333.788.top_right_bar_window_history.content.click&vd_source=1fecee762931e992c96e5e166be13b76

In [None]:
# https://gemini.google.com/app/bcd426a78e4d0c59

In [None]:
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import os
import time

def _do_ring_allreduce(rank, world_size, tensor):
    """
    Performs a conceptual ring allreduce operation on a given tensor.
    This function manually simulates the scatter-reduce and all-gather phases
    using point-to-point communication.
    """
    # Ensure the tensor is a float type for summation
    tensor = tensor.float()

    # Divide the tensor into chunks for the ring allreduce
    # For simplicity, we assume the tensor size is divisible by world_size
    chunk_size = tensor.numel() // world_size
    # Reshape the tensor into chunks for easier indexing
    chunks = [tensor.view(-1)[i * chunk_size:(i + 1) * chunk_size] for i in range(world_size)]

    # Calculate left and right neighbors in the ring
    left_neighbor = (rank - 1 + world_size) % world_size
    right_neighbor = (rank + 1) % world_size

    print(f"Rank {rank}: Initial tensor: {tensor.tolist()}")
    print(f"Rank {rank}: Initial chunks: {[c.tolist() for c in chunks]}")
    print(f"Rank {rank}: Left neighbor: {left_neighbor}, Right neighbor: {right_neighbor}")

    # --- Phase 1: Scatter-Reduce ---
    # Each process sends a chunk to its right neighbor and receives from its left.
    # It accumulates the received chunk into its own corresponding chunk.
    print(f"\nRank {rank}: Starting Scatter-Reduce Phase...")
    for i in range(world_size - 1):
        # The chunk to send is (rank - i + world_size) % world_size
        # The chunk to receive is (rank - i - 1 + world_size) % world_size
        send_chunk_idx = (rank - i + world_size) % world_size
        recv_chunk_idx = (rank - i - 1 + world_size) % world_size

        # Create a buffer for receiving
        recv_buffer = torch.empty_like(chunks[recv_chunk_idx])

        # Send to right neighbor, receive from left neighbor
        # Use non-blocking communication (isend/irecv) for better overlap
        send_req = dist.isend(chunks[send_chunk_idx], dst=right_neighbor)
        recv_req = dist.irecv(recv_buffer, src=left_neighbor)

        # Wait for communication to complete
        send_req.wait()
        recv_req.wait()

        # Accumulate the received data into the correct local chunk
        chunks[recv_chunk_idx] += recv_buffer
        print(f"Rank {rank} (iter {i+1}): Received chunk {recv_chunk_idx} from {left_neighbor}, current chunks: {[c.tolist() for c in chunks]}")
        time.sleep(0.1) # Small delay for clearer output in Colab

    print(f"Rank {rank}: Scatter-Reduce Phase Complete. Chunks: {[c.tolist() for c in chunks]}")

    # --- Phase 2: All-Gather ---
    # Each process sends its accumulated chunk to its right neighbor and receives
    # a new chunk from its left neighbor, effectively gathering all the final
    # reduced chunks.
    print(f"\nRank {rank}: Starting All-Gather Phase...")
    for i in range(world_size - 1):
        # The chunk to send is (rank - world_size + 1 + i + world_size) % world_size
        # The chunk to receive is (rank - world_size + i + world_size) % world_size
        send_chunk_idx = (rank - (world_size - 1) + i + world_size) % world_size
        recv_chunk_idx = (rank - (world_size - 1) + i - 1 + world_size) % world_size

        # Create a buffer for receiving
        recv_buffer = torch.empty_like(chunks[recv_chunk_idx])

        # Send to right neighbor, receive from left neighbor
        send_req = dist.isend(chunks[send_chunk_idx], dst=right_neighbor)
        recv_req = dist.irecv(recv_buffer, src=left_neighbor)

        # Wait for communication to complete
        send_req.wait()
        recv_req.wait()

        # Overwrite the received data into the correct local chunk (no accumulation)
        chunks[recv_chunk_idx] = recv_buffer
        print(f"Rank {rank} (iter {i+1}): Received chunk {recv_chunk_idx} from {left_neighbor}, current chunks: {[c.tolist() for c in chunks]}")
        time.sleep(0.1) # Small delay for clearer output in Colab

    print(f"Rank {rank}: All-Gather Phase Complete. Chunks: {[c.tolist() for c in chunks]}")

    # Reconstruct the final tensor from the gathered chunks
    final_tensor = torch.cat(chunks).view_as(tensor)
    return final_tensor

def run_process(rank, world_size):
    """
    Function to be executed by each process.
    Initializes the distributed environment and performs the ring allreduce.
    """
    # Set environment variables for distributed communication
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500' # Use a consistent, available port

    # Initialize the process group
    # 'gloo' backend is suitable for CPU-based communication on a single machine.
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    print(f"Rank {rank} (out of {world_size}) process group initialized.")

    # Create an initial tensor for this rank
    # Each element of the tensor is initialized with the rank ID
    initial_data = torch.full((world_size * 2,), float(rank)) # Example: tensor of size 8 for world_size=4
    print(f"Rank {rank}: Initial data for allreduce: {initial_data.tolist()}")

    # Perform the custom ring allreduce
    final_result = _do_ring_allreduce(rank, world_size, initial_data)

    print(f"\nRank {rank}: Final tensor after custom ring allreduce: {final_result.tolist()}")

    # Verify the result against PyTorch's built-in all_reduce for comparison
    # (This part is just for verification, not part of the custom ring allreduce itself)
    torch_built_in_tensor = initial_data.clone()
    dist.all_reduce(torch_built_in_tensor, op=dist.ReduceOp.SUM)
    print(f"Rank {rank}: Result from PyTorch's built-in all_reduce: {torch_built_in_tensor.tolist()}")

    # Clean up the process group
    dist.destroy_process_group()
    print(f"Rank {rank}: Process group destroyed.")

def main():
    """
    Main function to spawn multiple processes for the distributed example.
    """
    world_size = 4 # Number of simulated processes/nodes
    print(f"Spawning {world_size} processes for the conceptual ring allreduce example...")
    # mp.spawn is used to launch multiple processes, each running 'run_process'
    mp.spawn(run_process,
             args=(world_size,),
             nprocs=world_size,
             join=True) # 'join=True' makes the main process wait for all spawned processes to finish
    print("\nAll processes have completed the ring allreduce simulation.")

if __name__ == "__main__":
    # This block ensures that the multiprocessing starts correctly on Windows
    # and in environments like Colab.
    mp.set_start_method('spawn', force=True)
    main()


Spawning 4 processes for the conceptual ring allreduce example...


W0604 04:46:49.151000 344 torch/multiprocessing/spawn.py:169] Terminating process 459 via signal SIGTERM


ProcessExitedException: process 2 terminated with exit code 1