# The Astro DDPM

In the last part we are going to work with a pre-trained diffusion model for generating astronomical images.


The reference implementation in PyTorch is available at: https://github.com/Smith42/astroddpm
The paper is here: https://arxiv.org/abs/2111.01713

We will implement the reverse diffusion process to generate new images from noise.


In [None]:
from jax import numpy as jnp
import jax

from unet import UNet
import matplotlib.pyplot as plt

import numpy as np

In [None]:
jnp.set_printoptions(precision=4, suppress=True, linewidth=85)

In [None]:
model = UNet.read("../data/models/probes_model_00745000.pt")

In [None]:
def cosine_beta_schedule(timesteps, s=0.008):
    """Cosine schedule"""
    x = jnp.linspace(0.0, 1.0, timesteps + 1)
    alphas_cumprod = jnp.cos((x + s) / (1.0 + s) * jnp.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return jnp.clip(betas, a_min=0, a_max=0.999)

In [None]:
from typing import Callable
from dataclasses import dataclass
from jax.tree_util import register_dataclass


@register_dataclass
@dataclass
class DiffusionModel:
    """Diffusion model"""

    beta: jax.Array
    alpha: jax.Array
    alpha_bar: jax.Array
    model: Callable

    @classmethod
    def from_beta(cls, beta, model):
        """Create diffusion model from beta schedule. Beta and the model must be consistent!"""
        alpha = 1.0 - beta
        alpha_bar = jnp.cumprod(alpha)
        return cls(beta=beta, alpha=alpha, alpha_bar=alpha_bar, model=model)

    @jax.jit
    def reverse_diffusion_step(self, x_t, t):
        """Reverse diffusion step"""
        predicted_noise = self.model(x_t, t)
        jax.debug.print("noise {}", predicted_noise.mean())

        alpha_t = jnp.expand_dims(self.alpha[t], axis=(1, 2, 3))
        alpha_bar_t = jnp.expand_dims(self.alpha_bar[t], axis=(1, 2, 3))

        mu = (1.0 / jnp.sqrt(alpha_t)) * (
            x_t - (1.0 - alpha_t) / jnp.sqrt(1.0 - alpha_bar_t) * predicted_noise
        )
        jax.debug.print("mu {}", predicted_noise.mean())
        return jnp.clip(mu, -1, 1)

    def sample(self, key, image_size, batch_size=1):
        """Generate images using reverse diffusion process."""
        # Start from random noise
        key, subkey = jax.random.split(key)

        x = jax.random.normal(subkey, (batch_size, 3) + image_size)

        # Reverse diffusion TODO: implement as JAX scan...
        for t in reversed(range(len(self.beta))):
            key, subkey = jax.random.split(key)
            t_batch = jnp.ones((batch_size,), dtype=jnp.int32) * t

            x = self.reverse_diffusion_step(x, t_batch)

            noise = jax.random.normal(subkey, x.shape) if t > 0 else 0

            # thats the brownian motion term from the Langevin Dynamics example
            x = x + jnp.sqrt(self.beta[t]) * noise
            jax.debug.print("{}", x.mean())

        return x

In [None]:
time_steps = 1000
beta = cosine_beta_schedule(time_steps)
diffusion_model = DiffusionModel.from_beta(beta, model)

In [None]:
result = diffusion_model.sample(jax.random.PRNGKey(237), (256, 256))

In [None]:
def reverse_transform(x):
    """Reverse transform see Eq (9) of https://arxiv.org/pdf/2111.01713v1"""
    return np.flip((0.5 * (x + 1) / x.max()).T, axis=-1)

In [None]:
data = reverse_transform(result[0])


def percentile_clip(x, pmin=0.05, pmax=99.99):
    """Percentile clip"""
    pmin = np.percentile(x, pmin)
    pmax = np.percentile(x, pmax)
    return np.clip(x, pmin, pmax)