<a href="https://colab.research.google.com/github/mikonvergence/DiffusionFastForward/blob/master/01-colab-Diffusion-Sandbox.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

> This is part of [DiffusionFastForward](https://github.com/mikonvergence/DiffusionFastForward) course. For more content, please go to https://github.com/mikonvergence/DiffusionFastForward.

# Diffusion Sandbox

In this notebook, the intricacies of a denosing diffusion framework are illustrated with the aid of simple snippets.

First, let's import an image to use for the examples.

In [None]:
! git clone https://github.com/mikonvergence/DiffusionFastForward
!pip -q install pytorch-lightning==1.9.3 diffusers einops kornia

print("Indiciated packages are installed. We're good to go.")

In [None]:
import sys
sys.path.append('./DiffusionFastForward/')

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import imageio.v2 as imageio

mpl.rcParams['figure.figsize'] = (12, 8)

img = torch.FloatTensor(imageio.imread('./DiffusionFastForward/imgs/hills_2.png')/255)
plt.imshow(img)

### Using CPU or GPU

By default, the PyTorch will use the cpu to store the tensor. But here I want to assign our image tensor into GPU to accelerate the blurring processes. 

In [None]:
img.device

In [None]:
if torch.cuda.is_available():
    # Get the count of available GPUs
    gpu_count = torch.cuda.device_count()
    print(f"Number of available GPUs: {gpu_count}")
    
    # Default to the first GPU (index 0)
    default_gpu_index = 0
    print(f"Using GPU with index {default_gpu_index} by default")
    
    # Get the name of the GPU being used
    print("Choosen GPU Name: " + torch.cuda.get_device_name(default_gpu_index))
    
    # Move the tensor to the selected GPU (default GPU)
    img = img.to(f'cuda:{default_gpu_index}')
else:
    print("CUDA is not available. Using CPU.")
    img = img.to('cpu')  # Fallback to CPU
    
img.device

### Image Basic properties
I'm sure if you could sucessfully run the cells above, you would get an image of some hills and some thick cloud is paved above them. If you can see it, Congrats!

**Shape of the Image**
By using the imageio package, we read the image `hills_2.png` into a `numpy.ndarray`- n dimensional array.

It is also an object encoded with some intristic properties i.e. shape. You can do 
```
.shape
```
to call the property shape of the tensor. This should give you (512, 1024, 3)

- **512**: length of the height of our 2-D image (512 pixels)
   
- **1024**: length of the width of the image (1024 pixels)
   
- **3**: <span style="color:red;">Red</span>, <span style="color:green;">Green</span> and <span style="color:blue;">blue</span> layers make image colored.
   
**Why do we divide by 255**

The intensity of these three colours ranges from 0 to 255 ($[0,255]$). So, `/255` is basically a normalization. The elements of our tensor are now in $[0,1]$


In [None]:
img_m = imageio.imread('./DiffusionFastForward/imgs/hills_2.png')

type(img_m)

In [None]:
img_m.shape

### Before we start...
The majority of the diffusion models assume that the images are scaled to the `[-1,+1]` range (which tends to simplify many equations). This tutorial will follow the same approach, so we need to define input and output transformation functions `input_T()` and `output_T()`.

Also, let's define our own `show()` wrapper function that displays the image with automatic output transformation!

In [None]:
def input_T(input):
    # [0,1] -> [-1,+1]
    return 2*input-1
    
def output_T(input):
    # [-1,+1] -> [0,1]
    return (input+1)/2

def show(input):
    # plt.imshow runs on cpu.
    if input.is_cuda:
        input = input.to('cpu')
    plt.imshow(output_T(input).clip(0,1))
    
img_=input_T(img)
show(img_)

### Defining a schedule
The diffusion process is built based on a variance schedule, which determines the levels of added noise at each step of the process. To that end, our schedule is defined below, with the following quantities:

* `betas`:$\beta_t$ 


* `alphas`: $\alpha_t=1-\beta_t$


* `alphas_sqrt`:  $\sqrt{\alpha_t}$ 


* `alphas_prod`: $\bar{\alpha}_t=\prod_{i=0}^{t}\alpha_i$ 


* `alphas_prod_sqrt`: $\sqrt{\bar{\alpha}_t}$ 

In [None]:
num_timesteps=10000
betas=torch.linspace(1e-4,2e-2,num_timesteps)

alphas=1-betas
alphas_sqrt=alphas.sqrt()
alphas_cumprod=torch.cumprod(alphas,0)
alphas_cumprod_sqrt=alphas_cumprod.sqrt()

In [None]:
betas[0].item()

#### Torch Tensor Default Notice

By default, PyTorch truncates the display of floating-point numbers to **4 decimal places** in tensors for readability. However, the actual precision is preserved internally.

You can use

```
torch.set_printoptions(precision=8)
```
to adjust the precision globally when printing tensors.


## Forward Process
The forward process $q$ determines how subsequent steps in the diffusion are derived (gradual distortion of the original sample $x_0$).

📃 First, let's bring up the key equations describing this process...

Basic format of the forward step:
$$q(x_t|x_{t−1}) := \mathcal{N}(x_t; \sqrt{1 − \beta_t}x_{t−1}, \beta_tI) \tag{1}$$

to step directly from $x_0$ to $x_t$:
$$q(x_t|x_0) = \mathcal{N}(x_t;\sqrt{\bar{\alpha_t}}x_0, (1 − \bar{\alpha_t})I) \tag{2}$$

### Let's define a function `forward_step()` that will allow us to use both $q(x_t|x_{t-1})$ and  `forward_jump()` for $q(x_t|x_0)$

In [None]:
def forward_step(t, condition_img, return_noise=False):
    """
        forward step: t-1 -> t
    """    
    assert t >= 0

    mean=alphas_sqrt[t]*condition_img    
    std=betas[t].sqrt()
      
    # sampling from N
    if not return_noise:
        return mean+std*torch.randn_like(img)
    else:
        noise=torch.randn_like(img)
        return mean+std*noise, noise
    
def forward_jump(t, condition_img, condition_idx=0, return_noise=False):
    """
        forward jump: 0 -> t
    """   
    assert t >= 0
    
    mean=alphas_cumprod_sqrt[t]*condition_img
    std=(1-alphas_cumprod[t]).sqrt()
      
    # sampling from N
    if not return_noise:
        return mean+std*torch.randn_like(img)
    else:
        noise=torch.randn_like(img)
        return mean+std*noise, noise

In [None]:
N=5 # number of computed states between x_0 and x_T
M=4 # number of samples taken from each distribution

In the first example, when `t==0`, we want to derive a sample $x_t$ based on the clean sample $x_0$!

The first column shows the mean image for a given stage of the diffusion, and the subsequent columns to the right show several samples taken from the same distribution (they are different if you look closely!).

### Essence

- The 1st column gradually turn our image into a total grey image from top to down. Each of them corresponding to the **theoretical mean value of image** with t_step, $\mu_{t}$ jumped from the original image $x_0$
- Column 2 to 5 illustrate we sample 4 times with the specific given `t_step`s from $x_0$. 
- In the perspective of a row, the first element is $\mu_{t}$, the rest of the elements are just repeated forward jumps with the same `t_step` from the original image $x_0$. In short, these jumps are equivalent to adding a Gaussian mean zero white noise on $\mu_{t}$.

In [None]:
plt.figure(figsize=(12,8))
for idx in range(N):
    t_step=int(idx*(num_timesteps/N))
    
    plt.subplot(N,1+M,1+(M+1)*idx) # plt.subplot(nrows, ncols, index)
    show(alphas_cumprod_sqrt[t_step]*img_) # first column gradually add black
    #show(0*img_)
    plt.title(r'$\mu_t=\sqrt{\bar{\alpha}_t}x_0$') if idx==0 else None
    plt.ylabel("t: {:.2f}".format(t_step/num_timesteps))
    plt.xticks([])
    plt.yticks([])
    
    for sample in range(M):
        x_t=forward_jump(t_step,img_)
        
        plt.subplot(N,1+M,2+(1+M)*idx+sample)
        plt.title("t_step: {:d}".format(t_step))
        show(x_t)        
        plt.axis('off')
        
plt.tight_layout()

Alternatively, we can test the process of going from $x_{t-1}$ to $x_t$, which is a single step in the diffusion process. For that we can use the `forward_step` function.

Note that the mean $\mu_t$ is now a bit different (first column) since it is conditioned on a specific sample of $x_{t-1}$!

### Essence

- We first do jump function from $x_{0}$ to $x_{t-1}$ for the first column with vary steps. Just to `t_step-1`!!
- In the perspective of row (from left to right), we repeat 4 times the forward step from `t_step-1` to `t_step`.
- Basically the rest of the 4 images in a rows are all forward stepped from the the first element. These first elements are the first column.

In [None]:
plt.figure(figsize=(12,8))
for idx in range(N):
    t_step=int(idx*(num_timesteps/N))
    prev_img=forward_jump(max([0,t_step-1]),img_) # directly go the t-1 state 
    
    plt.subplot(N,1+M,1+(M+1)*idx)
    show(alphas_sqrt[t_step]*prev_img)
    plt.title(r'$\mu_t=\sqrt{1-\beta_t}x_{t-1}$') if idx==0 else None
    plt.ylabel("t: {:.2f}".format(t_step/num_timesteps))
    plt.xticks([])
    plt.yticks([])
    
    for sample in range(M):
        plt.subplot(N,1+M,2+(1+M)*idx+sample)
        x_t=forward_step(t_step,prev_img) # t-1 state to t state
        show(x_t)        
        plt.axis('off')
plt.tight_layout()

# Reverse Process

The purpose of the reverse process $p$ is to approximate the previous step $x_{t-1}$ in the diffusion chain based on a sample $x_t$. In practice, this approximation $p(x_{t-1}|x_t)$ must be done without the knowledge of $x_0$.

A parametrizable prediction model with parameters $\theta$ is used to estimate $p_\theta(x_{t-1}|x_t)$.

The reverse process will also be (approximately) gaussian if the diffusion steps are small enough:

$$p_\theta(x_{t-1}|x_t) := \mathcal{N}(x_{t-1};\mu_\theta(x_t),\Sigma_\theta(x_t))\tag{3}$$

In many works, it is assumed that the variance of this distribution should not depend strongly on $x_0$ or $x_t$, but rather on the stage of the diffusion process $t$. This can be observed in the true distribution $q(x_{t-1}|x_t, x_0)$, where the variance of the distribution equals $\tilde{\beta}_t$.


### Essence

We assume the sampling previous step $t-1$ from future step $t$ follows a Gaussian distribution.
* This assumption suggests us in the case of reversing process, **all we care is the <span style="color:red;">mean</span> of the Gaussain distribution**. Let's think in this way, if we knew that probability sampling distribution, we just sample it many times and take the average at the end. This is so called *Monte Carlo Sampling*. By doing this, we converge to the actual true mean $\mu_\theta$ and manage to get rid of the variance term $\Sigma_\theta$.

### Parameterizing $\mu_\theta$
There are at least **3 ways** of parameterizing the <span style="color:red;">mean</span> of the reverse step distribution $p_\theta(x_{t-1}|x_t)$: Essentially $\mu_{\theta}(x_t)$!!!!
* Directly (a neural network will estimate $\mu_\theta$)
* Via $x_0$ (a neural network will estimate $x_0$)
$$\tilde{\mu}_\theta = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}x_0 + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}x_t\tag{4}$$
* Via noise $\epsilon$ subtraction from $x_0$ (a neural network will estimate $\epsilon$)
$$x_0=\frac{1}{\sqrt{\bar{\alpha}_t}}(x_t-\sqrt{1-\bar{\alpha}_t}\epsilon)\tag{5}$$

