In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm


class Diffusion:
    def __init__(self, T=500, device="cpu"):
        self.T = T
        self.device = device

        self.beta = torch.linspace(1e-4, 0.015, T).to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def q_sample(self, x0, t):
        noise = torch.randn_like(x0)
        sqrt_ah = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_1_ah = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        return sqrt_ah * x0 + sqrt_1_ah * noise, noise



class SinusoidalEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        half = self.dim // 2
        emb = torch.exp(
            torch.arange(half, device=device) * (-torch.log(torch.tensor(10000.0)) / half)
        )
        emb = t[:, None] * emb[None, :]
        return torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)



class Block(nn.Module):
    def __init__(self, in_ch, out_ch, t_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.time = nn.Linear(t_dim, out_ch)
        self.norm = nn.GroupNorm(8 if out_ch >= 8 else 1, out_ch)

    def forward(self, x, t):
        h = F.silu(self.conv1(x))
        h += self.time(t)[:, :, None, None]
        h = F.silu(self.norm(self.conv2(h)))
        return h


class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        t_dim = 256
        self.time_mlp = nn.Sequential(
            SinusoidalEmbedding(128),
            nn.Linear(128, t_dim),
            nn.SiLU()
        )

        self.down1 = Block(1, 64, t_dim)
        self.down2 = Block(64, 128, t_dim)
        self.down3 = Block(128, 256, t_dim)

        self.up1 = Block(256 + 128, 128, t_dim)
        self.up2 = Block(128 + 64, 64, t_dim)

        self.final = nn.Conv2d(64, 1, 1)

    def forward(self, x, t):
        t = self.time_mlp(t)

        x1 = self.down1(x, t)
        x2 = self.down2(F.avg_pool2d(x1, 2), t)
        x3 = self.down3(F.avg_pool2d(x2, 2), t)

        u1 = F.interpolate(x3, scale_factor=2)
        u1 = self.up1(torch.cat([u1, x2], dim=1), t)

        u2 = F.interpolate(u1, scale_factor=2)
        u2 = self.up2(torch.cat([u2, x1], dim=1), t)

        return self.final(u2)



class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.model = model
        self.shadow = {k: v.clone() for k, v in model.state_dict().items()}

    def update(self):
        for k, v in self.model.state_dict().items():
            self.shadow[k] = self.decay * self.shadow[k] + (1 - self.decay) * v

    def apply(self):
        self.model.load_state_dict(self.shadow)



def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Device:", device)

    diffusion = Diffusion(T=500, device=device)
    model = UNet().to(device)
    ema = EMA(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    loader = DataLoader(
        datasets.MNIST("data", train=True, download=True, transform=transform),
        batch_size=128,
        shuffle=True
    )

    model.train()
    for epoch in range(30):
        for images, _ in tqdm(loader, desc=f"Epoch {epoch}"):
            images = images.to(device)
            t = torch.randint(0, diffusion.T, (images.size(0),), device=device)
            xt, noise = diffusion.q_sample(images, t)

            loss = F.mse_loss(model(xt, t), noise)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            ema.update()

        print(f"Epoch {epoch} | Loss {loss.item():.4f}")


    ema.apply()
    model.eval()

    with torch.no_grad():
        x = torch.randn(16, 1, 28, 28, device=device)
        for i in reversed(range(diffusion.T)):
            t = torch.full((16,), i, device=device, dtype=torch.long)
            noise = model(x, t)

            alpha = diffusion.alpha[i]
            alpha_hat = diffusion.alpha_hat[i]
            beta = diffusion.beta[i]

            z = torch.randn_like(x) if i > 0 else 0
            x = (1 / torch.sqrt(alpha)) * (
                x - (1 - alpha) / torch.sqrt(1 - alpha_hat) * noise
            ) + torch.sqrt(beta) * z

        save_image((x + 1) / 2, "clean_mnist_ddpm.png")
        print("Saved clean_mnist_ddpm.png")

if __name__ == "__main__":
    main()


Device: cuda


100%|██████████| 9.91M/9.91M [00:00<00:00, 17.0MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 453kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.24MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.82MB/s]
Epoch 0: 100%|██████████| 469/469 [00:27<00:00, 17.03it/s]


Epoch 0 | Loss 0.0568


Epoch 1: 100%|██████████| 469/469 [00:26<00:00, 17.38it/s]


Epoch 1 | Loss 0.0567


Epoch 2: 100%|██████████| 469/469 [00:26<00:00, 17.38it/s]


Epoch 2 | Loss 0.0501


Epoch 3: 100%|██████████| 469/469 [00:27<00:00, 17.10it/s]


Epoch 3 | Loss 0.0492


Epoch 4: 100%|██████████| 469/469 [00:27<00:00, 17.08it/s]


Epoch 4 | Loss 0.0529


Epoch 5: 100%|██████████| 469/469 [00:27<00:00, 17.18it/s]


Epoch 5 | Loss 0.0488


Epoch 6: 100%|██████████| 469/469 [00:27<00:00, 17.07it/s]


Epoch 6 | Loss 0.0477


Epoch 7: 100%|██████████| 469/469 [00:27<00:00, 17.13it/s]


Epoch 7 | Loss 0.0437


Epoch 8: 100%|██████████| 469/469 [00:27<00:00, 17.13it/s]


Epoch 8 | Loss 0.0385


Epoch 9: 100%|██████████| 469/469 [00:27<00:00, 17.11it/s]


Epoch 9 | Loss 0.0449


Epoch 10: 100%|██████████| 469/469 [00:27<00:00, 17.10it/s]


Epoch 10 | Loss 0.0445


Epoch 11: 100%|██████████| 469/469 [00:27<00:00, 17.13it/s]


Epoch 11 | Loss 0.0414


Epoch 12: 100%|██████████| 469/469 [00:27<00:00, 17.10it/s]


Epoch 12 | Loss 0.0401


Epoch 13: 100%|██████████| 469/469 [00:27<00:00, 17.15it/s]


Epoch 13 | Loss 0.0415


Epoch 14: 100%|██████████| 469/469 [00:27<00:00, 17.10it/s]


Epoch 14 | Loss 0.0379


Epoch 15: 100%|██████████| 469/469 [00:27<00:00, 17.12it/s]


Epoch 15 | Loss 0.0364


Epoch 16: 100%|██████████| 469/469 [00:27<00:00, 17.12it/s]


Epoch 16 | Loss 0.0386


Epoch 17: 100%|██████████| 469/469 [00:27<00:00, 17.09it/s]


Epoch 17 | Loss 0.0368


Epoch 18: 100%|██████████| 469/469 [00:27<00:00, 17.06it/s]


Epoch 18 | Loss 0.0328


Epoch 19: 100%|██████████| 469/469 [00:27<00:00, 17.10it/s]


Epoch 19 | Loss 0.0454


Epoch 20: 100%|██████████| 469/469 [00:27<00:00, 17.11it/s]


Epoch 20 | Loss 0.0383


Epoch 21: 100%|██████████| 469/469 [00:27<00:00, 17.13it/s]


Epoch 21 | Loss 0.0422


Epoch 22: 100%|██████████| 469/469 [00:27<00:00, 17.12it/s]


Epoch 22 | Loss 0.0347


Epoch 23: 100%|██████████| 469/469 [00:27<00:00, 17.12it/s]


Epoch 23 | Loss 0.0387


Epoch 24: 100%|██████████| 469/469 [00:27<00:00, 17.11it/s]


Epoch 24 | Loss 0.0323


Epoch 25: 100%|██████████| 469/469 [00:27<00:00, 17.10it/s]


Epoch 25 | Loss 0.0407


Epoch 26: 100%|██████████| 469/469 [00:27<00:00, 17.12it/s]


Epoch 26 | Loss 0.0361


Epoch 27: 100%|██████████| 469/469 [00:27<00:00, 17.11it/s]


Epoch 27 | Loss 0.0323


Epoch 28: 100%|██████████| 469/469 [00:27<00:00, 17.14it/s]


Epoch 28 | Loss 0.0376


Epoch 29: 100%|██████████| 469/469 [00:27<00:00, 17.10it/s]


Epoch 29 | Loss 0.0376
Saved clean_mnist_ddpm.png
