In [1]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.datasets import make_swiss_roll as swiss_roll


## A minimalist code to demonstrate Diffusion process

In [2]:
# UTILS

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print('--> running on', device)

def sample_batch(size):
    x, _ = swiss_roll(size)
    return x[:, [2, 0]] / 10.0 * np.array([1, -1])

def plot(model):
    plt.figure(figsize=(10, 6))

    x0 = sample_batch(5000)
    # getting the sample at t=20 and t=40
    x20 = model.forward_process(torch.from_numpy(x0).to(torch.float32).to(device), 20)[-1].data.cpu().numpy()
    x40 = model.forward_process(torch.from_numpy(x0).to(torch.float32).to(device), 40)[-1].data.cpu().numpy()

    data = [x0, x20, x40]

    for i, t in enumerate([0, 20, 39]):
        plt.subplot(2, 3, 1 + i)
        plt.scatter(data[i][:, 0], data[i][:, 1], alpha=0.1, s=1)
        plt.xlim([-2, 2])
        plt.ylim([-2, 2])
        plt.gca().set_aspect('equal')
        if t == 0: 
            plt.ylabel(r'$q(\mathbf{x}^{(0...T)})$', fontsize=17, rotation=0, labelpad=60)
        if i == 0: 
            plt.title(r'$t=0$', fontsize=17)
        if i == 1:
            plt.title(r'$t=\frac{T}{2}$', fontsize=17)
        if i == 2: 
            plt.title(r'$t=T$', fontsize=17)
    
    samples = model.sample(5000, device)
    for i, t in enumerate([0, 20, 40]):
        plt.subplot(2, 3, 4 + i)
        plt.scatter(samples[40 - t][:, 0].data.cpu().numpy(), samples[40 - t][:, 1].data.cpu().numpy(), alpha=.1, s=1, c='r')
        plt.xlim([-2, 2])
        plt.ylim([-2, 2])
        plt.gca().set_aspect('equal')
        if t == 0: 
            plt.ylabel(r'$p(\mathbf{x}^{(0...T)})$', fontsize=17, rotation=0, labelpad=60)
    plt.savefig("Imgs/diffusion_model.png", bbox_inches='tight')
    plt.close()



--> running on mps


### We need a network to learn the conditional distribution q(x_t | x_{t+1})

Since our dataset is a 2D (x, y) swiss roll coordinates our neural net will learn to predict the mu and sigma of the reverse gaussain distribution to denoise an input at time Xt to Xt-1

In [3]:
class MLP(nn.Module):

    def __init__(self, N=40, data_dim=2, hidden_dim=64):
        super(MLP, self).__init__()

        self.network_head = nn.Sequential(nn.Linear(data_dim, hidden_dim), nn.ReLU(),
                                          nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), )
        self.network_tail = nn.ModuleList([nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
                                                         nn.ReLU(), nn.Linear(hidden_dim, data_dim * 2)
                                                         ) for _ in range(N)])

    def forward(self, x, t: int):
        h = self.network_head(x)
        return self.network_tail[t](h)

In [4]:
class DiffusionModel(nn.Module):
    def __init__(self, mlp_model, T=40, device='mps'):
        super().__init__()

        self.mlp_model = mlp_model
        self.T = T
        self.device = device

        # contruction our non linear betas schedule

        betas = torch.linspace(-18, 10, self.T)
        self.betas = torch.sigmoid(betas) * (3e-1 - 1e-5) + 1e-5


        self.alpha = 1. - self.betas
        self.alpha_bar = torch.cumprod(self.alpha, dim=0) # cumulative product
        self.sigma2 = self.betas

    def forward_process(self, x0, t):

        t = t - 1  # Start indexing at 0
        beta_forward = self.betas[t]
        alpha_forward = self.alpha[t]
        alpha_cum_forward = self.alpha_bar[t]
        xt = x0 * torch.sqrt(alpha_cum_forward) + torch.randn_like(x0) * torch.sqrt(1. - alpha_cum_forward)
        # Retrieved from https://github.com/Sohl-Dickstein/Diffusion-Probabilistic-Models/blob/master/model.py#L203
        mu1_scl = torch.sqrt(alpha_cum_forward / alpha_forward)
        mu2_scl = 1. / torch.sqrt(alpha_forward)
        cov1 = 1. - alpha_cum_forward / alpha_forward
        cov2 = beta_forward / alpha_forward
        lam = 1. / cov1 + 1. / cov2
        mu = (x0 * mu1_scl / cov1 + xt * mu2_scl / cov2) / lam
        sigma = torch.sqrt(1. / lam)
        return mu, sigma, xt

    def reverse(self, xt, t):

        t = t - 1  # Start indexing at 0
        if t == 0: 
            return None, None, xt
        mu, h = self.mlp_model(xt, t).chunk(2, dim=1)
        sigma = torch.sqrt(torch.exp(h))
        samples = mu + torch.randn_like(xt) * sigma
        return mu, sigma, samples

    def sample(self, size, device):
        noise = torch.randn((size, 2)).to(device)
        samples = [noise]
        for t in range(self.T):
            _, _, x = self.reverse(samples[-1], self.T - t - 1 + 1)
            samples.append(x)
        return samples


### Putting all together

In [5]:
def train(model, optimizer, nb_epochs=150000, batch_size=64_000):
    training_loss = []
    for _ in tqdm(range(nb_epochs)):
        x0 = torch.from_numpy(sample_batch(batch_size)).float().to(device)
        t = np.random.randint(2, 40 + 1)
        mu_posterior, sigma_posterior, xt = model.forward_process(x0, t)
        mu, sigma, _ = model.reverse(xt, t)

        KL = (torch.log(sigma) - torch.log(sigma_posterior) + (sigma_posterior ** 2 + (mu_posterior - mu) ** 2) / (
                2 * sigma ** 2) - 0.5)
        loss = KL.mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        training_loss.append(loss.item())


device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
model_mlp = MLP(hidden_dim=128).to(device)
model = DiffusionModel(model_mlp)
optimizer = torch.optim.Adam(model_mlp.parameters(), lr=1e-4)
train(model, optimizer)
plot(model)

 21%|██▏       | 32120/150000 [15:32<59:52, 32.81it/s]  

In [None]:
betas = torch.linspace(-18, 10, 40)
betas = torch.sigmoid(betas) * (3e-1 - 1e-5) + 1e-5

print(betas)

tensor([1.0005e-05, 1.0009e-05, 1.0019e-05, 1.0039e-05, 1.0081e-05, 1.0166e-05,
        1.0339e-05, 1.0696e-05, 1.1426e-05, 1.2924e-05, 1.5995e-05, 2.2291e-05,
        3.5199e-05, 6.1659e-05, 1.1589e-04, 2.2702e-04, 4.5461e-04, 9.2014e-04,
        1.8701e-03, 3.7989e-03, 7.6763e-03, 1.5317e-02, 2.9796e-02, 5.5312e-02,
        9.5000e-02, 1.4616e-01, 1.9823e-01, 2.3992e-01, 2.6735e-01, 2.8313e-01,
        2.9153e-01, 2.9581e-01, 2.9794e-01, 2.9899e-01, 2.9951e-01, 2.9976e-01,
        2.9988e-01, 2.9994e-01, 2.9997e-01, 2.9999e-01])
