In [7]:
import os
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

# Data Preparation

This dataset generates size number of samples, where each sample is a tuple:



*   A 20-dimensional input tensor (torch.rand(20))
*   A 1-dimensional target/output tensor (torch.rand(1))





It's useful for testing models quickly without needing real data.

In [2]:
class MyTrainDataset(Dataset):
    def __init__(self, size):
        self.size = size
        self.data = [(torch.rand(20), torch.randint(0, 2, (1,)).item()) for _ in range(size)]  # binary classification

    def __len__(self):
        return self.size

    def __getitem__(self, index):
        x, y = self.data[index]
        return x, y

# Single GPU/ Non- distributed Training

The single-GPU training section defines a simple training workflow that runs on a single device—either CPU or a single CUDA-enabled GPU. It includes a custom dataset class (MyTrainDataset) that generates random input-target pairs for binary classification. A Trainer class encapsulates the training logic: it moves data and models to the selected device, performs forward and backward passes, computes binary cross-entropy loss, and updates weights using stochastic gradient descent (SGD). A checkpoint is saved periodically to disk to preserve model state. This setup is ideal for debugging or small-scale training without needing distributed processing, and it’s straightforward to run in notebooks or local machines with limited GPU availability.

In [3]:
class Trainer:
    def __init__(self, model, train_data, optimizer, device, save_every):
        self.device = device
        self.model = model.to(device)
        self.train_data = train_data
        self.optimizer = optimizer
        self.save_every = save_every

    def _run_batch(self, source, targets):
        self.optimizer.zero_grad()
        output = self.model(source)
        loss = F.binary_cross_entropy_with_logits(output.squeeze(), targets.float())
        loss.backward()
        self.optimizer.step()

    def _run_epoch(self, epoch):
        b_sz = len(next(iter(self.train_data))[0])
        print(f"[{self.device}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
        for source, targets in self.train_data:
            source = source.to(self.device)
            targets = targets.to(self.device)
            self._run_batch(source, targets)

    def _save_checkpoint(self, epoch):
        PATH = "checkpoint.pt"
        torch.save(self.model.state_dict(), PATH)
        print(f"Epoch {epoch} | Checkpoint saved at {PATH}")

    def train(self, max_epochs):
        for epoch in range(max_epochs):
            self._run_epoch(epoch)
            if epoch % self.save_every == 0:
                self._save_checkpoint(epoch)


Defines a training loop for standard (non-distributed) PyTorch training. This:



*   Moves data and model to GPU/CPU (device)
*   Calculates loss
*   Performs backprop and optimizer step
*   Saves checkpoint every few epochs

In [4]:
# Utility functions
def load_train_objs():
    train_set = MyTrainDataset(2048)
    model = torch.nn.Linear(20, 1)  # Output logit for binary classification
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    return train_set, model, optimizer

def prepare_dataloader(dataset, batch_size):
    return DataLoader(dataset, batch_size=batch_size, pin_memory=True, shuffle=True)

In [5]:
# Run training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
total_epochs = 5
save_every = 1
batch_size = 32

dataset, model, optimizer = load_train_objs()
train_data = prepare_dataloader(dataset, batch_size)
trainer = Trainer(model, train_data, optimizer, device, save_every)
trainer.train(total_epochs)

[cuda] Epoch 0 | Batchsize: 32 | Steps: 64
Epoch 0 | Checkpoint saved at checkpoint.pt
[cuda] Epoch 1 | Batchsize: 32 | Steps: 64
Epoch 1 | Checkpoint saved at checkpoint.pt
[cuda] Epoch 2 | Batchsize: 32 | Steps: 64
Epoch 2 | Checkpoint saved at checkpoint.pt
[cuda] Epoch 3 | Batchsize: 32 | Steps: 64
Epoch 3 | Checkpoint saved at checkpoint.pt
[cuda] Epoch 4 | Batchsize: 32 | Steps: 64
Epoch 4 | Checkpoint saved at checkpoint.pt


# Multi-GPU

The multi-GPU training section leverages PyTorch's DistributedDataParallel (DDP) framework to scale training across multiple GPUs in parallel. The setup begins with a ddp_setup() function that initializes the process group and assigns each training process to a dedicated GPU. A DDP-compatible Trainer class is then used, which wraps the model with torch.nn.parallel.DistributedDataParallel and distributes data loading using DistributedSampler, ensuring that each process sees a unique shard of the dataset. The set_epoch() call ensures data is reshuffled differently at each epoch across processes. Training proceeds independently on each GPU, with synchronization handled by DDP under the hood. Checkpoints are saved only by the process with rank=0 to avoid race conditions. This structure enables efficient training of large models on machines with multiple GPUs, such as AWS EC2 instances like p3.8xlarge.

If your model contains any BatchNorm layers, it needs to be converted to SyncBatchNorm to sync the running stats of BatchNorm layers across replicas.

Use the helper function torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) to convert all BatchNorm layers in the model to SyncBatchNorm.

