In [None]:

import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP


In [None]:

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    backend = "nccl" if torch.cuda.is_available() else "gloo"
    dist.init_process_group(backend, rank=rank, world_size=world_size)

def cleanup():
    if dist.is_initialized():
        dist.destroy_process_group()


In [None]:

def get_model(rank):
    model = torchvision.models.resnet18(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, 10)  # CIFAR10 has 10 classes

    device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    if dist.is_initialized():
        model = DDP(model, device_ids=[rank] if torch.cuda.is_available() else None)

    return model, device


In [None]:

def get_dataloader(rank, world_size, batch_size=64):

    transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])

    train_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform)

    sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)               if world_size > 1 else None

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=sampler,
        shuffle=(sampler is None)
    )

    return train_loader, sampler


In [None]:

def train(rank, world_size, return_dict):

    if world_size > 1:
        setup(rank, world_size)

    model, device = get_model(rank)
    train_loader, sampler = get_dataloader(rank, world_size)

    optimizer = optim.SGD(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    start_time = time.time()

    model.train()
    for epoch in range(1):  # 1 epoch for demo
        if sampler:
            sampler.set_epoch(epoch)

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    total_time = time.time() - start_time

    if rank == 0:
        return_dict["time"] = total_time

    if world_size > 1:
        cleanup()


In [None]:

def run_experiment(world_size):

    manager = mp.Manager()
    return_dict = manager.dict()

    if world_size > 1:
        mp.spawn(train, args=(world_size, return_dict), nprocs=world_size, join=True)
    else:
        train(0, 1, return_dict)

    return return_dict["time"]


if __name__ == "__main__":
    print("Running single process baseline...")
    t1 = run_experiment(1)
    print("Time (1 process):", t1)

    if torch.cuda.device_count() > 1:
        world_size = torch.cuda.device_count()
        print(f"Running distributed with {world_size} processes...")
        tN = run_experiment(world_size)
        print("Time (N processes):", tN)

        speedup = t1 / tN
        print("Speedup:", speedup)
