# Toy diffusion model generating data from a complex 1-d distribution

## Generate data

In [None]:
import numpy as np

generator = np.random.default_rng(seed=12345)
# Uniform numbers over the range [0, 1]
#x = generator.uniform(low=0.0, high=1.0, size=100000)
x = generator.uniform(low=0.0, high=1.0, size=10000)
# Add gaussian centered in 0.2
x = np.concatenate([x, generator.normal(0.2, 1, size=10000)])
# Add another high variance gaussian centered in 5
x = np.concatenate([x, generator.normal(5, 5, size=20000)])
# Add another low variance gaussian centered in 15
x = np.concatenate([x, generator.normal(15, 1, size=10000)])
# Normalize and center
x = (x - x.mean()) / x.std()

In [None]:
import seaborn as sns

sns.displot(x, kde=True, bins=100)

## Prepare data for learning

In [None]:
from einops import rearrange
import torch

X = rearrange(torch.tensor(x, dtype=torch.float32), "x -> x 1")

## Noising functions

In [None]:
diffusion_steps = 1000

s = 0.008
timesteps = torch.tensor(range(0, diffusion_steps), dtype=torch.float32)
schedule = torch.cos((timesteps / diffusion_steps + s) / (1 + s) * torch.pi / 2)**2

baralphas = schedule / schedule[0]
betas = 1 - baralphas / torch.concatenate([baralphas[0:1], baralphas[0:-1]])
alphas = 1 - betas
sigmas = torch.sqrt(betas)

sns.lineplot(baralphas)

In [None]:
def noise(Xbatch, t):
    if torch.is_tensor(t):
        t = t.flatten()
    eps = torch.randn(size=(len(Xbatch), 1))
    noised = rearrange(torch.sqrt(baralphas[t]), "x -> x 1") * Xbatch + rearrange(torch.sqrt(1 - baralphas[t]), "x -> x 1") * eps
    return noised, eps

In [None]:
noiselevel = 999

noised, eps = noise(X, [noiselevel] * len(X))
sns.displot(X, kde=True, bins=100)
sns.displot(noised, kde=True, bins=100)
denoised = 1 / torch.sqrt(baralphas[noiselevel]) * (noised - torch.sqrt(1 - baralphas[noiselevel]) * eps)
sns.displot(denoised, kde=True, bins=100)

In [None]:
X - denoised

## Diffusion network

In [None]:
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(2, 64),
    nn.ReLU(),
    nn.Linear(64, 128),
    nn.ReLU(),
    nn.Linear(128, 256),
    nn.ReLU(),
    nn.Linear(256, 1)
)

device = "cuda"
model = model.to(device)

model

Denoising model training

In [None]:
from einops import rearrange
import torch.optim as optim
from tqdm import tqdm

nepochs = 100
batch_size = 256

loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(nepochs):
    epoch_loss = 0
    for i in range(0, len(X), batch_size):
        Xbatch = X[i:i+batch_size]
        timesteps = torch.randint(0, diffusion_steps, size=[len(Xbatch), 1])
        noised, eps = noise(Xbatch, timesteps)
        predicted_noise = model(torch.hstack([noised, timesteps]).to(device))
        loss = loss_fn(predicted_noise, eps.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss
    print(f"Epoch {epoch} loss = {epoch_loss / i}")

Best model (complex problem):
* Without t: 0.003909660503268242
* With t as a second input: Epoch 99 loss = 0.0016174211632460356

## Test sampler

In [None]:
def sample(model, nsamples):
    model = model.to("cpu")
    with torch.no_grad():
        x = torch.randn(size=(nsamples, 1))
        for t in range(diffusion_steps-1, 0, -1):
            predicted_noise = model(torch.hstack([x, torch.ones(size=[nsamples, 1]) * t]))
            # Predict x0 using DDPM equations
            # x = 1 / torch.sqrt(alphas[t]) * (x - betas[t] / torch.sqrt(1 - baralphas[t]) * predicted_noise)
            # x = x - (1 - alphas[t]) / torch.sqrt(1 - baralphas[t]) * predicted_noise
            # Predict original sample using DDPM Eq. 15
            x0 = (x - (1 - baralphas[t]) ** (0.5) * predicted_noise) / baralphas[t] ** (0.5)
            # Predict previous samples using DDPM Eq. 7
            c0 = (baralphas[t-1] ** (0.5) * betas[t]) / (1 - baralphas[t])
            ct = alphas[t] ** (0.5) * (1 - baralphas[t-1]) / (1 - baralphas[t])
            x = c0 * x0 + ct * x
            if t > 1:
                # TODO: something is off here, DDPM Eq. 7 does not scale the variance by **0.5, but without it we get bad results
                variance = (1 - baralphas[t-1]) / (1 - baralphas[t]) * betas[t]
                variance = torch.clamp(variance, min=1e-20)
                x += variance ** (0.5) * torch.randn(size=(nsamples, 1))  # Why **(0.5)?
                # x += sigmas[t] * torch.randn(size=(nsamples, 1))
                # x += (1 - baralphas[t-1]) / (1 - baralphas[t]) * betas[t] * torch.randn(size=(nsamples, 1))
            # print(f"Step {t} = {x[0]}")
        return x

In [None]:
Xgen = sample(model, 10000)
sns.displot(Xgen, kde=True, bins=100)