# Toy diffusion model generating data from a complex 1-d distribution

## Generate data

In [None]:
import numpy as np

generator = np.random.default_rng(seed=12345)
# Uniform numbers over the range [0, 1]
#x = generator.uniform(low=0.0, high=1.0, size=100000)
x = generator.uniform(low=0.0, high=1.0, size=10000)
# Add gaussian centered in 0.2
x = np.concatenate([x, generator.normal(0.2, 1, size=10000)])
# Add another high variance gaussian centered in 5
x = np.concatenate([x, generator.normal(5, 5, size=20000)])
# Add another low variance gaussian centered in 15
x = np.concatenate([x, generator.normal(15, 1, size=10000)])
# Normalize and center
x = (x - x.mean()) / x.std()

In [None]:
import seaborn as sns

sns.displot(x, kde=True, bins=100)

## Prepare data for learning

In [None]:
from einops import rearrange
import torch

X = rearrange(torch.tensor(x, dtype=torch.float32), "x -> x 1")

## Noising functions

In [None]:
diffusion_steps = 40

s = 0.008
timesteps = torch.tensor(range(0, diffusion_steps), dtype=torch.float32)
schedule = torch.cos((timesteps / diffusion_steps + s) / (1 + s) * torch.pi / 2)**2

baralphas = schedule / schedule[0]
betas = 1 - baralphas / torch.concatenate([baralphas[0:1], baralphas[0:-1]])
alphas = 1 - betas
sigmas = torch.sqrt(betas)

sns.lineplot(baralphas)

In [None]:
def noise(Xbatch, t):
    if torch.is_tensor(t):
        t = t.flatten()
    eps = torch.randn(size=(len(Xbatch), 1))
    noised = rearrange(torch.sqrt(baralphas[t]), "x -> x 1") * Xbatch + rearrange(torch.sqrt(1 - baralphas[t]), "x -> x 1") * eps
    return noised, eps

In [None]:
noiselevel = 20

noised, eps = noise(X, [noiselevel] * len(X))
sns.displot(X, kde=True, bins=100)
sns.displot(noised, kde=True, bins=100)
denoised = 1 / torch.sqrt(baralphas[noiselevel]) * (noised - torch.sqrt(1 - baralphas[noiselevel]) * eps)
sns.displot(denoised, kde=True, bins=100)

In [None]:
X - denoised

## Diffusion network

In [None]:
import torch.nn as nn

class PolynomialExpansionLayer(nn.Module):
    """Custom layer that expands the given variables into nonlinear combinations following a polynomial of d degree"""
    def __init__(self, degree):
        super().__init__()
        self.degree = degree

    def forward(self, x):
        features = x.shape[-1]
        result = x.clone()
        for exp in range(2, self.degree+1):
            result = torch.hstack([result, x ** exp])
        return result

In [None]:
x = torch.tensor([[1, 2], [3, 4], [5, 6]])
print(x)

PolynomialExpansionLayer(3)(x)

In [None]:
import math

class PositionalEncoding(nn.Module):
    """Custom layer that adds features in a fashion similar to a transformer positional encoding"""

    def __init__(self, t_index: int, max_t: int, d: int = 256):
        super().__init__()
        self.t_index = t_index

        position = torch.arange(max_t).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d, 2) * (-math.log(10000.0) / d))
        pe = torch.zeros(max_t, d)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        return torch.hstack([x, self.pe[x[:, self.t_index].long()]])

In [None]:
x = torch.tensor([[1, 2], [3, 4], [200, 6]])
print(x)

PositionalEncoding(1, 1000, 32)(x)

In [None]:
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(2, 256),
    nn.LayerNorm(256),
    nn.ReLU(),
    nn.Linear(256, 512),
    nn.LayerNorm(512),
    nn.ReLU(),
    nn.Linear(512, 1024),
    nn.LayerNorm(1024),
    nn.ReLU(),
    nn.Linear(1024, 1)
)

device = "cuda"
model = model.to(device)

model

In [None]:
# Old implementation
class DiffusionModel(nn.Module):
    def __init__(self):
        super(DiffusionModel, self).__init__()
        
        self.linear1 = nn.Linear(2, 1024)
        self.norm1 = nn.LayerNorm(1024)
        self.linear2 = nn.Linear(1025, 1024)
        self.norm2 = nn.LayerNorm(1024)
        self.linear3 = nn.Linear(1025, 1024)
        self.norm3 = nn.LayerNorm(1024)
        self.linear4 = nn.Linear(1025, 1)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        val = torch.hstack([x, t])  # Add t to inputs
        val = self.linear1(val)
        val = self.norm1(val)
        val = nn.functional.relu(val)
        val = torch.hstack([val, t])  # Add t again
        val = self.linear2(val)
        val = self.norm2(val)
        val = nn.functional.relu(val)
        val = torch.hstack([val, t])  # Add t again
        val = self.linear3(val)
        val = self.norm3(val)
        val = nn.functional.relu(val)
        val = torch.hstack([val, t])  # Add t again
        val = self.linear4(val)
        return val

