In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR
import torch.distributions as D
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
import torchvision

import numpy as np
import matplotlib.pyplot as plt

import seaborn as sns

sns.set_theme(style="dark")

In [2]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [3]:
def get_default_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    # if torch.backends.mps.is_available():
    # return torch.device("mps")
    return torch.device("cpu")


DEVICE = get_default_device()

In [4]:
from typing import cast

vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(DEVICE)
# vae = cast(AutoencoderKL, torch.compile(vae, mode="max-autotune"))
vae.eval()
if DEVICE == torch.device("cuda"):
    vae = cast(AutoencoderKL, torch.compile(vae, mode="max-autotune"))
pass

In [5]:
transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.RandomCrop(32, padding=4),
        torchvision.transforms.RandomRotation(10),
    ]
)

In [6]:
class CifarDataset(torch.utils.data.Dataset):
    def __init__(self, train: bool = True):
        self.cifar = torchvision.datasets.CIFAR10(
            root="/tmp/cifar", download=True, train=train, transform=transform
        )

    def __len__(self):
        return len(self.cifar)

    def __getitem__(self, idx: int):
        x, y = self.cifar[idx]
        t = torch.rand(1)
        return x, y, t

In [7]:
cifar_train = CifarDataset(train=True)
cifar_test = CifarDataset(train=False)

train_dataloader = torch.utils.data.DataLoader(
    cifar_train, batch_size=512, shuffle=True
)
test_dataloader = torch.utils.data.DataLoader(cifar_test, batch_size=512, shuffle=True)

latent_shape = (
    vae.encode(next(iter(train_dataloader))[0].to(DEVICE), return_dict=False)[0]
    .mean[0]
    .shape
)

(
    next(iter(train_dataloader))[0].shape,
    next(iter(train_dataloader))[1].shape,
    next(iter(train_dataloader))[2].shape,
    latent_shape,
)

(torch.Size([512, 3, 32, 32]),
 torch.Size([512]),
 torch.Size([512, 1]),
 torch.Size([4, 4, 4]))

In [8]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        half_dim = self.dim // 2
        emb = torch.exp(
            -torch.arange(half_dim, device=x.device)
            * torch.log(torch.tensor(10000.0))
            / (half_dim - 1)
        )
        emb = x.unsqueeze(1) * emb.unsqueeze(0)
        return torch.cat([emb.sin(), emb.cos()], dim=-1)


class ResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, groups=8):
        super().__init__()
        self.time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, out_ch))
        self.conv_block = nn.Sequential(
            nn.GroupNorm(groups, in_ch),
            nn.SiLU(),
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(groups, out_ch),
            nn.SiLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
        )
        self.res_conv = (
            nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
        )

    def forward(self, x, t_emb):
        h = self.conv_block(x)
        time_proj = self.time_mlp(t_emb).unsqueeze(-1).unsqueeze(-1)
        h = h + time_proj
        return h + self.res_conv(x)


class UNet(nn.Module):
    def __init__(self, in_ch=4, base_ch=64, channel_mults=(1, 2), time_emb_dim=128):
        super().__init__()
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim * 4),
            nn.SiLU(),
            nn.Linear(time_emb_dim * 4, time_emb_dim),
        )
        self.init_conv = nn.Conv2d(in_ch, base_ch, 3, padding=1)
        down_chs = [base_ch * m for m in channel_mults]
        self.downs = nn.ModuleList()
        prev_ch = base_ch
        for ch in down_chs:
            self.downs.append(ResidualBlock(prev_ch, ch, time_emb_dim))
            prev_ch = ch
        self.bottleneck = ResidualBlock(prev_ch, prev_ch, time_emb_dim)
        self.ups = nn.ModuleList()
        for idx, ch in enumerate(reversed(down_chs)):
            # ch is skip channel
            out_ch = down_chs[-2 - idx] if idx < len(down_chs) - 1 else base_ch
            self.ups.append(
                nn.ModuleDict(
                    {
                        "upsample": nn.ConvTranspose2d(prev_ch, out_ch, 4, 2, 1),
                        "res": ResidualBlock(out_ch + ch, out_ch, time_emb_dim),
                    }
                )
            )
            prev_ch = out_ch
        self.final_conv = nn.Sequential(
            nn.GroupNorm(8, base_ch),
            nn.SiLU(),
            nn.Conv2d(base_ch, in_ch, 1),
        )

    def forward(self, x, t):
        # x: (B, C, H, W), t: (B,)
        t_emb = self.time_mlp(t)
        h = self.init_conv(x)
        skips = []
        for block in self.downs:
            h = block(h, t_emb)
            skips.append(h)
            h = F.avg_pool2d(h, 2)
        h = self.bottleneck(h, t_emb)
        for module in self.ups:
            module = cast(nn.ModuleDict, module)
            h = module["upsample"](h)
            skip = skips.pop()
            h = torch.cat([h, skip], dim=1)
            h = module["res"](h, t_emb)
        # final
        return self.final_conv(h)

    def ode_forward(self, t: torch.Tensor, x: torch.Tensor):
        t_tensor = t.repeat(x.shape[0])
        return self.forward(x, t_tensor)

