In [1]:
%%writefile train_script.py
import torch
from torch import nn
from torch import optim
import os
import torch.distributed as dist
def init_dist():
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    local_rank = int(os.environ["LOCAL_RANK"])
    if torch.cuda.is_available():
        device=torch.device(f"cuda:{local_rank}")
        dist.init_process_group(
            backend="nccl", rank=rank, world_size=world_size, device_id=device
        )
    else:
        device= torch.device("cpu")
        dist.init_process_group(
            backend="gloo", rank=rank, world_size=world_size, device_id=device
        )
    return rank, world_size, device



class PipelineComms:
    def __init__(self, rank, world_size):
        self.rank = rank
        self.world_size = world_size
        self.prev_rank = rank - 1 if rank > 0 else None
        self.next_rank = rank + 1 if rank < world_size - 1 else None

    def send_forward(self, tensor):
        dist.send(tensor.contiguous(), dst=self.next_rank)

    def recv_forward(self, shape, device, dtype=torch.float32):
        tensor = torch.zeros(shape, dtype=dtype, device=device)
        dist.recv(tensor, src=self.prev_rank)
        return tensor

    def send_backward(self, tensor):
         dist.send(tensor.contiguous(), dst=self.prev_rank)

    def recv_backward(self, shape, device, dtype=torch.float32):
        tensor = torch.zeros(shape, dtype=dtype, device=device)
        dist.recv(tensor, src=self.next_rank)
        return tensor

    def isend_forward(self, tensor):
        return dist.isend(tensor.contiguous(), dst=self.next_rank)



class SharedMLP(nn.Module):
    def __init__(self, dim, total_layers, rank, world_size):
        super().__init__()
        self.rank=rank
        self.is_first = rank == 0
        self.is_last = rank + 1 == world_size
        self.world_size=world_size
        self.lpg= total_layers// world_size
        layer= []
        for i in range(self.lpg):
            layer.append(nn.Linear(dim,dim))
            layer.append(nn.ReLU())
        if self.is_last:
            layer.append(nn.Linear(dim,2))
            self.loss_fn=nn.CrossEntropyLoss()
        self.net=nn.Sequential(*layer)

    def forward(self,x,targets=None):
        x=self.net(x)
        if self.is_last and targets is not None:
            return self.loss_fn(x, targets)
        return x




def gpipe_pipeline_step(model, comms, batch, targets, hidden_dim, chunks, device):
    if comms.rank == 0:
        micro_batches = torch.chunk(batch, chunks)
    if comms.rank == comms.world_size - 1:
        micro_targets = targets.chunk(chunks)
    input_buffers = []
    output_buffers = []

    for i in range(chunks):
        if comms.rank == 0:
            input_data = micro_batches[i]
        else:
            shape = (batch // chunks, hidden_dim)
            input_data = comms.recv_forward(shape, device)
            input_data.requires_grad = True

        if comms.rank == comms.world_size - 1:
            output = model(input_data, micro_targets[i])
        else:
            output = model(input_data)
            comms.send_forward(output.detach())

        input_buffers.append(input_data)
        output_buffers.append(output) 
        
    if comms.rank == comms.world_size - 1:
        total_loss = torch.zeros(output.shape, device=device)
        
    for i in range(chunks):
        input_data = input_buffers[i]
        output = output_buffers[i]

        if comms.rank == comms.world_size - 1:
            loss = output / chunks
            loss.backward()
            total_loss += loss
        else:
            grad_from_next = comms.recv_backward(output.shape, device)
            output.backward(grad_from_next)

        if comms.rank != 0:
            comms.send_backward(input_data.grad)

    if comms.rank == comms.world_size - 1:
        return total_loss



Writing train_script.py


In [6]:
%%writefile trainer.py
import torch
from torch import nn
from torch import optim
import os,time
import torch.distributed as dist
from train_script import PipelineComms,init_dist,SharedMLP,gpipe_pipeline_step

BATCH_SIZE = 96
HIDDEN_DIM = 128
TOTAL_LAYERS = 16
STEPS = 50
CHUNKS = 8
rank, world_size, device = init_dist()
comms = PipelineComms(rank, world_size)
torch.manual_seed(42)
if rank == 0:
    print(f"--- Starting Micro PP on {world_size} Processes ({device}) ---")

model = SharedMLP(HIDDEN_DIM, TOTAL_LAYERS, rank, world_size).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)

if rank == 0:
    fixed_input = torch.randn(BATCH_SIZE, HIDDEN_DIM).to(device)
else:
    fixed_input = BATCH_SIZE

if rank == world_size - 1:
   
    fixed_target = torch.randint(0, 2, (BATCH_SIZE,)).to(device)
else:
    fixed_target = None

start_time = time.time()
model.train()
for step in range(STEPS):
    optimizer.zero_grad()
    if model.is_last:
        loss = gpipe_pipeline_step(
            model, comms, fixed_input, fixed_target, HIDDEN_DIM,3 ,device
        )
    else:
       
        gpipe_pipeline_step(
            model, comms, fixed_input, fixed_target, HIDDEN_DIM, 3,device
        )

    optimizer.step()
    if rank == world_size - 1 and step % 5 == 0:
        print(f"Step {step:02d} | Loss: {loss.item():.6f}")


if rank == world_size - 1:
    print("--- Training Complete ---")
    duration = time.time() - start_time
    print(f"Final Loss: {loss.item():.6f} | Time: {duration:.3f}s")
torch.distributed.destroy_process_group()


Overwriting trainer.py


In [7]:

!torchrun --nproc_per_node=2 trainer.py
 

W0214 15:28:00.307000 117 torch/distributed/run.py:774] 
W0214 15:28:00.307000 117 torch/distributed/run.py:774] *****************************************
W0214 15:28:00.307000 117 torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0214 15:28:00.307000 117 torch/distributed/run.py:774] *****************************************
--- Starting Micro PP on 2 Processes (cuda:0) ---
Step 00 | Loss: 0.697213
Step 05 | Loss: 0.691673
Step 10 | Loss: 0.691702
Step 15 | Loss: 0.691123
Step 20 | Loss: 0.687392
Step 25 | Loss: 0.554777
Step 30 | Loss: 0.374644
Step 35 | Loss: 0.328649
Step 40 | Loss: 0.299344
Step 45 | Loss: 0.278601
--- Training Complete ---
Final Loss: 0.260255 | Time: 0.994s