In [None]:
import torch.nn as nn

class DiffusionBlock(nn.Module):
    def __init__(self, nunits):
        super(DiffusionBlock, self).__init__()
        self.linear1 = nn.Linear(nunits+1, nunits+1)
        self.norm1 = nn.LayerNorm(nunits+1)
        self.linear2 = nn.Linear(nunits+1, nunits)
        self.norm2 = nn.LayerNorm(nunits)
        
    def forward(self, x: torch.Tensor, t: torch.Tensor):
        val = torch.hstack([x, t])  # Add t to inputs
        val = self.linear1(val)
        val = self.norm1(val)
        val = nn.functional.relu(val)
        val = self.linear2(val)
        val = self.norm2(val)
        val = nn.functional.relu(val)
        return val + x  # Skip connection
        
    
class DiffusionModel(nn.Module):
    def __init__(self, nblocks: int = 2, nunits: int = 64):
        super(DiffusionModel, self).__init__()
        
        self.inblock = nn.Linear(2, nunits)
        self.midblocks = nn.ModuleList([DiffusionBlock(nunits) for _ in range(nblocks)])
        self.outblock = nn.Linear(nunits, 1)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        val = torch.hstack([x, t])  # Add t to inputs
        val = self.inblock(val)
        for midblock in self.midblocks:
            val = midblock(val, t)
        val = self.outblock(val)
        return val

model = DiffusionModel(nblocks=5)

device = "cuda"
model = model.to(device)

model

Denoising model training

In [None]:
from einops import rearrange
import torch.optim as optim
from tqdm import tqdm

nepochs = 100
batch_size = 2048

loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=nepochs)

for epoch in range(nepochs):
    epoch_loss = steps = 0
    for i in range(0, len(X), batch_size):
        Xbatch = X[i:i+batch_size]
        timesteps = torch.randint(0, diffusion_steps, size=[len(Xbatch), 1])
        noised, eps = noise(Xbatch, timesteps)
        #predicted_noise = model(torch.hstack([noised, timesteps]).to(device))
        predicted_noise = model(noised.to(device), timesteps.to(device))
        loss = loss_fn(predicted_noise, eps.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss
        steps += 1
    print(f"Epoch {epoch} loss = {epoch_loss / steps}")

Best model (complex problem):
* Without t: 0.003909660503268242
* With t as a second input: Epoch 99 loss = 0.0016174211632460356
* Custom network with 4 linear units (256 -> 512 -> 1024 -> 1) with ReLU, LayerNorm, and t reinjection, batchsize=1024: Epoch 99 loss = 0.0004879182088188827

After fixing issues with loss reporting:
* Custom network with 4 linear units (256 -> 512 -> 1024 -> 1) with ReLU, LayerNorm, and t reinjection, batchsize=1024: Epoch 99 loss = 0.5827606320381165
* Custom network with 4 linear units (1024 -> 1024 -> 1024 -> 1) with ReLU, LayerNorm, and t reinjection, batchsize=2048: Epoch 999 loss = 0.530308723449707

## Test sampler

In [None]:
def sample(model, nsamples):
    with torch.no_grad():
        x = torch.randn(size=(nsamples, 1)).to(device)
        for t in range(diffusion_steps-1, 0, -1):
            # predicted_noise = model(torch.hstack([x, torch.ones(size=[nsamples, 1]).to(device) * t]))
            predicted_noise = model(x, torch.ones(size=[nsamples, 1]).to(device) * t)
            # Predict x0 using DDPM equations
            # x = 1 / torch.sqrt(alphas[t]) * (x - betas[t] / torch.sqrt(1 - baralphas[t]) * predicted_noise)
            # x = x - (1 - alphas[t]) / torch.sqrt(1 - baralphas[t]) * predicted_noise
            # Predict original sample using DDPM Eq. 15
            x0 = (x - (1 - baralphas[t]) ** (0.5) * predicted_noise) / baralphas[t] ** (0.5)
            # Predict previous samples using DDPM Eq. 7
            c0 = (baralphas[t-1] ** (0.5) * betas[t]) / (1 - baralphas[t])
            ct = alphas[t] ** (0.5) * (1 - baralphas[t-1]) / (1 - baralphas[t])
            x = c0 * x0 + ct * x
            if t > 1:
                variance = (1 - baralphas[t-1]) / (1 - baralphas[t]) * betas[t]
                variance = torch.clamp(variance, min=1e-20)
                std = variance ** (0.5)
                x += std * torch.randn(size=(nsamples, 1)).to(device)
                # x += sigmas[t] * torch.randn(size=(nsamples, 1))
                # x += (1 - baralphas[t-1]) / (1 - baralphas[t]) * betas[t] * torch.randn(size=(nsamples, 1))
            # print(f"Step {t} = {x[0]}")
        return x

In [None]:
Xgen = sample(model, 10000).cpu()
sns.displot(X, kind="kde")
sns.displot(Xgen, kind="kde")