In [None]:
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()


In [None]:
def main(rank, world_size):
    setup(rank, world_size)
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # Load CIFAR10 dataset
    train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             download=True, transform=transform)
    
    # Create a DistributedSampler to shard the dataset
    train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank)   # world_size = number of GPUs
    
    train_loader = DataLoader(train_set, batch_size=64, sampler=train_sampler)  # fetches the respective subset for each GPU.
    
    # Define your model and move it to the GPU
    model = resnet18().to(rank)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    
    # Training loop
    for epoch in range(10):
        train_sampler.set_epoch(epoch)
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(rank), labels.to(rank)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            if i % 100 == 0:
                print(f'Rank {rank}, Epoch {epoch}, Batch {i}, Loss {loss.item()}')
    
    cleanup()


In [None]:
import torch.multiprocessing as mp

if __name__ == "__main__":
    world_size = 2  # Number of processes to spawn. Adjust according to your GPU count.
    mp.spawn(main, args=(world_size,), nprocs=world_size, join=True)