In [1]:
import os
import torch
import torch.optim as optim
import torch.distributed as dist
import torch.nn as nn
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
import sys
import time
import matplotlib.pyplot as plt

from ddpm import config as _config
from ddpm.config import cifar10_config
from ddpm.data import get_cifar10_datasets
from ddpm.diffusion_model import DiffusionModel

In [None]:
_config.DEBUG = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("System info: ")
print("Device:", device)
print("Device count", torch.cuda.device_count())
print("GPU Device:", torch.cuda.get_device_name(0))
print("GPU RAM:", f"{(torch.cuda.get_device_properties(0).total_memory / 1e9).__round__(2)} GB")

cifar10_config.res_net_config.initial_pad = 0
batch_size = cifar10_config.batch_size

# max_epochs = 500
max_epochs = 8

each_epochs = max_epochs // torch.cuda.device_count()

learning_rate = 1e-4

In [None]:
# train_loader, test_loader = get_cifar10_dataloaders(batch_size=batch_size)
train_dataset, test_dataset = get_cifar10_datasets()

In [None]:
model = DiffusionModel(cifar10_config)

# Create an EMA model (exact copy of the original model)
model_ema = DiffusionModel(cifar10_config)
model_ema.load_state_dict(model.state_dict())
model_ema.eval()

# Utility function to update EMA weights
def update_ema(model, ema_model, alpha=0.9999):
    """EMA update for each parameter."""
    with torch.no_grad():
        for p, p_ema in zip(model.parameters(), ema_model.parameters()):
            p_ema.data = alpha * p_ema.data + (1 - alpha) * p.data

In [None]:
def train(rank, world_size):
    dist.init_process_group(backend='nccl', world_size=world_size, rank=rank)
    torch.cuda.set_device(rank)
    rank_model = model.to(rank)
    rank_model = nn.parallel.DistributedDataParallel(rank_model, device_ids=[rank])
    rank_model_ema = 

    sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
    
    optimizer = optim.Adam(rank_model.parameters(), lr=learning_rate)

    # CosineAnnealingLR will decay the LR smoothly over max_epochs
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
    
    
    for epoch in range(each_epochs):
        start_time = time.time()
        sampler.set_epoch(epoch)
        for images, labels in dataloader:
            images = images.to(rank)
            labels = labels.to(rank)

            loss = rank_model(images, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            update_ema(rank_model, rank_model_ema)
        scheduler.step()
        
        rank_model_ema.eval()
        end_time = time.time()
        print(f"Rank {rank} Epoch {epoch} took {end_time - start_time:.3f} seconds")

    dist.destroy_process_group()

In [None]:
start_time = time.time()

world_size = torch.cuda.device_count()
mp.spawn(train, args=(world_size,), nprocs=world_size)

time.time() - start_time()