# Imports

In [2]:
import tqdm
from flax import linen as nn
import jax
from typing import Dict, Callable, Sequence, Any, Union
from dataclasses import field
import jax.numpy as jnp
import tensorflow_datasets as tfds
import tensorflow as tf

import matplotlib.pyplot as plt
from clu import metrics
from flax.training import train_state  # Useful dataclass to keep train state
import optax
from flax import struct                # Flax dataclasses
import time
import os
from datetime import datetime
from flax.training import orbax_utils
import functools

# Some Important Utils

In [3]:
normalizeImage = lambda x: jax.nn.standardize(x, mean=[127.5], std=[127.5])
denormalizeImage = lambda x: (x + 1.0) * 127.5


def plotImages(imgs, fig_size=(8, 8), dpi=100):
    fig = plt.figure(figsize=fig_size, dpi=dpi)
    imglen = imgs.shape[0]
    for i in range(imglen):
        plt.subplot(fig_size[0], fig_size[1], i + 1)
        plt.imshow(tf.cast(denormalizeImage(imgs[i, :, :, :]), tf.uint8))
        plt.axis("off")
    plt.show()

class RandomClass():
    def __init__(self, rng: jax.random.PRNGKey):
        self.rng = rng

    def get_random_key(self):
        self.rng, subkey = jax.random.split(self.rng)
        return subkey
    
    def get_sigmas(self, steps):
        return jnp.tan(self.theta_min + steps * (self.theta_max - self.theta_min)) / self.kappa

    def reset_random_key(self):
        self.rng = jax.random.PRNGKey(42)

class MarkovState(struct.PyTreeNode):
    pass

class RandomMarkovState(MarkovState):
    rng: jax.random.PRNGKey

    def get_random_key(self):
        rng, subkey = jax.random.split(self.rng)
        return RandomMarkovState(rng), subkey

# Data Pipeline

In [4]:
def get_dataset(data_name="celeb_a", batch_size=64, image_scale=256):
    def augmenter(image_scale=256, method="area"):
        @tf.function()
        def augment(sample):
            image = (
                tf.cast(sample["image"], tf.float32) - 127.5
            ) / 127.5
            image = tf.image.resize(
                image, [image_scale, image_scale], method=method, antialias=True
            )
            image = tf.image.random_flip_left_right(image)
            image = tf.image.random_contrast(image, 0.999, 1.05)
            image = tf.image.random_brightness(image, 0.2)

            image = tf.clip_by_value(image, -1.0, 1.0)
            return image
        return augment

    # Load CelebA Dataset
    data: tf.data.Dataset = tfds.load(data_name, split="all", shuffle_files=True)
    final_data = (
        data
        .cache()  # Cache after augmenting to avoid recomputation
        .map(
            augmenter(image_scale, method="area"),
            num_parallel_calls=tf.data.AUTOTUNE,
        )
        .repeat()  # Repeats the dataset indefinitely
        .shuffle(4096)  # Ensure this is adequate for your dataset size
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    ).as_numpy_iterator()
    return final_data, len(data)


# Noise Schedulers

A Noise schedule governs how noise is added in the forward diffusion steps. Basically, given a time step $t$, it 
returns the signal rate $\alpha$ and noise rate $\sigma_t$ with which to scale the initial data sample $x_0$ and Gaussian noise $\epsilon$ as given by the equation:

$x_t = \alpha_t * x_0 + \sigma_t * \epsilon_0$

where $x_t$ is the data sample at time $t$, $x_0$ is the initial data sample, and $\epsilon$ is the Gaussian noise, and $\alpha$ and $\sigma_t$ are the signal and noise rates at time $t$ respectively.

In variance preserving diffusion, the noise schedule is such that the variance of the data sample remains constant across time steps. This basically means the following:

$\alpha_t^2 + \sigma_t^2 = 1$

The idea is that with increasing time step $t$, the signal rate $\alpha_t$ decreases and the noise rate $\sigma_t$ increases, decreasing the $%$ of the initial data sample and increasing the amount of noise, slowly diffusing the data sample smoothly to the target normal distribution.

Ofcourse there are many ways to schedule the noise, and the constraint of variance preserving isn't the only way to go about it either. There are Variance exploding and variance preserving schedules as well. In this notebook, we will be looking at the variance preserving noise schedules only.