The approach of approximating the normal noise $\epsilon$ is used most widely.

Let's look at what an example $\epsilon$ might look like:

In [None]:
t_step=200

x_t,noise=forward_jump(t_step,img_,return_noise=True)

plt.subplot(1,3,1)
show(img_)
plt.title(r'$x_0$')
plt.axis('off')

plt.subplot(1,3,2)
show(x_t)
plt.title(r'$x_t$')
plt.axis('off')

plt.subplot(1,3,3)
show(noise)
plt.title(r'$\epsilon$')
plt.axis('off')

plt.show()

If $\epsilon$ is predicted correctly, we can use the equation (5) to **come back** $x_0$:

### Essence

- We know the exactly noise is given as $\epsilon$ from the forward jump function to get the noised $x_t$. We now want to get it back to $x_0$
- Basically, $$x_0=\frac{1}{\sqrt{\bar{\alpha}_t}}(x_t-\sqrt{1-\bar{\alpha}_t}\epsilon)\tag{5}$$
is denoise proccess. The substraction is to remove the jumping noise and the mulitplcation is to remove the mean fading in jump function.


- Therefore, this reversing process is **completely exact** back to $x_0$. As the loss showed in the middle's picture title bracket below!!

In [None]:
x_0_pred=(x_t-(1-alphas_cumprod[t_step]).sqrt()*noise)/(alphas_cumprod_sqrt[t_step])