In [8]:
# --- DDP Setup ---
def ddp_setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    torch.cuda.set_device(rank)
    init_process_group(backend="nccl", rank=rank, world_size=world_size)

In [9]:
# --- Trainer ---
class Trainer:
    def __init__(self, model, train_data, optimizer, gpu_id, save_every):
        self.gpu_id = gpu_id
        self.model = model.to(gpu_id)
        self.train_data = train_data
        self.optimizer = optimizer
        self.save_every = save_every
        self.model = DDP(model, device_ids=[gpu_id])

    def _run_batch(self, source, targets):
        self.optimizer.zero_grad()
        output = self.model(source).squeeze()
        loss = F.binary_cross_entropy_with_logits(output, targets.float())
        loss.backward()
        self.optimizer.step()

    def _run_epoch(self, epoch):
        b_sz = len(next(iter(self.train_data))[0])
        print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
        self.train_data.sampler.set_epoch(epoch) # call this additional line at every epoch
        for source, targets in self.train_data:
            source = source.to(self.gpu_id)
            targets = targets.to(self.gpu_id)
            self._run_batch(source, targets)

    def _save_checkpoint(self, epoch):
        if self.gpu_id == 0:
            ckp = self.model.module.state_dict()
            PATH = "checkpoint.pt"
            torch.save(ckp, PATH)
            print(f"Epoch {epoch} | Checkpoint saved at {PATH}")

    def train(self, max_epochs):
        for epoch in range(max_epochs):
            self._run_epoch(epoch)
            if epoch % self.save_every == 0:
                self._save_checkpoint(epoch)

Calling the set_epoch() method on the DistributedSampler at the beginning of each epoch is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be used in each epoch.

In [10]:
# --- Launch ---
def train_ddp(rank, world_size, total_epochs, save_every, batch_size):
    ddp_setup(rank, world_size)
    dataset = MyTrainDataset(2048)
    sampler = DistributedSampler(dataset)
    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, pin_memory=True)

    model = torch.nn.Linear(20, 1)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

    trainer = Trainer(model, train_loader, optimizer, gpu_id=rank, save_every=save_every)
    trainer.train(total_epochs)
    destroy_process_group()

DistributedSampler chunks the input data across all distributed processes.

The DataLoader combines a dataset and a
sampler, and provides an iterable over the given dataset.

Each process will receive an input batch of 32 samples; the effective batch size is 32 * nprocs, or 128 when using 4 GPUs.

In [11]:
# --- Config ---
world_size = 1  # Set to >1 for multi-GPU when running on AWS
total_epochs = 3
save_every = 1
batch_size = 32

world_size is the number of processes across the training job. For GPU training, this corresponds to the number of GPUs in use, and each process works on a dedicated GPU.

In [13]:
train_ddp(rank=0, world_size=1, total_epochs=3, save_every=1, batch_size=32)

[GPU0] Epoch 0 | Batchsize: 32 | Steps: 64
Epoch 0 | Checkpoint saved at checkpoint.pt
[GPU0] Epoch 1 | Batchsize: 32 | Steps: 64
Epoch 1 | Checkpoint saved at checkpoint.pt
[GPU0] Epoch 2 | Batchsize: 32 | Steps: 64
Epoch 2 | Checkpoint saved at checkpoint.pt


We only need to save model checkpoints from one process. Without this condition, each process would save its copy of the identical mode.

