# **Deep Generative Models**

<img src="https://hojonathanho.github.io/diffusion/assets/img/denoising_diffusion_all.png" width="60%" />

<a href="https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2022/blob/main/Indaba_2022_Prac_Template.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> [Change colab link to point to prac.]

© Deep Learning Indaba 2022. Apache License 2.0.

**Authors:** James Allingham **[, Add yourself!]**

Adapted from and inspired by [The Annotated Diffusion Model 🤗](https://huggingface.co/blog/annotated-diffusion).

**Introduction:** 

In this practical, we will investigate the fundamentals of generative modeling – a machine learning framework that allows us to learn how to sample new unseen data points that match the distribution of our training dataset. Generative modeling, though a powerful and flexible framework–which has provided many exciting advances in ML–has its own challenges and limitations. This practical will walk you through such challenges and will illustrate how to solve them by implementing a Denoising Diffusion Model (a.k.a. a Score-Based Generative Model), which is the backbone of the recent and exciting [Dalle-2](https://openai.com/dall-e-2/) and [Imagen](https://imagen.research.google/) models that we’ve all seen on [Twitter](https://twitter.com/search?q=%23dalle2%20%23imagen&src=typed_query).

**Topics:** 

Content: <font color='blue'>`Generative Models`</font>, <font color='red'>`Probabilistic Graphical Models`</font> 

Level: Advanced.




**Aims/Learning Objectives:**


* Understand the differences between generative and discriminative modeling.
* Understand how the probabilistic approach to ML is key to generative modeling.
* Understand the challenges of building generative models in practice, as well as their solutions.
* Understand, implement, and train a Denoising Diffusion Model.


**Prerequisites:**

* Familiarity with Jax and Haiku – going through the “Introduction to ML using Jax” practical is **strongly** recommended.
* Neural network basics (e.g., what ResNet, BatchNorm, and Adam are).
<!-- * Basic linear algebra. -->
* Basic probability theory (e.g., what a probability distribution is, what Bayes’ rule is).
* [Suggested] Attend the Monte Carlo 101 parallel – this session will provide a lot of background on probability theory.

**Outline:** 

[Points that link to each section. Auto-generate following the instructions [here](https://stackoverflow.com/questions/67458990/how-to-automatically-generate-a-table-of-contents-in-colab-notebook).]

**Before you start:**

For this practical, you will need to use a GPU to speed up training. To do this, go to the "Runtime" menu in Colab, select "Change runtime type" and then in the popup menu, choose "GPU" in the "Hardware accelerator" box.


## Installation and Imports

In [None]:
## Install and import anything required. Capture hides the output from the cell. 
#@title Install and import required packages. (Run Cell)

!pip install dm-haiku
!pip install optax
!pip install distrax
!pip install einops

import os 

# https://stackoverflow.com/questions/68340858/in-google-colab-is-there-a-programing-way-to-check-which-runtime-like-gpu-or-tpu
if int(os.environ["COLAB_GPU"]) > 0:
  print("a GPU is connected.")
elif "COLAB_TPU_ADDR" in os.environ and os.environ["COLAB_TPU_ADDR"]:
  print("A TPU is connected.")
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()
else:
  print("Only CPU accelerator is connected.")

from functools import partial

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import haiku as hk
import optax
import distrax
import tensorflow as tf
import tensorflow_datasets
from einops import rearrange
from opt_einsum import contract
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm

In [None]:
#@title Helper Functions. (Run Cell)

In [None]:
#@title Check what device you are using (Run Cell)
print(f"Num devices: {jax.device_count()}")
print(f" Devices: {jax.devices()}")

## **Sample Section 1**

[Background/content for the section.]

### Subsection - <font color='blue'>`Beginner`</font>

Math foundations:


**Math Task:**

[Optional math task or ask multiple choice question. E.g. the derivation of this would equal a, b or c. We could check this at the end of the prac.]


In [None]:
selection = 'a' #@param ["a", "b", "c"]
print(f"You selected: {selection}")

correct_answer = "a"
assert selection == correct_answer, "Incorrect answer, hint ..."

print("Nice, you got the correct answer!")

Code demonstration

In [None]:
# Code demonstration

**Code Task:**

In [None]:
# Code to be implemented during practical
# You should prove the function signature. 

In [None]:
# @title Answer to code task (Try not to peek until you've given it a good try!') 

**Group Task:**

Task that involves asking your neighbour or a group a question.

### Subsection - <font color='orange'>`Intermediate`</font>

Math foundations:


**Math Task:**

[Optional math task or ask multiple choice question. E.g. the derivation of this would equal a, b or c. We could check this at the end of the prac.]


In [None]:
selection = 'a' #@param ["a", "b", "c"]
print(f"You selected: {selection}")

correct_answer = "a"
assert selection == correct_answer, "Incorrect answer, hint ..."

print("Nice, you got the correct answer!")

Code demonstration

In [None]:
# Code demonstration

**Code Task:**

In [None]:
# Code to be implemented during practical
# You should prove the function signature. 

In [None]:
# @title Answer to code task (Try not to peek until you've given it a good try!') 

**Group Task:**

Task that involves asking your neighbour or a group a question.

### Subsection - <font color='green'>`Advanced`</font>

Math foundations:


**Math Task:**

[Optional math task or ask multiple choice question. E.g. the derivation of this would equal a, b or c. We could check this at the end of the prac.]


In [None]:
selection = 'a' #@param ["a", "b", "c"]
print(f"You selected: {selection}")

correct_answer = "a"
assert selection == correct_answer, "Incorrect answer, hint ..."

print("Nice, you got the correct answer!")

Code demonstration

In [None]:
# Code demonstration

**Code Task:**

In [None]:
# Code to be implemented during practical
# You should prove the function signature. 

In [None]:
# @title Answer to code task (Try not to peek until you've given it a good try!') 

**Group Task:**

Task that involves asking your neighbour or a group a question.

### Subsection - <font color='purple'>`Optional`</font>

Math foundations:


**Math Task:**

[Optional math task or ask multiple choice question. E.g. the derivation of this would equal a, b or c. We could check this at the end of the prac.]


In [None]:
selection = 'a' #@param ["a", "b", "c"]
print(f"You selected: {selection}")

correct_answer = "a"
assert selection == correct_answer, "Incorrect answer, hint ..."

print("Nice, you got the correct answer!")

Code demonstration

In [None]:
# Code demonstration

**Code Task:**

In [None]:
# Code to be implemented during practical
# You should prove the function signature. 

In [None]:
# @title Answer to code task (Try not to peek until you've given it a good try!') 

**Group Task:**

Task that involves asking your neighbour or a group a question.

### Section Quiz 

Optional end of section quiz. Below is an example of an assessment.

In [None]:
#@title Generate Quiz Form. (Run Cell)
from IPython.display import HTML
HTML(
"""
<iframe 
	src="https://forms.gle/zbJoTSz3nfYq1VrY6",
  width="80%" 
	height="1200px" >
	Loading...
</iframe>
"""
)

## **Deep Generative Models**

### **[Optional] Quick Probability Refreasher – The Sum, Product, and Bayes' rules**

In probability theory there are 3 fundamental rules that come up over and over again. So let's quickly recap them if they aren't at the tip of your mind. 

1. **The Product Rule:** $$p(x, y) = p(x|y)p(y) = p(y|x)p(x), $$ tells us that the joint probability distribution over two random variables ($x$ and $y$ in this example) is equal to the product of the probability distribution for one of those random variables (e.g., $p(x)$) multiplied by the conditional probability distribution of the second random variable given the first (e.g., $p(y|x)$). 

2. **The Sum Rule:** $$p(x) = \sum_y p(x,y), $$ tells us that we can sum over (or marginalise) a random variable ($y$ in this example) in a joint probability distribution in order to get a distribution over the remaining random variable(s) ($x$ in this example). In the case of a continuous distribution, the sum $\sum$ would be replaced by an integral $\int$.

3. **Bayes' Rule:** $$p(y|x) = \frac{p(x|y)p(y)}{p(x)}$$ tells us how to convert between conditional distributions. That is, if we have $p(x|y)$ (as well as $p(x)$ and $p(y)$) we can convert it to $p(y|x))$ and vice-versa. 



TODO: add intuition, discussion point, etc.

#### **[Optional] Math Task**

Bayes' rule can actually be derived from only the product rule. Test your understanding by trying to derive it yourself (or with your neighbours)! 

BONUS task: Go one step further and use the sum rule to write Bayes' rule without $p(x)$.

##### Answer to the math task (Try not to peek until you've given it a good try!') 


\begin{align}
p(x, y) & = p(y|x)p(x) && \triangleright \text{product rule}\\
\therefore p(y|x) &= \frac{p(x, y)}{p(x)} \\
&= \frac{p(x|y)p(y)}{p(x)} && \triangleright \text{product rule}
\end{align}

##### Answer to BONUS math task (Try not to peek until you've given it a good try!') 

\begin{align}
p(y|x) &= \frac{p(x|y)p(y)}{p(x)} \\
&= \frac{p(x|y)p(y)}{\sum_y p(x|y)p(y)}
\end{align}

### **What is Generative Modeling?**

One way that machine learning methods can be categorised is by whether they are *discriminative* or *generative*. Some of the most common machine learning tasks, such as regression and classification, are typically solved using discriminative methods. So you might be wondering "what is generative modeling?".

To understand what generative modelling is, let us compare and contrast it with disciminiative modelling. Specifically, let us consider the task of modelling data from two different classes. 

**Discriminative:** In this case, we are interested in learning a probability distibution $p(y|x)$, i.e., the probablity of the class $y$ being either 0 or 1 given an example $x$. For any example $x$, we can determine whether it is more likely to be from class 0 or class 1, and draw a corresponding decission boundary where $p(y = 0|x) = p(y = 1|x)$. 

**Generative:** Here, we are interested in learning the probability distribution $p(x|y)$, i.e., the probability of observing some data $x$ given that it is from class $y$. We might also be interested in learning the probability distribution $p(x)$ which is the probability of observing $x$ in either of the classes. *The cool thing about generative modeling is that if we can sample $x \sim p(x|y)$ or $x \sim p(x)$ we can generate new unseen examples*. 

---

<img src="https://miro.medium.com/max/1000/0*TXI_h4llG-SmdwYY.png" width="60%" />

*In the picture above ([source](https://betterprogramming.pub/generative-vs-discriminative-models-d26def8fd64a_)), red and blue dots represent examples ($x$) from classes $y = 0$ and $y = 1$, respectively. The dashed black line shows the decision boundary in the diciminative case. In the generative case, the light red and blue ovals show the areas where $p(y = 0|x)$ and $p(y = 1|x)$ are large, respectively.*

---

#### **Generative and Discriminative Models are Two Sides of the Same Coin**

From Bayes' rule $$p(y|x) = \frac{p(x|y)p(y)}{p(x)}$$ we can see that the discriminative model $p(y|x)$ and the generative model $p(x|y)$ are very closely linked. If we know $p(x)$ and $p(y)$ then we can convert between a generative and a discriminative model. However, as we will see in the next section, this isn't as easy as it might seem at first glance.

### **Generative Models + NNs = Deep Generative Models**

Now that we've discussed what a generative model is, we can answer the question "What is a **deep** generative model?". In this case, the answer is simple! A deep generative model is a generative model in which either $p_\theta (x)$ or $p_\theta (x|y)$ are represeented by (deep) neural networks (with paramters $\theta$)!

To be more concrete, a deep generative model is a neural network that allows us to sample $x$ from the distribution $p_\theta(x)$ (or $p_\theta(x|y)$). This is often done by sampling random noise $z$ and using that as the input of the NN – in this case the output of the NN is a sample $x$. Some deep generative models also allow us to evaluate $p_\theta(x)$ for a given example $x$, i.e., they allow us to measure the  probability of observing a particular example. 

#### **The Challenge of Deep Generative Models**

*The central problem of deep generative models is that we do not know how to write down the distribution $p(x)$.* 

Consider, for example, building a generative model for images of cats. What do you think that the true distribution of such images is? How would you write that down mathematically? These are very difficult questions for humans to answer because images are very high dimensional and it is difficult for us to reason about high dimmensional spaces. 

The distribtuion $p_\theta(x|y)$ is similarly high dimensional and difficult for use to describe, so we are not saved by considering the conditional model. (This should not come as a surprise, since we know that $p_\theta(x)$ and $p_\theta(x|y)$ are closely tied together via Bayes' rule!)

**Why do we need to write down $p_\theta(x)$?** We said above that we would train a neural network to allow use to draw samples from $p_\theta(x)$, so if we don't need to evaluate the probability of those samples, why do we need to worry about writing down $p_\theta(x)$? Unfortunately, we need to **train** our model! If we don't train the model, we have no way of ensuring that the samples we draw actually come from the distribution $p_\theta(x)$! 

**How do we train our deep generative model?** From the probabilistic perspective, there is really only one way to train our model: we want to maximise $p_\theta(x)$ for our observed data. In practice, we will instead minimise (using SGD on the weights $\theta$) the average *negative log-likelihood* (NLL): $$ \mathcal{L}_\text{NLL} = -\frac{1}{N} \sum_{n=1}^N \log p_\theta(x_n),$$
where $\{x_n\}_{n=1}^N$ is our training dataset. Thus, we need to be able to evaluate $p(x_n)$. 

#### **Other Types of Deep Generative Models**







While we will be focusing on Denoising Diffusion Models in this practical, it is certainly worth keeping in mind that there are many other kinds of deep generative models. Different kinds of models make different assumptions with different pros and cons. Fundamentally, they all solve the problem of calculating $p(x)$ in different ways. Some common deep generative models include:

1. Autoregressive models,
2. Variational Autoencoders (VAEs),
3. Normalizing Flows (NFs),
4. Energy-based Models (EBMs), and
5. Generative Adversarial Neworks (GANs)

##### **[Optional] Extra reading on other deep generative models**

**Autoregressive models** make the assumption that the probability distribution over a $D$ dimensional observation $\mathbf{x}$ factorises as $$ p(\mathbf{x}) = p(x_1) \prod_{d=1}^{D-1} p(x_{d+1}|x_{d},...,x_1), $$ where $x_d$ is the $d^{th}$ element of $\mathbf{x}$. This assumption allows us to choose simple distributions for each $p(x_d|...)$ but still have a complex overall distribution for $p(\mathbf{x})$. However, these models are very expensive to train and sample from due to the conditional structure. Furthermore, it is not clear what the correct ordering of the dimmensions of $\mathbf{x}$ is. Nevertheless, autoregressive models have seen huge success – a prime example being [WaveNet](https://www.deepmind.com/blog/wavenet-a-generative-model-for-raw-audio)!

[TODO: write 1 paragraph summaries for VAEs, NFs, EBMs, and GANs]

## **Denoising Diffusion Models**





Denoising diffusion models avoid the problem of estimating $p_\theta(x)$ by instead estimating the *score function* $\nabla_x \log p_\theta(x)$, i.e., the gradient of the log-likelihood with respect to the observed data $x$. 
It is because of the importance of the score function that denoising diffusion models are also known as *score-based generative models*.

**Note.** The score function $\nabla_x \log p_\theta(x)$ is *not* the same as the gradient of the log-likelihood w.r.t the parameters of the model $\theta$: $\nabla_\theta \log p_\theta(x)$, which we would use to train other deep generative models such as NFs and autoregressive models. $\nabla_\theta \log p_\theta(x)$ tells us how to change the parameters of the model to increase the likelihood, while $\nabla_x \log p_\theta(x)$ tells us how to change the data itself!

#### **Score Function Illustration**



To get some intuition for what the score function is, and how we can use it to create samples from a desired distribution, lets consider a simple bi-modal distribution $$ p(\mathbf{x}) = 0.5 \cdot \mathcal{N}\left(\mathbf{x}\, \middle|\, \mathbf{\mu}=\left[ {\begin{array}{cc}
    0.8 \\
    3.2 \\
  \end{array} } \right], \Sigma = \left[ {\begin{array}{cc}
    0.8 & -0.8 \\
    -0.8 & 2.0 \\
  \end{array} } \right]\right) + 0.5 \cdot \mathcal{N}\left(\mathbf{x}\, \middle|\, \left[ {\begin{array}{cc}
    -1.3 \\
    -2.7 \\
  \end{array} } \right], \left[ {\begin{array}{cc}
    1.5 & 0.75 \\
    0.75 & 1.5 \\
  \end{array} } \right]\right). $$

  We can easily implement this density function in `jax` using the [`Distrax`](https://github.com/deepmind/distrax) library: 

In [None]:
dist_1 = distrax.MultivariateNormalFullCovariance([.8, 3.2], covariance_matrix=[[.8, -.8], [-.8, 2.]])
dist_2 = distrax.MultivariateNormalFullCovariance([-1.3, -2.7], covariance_matrix=[[1.5, .75], [.75, 1.5]])
density_fn = lambda x: 0.5*jnp.exp(dist_1.log_prob(x.astype(jnp.float32))) + 0.5*jnp.exp(dist_2.log_prob(x.astype(jnp.float32)))

**Code Task:** Use `jax`'s `grad` transform to implement the score function.

In [None]:
score_fn =  # YOUR SOLUTION GOES HERE 

In [None]:
# @title Answer to code task (Try not to peek until you've given it a good try!') 
score_fn = jax.grad(lambda x: jnp.log(density_fn(x)))

Now that we have the score function for $p(x)$, we can use it to take randomly sampled noise and transform it into samples from $p(x)$. The procedure is simple: start with an initial random guess $x_T$, and then for $T$ *time*steps calculate $$ x_{t-1} = x_{t} + \lambda \cdot \nabla_x \log p(x), $$
until we arrive at $x_0$, which we hope will be a sample from $p(x)$.
Here $\lambda$ is a step size parameter which controlls how big a change we make at each timestep.

**Group Task:**

1. Using the sliders below, experiment with randomly choosing different coordinates for $\mathbf{x}$, as well as different numbers of timesteps and step sizes. 
2. Compare your initial coordinates $\mathbf{x}_T$ (red crosses) with the final coordinates $\mathbf{x}_0$ (cyan crosses). Which are more similar to the coordinates directly sampled from $p(\mathbf{x})$ (black circles)? 
3. Discuss amongst yourselves to make sure that this behaviour makes sense.

In [None]:
# @title Two Gaussians Score Function
from matplotlib import cm

x0 =  0.1 #@param {type:"slider", min:-5, max:5, step:0.1}
x1 =  -4.9 #@param {type:"slider", min:-5, max:5, step:0.1}
x = jnp.array([x0, x1], dtype=jnp.float32)

num_steps = 5 #@param {type:"slider", min:1, max:10, step:1}
step_size = 0.2 #@param {type:"slider", min:0.1, max:1, step:0.05}


# Make plots of the density p(x)
X0 = jnp.arange(-5, 5.1, 0.25)
X1 = jnp.arange(-5, 5.1, 0.25)
X0, X1 = jnp.meshgrid(X0, X1)

Xs = jnp.concatenate(
    [X0.reshape(-1, 1), X1.reshape(-1, 1)], axis=1
)
Z = jax.vmap(density_fn)(Xs).reshape(X1.shape)


fig = plt.figure(figsize=(12, 6))

ax1 = fig.add_subplot(121, projection='3d')
ax1.plot_surface(X0, X1, Z, cmap=cm.coolwarm,
                linewidth=0, alpha=0.4, antialiased=False)

ax2 = fig.add_subplot(122)
ax2.contourf(X0, X1, Z, cmap=cm.coolwarm,
            alpha=0.4, antialiased=False)


# Plot some samples from the true data distribution
key1, key2 = jax.random.split(jax.random.PRNGKey(42)) 
samples1, log_probs1 = dist_1.sample_and_log_prob(seed=key1, sample_shape=(25))
samples2, log_probs2 = dist_2.sample_and_log_prob(seed=key2, sample_shape=(25))  

log_probs = jnp.concatenate([log_probs1, log_probs2])
probs = 0.5*jnp.exp(log_probs)
samples = jnp.concatenate([samples1, samples2])

ax1.scatter(samples[:, 0], samples[:, 1], probs, c='k', marker='o', alpha=0.5)
ax2.scatter(samples[:, 0], samples[:, 1], c='k', marker='o', alpha=0.5)


# Plot the trajectory of the sample over time
ax1.scatter(x[0], x[1], density_fn(x), c='r', marker='x', s=75)
ax2.scatter(x[0], x[1], c='r', marker='x', s=75)

for step in range(num_steps):
    grad = score_fn(x)
    x = x + step_size*grad

    color = 'k' if step < num_steps - 1 else 'c'
    ax1.scatter(x[0], x[1], density_fn(x), c=color, marker='x', s=75)
    ax2.scatter(x[0], x[1], c=color, marker='x', s=75)  


ax1.set_ylim(-5, 5)
ax1.set_xlim(-5, 5)
ax1.set_zlim(0, 0.1)

ax2.set_ylim(-5, 5)
ax2.set_xlim(-5, 5)

plt.show()

#### **[Optional] Connections to EBM, VAEs, and Autoregressive Models**

[TODO]

#### **Diffusion**

In [None]:
# TODO: turn into coding task
def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return jnp.linspace(beta_start, beta_end, timesteps)

# TODO: plot this

In [None]:
# TODO: add explanations for this code! 

TIMESTEPS = 200

# define beta schedule
betas = linear_beta_schedule(timesteps=TIMESTEPS)

# define alphas 
alphas = 1. - betas
alphas_cumprod = jnp.cumprod(alphas, axis=0)
alphas_cumprod_prev = jnp.concatenate((jnp.array([1.0]), alphas_cumprod[:-1]))
sqrt_recip_alphas = jnp.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = jnp.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = jnp.sqrt(1. - alphas_cumprod)

# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a[..., t]
    return jnp.reshape(out, (batch_size, *((1,) * (len(x_shape) - 1))))

In [None]:
# forward diffusion
def q_sample(x_start, t, noise):
    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

In [None]:
def get_noisy_image(image, t, key):
    x_start = jnp.asarray(image, dtype=jnp.float32)[jnp.newaxis, ...]
    x_start = x_start / 255.  # normalize the images to [0, 1]
    x_start = (x_start * 2.) - 1.  # convert range to [-1, 1]

    # add noise
    noise = jax.random.normal(key, x_start.shape)
    x_noisy = q_sample(x_start, t, noise)

    # turn back into PIL image
    x_noisy = (x_noisy + 1.) / 2.
    x_noisy = x_noisy * 255.
    img_array = np.array(x_noisy).astype(jnp.uint8)[0]
    noisy_image = Image.fromarray(img_array)

    return noisy_image

In [None]:
from PIL import Image
import requests

# TODO: find better image... Indaba themed?

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image

In [None]:
get_noisy_image(image, jnp.array([40]), jax.random.PRNGKey(42))

In [None]:
imgs = [get_noisy_image(image, jnp.array([t]), jax.random.PRNGKey(0)) for t in [0, 50, 100, 150, 199]]

num_cols = len(imgs)
fig, axs = plt.subplots(figsize=(200,200), nrows=1, ncols=num_cols, squeeze=False)
for idx, img in enumerate(imgs):
    ax = axs[0, idx]
    ax.imshow(np.asarray(img))
    ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

plt.tight_layout()

### **Diffusion Models in Practice**

#### **The dataset**

In [None]:
mnist = tf.keras.datasets.mnist
# We are getting the labels only for the pruposes of exploring the data,
# we won't use them to train our models (or will we?).
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# text_labels = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

In [None]:
# Show 25 randomly selected images at a time
plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)

    img_index = np.random.randint(0, 50000)
    plt.imshow(train_images[img_index], cmap="gray_r")
    # plt.xlabel(text_labels[train_labels[img_index]])

In [None]:
train_data = np.expand_dims(train_images, axis=3) # add a channel dim
train_data = train_data.astype('float32') # convert to float32
train_data = (train_data - 127.5) / 127.5 # normalize the images to [-1, 1]
# TODO: move all of the above into a MAP statement?
BATCH_SIZE = 128
train_dataset = tf.data.Dataset.from_tensor_slices(train_data).cache().repeat().shuffle(BATCH_SIZE*10).batch(BATCH_SIZE)
train_dataset = iter(tensorflow_datasets.as_numpy(train_dataset))

#### **Building the NN**

In [None]:
class SinusoidalTimeEmbeddings(hk.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def __call__(self, time):
        half_dim = self.dim // 2
        embeddings = jnp.log(10000) / (half_dim - 1)
        embeddings = jnp.exp(jnp.arange(half_dim) * -embeddings)
        embeddings = time[:, jnp.newaxis] * embeddings[jnp.newaxis, :]
        embeddings = jnp.concatenate((jnp.sin(embeddings), jnp.cos(embeddings)), axis=-1)
        return embeddings

In [None]:
class Block(hk.Module):
    def __init__(self, dim_out, groups = 8):
        super().__init__()
        self.proj = hk.Conv2D(dim_out, kernel_shape=3, padding=(1, 1))
        self.norm = hk.GroupNorm(groups)
        self.act = jax.nn.silu

    def __call__(self, x):
        x = self.proj(x)
        x = self.norm(x)
        x = self.act(x)
        return x

class ResnetBlock(hk.Module):
    """https://arxiv.org/abs/1512.03385"""
    
    def __init__(self, dim_out, groups=8, change_dim=False):
        super().__init__()
        self.mlp = hk.Sequential([jax.nn.silu, hk.Linear(dim_out)])
        self.block1 = Block(dim_out, groups=groups)
        self.block2 = Block(dim_out, groups=groups)
        self.res_conv = hk.Conv2D(dim_out, kernel_shape=1, padding=(0, 0)) if change_dim else lambda x: x

    def __call__(self, x, time_emb):
        h = self.block1(x)

        time_emb = self.mlp(time_emb)
        h = time_emb[:, jnp.newaxis, jnp.newaxis] + h

        h = self.block2(h)
        return h + self.res_conv(x)

In [None]:
def Upsample(dim):
    return hk.Conv2DTranspose(dim, kernel_shape=4, stride=2)

def Downsample(dim):
    return hk.Conv2D(dim, kernel_shape=4, stride=2, padding=(1, 1))

In [None]:
# TODO: add picture

class Unet(hk.Module):
    def __init__(
        self,
        dim,
        dim_mults=(1, 2, 4,),
        channels=3,
        resnet_block_groups=7,
    ):
        super().__init__()

        # determine dimensions
        init_dim = dim // 3 * 2
        self.init_conv = hk.Conv2D(init_dim, kernel_shape=7, padding=(3, 3))
        
        # time embeddings
        time_dim = dim * 4
        self.time_mlp = hk.Sequential([
            SinusoidalTimeEmbeddings(dim),
            hk.Linear(time_dim),
            jax.nn.gelu,
            hk.Linear(time_dim),
        ])

        # layers
        self.downs = []
        self.ups = []
        dims = list(map(lambda m: dim * m, dim_mults))

        for ind, stage_dim in enumerate(dims):
            is_last = ind >= len(dims) - 1

            self.downs.append([
                ResnetBlock(stage_dim, groups=resnet_block_groups, change_dim=True),
                ResnetBlock(stage_dim, groups=resnet_block_groups),
                Downsample(stage_dim) if not is_last else lambda x: x,
            ])

        mid_dim = dims[-1]
        self.mid_block1 = ResnetBlock(mid_dim, groups=resnet_block_groups)
        self.mid_block2 = ResnetBlock(mid_dim, groups=resnet_block_groups)

        rev_dims = list(reversed(dims))
        for ind, stage_dim in enumerate(rev_dims):
            is_last = ind >= len(rev_dims) - 1

            # TODO: turn into coding task:
            self.ups.append([
                ResnetBlock(stage_dim, groups=resnet_block_groups, change_dim=True),
                ResnetBlock(stage_dim, groups=resnet_block_groups),
                Upsample(stage_dim) if not is_last else lambda x: x,
            ])

        self.final_block = ResnetBlock(dim, groups=resnet_block_groups)
        self.final_conv = hk.Conv2D(channels, kernel_shape=1, padding=(0, 0))

    def __call__(self, x, time):
        x = self.init_conv(x)
        t = self.time_mlp(time)

        h = []
        # downsample
        for block1, block2, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            h.append(x)
            x = downsample(x)

        # bottleneck
        x = self.mid_block1(x, t)
        x = self.mid_block2(x, t)

        # upsample
        for block1, block2, upsample in self.ups:
            # TODO: turn into coding task:
            x = jnp.concatenate((x, h.pop()), axis=-1)
            x = block1(x, t)
            x = block2(x, t)
            x = upsample(x)

        x = self.final_block(x, t)
        return self.final_conv(x)

#### **Training**

In [None]:
def build_forward_fn(dim, channels, dim_mults=None, resnet_block_groups=7):
    if dim_mults is None:
        dim_mults=(1, 2, 4,)

    def forward_fn(x, time):
        """Forward pass."""
        model = Unet(
            dim, dim_mults=dim_mults, channels=channels,
            resnet_block_groups=resnet_block_groups,
        )

        return model(x, time)

    return forward_fn

In [None]:
LR = 3e-4
IMAGE_SIZE = 28
CHANNELS = 1

# initialise model
init_rng = jax.random.PRNGKey(42)
forward_fn = hk.transform(build_forward_fn(dim=IMAGE_SIZE, channels=CHANNELS))
forward_fn = hk.without_apply_rng(forward_fn)

dummy_input = jnp.ones((1, IMAGE_SIZE, IMAGE_SIZE, CHANNELS))
params = forward_fn.init(init_rng, dummy_input, jnp.array([0]))

# set up the optimiser
optimiser = optax.adam(LR)
opt_state = optimiser.init(params)

In [None]:
# TODO: add task for deriving how this loss relates to NLL?
# Bonus task: actually implement NLL to show that they are the same?
@jax.jit
def loss_fn(params, x_batch, t, noise):
    x_noisy = q_sample(x_start=x_batch, t=t, noise=noise)
    predicted_noise = forward_fn.apply(params, x_noisy, t)

    batch_loss = jnp.mean((predicted_noise - noise)**2)

    return batch_loss

@jax.jit
def update(params, opt_state, batch, t, noise):
  # get data neded for training
  loss, grads = jax.value_and_grad(loss_fn)(params, batch, t, noise)
  updates, opt_state = optimiser.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state, loss

# TODO: turn into coding task:
@partial(jax.jit, static_argnums=4)
def sample_iteration(params, key, x, t, add_noise=True):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    # Equation 11 in the paper
    # Use our model (noise predictor) to predict the mean
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * forward_fn.apply(params, x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    posterior_variance_t = extract(posterior_variance, t, x.shape)
    noise = jax.random.normal(key, x.shape)

    # Algorithm 2 line 4:
    return model_mean + jnp.sqrt(posterior_variance_t) * noise 

# Algorithm 2 but save all images:
def sample(params, key, image_size, batch_size=16, channels=3):
    shape = (batch_size, image_size, image_size, channels)

    img_key, key = jax.random.split(key, num=2)
    # start from pure noise (for each example in the batch)
    img = jax.random.normal(img_key, shape)
    imgs = []
    
    for i in tqdm(reversed(range(0, TIMESTEPS)), desc='sampling loop time step', total=TIMESTEPS):
        noise_key, key = jax.random.split(key, num=2)
        img = sample_iteration(params, noise_key, img, jnp.full((batch_size,), i), i != 0)
        imgs.append(img)

    return imgs

In [None]:
STEPS = 2501
LOG_EVERY = 100

key = jax.random.PRNGKey(42)

# TODO: fix memory leak

# Training & evaluation loop.
best_loss = jnp.inf
for step in tqdm(range(STEPS)):
    key, noise_key, time_key, sample_key = jax.random.split(key, num=4)

    x_batch = next(train_dataset)
    batch_size = x_batch.shape[0]

    t = jax.random.randint(time_key, (batch_size,), 0, TIMESTEPS)
    noise = jax.random.normal(noise_key, x_batch.shape)
    params, opt_state, loss = update(params, opt_state, x_batch, t, noise)
    # TODO: add losses to TQDM

    if loss < best_loss:
        best_params = params
        best_loss = loss
        best_step = step

    if step % LOG_EVERY == 0:
        print(f'{step:4} - {loss:.4f} \t ({best_step:4} - {best_loss:.4f})')

In [None]:
samples = sample(best_params, key, IMAGE_SIZE, 32, CHANNELS)

In [None]:
# Show 25 randomly selected images at a time
plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)

    img_index = np.random.randint(0, 50000)
    plt.imshow(samples[-1][i][:, :, 0], cmap="gray_r")

In [None]:
import matplotlib.animation as animation

sample_index = 32

fig = plt.figure()
ims = []
for i in range(TIMESTEPS):
    im = plt.imshow(samples[i][sample_index][:, :, 0], cmap="gray_r", animated=True)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    ims.append([im])

animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=5000)
plt.close()

HTML(animate.to_html5_video())

#### **Guided diffusion**

[TODO] explain guided diffusion 

#### **Excercises** 

1. [Beginner] Try scale up to FashionMNIST. Hint, adding attention (and ConvNext?) might help.
2. [Intermediate] Add label information to make a conditional generative model $p(x|y)$.
3. [Advanced] Implement classifier free guidance.

In [None]:
# @title Code that might be useful for scaling up the model... not sure yet

# Residual(PreNorm(Attention(dim)))

class Attention(hk.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = hk.Conv2D(hidden_dim * 3, kernel_shape=1, padding=(0, 0), with_bias=False)
        self.to_out = hk.Conv2D(dim, kernel_shape=1, padding=(0, 0))

    def __call__(self, x):
        b, h, w, c = x.shape
        qkv = jnp.array_split(self.to_qkv(x), 3, axis=-1)
        q, k, v = map(
            lambda t: rearrange(t, "b x y (h d) -> b h d (x y)", h=self.heads), qkv  # c = (h d), i = (x y)
        )
        q = q * self.scale

        sim = contract("b h d i, b h d j -> b h i j", q, k)
        # Softmax trick to avoid numerical issues
        sim = sim - jax.lax.stop_gradient(jnp.argmax(sim, axis=-1, keepdims=True))
        attn = jax.nn.softmax(sim, axis=-1)

        out = contract("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h (x y) d -> b x y (h d)", x=h, y=w)
        return self.to_out(out)


class PreNorm(hk.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        self.norm = hk.GroupNorm(1)

    def __call__(self, x):
        x = self.norm(x)
        return self.fn(x)

class Residual(hk.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def __call__(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

## Conclusion
**Summary:**

[Summary of the main points/takeaways from the prac.]

**Next Steps:** 

[Next steps for people who have completed the prac, like optional reading (e.g. blogs, papers, courses, youtube videos). This could also link to other pracs.]

**Appendix:** 

[Anything (probably math heavy stuff) we don't have space for in the main practical sections.]

**References:** 

Much of this material was adapted from and inspired by [The Annotated Diffusion Model 🤗](https://huggingface.co/blog/annotated-diffusion).

For other practicals from the Deep Learning Indaba, please visit [here](https://github.com/deep-learning-indaba/indaba-pracs-2022).

## Feedback

Please provide feedback that we can use to improve our practicals in the future.

In [None]:
#@title Generate Feedback Form. (Run Cell)
from IPython.display import HTML
HTML(
"""
<iframe 
	src="https://forms.gle/bvLLPX74LMGrFefo9",
  width="80%" 
	height="1200px" >
	Loading...
</iframe>
"""
)

<img src="https://baobab.deeplearningindaba.com/static/media/indaba-logo-dark.d5a6196d.png" width="50%" />