plt.subplot(1,3,1)
show(x_t)
plt.title('$x_t$ ($\ell_1$: {:.3f})'.format(F.l1_loss(x_t,img_)))
plt.axis('off')

plt.subplot(1,3,2)
show(x_0_pred)
plt.title('$x_0$ prediction ($\ell_1$: {:.3f})'.format(F.l1_loss(x_0_pred,img_)))
plt.axis('off') 

plt.subplot(1,3,3)
show(img_)
plt.title('$x_0$')
plt.axis('off')

plt.show()

$\ell_1$ stands for the $L_1$ loss which is mean absolute error (MAE) between each element in the input $x$ and target $y$. In the case of `torch.nn.functional`, you can refer the documentation by [here](https://pytorch.org/docs/stable/generated/torch.nn.L1Loss.html#torch.nn.L1Loss). In a nutshell, we are taking the mean of the absolute error of each element by default.
$$
\ell_1(x,y) =\frac{1}{N}\left[\sum_i l_i\right] 
$$
where
$$
l_i = |x_i - y_i|,
$$
and $N$ is the number of elements.


### Approximation (or knowledge) of $x_0$ allows us to approximate the mean $\mu_\theta$ from, using Eq. (4).

In [None]:
# estimate mean Eq.(4)
mean_pred=x_0_pred*(alphas_cumprod_sqrt[t_step-1]*betas[t_step])/(1-alphas_cumprod[t_step]) + x_t*(alphas_sqrt[t_step]*(1-alphas_cumprod[t_step-1]))/(1-alphas_cumprod[t_step])

# let's compare it to ground truth mean of the previous step (requires knowledge of x_0)
mean_gt=alphas_cumprod_sqrt[t_step-1]*img_

Since reverse process mean estimation $\tilde{\mu}_\theta$ in (4) is effectively linear interpolation between noisy $x_t$ and $x_0$ it is expected to have a higher error (as the additive noise is still present) compared to the mean computed using the forward process (which is computed by scaling the clean sample by a scalar value).

In [None]:
plt.subplot(1,3,1)
show(x_t)
plt.title('$x_t$   ($\ell_1$: {:.3f})'.format(F.l1_loss(x_t,img_)))
plt.subplot(1,3,2)
show(mean_pred)
plt.title(r'$\tilde{\mu}_{t-1}$' + '  ($\ell_1$: {:.3f})'.format(F.l1_loss(mean_pred,img_)))
plt.subplot(1,3,3)
show(mean_gt)
plt.title(r'$\mu_{t-1}$' + '  ($\ell_1$: {:.3f})'.format(F.l1_loss(mean_gt,img_)))
plt.show()

Once we get our `mean_pred` ($\tilde{\mu_{t}}$), we can define our distribution for the previous step

$$\tilde{\beta}_t=\beta_t \tag{6}$$

$$ p_\theta(x_{t-1}|x_t) := \mathcal{N}(x_{t-1};\tilde{\mu}_\theta(x_t,t),\tilde{\beta}_t I) \tag{7}$$

> Important: the experiment below should be treated as a simulation. In practice, the network must  predict either $\epsilon$ or $x_0$ or $\tilde{\mu}_\theta$. Here, the value of $epsilon$ is simply subs

In [None]:
def reverse_step(epsilon, x_t, t_step, return_noise=False):
    
    # estimate x_0 based on epsilon
    x_0_pred=(x_t-(1-alphas_cumprod[t_step]).sqrt()*epsilon)/(alphas_cumprod_sqrt[t_step])
    if t_step==0:
        sample=x_0_pred
        noise=torch.zeros_like(x_0_pred)
    else:
        # estimate mean
        mean_pred=x_0_pred*(alphas_cumprod_sqrt[t_step-1]*betas[t_step])/(1-alphas_cumprod[t_step]) + x_t*(alphas_sqrt[t_step]*(1-alphas_cumprod[t_step-1]))/(1-alphas_cumprod[t_step])

        # compute variance
        beta_pred=betas[t_step].sqrt() if t_step != 0 else 0

        sample=mean_pred+beta_pred*torch.randn_like(x_t)
        # this noise is only computed for simulation purposes (since x_0_pred is not known normally)
        noise=(sample-x_0_pred*alphas_cumprod_sqrt[t_step-1])/(1-alphas_cumprod[t_step-1]).sqrt()
    if return_noise:
        return sample, noise
    else:
        return sample

In [None]:
x_t,noise=forward_jump(1000-1,img_,return_noise=True)

state_imgs=[x_t]
for t_step in reversed(range(1000)):
    x_t,noise=reverse_step(noise,x_t,t_step,return_noise=True)
    
    if t_step % 500 == 0:
        state_imgs.append(x_t)

In [None]:
# show the initial blurred image
plt.figure()
show(state_imgs[0])

In [None]:
# show the denoised image
plt.figure()
show(state_imgs[-1])

In [None]:
# show the original image 
plt.figure()
show(img_)

In [None]:
plt.figure()
for idx,state_img in enumerate(state_imgs):
    plt.subplot(1,len(state_imgs),idx+1)
    show(state_img.clip(-1,1))
    plt.axis('off')
    


## Packaging into Components
The processes investigated above are neatly packaged into modular components for easier management of the diffusion framework.

First, the forward process component `GaussianForwardProcess` encapsulates the functions of $q(x_t|x_0)$ and $q(x_t|x_{t-1})$.

Below, we can see how different schedules of the variance parameter $\beta$ affect how the noise level changes throughout the progression.

In [None]:
from src import *

D=128
make_white=False
save=False
line_color='black' #'#9EFFB9'

# we slice a piece of the image to be blurred 
test_img=img[256-D:256+D,512-D:512+D,:]

for schedule in ['linear','quadratic','sigmoid','cosine']:
    fw=GaussianForwardProcess(1000,
                              schedule)

    plt.figure(figsize=(10,2))
    plt.subplot(1,6,1)    
    plt.plot(fw.betas,color=line_color)
    plt.title(schedule,color=line_color)
    plt.xlabel(r'step $t$',color=line_color)
    plt.ylabel(r'$\beta_t$',color=line_color)
    
    if make_white:
        plt.xticks(color='white')
        plt.gca().tick_params(axis='x', colors='white')
        plt.gca().tick_params(axis='y', colors='white')
        plt.gca().spines['top'].set_color('white')
        plt.gca().spines['right'].set_color('white')
        plt.gca().spines['left'].set_color('white')
        plt.gca().spines['bottom'].set_color('white')
    for step in range(5):
        plt.subplot(1,6,step+2)
        plt.imshow(fw(test_img.permute(2,0,1).unsqueeze(0).to('cpu'),torch.tensor(step*200))[0].permute(1,2,0))
        plt.axis('off')        
    plt.tight_layout()
    
    
    if save:
        plt.savefig('{}.png'.format(schedule),
                    dpi=200,
                    bbox_inches='tight',
                    pad_inches=0.0,
                    transparent=True)