In [9]:
N_EPOCHS = 200


model = UNet()
model = model.to(DEVICE)
optimizer = AdamW(model.parameters(), lr=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=N_EPOCHS)

ckpt_path = f"../ckpt/{model.__class__.__name__}"
if ckpt_path and False:
    model.load_state_dict(torch.load(f"{ckpt_path}/model.pth", map_location=DEVICE))
    optimizer.load_state_dict(
        torch.load(f"{ckpt_path}/optimizer.pth", map_location=DEVICE)
    )
    scheduler.load_state_dict(
        torch.load(f"{ckpt_path}/scheduler.pth", map_location=DEVICE)
    )

if DEVICE == torch.device("cuda"):
    model = cast(UNet, torch.compile(model, mode="reduce-overhead"))

In [10]:
from torchdiffeq import odeint


@torch.no_grad()
def sample_with_ode(
    model: nn.Module,
    n_samples: int = 500,
    n_steps: int = 25,
):
    model.eval()
    model_device = next(model.parameters()).device
    initial_samples = torch.randn((n_samples, *latent_shape), device=model_device)
    t_span = torch.linspace(0.0, 1.0, n_steps).to(model_device)
    trajectory = odeint(
        model.ode_forward,
        initial_samples,
        t_span,
        method="euler",
        atol=1e-5,
        rtol=1e-5,
    )
    trajectory = cast(torch.Tensor, trajectory)

    return trajectory[-1]

In [11]:
from torchmetrics.image.fid import FrechetInceptionDistance
import os


def to_latent(vae: AutoencoderKL, x_raw: torch.Tensor) -> torch.Tensor:
    return vae.encode(x_raw, return_dict=False)[0].mean


def from_latent(vae: AutoencoderKL, z: torch.Tensor) -> torch.Tensor:
    return vae.decode(cast(torch.FloatTensor, z), return_dict=False)[0]

In [12]:
from tqdm.auto import tqdm


def train(
    model: nn.Module,
    dataloader: DataLoader,
    val_dataloader: DataLoader,
    n_epochs: int,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler.CosineAnnealingLR,
    verbose: bool = False,
    contrastive_flow_weight: float = 0.00,
):
    def step(x_raw, t):
        x = to_latent(vae, x_raw)
        noise = torch.randn_like(x, device=DEVICE)
        t_expanded = t.unsqueeze(-1).unsqueeze(-1)
        z = t_expanded * x + (1 - t_expanded) * noise
        target_u = x - noise
        u = model(z, t.squeeze(-1))
        loss = F.mse_loss(u, target_u)
        if contrastive_flow_weight > 0.0:
            u_hat = torch.roll(u, shifts=1, dims=0)
            loss_contrastive = F.mse_loss(u, u_hat)
            loss = loss - contrastive_flow_weight * loss_contrastive
        return loss

    @torch.no_grad()
    def get_fid(nfe: int = 25):
        real_images = next(iter(val_dataloader))[0]
        real_images = (real_images * 255).to(torch.uint8).to(DEVICE)  # BS x 3 x 32 x 32
        z_s = sample_with_ode(model, n_samples=real_images.shape[0], n_steps=nfe)
        generated_images = from_latent(vae, z_s)
        generated_images = (
            (generated_images * 255).to(torch.uint8).to(DEVICE)
        )  # BS x 3 x 32 x 32
        fid = FrechetInceptionDistance().to(DEVICE)
        fid.update(real_images, real=True)
        fid.update(generated_images, real=False)
        return fid.compute()

    def ckpt_callback(epoch: int, val_loss: float):
        ckpt_dir = "../ckpt"
        subdir = f"{ckpt_dir}/{model.__class__.__name__}"
        if not os.path.exists(subdir):
            os.makedirs(subdir)
        torch.save(model.state_dict(), f"{subdir}/model.pth")
        torch.save(optimizer.state_dict(), f"{subdir}/optimizer.pth")
        torch.save(scheduler.state_dict(), f"{subdir}/scheduler.pth")
        print(f"Saved checkpoint at epoch {epoch} with val loss {val_loss:.4f}")

    log_interval = 1
    model.train()
    best_val_loss = float("inf")
    loss_history = []
    val_loss_history = []
    fid_25_history = []
    fid_50_history = []
    for epoch in range(n_epochs):
        losses = []
        for x_raw, _, t in tqdm(dataloader):
            loss = step(x_raw.to(DEVICE), t.to(DEVICE))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

        scheduler.step()

        val_losses = []
        with torch.no_grad():
            for x_raw, _, t in val_dataloader:
                loss = step(x_raw.to(DEVICE), t.to(DEVICE))
                val_losses.append(loss.item())

        current_loss = float(np.mean(losses))
        current_val_loss = float(np.mean(val_losses))

        fid_25 = get_fid(nfe=25)
        fid_50 = get_fid(nfe=50)
        fid_25_history.append(fid_25)
        fid_50_history.append(fid_50)
        if (epoch % log_interval == 0 or epoch == n_epochs - 1) and verbose:
            print(
                f"Epoch {epoch}\t loss: {current_loss:.4f}\t val loss: {current_val_loss:.4f}\t FID_25: {fid_25:.4f}\t FID_50: {fid_50:.4f}"
            )

        loss_history.append(current_loss)
        val_loss_history.append(current_val_loss)

        if current_val_loss < best_val_loss:
            best_val_loss = current_val_loss
            ckpt_callback(epoch, current_val_loss)

    return loss_history, val_loss_history