In [None]:
class NoiseScheduler():
    def __init__(self, timesteps,
                    dtype=jnp.float32,
                    clip_min=-1.0,
                    clip_max=1.0,
                    *args, **kwargs):
        self.max_timesteps = timesteps
        self.dtype = dtype
        self.clip_min = clip_min
        self.clip_max = clip_max

    def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
        raise NotImplementedError
    
    def get_weights(self, steps):
        raise NotImplementedError
    
    def reshape_rates(self, rates:tuple[jnp.ndarray, jnp.ndarray], shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
        alpha, sigma = rates
        alpha = jnp.reshape(alpha, shape)
        sigma = jnp.reshape(sigma, shape)
        return alpha, sigma
    
    def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
        raise NotImplementedError
    
    def add_noise(self, images, noise, steps) -> jnp.ndarray:
        alpha, sigma = self.get_rates(steps)
        return alpha * images + sigma * noise
    
    def remove_all_noise(self, noisy_images, noise, steps, clip_denoised=True, rates=None) -> jnp.ndarray:
        alpha, sigma = self.get_rates(steps)
        x_0 = (noisy_images - noise * sigma) / alpha
        return x_0
    
    def transform_inputs(self, x, steps) -> tuple[jnp.ndarray, jnp.ndarray]:
        return x, steps
    
    def get_posterior_mean(self, x_0, x_t, steps):
        raise NotImplementedError
    
    def get_posterior_variance(self, steps, shape=(-1, 1, 1, 1)):
        raise NotImplementedError

    def get_max_variance(self):
        alpha_n, sigma_n = self.get_rates(self.max_timesteps)
        variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2) 
        return variance


## Cosine Noise Scheduler

Cosine Schedule is one of the most widely used noise schedules. Its a Variance Preserving noise schedule and can either be parameterized directly as the functions $\alpha(t)$ and $\sigma(t)$ defined directly as 

$cos(\dfrac{\pi t}{2 T})$ and $sin(\dfrac{\pi t}{2 T})$, 

which is simple to understand as $sin(x)^2 + cos(x)^2 = 1$, which satisfies the variance preserving constraint, 
or in terms of functions depending on a $\beta(t)$ parameter as defined in the original DDPM paper and many more places. 

You see, In the original DDPM Paper, the forward diffusion step is defined as 

$q(x_t|x_{t−1}) := \mathcal{N}(x_t;\sqrt{1-\beta_t}x_{t−1}, \beta_t I)$

where $q(x_t|x_{t−1})$ is the forward diffusion step. It's the conditional distribution of the data sample $x_t$ given the previous data sample $x_{t-1}$, and the equation states that $x_t$ is normally distributed with mean $\sqrt{1-\beta_t}x_{t−1}$ and variance $\beta_t$.

In human speak, the forward diffusion is phrased as:

$x_t = \sqrt{1-\beta_t}x_{t−1} + \sqrt{\beta_t} \epsilon_t$

where $x_t$ is the data sample at time $t$, $x_{t-1}$ is the previous data sample, and $\epsilon_t$ is the Gaussian noise at time $t$. Notice that $x_t$ is phrased in terms of $x_{t-1}$ instead of the initial data sample $x_0$ as we did in the 'Noise Schedulers' section. 

To convert this formulation to the type 
$x_t = \alpha_t  x_0 + \sigma_t  \epsilon_0$,

we can find that our signal rate $\alpha_t$ and noise rate $\sigma_t$ are given by:

$\alpha_t = \prod_t \sqrt{1-\beta_t}$

$\sigma_t = \sqrt{1-\alpha_t^2}$ 

The thing is, the maths in DDPM paper use the symbol $\alpha$ for a different thing, as an intermediate to denote the value $\alpha_t = 1-\beta_t$, so just be careful with the notation. We use $\alpha$ to denote the signal rate in this notebook.

In [None]:

class ContinuousNoiseScheduler(NoiseScheduler):
    """
    General Continuous Noise Scheduler
    """
    def __init__(self, *args, **kwargs):
        super().__init__(timesteps=1, *args, **kwargs)

class CosineContinuousNoiseScheduler(ContinuousNoiseScheduler):
    def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
        signal_rates = jnp.cos((jnp.pi * steps) / (2 * self.max_timesteps))
        noise_rates = jnp.sin((jnp.pi * steps) / (2 * self.max_timesteps))
        return self.reshape_rates((signal_rates, noise_rates), shape=shape)
    
    def get_weights(self, steps):
        alpha, sigma = self.get_rates(steps, shape=())
        return 1 / (1 + (alpha ** 2 / sigma ** 2))
    