# Forward Diffusion Process

In this notebook, the forward diffusion process is simulated. In the forward diffusion process, gaussian noise is introduced successively until the original data becomes all noise. In this example we will use an image as our example datapoint and we will add gaussian noise to it until the original image is completely lost.

## Loading and preprocessing the image

First, let's add and import the neccesary packages.

In [None]:
# If you are using this notebook for the first time run the following commands to install the required packages: 
import Pkg
Pkg.generate("diffusion") # generate a new package
Pkg.activate("diffusion") # activate the package
Pkg.add("Printf")
Pkg.add("LaTeXStrings")
Pkg.add("Revise")
Pkg.add("Images")
Pkg.add("Plots")
Pkg.add("Flux")
Pkg.add("StatsBase")
Pkg.add("NNlib")
Pkg.add("BSON")
Pkg.add("ProgressMeter")
Pkg.add("Random")
Pkg.add("JSON")

In [None]:
# for the second time onwards, run the following command to activate the package
import Pkg
Pkg.activate("diffusion") # activate the package

In [None]:
using Images
using Plots
using Revise
include("./ImageProcessing.jl")
using .ImageProcessing

Next, we will load the image and preprocess it. The image will be converted to `float32` and normalized to the range [-1, 1].

In [None]:
img_f32 = ImageProcessing.load_image("./imgs/corgy-dungeon.jpg")

println("Original: ")
println("Max: ", maximum(img_f32))
println("Min: ", minimum(img_f32))

# normalize the image
img_f32_n = ImageProcessing.normalize(img_f32)

println("Normalized: ")
println("Max: ", maximum(img_f32_n))
println("Min: ", minimum(img_f32_n))


# Show the image -> you can dircetly pass the normalized image to the show_image function
ImageProcessing.show_image(img_f32_n)

## Forward Process - Adding Noise using the Gaussian Distribution

The Gaussian distribution is a continuous probability distribution that is defined by the probability density function (pdf):

$$
\text{pdf}(x, \mu, \sigma) = \frac{1}{\sigma \sqrt{2\pi}} e^{-\frac{1}{2}\left(\frac{x-\mu}{\sigma}\right)^2} 
$$

where $\mu$ is the mean and $\sigma$ is the standard deviation. When we sample x from the distribution we can write it as $x \sim \mathcal{N}(\mu, \sigma^2)$ or $\mathcal{N}(x; \mu, \sigma^2)$. To easily sample from any normal distribution we can simply sample from the standard normal distribution and then scale and shift the samples: 

$$
x = \mu + \sigma z \quad \text{where} \quad z \sim \mathcal{N}(0, 1)
$$

Here is a simple example where we sample 1000 times from a normal distribution with $\mu = 1.5$ and $\sigma = 0.7$ and then plot the histogram of the samples. We can see that the histogram is similar to the probability density function of the normal distribution. 

In [None]:
using StatsBase, LaTeXStrings
μ, σ, n = 1.5, 0.7, 1000
x = μ .+ σ .* randn(n);
h = fit(Histogram, x, nbins=50);
width = h.edges[1][2] - h.edges[1][1]
y = h.weights / sum(h.weights * width);
bar(h.edges[1], y, label="simulated", xlabel=L"x", ylabel="probability")

pdf(x, μ, σ) = 1 / (σ * sqrt(2π)) * exp(-0.5 * (x - μ)^2 / σ^2)
xt = -1:0.01:4
yt = pdf.(xt, μ, σ)
plot!(xt, yt, linewidth=3, label="theoretical")

### Exercise 

1. Try changing the mean and standard deviation of the normal distribution and see how the histogram changes. Generate a histogram for: $\mu = 10, \sigma = 12$.
2. Try increasing the number of samples and see how the histogram changes. Generate a histogram for `n_samples = 10000`, `n_samples = 100` and `n_samples = 10`.

In [None]:
# Exercise ... 

The multivariate normal distribution is a generalization of the one-dimensional normal distribution to higher dimensions. It may be written as the distribution of a vector that is normally distributed. 

$$
\mathcal{N}(x; \mathbf{\mu}, \mathbf{\Sigma})
$$

where $\mathbf{\mu}$ is the mean vector and $\mathbf{\Sigma}$ is the covariance matrix. If the covariance matrix is diagonal, the distributions are indepedant an we can use the Identity matrix $\mathbf{I}$ and a vector $\mathbf{\sigma}$ of standard deviations.
    

## Forward Process

The forward process $q$ determines how subsequent steps in the diffusion are derived (gradual distortion of the original sample $x_0$).

Basic format of the forward step:
$$
q(x_t|x_{t−1}) := \mathcal{N}(x_t; \sqrt{1 − \beta_t}x_{t−1}, \beta_tI) \tag{1}
$$

For a complete trajectory $x_{0}$ to $x_{1:T}$ We can describe it with the following product of conditional distributions: 
$$
q(\mathbf{x}_{1:T} \vert \mathbf{x}_0) = \prod^T_{t=1} q(\mathbf{x}_t \vert \mathbf{x}_{t-1})
$$

