In [1]:
import numpy as np
from numpy.random import randn as randn
import matplotlib.pyplot as plt
from abc import ABC, abstractmethod
from torchvision.datasets import MNIST, FashionMNIST
from torch.utils.data import DataLoader
from torchvision.transforms import transforms as trns
import torch
import torch.nn as nn
import torch.optim as optim
from schedulefree.adamw_schedulefree import AdamWScheduleFree
from pytorch_msssim import ssim
from PIL import Image
from tqdm import tqdm, trange
import os

from unet_model import UNet
from unet_model import build_xy_coordinates, sinusoidal_2d
DEVICE = "cuda"

In [2]:
pos_emb_dim = 16
xs, ys = build_xy_coordinates(28,28)
pos = sinusoidal_2d(xs, ys, dim=pos_emb_dim).to(DEVICE)

In [3]:
def train(epoch, model, optimizer, loss, dataloader, train_with_cond=False, T = 32):
    model.train()
    ema_loss = 0
    rng_state = torch.get_rng_state()
    torch.manual_seed(0)
    
    with tqdm(dataloader, desc=f"Epoch {epoch}", smoothing=0.01, disable=True) as pbar:

        for i, (img, cond) in enumerate(dataloader):
            # img ~ Image, z ~ N

            img = img.to(DEVICE)
            cond = cond.to(DEVICE).long()
            z = torch.randn_like(img).to(DEVICE)
            
            #### global ####
                        
            # t ~ global slices :
            #   (training)
            #   img_t = img_0 + (z - img_0) * t
            #   model_gloabl(img_t, t) = img_0 - z
            #   (generation)
            #   img_t-dt = img_t + model_global(img_t, t) * dt
            
            t = torch.sigmoid(torch.randn(img.shape[0])).to(DEVICE)
            img_t = img * (1 - t.view(-1, 1, 1, 1)) + z * t.view(-1, 1, 1, 1)
            target = img - z

            pos_exp = pos.expand(img.shape[0], -1, -1, -1) 
            img_emb = torch.cat([img_t, pos_exp], dim=1)

            optimizer.zero_grad()
            y = model(img_emb, t, cond = cond if train_with_cond else None).view(img.shape)
            eps_pred = img_t - (1 - t.view(-1, 1, 1, 1)) * y
            img0_pred = img_t + t.view(-1, 1, 1, 1) * y

            l = loss(y, target) + loss(eps_pred, z) + loss(img0_pred, img)
            l.backward()
            # all_grads = []
            # for param in model.parameters():
            #     if param.grad is not None:
            #         all_grads.append(param.grad.view(-1))  # Flatten and collect

            # # Compute average and max gradient
            # if all_grads:
            #     all_grads = torch.cat(all_grads)  # Concatenate all gradients into one tensor
            #     avg_grad = all_grads.mean().item()  # Compute mean
            #     max_grad = all_grads.abs().max().item()  # Compute max absolute gradient
            #     print(f"Epoch {epoch+1}: Avg Gradient = {avg_grad:.6f}, Max Gradient = {max_grad:.6f}")

            optimizer.step()
            
            ema_decay = min(0.99, i / 100)
            ema_loss = ema_decay * ema_loss + (1 - ema_decay) * l.item()
            
            pbar.update(1)
            pbar.set_postfix({"loss": ema_loss})
    
    torch.save(model.state_dict(), "./mnist-gen/mnist-gen-global.pth")


def test(epoch, model, gen_with_cond=False, T=32, save_dir='./unet/mnist-results_alphablend_test'):
    rng_state = torch.get_rng_state()
    torch.manual_seed(0)
    model.eval()
    IMAGE_COUNT = 16 * 16
    with torch.no_grad():
        pred_x = torch.randn(IMAGE_COUNT, 1, 28, 28).to(DEVICE)
        cond = torch.arange(IMAGE_COUNT).long().to(DEVICE) % 10
        t = torch.ones(IMAGE_COUNT,).to(DEVICE)
        pos_exp = pos.expand(IMAGE_COUNT, -1, -1, -1) 
        
        STEPS = T
        for i in range(STEPS):
            pred_emb = torch.cat([pred_x, pos_exp], dim=1)
            pred = model(pred_emb, t, cond = cond if gen_with_cond else None)
            pred_x = pred_x + pred * (1 / STEPS)
            t = t - (1 / STEPS)


    #print(torch.min(pred), torch.max(pred))
    pred_x = pred_x.reshape(16, 16, 28, 28).permute(0, 2, 1, 3)
    pred_x = pred_x.reshape(16 * 28, 16 * 28).cpu().numpy()
    pred_x = (pred_x * 255).clip(0, 255).astype(np.uint8)
    pred_x = Image.fromarray(pred_x)
    pred_x.save(f"{save_dir}/gen-{epoch}.png")
    torch.set_rng_state(rng_state)

In [5]:
if __name__ == "__main__":
    save_dir = './unet/mnist-results_alphablend_test'
    os.makedirs(save_dir, exist_ok=True)
    DEVICE = "cuda"
    CLASSES = None
    EPOCHS = 15
    T = 32
    # transform = trns.Compose([trns.ToTensor(), trns.Normalize((0.5,), (0.5,))])
    transform = trns.Compose([trns.ToTensor(),])
    dataset = MNIST("./data", download=True, transform=transform)
    dataloader = DataLoader(
        dataset,
        batch_size=512,
        shuffle=True,
        num_workers=1,
        pin_memory=True,
        persistent_workers=True,
    )

    model = UNet(in_channels=1 + pos_emb_dim*2, out_channels=1, emb_dim = 512, num_classes=CLASSES).to(DEVICE)
    # print(sum(p.numel() for p in model.parameters()) / 1e6)
    optimizer = AdamWScheduleFree(
        model.parameters(), 1e-3, weight_decay=0.001, warmup_steps=100
    )
    loss = nn.MSELoss()

    for i in trange(EPOCHS):
        test(i, model, gen_with_cond=bool(CLASSES), T=T, save_dir=save_dir)
        optimizer.train()
        train(i, model, optimizer, loss, dataloader, train_with_cond=bool(CLASSES), T=T)
        optimizer.eval()
    test(i + 1, model, gen_with_cond=bool(CLASSES), T=T)

100%|██████████| 15/15 [03:11<00:00, 12.76s/it]
