In [0]:
%pip install pytorch

In [0]:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset
from torch.nn.parallel import DistributedDataParallel as DDP

In [0]:
# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 50)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(50, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [0]:
# Training loop
def train(rank, world_size):
    """Training function for each process"""
    # Initialize process group for distributed training
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    # Set device
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")

    # Create dataset (random data for illustration)
    dataset_size = 1000
    x = torch.randn(dataset_size, 10)
    y = torch.randint(0, 2, (dataset_size,))
    dataset = TensorDataset(x, y)

    # Use DistributedSampler to split data among processes
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
    dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)

    # Create model, move to device, and wrap with DDP
    model = SimpleModel().to(device)
    model = DDP(model, device_ids=[rank])

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    for epoch in range(5):  # Train for 5 epochs
        sampler.set_epoch(epoch)  # Ensure data shuffling consistency
        for batch_x, batch_y in dataloader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            optimizer.zero_grad()
            outputs = model(batch_x)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()

        if rank == 0:
            print(f"Epoch {epoch+1} completed on rank {rank}")

    # Cleanup
    dist.destroy_process_group()

In [0]:
# Entry point
def main():
    world_size = torch.cuda.device_count()  # Number of GPUs available
    if world_size < 2:
        print("Need at least 2 GPUs for DistributedDataParallel")
        return

    # Use multiprocessing to launch training processes
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

In [0]:
if __name__ == "__main__":
    main()