# Toy diffusion model generating data from a complex 2-d distribution

General imports

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

## Generate data

In [None]:
from sklearn.datasets import make_swiss_roll

x, _ = make_swiss_roll(n_samples=100000, noise=0.5)
# Make two-dimensional to easen visualization
x = x[:, [0, 2]]

x = (x - x.mean()) / x.std()

plt.scatter(x[:, 0], x[:, 1])


## Prepare data for learning

In [None]:
from einops import rearrange
import torch

X = torch.tensor(x, dtype=torch.float32)

## Noising functions

In [None]:
diffusion_steps = 40

s = 0.008
timesteps = torch.tensor(range(0, diffusion_steps), dtype=torch.float32)
schedule = torch.cos((timesteps / diffusion_steps + s) / (1 + s) * torch.pi / 2)**2

baralphas = schedule / schedule[0]
betas = 1 - baralphas / torch.concatenate([baralphas[0:1], baralphas[0:-1]])
alphas = 1 - betas

sns.lineplot(baralphas)

In [None]:
from einops import rearrange

def noise(Xbatch, t):
    # t = rearrange(t, "x -> x 1")  # t to column tensor
    eps = torch.randn(size=Xbatch.shape)
    noised = (baralphas[t] ** 0.5).repeat(1, Xbatch.shape[1]) * Xbatch + ((1 - baralphas[t]) ** 0.5).repeat(1, Xbatch.shape[1]) * eps
    return noised, eps

In [None]:
noiselevel = 20

noised, eps = noise(X, torch.full([len(X), 1], fill_value=noiselevel))
plt.scatter(X[:, 0], X[:, 1])
plt.scatter(noised[:, 0], noised[:, 1], marker="*")
denoised = 1 / torch.sqrt(baralphas[noiselevel]) * (noised - torch.sqrt(1 - baralphas[noiselevel]) * eps)
plt.scatter(denoised[:, 0], denoised[:, 1], marker="1")

In [None]:
X - denoised

## Diffusion network

In [None]:
import torch.nn as nn

class DiffusionBlock(nn.Module):
    def __init__(self, nunits):
        super(DiffusionBlock, self).__init__()
        self.linear1 = nn.Linear(nunits+1, nunits+1)
        self.norm1 = nn.LayerNorm(nunits+1)
        self.linear2 = nn.Linear(nunits+1, nunits)
        self.norm2 = nn.LayerNorm(nunits)
        
    def forward(self, x: torch.Tensor, t: torch.Tensor):
        val = torch.hstack([x, t])  # Add t to inputs
        val = self.linear1(val)
        val = self.norm1(val)
        val = nn.functional.relu(val)
        val = self.linear2(val)
        val = self.norm2(val)
        val = nn.functional.relu(val)
        return val + x  # Skip connection
        
    
class DiffusionModel(nn.Module):
    def __init__(self, nfeatures: int, nblocks: int = 2, nunits: int = 64):
        super(DiffusionModel, self).__init__()
        
        self.inblock = nn.Linear(nfeatures+1, nunits)
        self.midblocks = nn.ModuleList([DiffusionBlock(nunits) for _ in range(nblocks)])
        self.outblock = nn.Linear(nunits, nfeatures)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        val = torch.hstack([x, t])  # Add t to inputs
        val = self.inblock(val)
        for midblock in self.midblocks:
            val = midblock(val, t)
        val = self.outblock(val)
        return val

model = DiffusionModel(nfeatures=2, nblocks=2)

device = "cuda"
model = model.to(device)

model

Denoising model training

In [None]:
from einops import rearrange
import torch.optim as optim
from tqdm import tqdm

nepochs = 100
batch_size = 2048

loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=nepochs)

for epoch in range(nepochs):
    epoch_loss = steps = 0
    for i in range(0, len(X), batch_size):
        Xbatch = X[i:i+batch_size]
        timesteps = torch.randint(0, diffusion_steps, size=[len(Xbatch), 1])
        noised, eps = noise(Xbatch, timesteps)
        predicted_noise = model(noised.to(device), timesteps.to(device))
        loss = loss_fn(predicted_noise, eps.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss
        steps += 1
    print(f"Epoch {epoch} loss = {epoch_loss / steps}")

Best model: 
* DiffusionModel(nfeatures=2, nblocks=5), 40 diffusion steps, batchsize=2048, Epoch 99 loss = 0.4353588819503784
* DiffusionModel(nfeatures=2, nblocks=2), 40 diffusion steps, batchsize=2048, Epoch 99 loss = 0.4339998960494995


## Test sampler

In [None]:
def sample_ddpm(model, nsamples, nfeatures):
    """Sampler following the Denoising Diffusion Probabilistic Models method by Ho et al (Algorithm 2)"""
    with torch.no_grad():
        x = torch.randn(size=(nsamples, nfeatures)).to(device)
        for t in range(diffusion_steps-1, 0, -1):
            predicted_noise = model(x, torch.full([nsamples, 1], t).to(device))
            # See DDPM paper between equations 11 and 12
            x = 1 / (alphas[t] ** 0.5) * (x - (1 - alphas[t]) / ((1-baralphas[t]) ** 0.5) * predicted_noise)
            if t > 1:
                # See DDPM paper section 3.2.
                # Choosing the variance through beta_t is optimal for x_0 a normal distribution (what we use here)
                variance = betas[t]
                std = variance ** (0.5)
                x += std * torch.randn(size=(nsamples, nfeatures)).to(device)
        return x

In [None]:
Xgen = sample_ddpm(model, 10000, 2).cpu()
plt.scatter(X[:, 0], X[:, 1])
plt.scatter(Xgen[:, 0], Xgen[:, 1], marker="1")

In [None]:
def sample_ddpm_x0(model, nsamples, nfeatures):
    """Sampler that uses the equations in DDPM paper to predict x0, then use that to predict x_{t-1}
    
    This is how DDPM is implemented in HuggingFace Diffusers, to allow working with models that predict
    x0 instead of the noise. It is also how we explain it in the Mixture of Diffusers paper.
    """
    with torch.no_grad():
        x = torch.randn(size=(nsamples, nfeatures)).to(device)
        for t in range(diffusion_steps-1, 0, -1):
            predicted_noise = model(x, torch.full([nsamples, 1], t).to(device))
            # Predict original sample using DDPM Eq. 15
            x0 = (x - (1 - baralphas[t]) ** (0.5) * predicted_noise) / baralphas[t] ** (0.5)  # CHECKED!
            # Predict previous sample using DDPM Eq. 7
            c0 = (baralphas[t-1] ** (0.5) * betas[t]) / (1 - baralphas[t])
            ct = alphas[t] ** (0.5) * (1 - baralphas[t-1]) / (1 - baralphas[t])
            x = c0 * x0 + ct * x
            # Add noise
            if t > 1:
                # variance = (1 - baralphas[t-1]) / (1 - baralphas[t]) * betas[t] # CHECKED!
                variance = betas[t]
                variance = torch.clamp(variance, min=1e-20) # CHECKED!
                std = variance ** (0.5) # CHECKED!
                x += std * torch.randn(size=(nsamples, nfeatures)).to(device) # CHECKED!
        return x

In [None]:
Xgen = sample_ddpm_x0(model, 10000, 2).cpu()
plt.scatter(X[:, 0], X[:, 1])
plt.scatter(Xgen[:, 0], Xgen[:, 1], marker="1")