<a target="_blank" href="https://colab.research.google.com/github/AshishKumar4/FlaxDiff/blob/main/tutorial%20notebooks/edm%20tutorial.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Building better diffusion models with EDM

In this notebookk, we would discuss the paper ["Elucidating the design space of Diffusion based models" by Karras et al](https://arxiv.org/pdf/2206.00364), along with ideas from other subsequent papers after DDIM/DDPM, and look at some more generalized way of looking at the noise schedules and diffusion models in general. 

**This is part 2 in the diffusion series, and so its expected and strongly recommended that the reader go through the [Part 1: Simple Diffusion](./simple%20diffusion%20flax.ipynb) first to understand the basics and get familiar with the way we implement the ideas**

In the previous notebook, we discussed how to build, train and sample from a denoising model built as specified by the DDPM and DDIM papers, and also touched upon the idea the forward and reverse diffusion processes can be treated as ODEs/SDEs instead of discrete markov chains, letting us use ODE/SDE solvers to integrate the denoising model's output to generate images. We would carry on from these ideas in this notebook. 

The generated images in the previous notebook weren't really of great. There are several reasons for that such as low training epochs, basic model architecture, etc. Subsequent papers after DDPM and DDIM introduced several improvements and ideas to increase the quality, such as better noise schedulers, better ways to model the gradient of log likelihood, etc. The EDM paper by Karras et al takes these ideas and also introduces some new generalizations.

## New Ideas

### 1. Noise Schedules

The noise schedules in the DDPM and DDIM papers were simple linear schedules. In the previous notebook, we used the cosine noise schedule to train our model. We discussed how its a 'Variance preserving noise schedule'. By that we meant that if we formulate our forward diffusion process of adding noise as 

$x_t = \alpha_t x_0 + \sigma_t \epsilon_0$

where 
- $x_t$ is the data sample at time $t$
- $x_0$ is the initial data sample
- $\epsilon$ is the Gaussian noise
- $\alpha$ $\sigma_t$ are the signal and noise rates at time $t$ respectively

Then the variance of the data sample at time $t$ is given by

$Var(x_t) = \alpha_t^2 Var(x_0) + \sigma_t^2$

Assuming the $Var(x_0) = 1$, we have 

$Var(x_t) = \alpha_t^2 + \sigma_t^2$

than, the noise schedule is variance preserving if 

$\alpha_t^2 + \sigma_t^2 = 1$

The awesome [Score based generative modeling through stochastic differential equations](https://arxiv.org/pdf/2011.13456) paper also introduced the variance exploding noise schedule, where the variance increases with time instead of being constant

$\alpha_t^2 + \sigma_t^2 > 1$

When dealing with a variance exploding noise schedule, one can overlook $alpha_t$ by setting it to 1 and just talk about the noise schedule as $\sigma_t$

$x_t = x_0 + \sigma_t \epsilon_0$

**In the notebook from now on, we shall deal with the noise schedule as $\sigma_t$ only**

Now, ofcourse a model shouldn't be given an input sample that has arbitrarily high variance, so we scale the input samples by the variance of the noise schedule

$Var(x_t) = 1 + \sigma_t^2$

thus the scaling factor of $x_t$ when given to the model is $\frac{1}{\sqrt{1 + \sigma_t^2}}$



# Install Dependencies

In [None]:
!pip install jax[cuda12]==0.4.28 flax[all] orbax grain-nightly augmax clu

# Imports

In [None]:
import tqdm
from flax import linen as nn
import jax
from typing import Dict, Callable, Sequence, Any, Union
from dataclasses import field
import jax.numpy as jnp
import tensorflow_datasets as tfds
import grain.python as pygrain
# import tensorflow as tf
import numpy as np
import augmax

import matplotlib.pyplot as plt
from clu import metrics
from flax.training import train_state  # Useful dataclass to keep train state
import optax
from flax import struct                # Flax dataclasses
import time
import os
from datetime import datetime
from flax.training import orbax_utils
import functools
from tensorflow_datasets.core.utils import gcs_utils
gcs_utils._is_gcs_disabled = True


# Some Important Utils

In [None]:
normalizeImage = lambda x: jax.nn.standardize(x, mean=[127.5], std=[127.5])
denormalizeImage = lambda x: (x + 1.0) * 127.5


def plotImages(imgs, fig_size=(8, 8), dpi=100):
    fig = plt.figure(figsize=fig_size, dpi=dpi)
    imglen = imgs.shape[0]
    for i in range(imglen):
        plt.subplot(fig_size[0], fig_size[1], i + 1)
        plt.imshow(tf.cast(denormalizeImage(imgs[i, :, :, :]), tf.uint8))
        plt.axis("off")
    plt.show()

class RandomClass():
    def __init__(self, rng: jax.random.PRNGKey):
        self.rng = rng

    def get_random_key(self):
        self.rng, subkey = jax.random.split(self.rng)
        return subkey
    
    def get_sigmas(self, steps):
        return jnp.tan(self.theta_min + steps * (self.theta_max - self.theta_min)) / self.kappa

    def reset_random_key(self):
        self.rng = jax.random.PRNGKey(42)

class MarkovState(struct.PyTreeNode):
    pass

class RandomMarkovState(MarkovState):
    rng: jax.random.PRNGKey

    def get_random_key(self):
        rng, subkey = jax.random.split(self.rng)
        return RandomMarkovState(rng), subkey

# Data Pipeline

For simplicity, we will use the oxford flowers dataset for this notebook. We will use a newer pipeline for this notebook, based on the `google/grain` library, which helps us avoid any tensorflow dependencies apart from having to use it the very first time to download the dataset.

**Tensorflow is required to download the dataset in the very first run. You can install tensorflow cpu version if you are having issues with the cuda stuff.**

**If you have previously downloaded the dataset using the TFDS pipeline in the previous notebook, you might need to run clean the previous dataset stored at `~/tensorflow_datasets/`, otherwise the pipeline will throw an error**

In [None]:
def load_labels_oxford_flowers102(path):
    def load_labels():
        with open(path, "r") as f:
            textlabels = [i.strip() for i in f.readlines()]
        return textlabels
    return load_labels

# Configure the following for your datasets
dataToLabelGenMap = {
    "oxford_flowers102": load_labels_oxford_flowers102("~/tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt"),   # Change this if required!
}

def get_dataset(data_name="oxford_flowers102", batch_size=64, image_scale=256, method=jax.image.ResizeMethod.LANCZOS3):
    data_source = tfds.data_source(data_name, split="all", try_gcs=False)
    
    gpu_device = jax.devices("gpu")[0]
    cpu_device = jax.devices("cpu")[0]
    
    print(f"Gpu Device: {gpu_device}, Cpu Device: {cpu_device}")
        
    def preprocess(image):
        # image = jax.device_put(image, device=jax.devices("cpu")[0])
        image = (image - 127.5) / 127.5
        image = jax.image.resize(image, (image_scale, image_scale, 3), method=method)
        image = jnp.clip(image, -1.0, 1.0)
        image = jax.device_put(image, device=jax.devices("gpu")[0]) 
        return  image
    
    preprocess = jax.jit(preprocess, backend="cpu")

    augments = augmax.Chain(
        augmax.HorizontalFlip(0.5),
        augmax.RandomContrast((-0.05, 0.05), 1.),
        augmax.RandomBrightness((-0.2, 0.2), 1.)
    )

    augments = jax.jit(augments, backend="cpu")
    
    if os.path.exists(f"./datacache/{data_name}_labels.pkl"):
        print("Loading labels from cache")
        with open(f"./datacache/{data_name}_labels.pkl", "rb") as f:
            import pickle
            embed = pickle.load(f)
            embed_labels = embed["embed_labels"]
            embed_labels_full = embed["embed_labels_full"]
            null_labels = embed["null_labels"]
            null_labels_full = embed["null_labels_full"]
    else:
        print("No cache found, generating labels")
        textlabels = dataToLabelGenMap[data_name]()
        
        model, tokenizer = defaultTextEncodeModel()

        embed_labels, embed_labels_full = encodePrompts(textlabels, model, tokenizer)
        embed_labels = embed_labels.tolist()
        embed_labels_full = embed_labels_full.tolist()
        
        null_labels, null_labels_full = encodePrompts([""], model, tokenizer)
        null_labels = null_labels.tolist()[0]
        null_labels_full = null_labels_full.tolist()[0]
        
        os.makedirs("./datacache", exist_ok=True)
        with open(f"./datacache/{data_name}_labels.pkl", "wb") as f:
            import pickle
            pickle.dump({
                "embed_labels": embed_labels,
                "embed_labels_full": embed_labels_full,
                "null_labels": null_labels,
                "null_labels_full": null_labels_full
                }, f)
        
    embed_labels = [np.array(i, dtype=np.float16) for i in embed_labels]
    embed_labels_full = [np.array(i, dtype=np.float16) for i in embed_labels_full]
    null_labels = np.array(null_labels, dtype=np.float16)
    null_labels_full  = np.array(null_labels_full, dtype=np.float16)
    
    def labelizer(labelidx:int) -> jnp.array:
        label_pooled = embed_labels[labelidx]
        label_seq = embed_labels_full[labelidx]
        # label_pooled = jax.device_put(label_pooled, device=jax.devices("gpu")[0])
        # label_seq = jax.device_put(label_seq, device=jax.devices("gpu")[0])
        return label_pooled, label_seq

    class augmenter(pygrain.RandomMapTransform):
        def random_map(self, element: Dict[str, Any], rng: np.random.Generator) ->  Dict[str, jnp.array]:
            image = element['image']
            image = preprocess(image)
            image = augments(rng.integers(0, 2**32, [2], dtype=np.uint32), image) 
            labelidx = element['label']
            label, label_seq = labelizer(labelidx)
            # image, label = move2gpu(image, label)
            return {'image':image, 'label':label, 'label_seq':label_seq}

    sampler = pygrain.IndexSampler(
        num_records=len(data_source),
        shuffle=True,
        seed=0,
        num_epochs=None,
        shard_options=pygrain.ShardByJaxProcess(),
    )

    transformations = [augmenter(), pygrain.Batch(batch_size, drop_remainder=True)]

    loader = pygrain.DataLoader(
        data_source=data_source,
        sampler=sampler,
        operations=transformations,
        worker_count=4,
        read_options=pygrain.ReadOptions(8, 500),
        worker_buffer_size=5
        )
    return {
        "loader": loader,
        "null_labels": null_labels,
        "null_labels_full": null_labels_full,
        "embed_labels": embed_labels,
        "embed_labels_full": embed_labels_full,
        "length": len(data_source),  
        "batch_size": batch_size,
        "image_size": image_scale
    }
