# 1. Config

In [None]:
from pathlib import Path
import torch
import torch.nn as nn
from dataclasses import dataclass


@dataclass
class Cifar10_Config:
    # dataset config
    cifar10_path = "path/to/cifar-10-batches-py"
    train_batch_size = 128
    eval_batch_size = 256
    num_workers = 4
    img_h = 32
    img_w = 32

    # training config
    device = "cuda" if torch.cuda.is_available() else "cpu"
    start_epoch = 0
    num_epochs = 200
    lr = 2e-4
    output_dir = "path/to/outputs/"
    save_image_epochs = 20
    save_model_epochs = 20
    save_generated_imgs_interval = 10
    grid_rows = 16
    grid_cols = 16
    resume = None
    beta1 = 0.9
    beta2 = 0.999
    warmup = 5000

    # model config
    in_channels = 3
    hid_channels = 128
    out_channels = 3
    ch_multipliers = [1, 2, 2, 2]
    num_res_blocks = 2
    apply_attn = [False, True, False, False]
    time_embedding_dim = 512
    drop_rate = 0.1
    resample_with_conv = True
    beta_start = 1e-4
    beta_end = 2e-2


cifar10_train_config = Cifar10_Config()
device = cifar10_train_config.device
print("running on ", device)

# 2. Training

In [None]:
from tqdm import tqdm
from torchinfo import summary
from Dataloader import get_cifar10_dataloader
from UNet import get_unet_model
from Functions import create_images_grid
from DDPM import DDPMScheduler

In [None]:
# init the mode and dataloader
model = get_unet_model(cifar10_train_config).to(device)
dataloader = get_cifar10_dataloader(cifar10_train_config)

# model summary
input = torch.randn(cifar10_train_config.train_batch_size, cifar10_train_config.img_channels, cifar10_train_config.img_h, cifar10_train_config.img_w).to(device)
t = torch.randint(0, cifar10_train_config.num_timesteps, (cifar10_train_config.train_batch_size,)).to(device)
summary(
    model,
    input_data=(input, t),
    col_names=["input_size", "output_size", "num_params"],
    row_settings=["var_names"],
)

In [None]:
# initialize the optimizer, lr scheduler, criterion, DDPM scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=cifar10_train_config.lr, betas=(cifar10_train_config.beta1, cifar10_train_config.beta2))
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=cifar10_train_config.num_epochs * len(dataloader),
    eta_min=1e-9,
    last_epoch=-1,
)
# lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
#         optimizer, lr_lambda=lambda t: min((t + 1) / cifar10_train_config.warmup, 1.0)
#     ) if cifar10_train_config.warmup > 0 else None

criterion = torch.nn.MSELoss()
ddpm_scheduler = DDPMScheduler(
    num_timesteps=cifar10_train_config.num_timesteps,
    beta_start=cifar10_train_config.beta_start,
    beta_end=cifar10_train_config.beta_end,
)

In [None]:
if cifar10_train_config.resume:
    checkpoint = torch.load(cifar10_train_config.resume, map_location='cpu')
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
    cifar10_train_config.start_epoch = checkpoint['epoch'] + 1

print("start epoch: ", cifar10_train_config.start_epoch)

In [None]:
def generation(
    cfg: dataclass,
    epoch: int,
    Scheduler: DDPMScheduler,
    model: nn.Module,
):
    # Perform reverse diffusion process with noisy images.
    xt = torch.randn(
        cfg.eval_batch_size,
        cfg.in_channels,
        cfg.img_h,
        cfg.img_w
    ).to(cfg.device)

    # Reverse diffusion for T timesteps
    with torch.no_grad():
        progress_bar = tqdm(range(Scheduler.num_timesteps - 1, -1, -1), desc="Generating images", total=Scheduler.num_timesteps)
        for t in progress_bar:
            ts = t * torch.ones((xt.shape[0], ), dtype=torch.long)
            noise_pred = model(xt, ts)
            xt, x0 = Scheduler.sample_prev(xt, noise_pred, ts)

            if t % cfg.save_generated_imgs_interval == 0:
                generated_images = (xt + 1) / 2
                generated_images = torch.clamp(generated_images, 0.0, 1.0)
                generated_images = generated_images.detach().cpu()
                generated_images = generated_images.permute(0, 2, 3, 1).numpy()
                generated_images = (generated_images * 255).astype("uint8")
    
                image_grid = create_images_grid(generated_images, rows=cfg.grid_rows, cols=cfg.grid_cols)
                grid_save_dir = Path(cfg.output_dir, "samples", f"epoch_{epoch}")
                grid_save_dir.mkdir(parents=True, exist_ok=True)
                image_grid.save(Path(grid_save_dir, f"generated_images_{t}.png"))

        progress_bar.close()

In [None]:
# define the training process
def train_evaluate(
    cfg: dataclass,
    model: nn.Module,
    dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    ddpm_scheduler: DDPMScheduler,
    train: bool = True,
    evaluate: bool = True,
):
    global_step = cfg.start_epoch * len(dataloader)
    best_eval_loss = float("inf")

    for epoch in range(cfg.start_epoch, cfg.num_epochs):
        progress_bar = tqdm(total=len(dataloader))
        progress_bar.set_description(f"Epoch {epoch}")

        if train:
            model.train()

            mean_loss = 0.0

            for step, (imgs, _) in enumerate(dataloader):
                imgs = imgs.to(device)
                train_batch_size = imgs.shape[0]

                # sample a random timestep for each image
                t = torch.randint(
                    0, ddpm_scheduler.num_timesteps, (train_batch_size,)
                ).long()

                # apply forward diffusion process at the given timestep
                noisy_imgs, noise = ddpm_scheduler.forward_process(imgs, t)

                # predict the noise
                noise_pred = model(noisy_imgs, t)

                # compute the loss
                loss = criterion(noise_pred, noise)
                mean_loss = mean_loss + (loss.detach().item() - mean_loss) / (step + 1)
                best_eval_loss = min(best_eval_loss, mean_loss)

                # backpropagation
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

                progress_bar.update(1)
                logs = {"loss": mean_loss, "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
                progress_bar.set_postfix(**logs)
                global_step += 1

            progress_bar.close()
        
        # evaluate
        if evaluate:
            if (epoch + 1) % cfg.save_image_epochs == 0 or epoch == cfg.num_epochs - 1:
                model.eval()
                generation(cfg, epoch, ddpm_scheduler, model)

        if (epoch + 1) % cfg.save_model_epochs == 0 or epoch == cfg.num_epochs - 1:
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'parameters': cfg,
                'epoch': epoch
            }
            torch.save(checkpoint, Path(cfg.output_dir,
                                        f"unet{cfg.img_h}_{cfg.img_w}_e{epoch}.pth"))

In [None]:
train_evaluate(
    cfg=cifar10_train_config,
    model=model,
    dataloader=dataloader,
    optimizer=optimizer,
    criterion=criterion,
    ddpm_scheduler=ddpm_scheduler,
    train=True,
    evaluate=True
)