In [1]:
from __future__ import annotations
from datasets import load_dataset
import numpy as np
import random
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm, trange
from collections import deque
from pathlib import Path
from model import UNet
from diffusion import GaussianDiffusionTrainer, GaussianDiffusionSampler

In [2]:
dataset = load_dataset("uoft-cs/cifar10", split="train")
dataset

Dataset({
    features: ['img', 'label'],
    num_rows: 50000
})

In [3]:
images = dataset["img"]
images = torch.stack([torch.tensor(np.array(img)) for img in images])
images = images.permute(0, 3, 1, 2) / 255.0 * 2 - 1
print(f"{images.shape=} {images.dtype=}")

images.shape=torch.Size([50000, 3, 32, 32]) images.dtype=torch.float32


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [5]:
T = 1000
ch = 128
ch_mult = [1, 2, 2, 2]
attn = [1]
num_res_blocks = 2
dropout = 0.1
lr = 1e-4

beta_1 = 1e-4
beta_T = 0.02

def warmup_lr(step):
    warmup = 5000
    return min(step, warmup) / warmup

# net输入是一幅图像x_t, 输出是x_{t-1}到x_t的噪声?
net = UNet(
    T=T, ch=ch, ch_mult=ch_mult, attn=attn,
    num_res_blocks=num_res_blocks, dropout=dropout)
net = net.to(device)

optim = torch.optim.Adam(net.parameters(), lr=lr)
sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr)
trainer = GaussianDiffusionTrainer(net, beta_1, beta_T, T).to(device)

In [None]:
batch_size = 32
total_steps = 800000
def infiniteloop(dataloader):
    while True:
        for x in iter(dataloader):
            yield x

dataloader = DataLoader(images, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
datalooper = infiniteloop(dataloader)

with trange(total_steps, dynamic_ncols=True) as pbar:
    recent_losses = deque(maxlen=1000)
    for step in pbar:
        # train
        optim.zero_grad()
        x_0 = next(datalooper).to(device)
        loss = trainer(x_0).mean()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip)
        optim.step()
        sched.step()
        
        recent_losses.append(loss.item())
        pbar.set_postfix(loss='%.3f' % np.mean(recent_losses))
        
        if step % 10000 == 0:
            ckpt_path = f"ckpts/cifat10_{step}.pkt"
            Path(ckpt_path).parent.mkdir(exist_ok=True, parents=True)
            torch.save(net, ckpt_path)

 49%|████▉     | 395149/800000 [15:30:11<15:41:08,  7.17it/s, loss=0.029]