In [0]:
%pip install pytorch

In [0]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
from torch.distributed.rpc import RRef

In [0]:
# Define model layers separately for parallel execution
class ParallelLayer(nn.Module):
    """A single layer that will be distributed across processes"""
    def __init__(self, input_size, output_size):
        super(ParallelLayer, self).__init__()
        self.layer = nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.layer(x)

# Model split into multiple layers across different workers
class TensorParallelModel(nn.Module):
    def __init__(self, workers):
        super(TensorParallelModel, self).__init__()
        self.workers = workers  # List of workers (process names)

        # Create Remote References (RRef) for each layer assigned to a worker
        self.layer1_rref = rpc.remote(self.workers[0], ParallelLayer, args=(10, 50))
        self.layer2_rref = rpc.remote(self.workers[1], ParallelLayer, args=(50, 2))

    def forward(self, x):
        # Forward pass split across workers
        x = self.layer1_rref.rpc_sync().forward(x)
        x = self.layer2_rref.rpc_sync().forward(x)
        return x

In [0]:
# Worker function that initializes an RPC server
def worker(rank, world_size):
    rpc.init_rpc(f"worker_{rank}", rank=rank, world_size=world_size, backend=rpc.BackendType.TENSORPIPE)

    # Wait for tasks
    rpc.shutdown()

In [0]:
# Master function (main training loop)
def main():
    world_size = 3  # Two worker processes + one master process

    # Start worker processes
    mp.spawn(worker, args=(world_size,), nprocs=world_size - 1, join=False)

    # Initialize RPC for master process
    rpc.init_rpc("master", rank=world_size - 1, world_size=world_size, backend=rpc.BackendType.TENSORPIPE)

    # Create model with parallel layers assigned to workers
    workers = ["worker_0", "worker_1"]
    model = TensorParallelModel(workers)

    # Create dummy dataset
    x = torch.randn(32, 10)
    y = torch.randint(0, 2, (32,))

    # Define optimizer and loss function
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=0.001)

    # Training loop
    for epoch in range(5):
        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch+1}: Loss = {loss.item()}")

    # Cleanup
    rpc.shutdown()

In [0]:
if __name__ == "__main__":
    main()