In [2]:
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms, models

def setup(rank, world_size):
    # Only initialize distributed if world_size > 1
    if world_size > 1:
        dist.init_process_group(
            backend="nccl",
            init_method="tcp://127.0.0.1:29500",
            rank=rank,
            world_size=world_size
        )
        torch.cuda.set_device(rank)

def cleanup(world_size):
    # Only destroy process group if world_size > 1
    if world_size > 1:
        dist.destroy_process_group()

def train(rank, world_size):
    print(f"[Rank {rank}] Starting training...")
    setup(rank, world_size)

    # Dataset & Sampler
    transform = transforms.Compose([transforms.ToTensor()])
    dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)

    # Use DistributedSampler only if world_size > 1, otherwise use default sampler
    if world_size > 1:
        sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
        dataloader = DataLoader(dataset, batch_size=64, sampler=sampler)
    else:
        dataloader = DataLoader(dataset, batch_size=64, shuffle=True)


    # Model
    model = models.resnet18(num_classes=10)
    # Modify the first convolutional layer to accept 1 input channel for MNIST
    model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    model = model.cuda(rank) # Ensure model is on the correct GPU

    # Use DDP only if world_size > 1
    if world_size > 1:
        ddp_model = DDP(model, device_ids=[rank])
    else:
        ddp_model = model # Use regular model when world_size is 1


    # Loss & Optimizer
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(ddp_model.parameters(), lr=1e-3)

    start = time.time()
    for epoch in range(2):
        # Set epoch for sampler only if it's a DistributedSampler
        if world_size > 1:
            sampler.set_epoch(epoch)
        epoch_loss = 0.0
        for X, y in dataloader:
            X, y = X.cuda(rank), y.cuda(rank)
            optimizer.zero_grad()
            output = ddp_model(X)
            loss = loss_fn(output, y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        if rank == 0:
            print(f"Epoch {epoch} Loss: {epoch_loss/len(dataloader):.4f}")
    end = time.time()

    if rank == 0:
        print(f"Training completed in {end-start:.2f}s on {world_size} GPU(s)")

    cleanup(world_size)

if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    # Only use mp.spawn if world_size > 1
    if world_size > 1:
        mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
    else:
        # Run training directly for single GPU
        train(0, world_size) # Pass rank 0 and world_size 1

[Rank 0] Starting training...
Epoch 0 Loss: 0.1342
Epoch 1 Loss: 0.0575
Training completed in 44.54s on 1 GPU(s)
