In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, DistributedSampler
import torch.distributed as dist
import os

In [None]:
def main():
    # Initialize distributed training environment
    dist.init_process_group(backend='nccl', init_method='env://')  # number of GPUs
    local_rank = torch.distributed.get_rank()
    torch.cuda.set_device(local_rank)  # instructing the current process to work with the GPU indexed by local_rank.
    device = torch.device("cuda", local_rank)

    # Data transforms and loading
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    train_sampler = DistributedSampler(train_dataset)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True, sampler=train_sampler)

    # Model setup
    model = models.resnet18(pretrained=True)
    model = model.to(device)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    # Training loop
    model.train()
    for epoch in range(10):  # loop over the dataset multiple times
        train_sampler.set_epoch(epoch)
        for i, (inputs, labels) in enumerate(train_loader, 0):
            inputs, labels = inputs.to(device), labels.to(device)
            
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            if i % 100 == 0:
                print(f'Epoch {epoch}, Batch {i}, Loss {loss.item()}')
                
    print('Finished Training')

if __name__ == '__main__':
    main()