In [None]:
loss_history, val_loss_history = train(
    model=model,
    dataloader=train_dataloader,
    val_dataloader=test_dataloader,
    n_epochs=N_EPOCHS,
    optimizer=optimizer,
    scheduler=scheduler,
    verbose=True,
)
plt.plot(loss_history)
plt.plot(val_loss_history)
plt.show()

  0%|          | 0/98 [00:00<?, ?it/s]

Epoch 0	 loss: 22.3724	 val loss: 20.1046	 FID_25: 263.0175	 FID_50: 257.9994
Saved checkpoint at epoch 0 with val loss 20.1046


  0%|          | 0/98 [00:00<?, ?it/s]

Epoch 1	 loss: 18.4770	 val loss: 17.3702	 FID_25: 249.6280	 FID_50: 250.5529
Saved checkpoint at epoch 1 with val loss 17.3702


  0%|          | 0/98 [00:00<?, ?it/s]

Epoch 2	 loss: 16.3204	 val loss: 15.7189	 FID_25: 244.5102	 FID_50: 247.2575
Saved checkpoint at epoch 2 with val loss 15.7189


  0%|          | 0/98 [00:00<?, ?it/s]

Epoch 3	 loss: 14.7098	 val loss: 14.1502	 FID_25: 240.9409	 FID_50: 241.9839
Saved checkpoint at epoch 3 with val loss 14.1502


  0%|          | 0/98 [00:00<?, ?it/s]

Epoch 4	 loss: 13.3711	 val loss: 12.8383	 FID_25: 239.0325	 FID_50: 234.9365
Saved checkpoint at epoch 4 with val loss 12.8383


  0%|          | 0/98 [00:00<?, ?it/s]

Epoch 5	 loss: 12.1774	 val loss: 11.7377	 FID_25: 236.4914	 FID_50: 237.3486
Saved checkpoint at epoch 5 with val loss 11.7377


  0%|          | 0/98 [00:00<?, ?it/s]

Epoch 6	 loss: 11.2177	 val loss: 10.8918	 FID_25: 231.2109	 FID_50: 229.3739
Saved checkpoint at epoch 6 with val loss 10.8918


  0%|          | 0/98 [00:00<?, ?it/s]

Epoch 7	 loss: 10.3140	 val loss: 9.9680	 FID_25: 230.3189	 FID_50: 236.2601
Saved checkpoint at epoch 7 with val loss 9.9680


  0%|          | 0/98 [00:00<?, ?it/s]

Epoch 8	 loss: 9.6016	 val loss: 9.3335	 FID_25: 230.4183	 FID_50: 231.6828
Saved checkpoint at epoch 8 with val loss 9.3335


  0%|          | 0/98 [00:00<?, ?it/s]