Single GPU Training Loop

In [None]:
import torch
from torch.utils.data import DataLoader

def train_single_gpu(model, train_dataset, learning_rate, epochs, device):
    model.to(device)
    model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = torch.nn.CrossEntropyLoss()

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

    for epoch in range(epochs):
        for batch in train_loader:
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")


Distributed Data Parallel (DDP):

In [None]:
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def train_ddp(model, train_dataset, learning_rate, epochs, rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    model.to(rank)
    model = DDP(model, device_ids=[rank])

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = torch.nn.CrossEntropyLoss()

    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False, sampler=train_sampler)

    for epoch in range(epochs):
        train_sampler.set_epoch(epoch)
        for batch in train_loader:
            inputs, targets = batch
            inputs, targets = inputs.to(rank), targets.to(rank)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            if rank == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

    dist.destroy_process_group()


Fully Sharded Data Parallel (FSDP)

In [None]:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

def train_fsdp(model, train_dataset, learning_rate, epochs, rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    model = FSDP(model).to(rank)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = torch.nn.CrossEntropyLoss()

    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False, sampler=train_sampler)

    for epoch in range(epochs):
        train_sampler.set_epoch(epoch)
        for batch in train_loader:
            inputs, targets = batch
            inputs, targets = inputs.to(rank), targets.to(rank)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            if rank == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

    dist.destroy_process_group()
