In [1]:
%%writefile train_script.py
import torch
from torch import nn
from torch import optim
import os
import torch.distributed as dist
def init_dist():
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    local_rank = int(os.environ["LOCAL_RANK"])
    if torch.cuda.is_available():
        device=torch.device(f"cuda:{local_rank}")
        dist.init_process_group(
            backend="nccl", rank=rank, world_size=world_size, device_id=device
        )
    else:
        device= torch.device("cpu")
        dist.init_process_group(
            backend="gloo", rank=rank, world_size=world_size, device_id=device
        )
    return rank, world_size, device



class PipelineComms:
    def __init__(self, rank, world_size):
        self.rank = rank
        self.world_size = world_size
        self.prev_rank = rank - 1 if rank > 0 else None
        self.next_rank = rank + 1 if rank < world_size - 1 else None

    def send_forward(self, tensor):
        dist.send(tensor.contiguous(), dst=self.next_rank)

    def recv_forward(self, shape, device, dtype=torch.float32):
        tensor = torch.zeros(shape, dtype=dtype, device=device)
        dist.recv(tensor, src=self.prev_rank)
        return tensor

    def send_backward(self, tensor):
         dist.send(tensor.contiguous(), dst=self.prev_rank)

    def recv_backward(self, shape, device, dtype=torch.float32):
        tensor = torch.zeros(shape, dtype=dtype, device=device)
        dist.recv(tensor, src=self.next_rank)
        return tensor

    def isend_forward(self, tensor):
        return dist.isend(tensor.contiguous(), dst=self.next_rank)



class SharedMLP(nn.Module):
    def __init__(self, dim, total_layers, rank, world_size):
        super().__init__()
        self.rank=rank
        self.is_first = rank == 0
        self.is_last = rank + 1 == world_size
        self.world_size=world_size
        self.lpg= total_layers// world_size
        layer= []
        for i in range(self.lpg):
            layer.append(nn.Linear(dim,dim))
            layer.append(nn.ReLU())
        if self.is_last:
            layer.append(nn.Linear(dim,2))
            self.loss_fn=nn.CrossEntropyLoss()
        self.net=nn.Sequential(*layer)

    def forward(self,x,targets=None):
        x=self.net(x)
        if self.is_last and targets is not None:
            return self.loss_fn(x, targets)
        return x




def naive_pipeline_step(
    model, comms, batch, target, hidden_dim, device
):
    if comms.rank == 0:
        input_data = batch
    else:
        shape = (batch, hidden_dim)
        input_data = comms.recv_forward(shape, device)
        input_data.requires_grad = True
        
    output = model(input_data, target)
    if not model.is_last:
        comms.send_forward(output.detach())

    if model.is_last:
        loss = output
        loss.backward()
    else:
        grad_from_next = comms.recv_backward(output.shape, device)
        output.backward(grad_from_next)
    grad_to_send = input_data.grad
    if not model.is_first:
        comms.send_backward(grad_to_send)
    if model.is_last:
        return loss



Writing train_script.py


In [None]:
%%writefile trainer.py
import torch
from torch import nn
from torch import optim
import os,time
import torch.distributed as dist
from train_script import PipelineComms,init_dist,SharedMLP,naive_pipeline_step

BATCH_SIZE = 32
HIDDEN_DIM = 128
TOTAL_LAYERS = 16
STEPS = 50
CHUNKS = 8
rank, world_size, device = init_dist()
comms = PipelineComms(rank, world_size)
torch.manual_seed(42)
if rank == 0:
    print(f"--- Starting Micro PP on {world_size} Processes ({device}) ---")

model = SharedMLP(HIDDEN_DIM, TOTAL_LAYERS, rank, world_size).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)

if rank == 0:
    fixed_input = torch.randn(BATCH_SIZE, HIDDEN_DIM).to(device)
else:
    fixed_input = BATCH_SIZE

if rank == world_size - 1:
   
    fixed_target = torch.randint(0, 2, (BATCH_SIZE,)).to(device)
else:
    fixed_target = None

start_time = time.time()
model.train()
for step in range(STEPS):
    optimizer.zero_grad()
    if model.is_last:
        loss = naive_pipeline_step(
            model, comms, fixed_input, fixed_target, HIDDEN_DIM, device
        )
    else:
       
        naive_pipeline_step(
            model, comms, fixed_input, fixed_target, HIDDEN_DIM, device
        )

    optimizer.step()
    if rank == world_size - 1 and step % 5 == 0:
        print(f"Step {step:02d} | Loss: {loss.item():.6f}")


if rank == world_size - 1:
    print("--- Training Complete ---")
    duration = time.time() - start_time
    print(f"Final Loss: {loss.item():.6f} | Time: {duration:.3f}s")
torch.distributed.destroy_process_group()

In [None]:
print("hi")

In [None]:
!torchrun --nproc_per_node=2 trainer.py