# From Score Matching to DDPM: Swiss Roll Diffusion Example

In this next section we will go from the simple Langevin dynamics example to Denoising Diffusion Probabilistic Models (DDPM) [Ho et al](https://arxiv.org/abs/2006.11239).
The Langevin dynamics example required the knowledge of the score function, which we derived from the bimodal Gaussian 
PDF using `jax.grad`.

However in general we do not have access to the score function, we just have access to the data. However we
can estimate the score function from the data in a process called "score matching", introduced by [Hyvarinen, 2005](https://jmlr.org/papers/volume6/hyvarinen05a/hyvarinen05a.pdf).
The core of the method involves the intesting result, that the MSE loss between the score function and the gradient of the log PDF:

$$\mathcal{L}_{mse} = E_{\mathbf{x} \sim p(\mathbf{x})} \left[ \left\lVert \mathcal{F}_{\theta}(\mathbf{x}) - \nabla_{\mathbf{x}} \log p(\mathbf{x}) \right\lVert_2^2 \right]$$

Can be reformulated as:

$$\mathcal{L}_{matching} = E_{\mathbf{x} \sim p(\mathbf{x})} \left[ \text{ tr}\left( \nabla_{\mathbf{x}}  \mathcal{F}_{\theta}(\mathbf{x})  \right) + \frac{1}{2} \left\Vert \mathcal{F}_{\theta}(\mathbf{x}) \right\lVert_2^2 \right]$$


See also:
- https://github.com/joseph-nagel/diffusion-demo/blob/main/notebooks/swissroll.ipynb
- https://github.com/acids-ircam/diffusion_models/
- 


In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

from sklearn.datasets import make_swiss_roll
import numpy as np
from .utils import animate_trajectory, plot_vector_field

In [None]:
random_state = np.random.RandomState(8923)

In [None]:
n_samples = 1_000
x = make_swiss_roll(n_samples=n_samples, noise=0.0, random_state=random_state)[0]
x = 0.13 * x[:, [0, 2]]  # drop z axis

In [None]:
ax = plt.subplot()

ax.scatter(x[:, 0], x[:, 1], s=1)
ax.set_aspect("equal")

scale = 2.0
ax.set_xlim(-scale, scale)
ax.set_ylim(-scale, scale)

In [None]:
n_steps = 100

betas = jnp.geomspace(1e-3, 1, n_steps)

In [None]:
# iterative version
def diffuse_one_step(carry, beta):
    """Diffuse one step"""
    x, key = carry
    key, subkey = jax.random.split(key)
    noise = jax.random.normal(subkey, shape=x.shape)
    x_diffused = jnp.sqrt(1 - beta) * x + jnp.sqrt(beta) * noise
    return (x_diffused, key), x_diffused


init = (x, jax.random.PRNGKey(876234))

x_diffused = jax.lax.scan(diffuse_one_step, init=init, xs=betas)[1]

In [None]:
animate_trajectory(x_diffused, "diffusion_iterative.gif")

I you do the math you can directly evaluate the state at time t = idx, see e.g. https://lilianweng.github.io/posts/2021-07-11-diffusion-models/

In [None]:
alphas = 1 - betas.reshape(-1, 1, 1)
alpha_bars = jnp.cumprod(alphas, axis=0)

# closed form version
noise = jax.random.normal(jax.random.PRNGKey(1234), shape=(n_steps, n_samples, 2))
x_diffused_closed = jnp.sqrt(alpha_bars) * x + jnp.sqrt(1 - alpha_bars) * noise

In [None]:
animate_trajectory(x_diffused_closed, "diffusion_closed.gif")

## Score Matching




In [None]:
@jax.jit
def score_matching_loss(model, samples):
    """Score matching loss"""
    logp = jax.vmap(model)(samples)
    norm_loss = jnp.linalg.norm(logp, axis=-1) ** 2 / 2.0
    jacob_mat = jax.vmap(jax.jacobian(model))(samples)

    tr_jacobian_loss = jnp.trace(jacob_mat, axis1=-2, axis2=-1)

    return (tr_jacobian_loss + norm_loss).mean()

In [None]:
import equinox as eqx
import optax

In [None]:
class MLP(eqx.Module):
    """Simple MLP"""

    fc1: eqx.nn.Linear
    fc2: eqx.nn.Linear
    fc3: eqx.nn.Linear

    def __init__(self, in_dim, out_dim, key, hidden_dim=128):
        keys = jax.random.split(key, 3)
        self.fc1 = eqx.nn.Linear(in_dim, hidden_dim, key=keys[0])
        self.fc2 = eqx.nn.Linear(hidden_dim, hidden_dim, key=keys[1])
        self.fc3 = eqx.nn.Linear(hidden_dim, out_dim, key=keys[2])

    def __call__(self, x):
        x = jax.nn.softplus(self.fc1(x))
        x = jax.nn.softplus(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
model = MLP(2, 2, key=jax.random.PRNGKey(1234))

In [None]:
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(model)

In [None]:
n_steps = 1000


@eqx.filter_jit
def make_step(model, opt_state, x):
    loss_value, grads = eqx.filter_value_and_grad(score_matching_loss)(model, x)
    updates, opt_state = optimizer.update(
        grads, opt_state, eqx.filter(model, eqx.is_array)
    )
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss_value


for step in range(n_steps):
    model, opt_state, train_loss = make_step(model, opt_state, x)
    print(f"Step {step}, loss: {train_loss:.4f}", end="\r")

In [None]:
@jax.jit
def sample_simple(model, x, n_steps=70, eps=1e-3):
    # create a step function to pass to jax.scan
    def step(x, i):
        x = x + eps * forward(model, x)
        return x, x

    return jax.lax.scan(step, x, np.arange(n_steps))[1]