For infinite steps $T \rightarrow \infty$ the input data would be transformed into a variable from an isotropic Gaussian distribution $x_T \sim N(0,I)$.

<div style="text-align: center">
        <img src="./imgs/DDPM.png" alt="DDPM" width="800"/>
</div>

Since a product of Gaussians is a Gaussian, we can write the forward process as a single Gaussian distribution. Using $\alpha_t=1-\beta_t$ and $\bar{\alpha_t}=\prod^t_{i=1}\alpha_i$ we can write a simple form of the forward process. To step directly from $x_0$ to $x_t$ we can use the reparmetrization trick:
$$
q(x_t|x_0) = \mathcal{N}(x_t;\sqrt{\bar{\alpha_t}}x_0, (1 − \bar{\alpha_t})I) \tag{2}
$$

### Defining a schedule
The diffusion process is built based on a variance schedule, which determines the levels of added noise at each step of the process. To that end, our schedule is defined below, with the following quantities:

* `betas`:$\beta_t \in [0, 1]$ 


* `alphas`: $\alpha_t=1-\beta_t$


* `alphas_sqrt`:  $\sqrt{\alpha_t}$ 


* `alphas_prod`: $\bar{\alpha}_t=\prod_{i=0}^{t}\alpha_i$ 


* `alphas_prod_sqrt`: $\sqrt{\bar{\alpha}_t}$ 

In [None]:
num_timesteps = 1000
betas = range(1e-4, 2e-2, length=num_timesteps) # from 0.0001 to 0.02 in 1000 steps
alphas = 1 .- betas                            # from 0.9999 to 0.98 in 1000 steps
alphas_sqrt = sqrt.(alphas)
alphas_cumprod = cumprod(alphas)
alphas_cumprod_sqrt = sqrt.(alphas_cumprod);

### Fast forward Stepping by Jumping
Let's define the function  `forward_jump()` for $q(x_t|x_0) = \mathcal{N}(x_t;\sqrt{\bar{\alpha_t}}x_0, (1 − \bar{\alpha_t})I)$.

### Exercise

Define the function and test it with the function `plot_samples(N, M, num_timesteps, img_)` below.

In [None]:
using Flux  # for the randn_like function, if necessary

function forward_jump(t, condition_img; condition_idx=0)
    # forward jump: 0 -> t
    @assert t >= 0

    # get the mean and std of the distribution
    # Exercise ... 
    mean = ...
    std = ...
    # sample from a random normal distribution (check ?randn for help) with the same size as the image
    noise = ...

    return mean + std * noise
end

Let's test the function with our example.

The following function samples N steps of the diffusion process, starting from the given image. At each step the mean is plotted in the first coluinm and M samples from the distribution are plotted in the subsequent column. Hence, the subsequent columns to the right show several samples taken from the same distribution (they are different if you look closely!).

In [None]:
N = 5; # number of computed states between x_0 and x_T
M = 4; # number of samples taken from each distribution

function plot_samples(N, M, num_timesteps, img_)
    # Initial column
    plts = []
    for idx in 1:N
        t_step = Int((idx-1) * (num_timesteps / N))+1
        img_t = alphas_cumprod_sqrt[t_step] * img_
        plt = show_image(img_t)
        plot!(plt, xticks=[], yticks=[], title=L"\sqrt{\bar{\alpha_t}}x_0")
        push!(plts, plt)
        # Remaining columns for each sample
        for sample in 1:M
            x_t = forward_jump(t_step, img_)
            plt = show_image(x_t)
            plot!(plt, xticks=[], yticks=[], title="q$sample at t=$(t_step-1)")
            push!(plts, plt)
        end
    end

    # Display the plots
    plot(plts..., layout=(N + 1, M + 1), size=(1200, 800))
    
end

# Call the function with appropriate parameters
plot_samples(N, M, num_timesteps, img_f32_n)


## Looking at the noise added

Let's look at what an example noise distribution might look like. Therefore, let's define the function `forward_jump_with_noise()` similar to the function `forward_jump()` where we also return the generated noise. 

### Exercise

Define the function and test it below.

In [None]:
function forward_jump_with_noise(t, condition_img; condition_idx=0)
    # forward jump: 0 -> t
    @assert t >= 0

    # get the mean and std of the distribution
    # Exercise ... 
    mean = ...
    std = ...
    # sample from a random normal distribution (check ?randn for help) with the same size as the image
    noise = ... 
    
    return mean + std * noise, noise
end

And now let us look at an example of the noise $\epsilon$ generated by the function `forward_jump_with_noise()`.

In [None]:
t_step = 50

x_t, noise = forward_jump_with_noise(t_step, img_f32_n)

p1 = show_image(img_f32_n)
p1 = plot!(p1, title="x_0", aspect_ratio=:equal, legend=false)
p2 = show_image(x_t)
p2 = plot!(p2, title="x_t", aspect_ratio=:equal, legend=false)
p3 = show_image(noise)
p3 = plot!(p3, title="ε", aspect_ratio=:equal, legend=false)

plot(p1, p2, p3, layout=(1, 3), size=(1200, 400))