In [None]:
# --- Spawn (when you have access to multiple GPUs. Else use the command above.) ---
import torch.multiprocessing as mp #PyTorch wrapper around Python’s native multiprocessing
mp.spawn(
    train_ddp,
    args=(world_size, total_epochs, save_every, batch_size),
    nprocs=world_size
)

Colab does not fully support torch.multiprocessing.spawn(), especially when it tries to create new Python processes inside a notebook cell.

This happens even when world_size=1 because spawn() still starts a subprocess, which doesn’t work reliably in Colab's Jupyter runtime.

# Fault Tolerance in Distributed Training with torchrun

The fault-tolerant training section builds on the DDP foundation and introduces support for elastic training via torchrun and checkpoint-based resumption. Instead of manually specifying ranks, it reads the LOCAL_RANK environment variable set by torchrun, enabling seamless integration with PyTorch's built-in launcher. This version of the Trainer class supports snapshot loading, allowing training to resume from the last saved epoch in the event of a failure or preemption. The snapshot includes both the model's state and the number of epochs already completed. By restoring from these checkpoints, training jobs can continue without restarting from scratch. This approach is critical in cloud environments where interruptions are common or where long training jobs need robust recovery mechanisms.

PyTorch offers a utility called torchrun that provides fault-tolerance and elastic training. When a failure occurs, torchrun logs the errors and attempts to automatically restart all the processes from the last saved “snapshot” of the training job.

In [14]:
# --- DDP Setup ---
def ddp_setup():
    local_rank = int(os.environ.get("LOCAL_RANK", 0))  # default to 0 for Colab
    torch.cuda.set_device(local_rank)
    init_process_group(backend="nccl")

In [15]:
# --- Trainer ---
class Trainer:
    def __init__(self, model, train_data, optimizer, save_every, snapshot_path):
        self.gpu_id = int(os.environ.get("LOCAL_RANK", 0))
        self.model = model.to(self.gpu_id)
        self.train_data = train_data
        self.optimizer = optimizer
        self.save_every = save_every
        self.epochs_run = 0
        self.snapshot_path = snapshot_path
        if os.path.exists(snapshot_path):
            print("Loading snapshot...")
            self._load_snapshot(snapshot_path)
        self.model = DDP(self.model, device_ids=[self.gpu_id])

    def _load_snapshot(self, snapshot_path):
        loc = f"cuda:{self.gpu_id}"
        snapshot = torch.load(snapshot_path, map_location=loc)
        self.model.load_state_dict(snapshot["MODEL_STATE"])
        self.epochs_run = snapshot["EPOCHS_RUN"]
        print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

    def _run_batch(self, source, targets):
        self.optimizer.zero_grad()
        output = self.model(source).squeeze()
        loss = F.binary_cross_entropy_with_logits(output, targets.float())
        loss.backward()
        self.optimizer.step()

    def _run_epoch(self, epoch):
        b_sz = len(next(iter(self.train_data))[0])
        print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
        self.train_data.sampler.set_epoch(epoch)
        for source, targets in self.train_data:
            source = source.to(self.gpu_id)
            targets = targets.to(self.gpu_id)
            self._run_batch(source, targets)

    def _save_snapshot(self, epoch):
        if self.gpu_id == 0:
            snapshot = {
                "MODEL_STATE": self.model.module.state_dict(),
                "EPOCHS_RUN": epoch,
            }
            torch.save(snapshot, self.snapshot_path)
            print(f"Epoch {epoch} | Snapshot saved to {self.snapshot_path}")

    def train(self, max_epochs):
        for epoch in range(self.epochs_run, max_epochs):
            self._run_epoch(epoch)
            if self.gpu_id == 0 and epoch % self.save_every == 0:
                self._save_snapshot(epoch)

In [18]:
# --- Run Directly in Notebook ---
# CONFIGURE HERE
total_epochs = 5
save_every = 1
batch_size = 32
snapshot_path = "snapshot.pt"

In [None]:
# Run Training
ddp_setup()
dataset, model, optimizer = MyTrainDataset(2048), torch.nn.Linear(20, 1), torch.optim.SGD(torch.nn.Linear(20, 1).parameters(), lr=1e-3)
train_data = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=DistributedSampler(dataset), pin_memory=True)
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
trainer.train(total_epochs)
destroy_process_group()