# Denoising Diffusion Probabilistic Models

In the last section we looked at the forward diffusion process of generating noise from a clean image. In this section we will look at the reverse process of generating data based on noise. Instead of images we will start with a generated 2D dataset.

## Backgorund - Reverse Process

The purpose of the reverse process $p$ is to approximate the previous step $x_{t-1}$ in the diffusion chain based on a sample $x_t$. In practice, this approximation $p(x_{t-1}|x_t)$ must be done without the knowledge of $x_0$.

A parametrizable prediction model with parameters $\theta$ is used to estimate $p_\theta(x_{t-1}|x_t)$.

The reverse process will also be (approximately) gaussian if the diffusion steps are small enough:

$$
p_\theta({x}_{0:T}) = p({x}_T) \prod^T_{t=1} p_\theta({x}_{t-1} \vert {x}_t) \quad
p_\theta({x}_{t-1} \vert {x}_t) := \mathcal{N}({x}_{t-1}; \boldsymbol{\mu}_\theta({x}_t, t), \boldsymbol{\Sigma}_\theta({x}_t, t))
\tag{3}
$$

In many works, it is assumed that the variance of this distribution should not depend strongly on $x_0$ or $x_t$, but rather on the stage of the diffusion process $t$. This can be observed in the true distribution $q(x_{t-1}|x_t, x_0)$, where the variance of the distribution equals $\tilde{\beta}_t$.

### Parameterizing $\mu_\theta$
There are at least 3 ways of parameterizing the mean of the reverse step distribution $p_\theta(x_{t-1}|x_t)$:
1. Directly (a neural network will **estimate $\mu_\theta$**)
$$\mu_\theta = \mu_\theta(x_t,t)$$
2. Via $x_0$ (a neural network will **estimate $x_0$**)
$$\tilde{\mu}_\theta = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}x_{0, \theta}(x_t,t) + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}x_t $$

3. Via noise $\epsilon$ subtraction from $x_t$ (a neural network will **estimate $\epsilon$**)
$$x_t=\sqrt{\bar{\alpha}_t}\hat{x}_0 + \epsilon_\theta(x_t,t)\sqrt{1-\bar{\alpha}_t}$$
$$\hat{x}_0=\frac{1}{\sqrt{\bar{\alpha}_t}}x_t - \epsilon_\theta(x_t,t)\sqrt{\frac{1}{\bar{\alpha}_t}-1}$$


Approach 3 approximating the normal noise $\epsilon_\theta$ is used most widely. For the diffusion denoising process of images, U-Nets are often used. The U-Net architecture is a convolutional neural network that was introduced for image segmentation. It is called U-Net because its architecture looks U-shaped. The input gets downsampled and then upsampled again to the output. Additionally layers with the same size on the down and upsampling side are connected with skip connections, keeping important information. An example of a U-Net for estimating $\epsilon$ is shown in the figure below.

<figure>
<img src="./imgs/U-net.png" alt="U-Net" width="500"/>
<figcaption>U-Net Architecture with ResNet blocks (from: https://cvpr2022-tutorial-diffusion-models.github.io/)</figcaption>
</figure>

## Loading necessary libraries

In [None]:
import Pkg
Pkg.activate("diffusion")
using Plots
using Flux
using BSON
using JSON
using StatsBase
using Printf
using LaTeXStrings
using Dates

In [None]:
using Revise
include("./DenoisingDiffusion/src/DenoisingDiffusion.jl")
using .DenoisingDiffusion
using .DenoisingDiffusion: train!

In [None]:
includet("./DenoisingDiffusion/common/datasets.jl");
includet("./DenoisingDiffusion/common/utilities.jl");
directory = joinpath("outputs", "2d_" * Dates.format(now(), "yyyymmdd_HHMM"))

## Generating the Data

First, we generate a noise disturbed 2D Swiss Roll dataset with 10000 samples. The Swiss Roll is a classic dataset with a non-linear structure. The data is generated to be similar to the `make_swiss_roll` function from the `sklearn` library.

In [None]:
n_samples = 10_000
X = normalize_neg_one_to_one(make_spiral(n_samples))
X_val = normalize_neg_one_to_one(make_spiral(floor(Int, 0.1 * n_samples)))

In [None]:
scatter(X[1, :], X[2, :], alpha=0.5, aspectratio=:equal)

## Model

Since we now have our data we can start building our model. Similar to the U-Net from above we want to create a model that can estimate the noise $\epsilon$ from the data point $x$ for a given time $t$. First we will use a simple feedforward neural network since we only have 2D data:

<figure>
<img src="./imgs/simple_MLP_model.png" alt="simple_MLP_model" width="500"/>
<figcaption>Conditional noise estimation model.</figcaption>
</figure>

### How can we build a model that takes the data and the time as input and returns the noise?

Notice how the model not only has to depend on the data point $x$ but also on the time $t$. This is because the noise $\epsilon$ is dependent on the time $t$ not only the data point $x$. A good model should therefore be time dependant. However, for efficiency we should share the weights of the model for different times. This can be done by using an embedding vectors for the time $t$ and add them to the output of the layers which are based on the data $x$.

We can do this using `Flux.Embedding` or using `SinusoidalPositionEmbedding` from `DenoisingDiffusion.jl`. It creates a matrix where every column as a whole is unique. Each column can then be used as a time embedding for a particular time step. The uniqueness of each column is accomplished by using periodic trigonometric functions for the rows with gradually increasing frequency. See the image below for a visual demonstration.

<figure>
<img src="./imgs/position_encodings.png" alt="position_encodings" width="500"/>
<figcaption>Heatmap demonstrating the sinusodial embedding.</figcaption>
</figure>

You can try both embedding functions. We will use the latter for now.

In [None]:
### settings
num_timesteps = 40
to_device = cpu
num_epochs = 120

d_hid = ... # set to reasonable value
d_in_out = ... # set to correct value for the dataset

In [None]:
model = ConditionalChain(
    Parallel(.+, Dense(d_in_out, d_hid), Chain(SinusoidalPositionEmbedding(num_timesteps, d_hid), Dense(d_hid, d_hid))),
    swish,
    Parallel(.+, Dense(d_hid, d_hid), Chain(SinusoidalPositionEmbedding(num_timesteps, d_hid), Dense(d_hid, d_hid))),
    swish,
    Parallel(.+, Dense(d_hid, d_hid), Chain(SinusoidalPositionEmbedding(num_timesteps, d_hid), Dense(d_hid, d_hid))),
    swish,
    Dense(d_hid, d_in_out),
)
display(model)

Test the Model

In [None]:
# test forward pass by getting some samples from our training data
n_batch = 10 
X_t = X[:, 1:n_batch]
X_t = to_device(X_t)
println(size(X_t))
# make time steps 
timesteps = rand(1:num_timesteps, size(X_t, n_batch)) |> to_device
# test forward pass
X_t_hat = model(X_t, timesteps)
# the output should be the same shape as the input
println(size(X_t_hat))

### Let's define our schedule again 
This is similar to the schedule we used in the previous section. We use the same schedule for the forward and reverse process.

In [None]:
βs = linear_beta_schedule(num_timesteps, 8e-6, 9e-5)

Based on the schedule and our noise prediction model we can now define a `GaussianDiffusion` Model. It collects everything to estimates the normal distributions and saves it in a struct for easy access: 
```julia
function GaussianDiffusion(V::DataType, βs::AbstractVector, data_shape::NTuple, denoise_fn)
    αs = 1 .- βs
    α_cumprods = cumprod(αs)
    α_cumprod_prevs = [1, (α_cumprods[1:end-1])...]

    sqrt_α_cumprods = sqrt.(α_cumprods)
    sqrt_one_minus_α_cumprods = sqrt.(1 .- α_cumprods)
    sqrt_recip_α_cumprods = 1 ./ sqrt.(α_cumprods)
    sqrt_recip_α_cumprods_minus_one = sqrt.(1 ./ α_cumprods .- 1)

    posterior_variance = βs .* (1 .- α_cumprod_prevs) ./ (1 .- α_cumprods)
    posterior_log_variance_clipped = log.(max.(posterior_variance, 1e-20))

    posterior_mean_coef1 = βs .* sqrt.(α_cumprod_prevs) ./ (1 .- α_cumprods)
    posterior_mean_coef2 = (1 .- α_cumprod_prevs) .* sqrt.(αs) ./ (1 .- α_cumprods)

    GaussianDiffusion{V}(
        length(βs),
        data_shape,
        denoise_fn,
        βs,
        αs,
        α_cumprods,
        α_cumprod_prevs,
        sqrt_α_cumprods,
        sqrt_one_minus_α_cumprods,
        sqrt_recip_α_cumprods,
        sqrt_recip_α_cumprods_minus_one,
        posterior_variance,
        posterior_log_variance_clipped,
        posterior_mean_coef1,
        posterior_mean_coef2
    )
end
```

In [None]:
diffusion = GaussianDiffusion(Vector{Float32}, βs, (2,), model)
diffusion = diffusion |> to_device

## Training the Model

Now we have everything to train our model, except for the loss we want to minimize. In the last section we discussed the forward process $q(x_t|x_{t-1})$ and now the reverse process $p_{\theta}(x_{t-1}|x_t)$. Ideally if we start with $x_{t-1}$ to generate $x_t$ using $q(x_t|x_{t-1})$ and then use $x_t$ to generate $x_{t-1}$ using $p_{\theta}(x_{t-1}|x_t)$ we should get the same $x_{t-1}$ as we started with. Hence the forward the reverse process should cancel each other out. 

The paper by [Ho et al.](https://arxiv.org/abs/2006.11239) which intorduced DDPM, showed that they could successfully train the model using a simple loss of the form: 

$$
\mathcal{L} = \mathbb{E} [|| \epsilon - \epsilon_{\theta} ||^2]
$$

where $\epsilon$ is the true noise and $\epsilon_{\theta}$ is the predicted noise. We can use the `Flux.mse` loss function to minimize the difference between the true noise and the predicted noise. Additionally we  use the `ADAM` optimizer to minimize this loss. We will train the model for 100 epochs.

### The loss function

Now let's define a function which takes the model, the data and the mse loss function and uses these to sample noised images, denoise them and calculate the loss. 

1. Generate the time steps for the schedule
2. Sample the noise for the data

In [None]:
x_start = X_t
timesteps = # Get random time steps for each batch in the correct range using the rand function
noise = randn(eltype(eltype(diffusion)), size(x_start)) |> to_device

3. Add noise to the data:

    To noise the data we can use the `q_sample` from `GaussianDiffusion.jl`. The function is defined as follows:
    ```julia
    function q_sample(
        diffusion::GaussianDiffusion, 
        x_start::AbstractArray, 
        timesteps::AbstractVector{Int}, 
        noise::AbstractArray
        )
        coeff1 = _extract(diffusion.sqrt_α_cumprods, timesteps, size(x_start))
        coeff2 = _extract(diffusion.sqrt_one_minus_α_cumprods, timesteps, size(x_start))
        coeff1 .* x_start + coeff2 .* noise
    end
    ```

    Using X_t from above let's call it and see what it returns:

In [None]:
x = ... # sample from the diffusion model using q_sample

4. Estimate the noise from the input data:

    Therefore we call our model to estimate the noise of the noised data using the `denoise_fn` from the `GaussianDiffusion` struct: `diffusion.denoise_fn(x, timesteps)'.

In [None]:
model_out =  ... # get the output of the model

This does not jet the right function to estimate the noise from the data, since we have to train our model first. However, now you have all the necessary steps to do the last step: 

5. Calculate the loss:

In [None]:
loss_type = Flux.mse;
est_loss = ... # calculate the loss

Finally let's put all together:

In [None]:
function p_losses(diffusion::GaussianDiffusion, loss, x_start::AbstractArray{T,N}; to_device=cpu) where {T,N}
    # estimate and return the loss based on the steps above 
end

In [None]:
# This defines us a function that we can use to calculate the loss using the Flux training loop
loss(diffusion, x::AbstractArray) = p_losses(diffusion, loss_type, x; to_device=to_device)

In [None]:
# test the loss with the X_t samples from the previous test:
loss(diffusion, X_t)

Define the data loaders and test our initial model using the validation data and the loss function we defined above.

In [None]:
# define the train and val data loaders
train_data = Flux.DataLoader(X |> to_device; batchsize=32, shuffle=true);
val_data = Flux.DataLoader(X_val |> to_device; batchsize=32, shuffle=false);
opt = Adam(0.001);

# Calculating initial loss
val_loss = 0.0
for x in val_data
    global val_loss
    val_loss += loss(diffusion, x)
end
val_loss /= length(val_data)
@printf("\nval loss: %.5f\n", val_loss)

### Start the Training:

The Training algorithm is as follows:

Repeat for each epoch:  
- Sample data: $x_0 \sim q(x_0)$   
- Sample time: $t \sim \mathcal{U}(1, ... ,T)$  
- Sample noise: $\epsilon \sim \mathcal{N}(0,I)$  
- Generate $x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon$  
- Estimate $\epsilon_{\theta} \sim p_{\theta}(x_{t-1}|x_t)$  
- Compute the loss: $\mathcal{L} = \mathbb{E} [|| \epsilon - \epsilon_{\theta} ||^2]$  
- Update the model parameters using the optimizer

In [None]:
start_time = time_ns()
opt_state = Flux.setup(opt, diffusion)
history = train!(loss, diffusion, train_data, opt_state, val_data; num_epochs=num_epochs)
end_time = time_ns() - start_time

@printf "time taken: %.2fs\n" end_time / 1e9

### Plot the Training and Validation Loss

In [None]:
diffusion = diffusion |> cpu

canvas_train = plot(
    1:length(history["mean_batch_loss"]), history["mean_batch_loss"], label="mean batch loss",
    xlabel="epoch",
    ylabel="loss",
    legend=:right, # :best, :right
    ylims=(0, Inf),
)
plot!(canvas_train, 1:length(history["val_loss"]), history["val_loss"], label="validation loss")
display(canvas_train)

### Plot the Results

We can predict the noise $\epsilon_{\theta}$ using the trained model $\epsilon_\theta(x_t,t)$ simply like this:
```julia
noise = diffusion.denoise_fn(x, timesteps)
```
After that we can predict the start point using the predicted noise (this is done using the `denoise` function from `DenoisingDiffusion.jl`):
$$\hat{x}_0=\frac{1}{\sqrt{\bar{\alpha}_t}}x_t - \epsilon_\theta(x_t,t)\sqrt{\frac{1}{\bar{\alpha}_t}-1}$$

Then we can estimate the mean $\tilde{\mu}_t$ and standard deviation $\tilde{\beta}_t$ of the posterior distribution $p_\theta({x}_{t-1} \vert {x}_t)$:
$$\tilde{\mu}_t(x_t, \hat{x}_0) = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}\hat{x}_0 + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}x_t$$
$$\tilde{\beta}_t = \beta_t\frac{(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}$$

This estimation is done using the `q_posterior_mean_variance` function from `DenoisingDiffusion.jl`. Lastly we call the `p_sample` function to sample from the posterior distribution $p_\theta({x}_{t-1} \vert {x}_t)$ to get $x_{t-1}$ for the reverse process:
$$x_{t-1} = \tilde{\mu}_t(x_t, \hat{x}_0)+\tilde{\beta}_t z \quad \text{where} \quad z \sim \mathcal{N}(0, I)$$

> Note that I used $z$ here since it is a sample from the standard normal distribution and should not be confused with the predicted noise $\epsilon_{\theta}$. Nevertheless, the predicted sample $x_{t-1}$ is a sample from the posterior distribution:
>$$ p_\theta({x}_{t-1} \vert {x}_t) := \mathcal{N}({x}_{t-1}; \boldsymbol{\mu}_t(x_t, \hat{x}_0), \tilde{\beta}_t)$$

To generate samples starting from noise we use these steps several times as shown in the following:

Repeat for each sample:  
- Sample noise: $\epsilon \sim \mathcal{N}(0,I)$  
- Repeat for each time step $t$ in reverse order:  
    - Predict the noise: $\epsilon_{\theta}(x_t, t)$
    - Predict the start point: $\hat{x}_0$
    - Estimate the mean and standard deviation of the posterior distribution: $\tilde{\mu}_t(x_t, \hat{x}_0)$ and $\tilde{\beta}_t$
    - Sample from the posterior distribution: $x_{t-1} = \tilde{\mu}_t(x_t, \hat{x}_0)+\tilde{\beta}_t z$
    - Set $x_t = x_{t-1}$
- Return the final generated sample $x_t$ 

In [None]:
#Generate new samples and denoise it to the first time step.
function p_sample_loop(diffusion::GaussianDiffusion, shape::NTuple; clip_denoised::Bool=true, to_device=cpu)
    T = eltype(eltype(diffusion))
    x = randn(T, shape) |> to_device
    for t in diffusion.num_timesteps:-1:1
        timesteps = fill(t, shape[end]) |> to_device
        noise = randn(T, size(x)) |> to_device
        x, x_start = p_sample(diffusion, x, timesteps, noise; clip_denoised=clip_denoised, add_noise=(t != 1))
    end
    x
end

function p_sample_loop(diffusion::GaussianDiffusion, batch_size::Int; options...)
    p_sample_loop(diffusion, (diffusion.data_shape..., batch_size); options...)
end

In [None]:
X0 = p_sample_loop(diffusion, 1000)
canvas_samples = scatter(X0[1, :], X0[2, :], alpha=0.5, label="",
    aspectratio=:equal,
    xlims=(-2, 2), ylims=(-2, 2),
)
display(canvas_samples)

## Visualize the reverse process
Note that `p_sample` can additionally return the estimated start sample at each step $\hat{x}_0$. So we can plot both the current sample after the performed time steps and the estimated start sample at each time step. 

First let's define a function which samples the whole reverse process for a given sample:

In [None]:
# Generate new samples and denoise them to the first time step. Return all samples where the last dimension is time.
function p_sample_loop_all(diffusion::GaussianDiffusion, shape::NTuple; clip_denoised::Bool=true, to_device=cpu)
    T = eltype(eltype(diffusion))
    x = randn(T, shape) |> to_device
    x_all = Array{T}(undef, size(x)..., 0) |> to_device
    x_start_all = Array{T}(undef, size(x)..., 0) |> to_device
    tdim = ndims(x_all)
    for t in diffusion.num_timesteps:-1:1
        timesteps = fill(t, shape[end]) |> to_device
        noise = randn(T, size(x)) |> to_device
        x, x_start = p_sample(diffusion, x, timesteps, noise; clip_denoised=clip_denoised, add_noise=(t != 1))
        x_all = cat(x_all, x, dims=tdim)
        x_start_all = cat(x_start_all, x_start, dims=tdim)
    end
    x_all, x_start_all
end

function p_sample_loop_all(diffusion::GaussianDiffusion, batch_size::Int=16; options...)
    p_sample_loop_all(diffusion, (diffusion.data_shape..., batch_size); options...)
end

Note that `size(x)...` is used to get the size of the data point $x$ not as a tuple but as separate arguments. For Example, if x is a 2D array with size (3, 4), size(x)... would be equivalent to 3, 4. In function arguments it is used to indicate that the function can take any number of arguments.

In [None]:
Xs, X0s = p_sample_loop_all(diffusion, 1000);
anim_denoise = @animate for i ∈ 1:(num_timesteps+10)
    i = i > num_timesteps ? num_timesteps : i
    p1 = scatter(Xs[1, :, i], Xs[2, :, i],
        alpha=0.5,
        title=L"${x}_t$",
        label="",
        aspectratio=:equal,
        xlims=(-2, 2), ylims=(-2, 2),
        figsize=(400, 400),
    )
    p2 = scatter(X0s[1, :, i], X0s[2, :, i],
        alpha=0.5,
        title=L"$\hat{x}_0$",
        label="",
        aspectratio=:equal,
        xlims=(-2, 2), ylims=(-2, 2),
        figsize=(400, 400),
    )
    plot(p1, p2, plot_title="i=$i")
end
directory = "diffusion"
gif(anim_denoise, joinpath(directory, "reverse_x0.gif"), fps=8)