# INSA, GMM Image
## Practical sessions: Introduction to Diffusion

Welcome in this second practical session on diffusion models.  
In the previous session we saw how to train a diffusion model on a toy 2d dataset.  
Today we will continue a little bit on a similar dataset to work and see if we manage to learn a class conditionned diffusion model.  
After that we will work with pretrained models for image generation and try try several things to improve the quality of the generated images.  
Finally we will play a bit with a stable diffusion model to generate images from text, do some inpainting or image conditioning.  
Finally we will use an inversion mechanism to carefully edit images.

Let's start by doin our noise conditionned diffusion model.  
To illustrate the concept of class conditionned diffusion, we will work with a two class dataset, called two moons.  
You probably know it from the sklearn library.  It is a simple dataset with two classes and two features.  
Let's load it and plot it.  

In [None]:
import torch
from sklearn.datasets import make_moons

def get_moons_dataset(n_samples=1000, noise=0.1, random_state=42):
    x_0, y = make_moons(n_samples=n_samples, noise=noise, random_state=random_state)
    x_0 = torch.FloatTensor(x_0)
    y = torch.LongTensor(y)
    
    return x_0, y

x_0, y = get_moons_dataset()
print(f"Features shape: {x_0.shape}") 
print(f"Labels shape: {y.shape}")   

In [None]:
import matplotlib.pyplot as plt
plt.scatter(x_0[:,0],x_0[:,1],c=y)
plt.show()

We will use this dataset to illustrate the concept of class conditionned diffusion.  
This data will be our initial distribution, the one we want to sample from and the class will be our condition.  
Like in the previous session, we don't know how to sample from this distribution.  We will use a noise schedule to gradually add noise to the data until we reach a Gaussian distribution for which we know how to sample.  Then we will use the reverse process to sample from the initial distribution.  

### Noise Schedule
Let's begin by defining the noise schedule.  
The noise schedule is a function that defines the amount of noise to add at each timestep.  
We will use a linear schedule for this example.  

In [None]:
T=200
alpha_min=0.0001
alpha_max=0.05
alphas = torch.linspace(alpha_min, alpha_max, T)
alphas = 1. - alphas
alpha_bar = torch.cumprod(alphas, dim=0)
plt.figure(figsize=[6, 6])
plt.plot(torch.arange(T), alpha_bar)
plt.xlabel('Timestep')
plt.ylim(0,1.05)

### Forward Pass
The forward pass is the process of adding noise to the data at each timestep.  
We will use the noise schedule to define the amount of noise to add at each timestep.  
We saw in class that the forward step can be written as:
$$x_t = \sqrt{\alpha_t}x_{t-1} + \sqrt{1-\alpha_t}\epsilon_t$$
where $\alpha_t$ is the amount of noise to add at timestep $t$ and $\epsilon_t$ is a noise sample.  
Here is the function to perform the forward pass.  

In [None]:
from typing import List

def forward_step(x_t_minus_1:torch.Tensor, alphas:torch.Tensor, t:int, eps:torch.Tensor) -> torch.Tensor:
    """
    Takes the previous step, the alphas and the timestep and returns the next step
    args:
        x_t_minus_1: the previous step
        alphas: the alphas of the noise schedule
        t: the timestep
        eps: the noise sample
    returns:
        x_t: the next step
    """
    return alphas[t].sqrt()*x_t_minus_1 + (1-alphas[t]).sqrt()*eps

def forward_pass(x_0:torch.Tensor, alphas:torch.Tensor, T:int=200) -> List[torch.Tensor]:
    """
    Takes the initial data, the alphas and the number of timesteps and returns the list of forward steps
    args:
        x_0: the initial data
        alphas: the alphas of the noise schedule
        T: the number of timesteps
    returns:
        x_series: a list of the forward steps
    """
    x_series = [x_0]
    for t in range(T):
        eps = torch.randn_like(x_0)
        x_series.append(forward_step(x_series[-1], alphas, t, eps))
    return x_series

x_series = forward_pass(x_0, alphas, T)

Now plot the different steps of the forward pass for t = [0, 12, 25, 50, 75, 100, 125, 150, 175, 200]

In [None]:
figure = plt.figure(figsize=(20, 10))
for i, t in enumerate([0, 12, 25, 50, 75, 100, 125, 150, 175, 200]):
...

You should observe that the data becomes more and more noisy as the timestep increases.  
The following code animates the forward pass and allows you to see the data evolve over time.


In [None]:
from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt
from IPython.display import HTML
from functools import partial

fig, ax = plt.subplots(figsize=(6, 6))

def animate(i:int, series:List[torch.Tensor], labels:torch.Tensor):
    ax.clear()
    data = series[i]
    ax.scatter(data[:, 0], data[:, 1], s=15, alpha=0.5, c=y)
    ax.set_axis_off()

animate_forward = partial(animate, series=x_series, labels=y)

anim = FuncAnimation(fig, animate_forward, frames=len(x_series),
                    interval=250)  # 500ms between frames

HTML(anim.to_jshtml())

For training, we would like to have diversity in the training batches, meaning different samples with different timesteps.  
We saw in class that it is to directly noise the data for a given timestep without having to go through the forward pass for all the timesteps.  
$$x_t = \sqrt{\bar{\alpha_t}}x_{t-1} + \sqrt{1-\bar{\alpha_t}}\epsilon_t$$
The following function to sample the data for a given timestep.

In [None]:
def sample_x_t(x_0:torch.Tensor, t:int, alpha_bar:torch.Tensor, eps:torch.Tensor) -> torch.Tensor:
    """
    Takes the initial data, the alphas and the number of timesteps and returns a noisy version of the data for a given timestep
    args:
        x_0: the initial data
        alphas: the alphas of the noise schedule
        T: the number of timesteps
    returns:
        x_t: a noisy version of the data for a given timestep
    """
    return alpha_bar[t, None].sqrt()*x_0 + (1.-alpha_bar[t, None]).sqrt()*eps


figure = plt.figure(figsize=(20, 4))
for i, t in enumerate([0, 6, 12, 25, 50]):
    dataset, time_step = x_series[t], t
    figure.add_subplot(1, 5, i+1)
    plt.title(time_step)
    plt.axis("off")
    eps = torch.randn_like(x_0)
    x_t = sample_x_t(x_0, t, alpha_bar, eps)
    plt.scatter(x_t[:,0], x_t[:,1],s=15,alpha=0.5, c=y)


## Training
We will now train a denoising model to learn the reverse process.  
First, we need to create a dataset with our data.  
### Dataset
We now define a torch dataset to load our data.  We then split the data into a train and test set and create dataloaders for each.  

In [None]:
from torch.utils.data import DataLoader, Dataset, random_split

class TwoMoonsDataset(Dataset):
    def __init__(self, n_samples=1000, noise=0.1, random_state=42):
        x_0, y = make_moons(n_samples=n_samples, noise=noise, random_state=random_state)
        self.x_0 = torch.FloatTensor(x_0)
        self.y = torch.LongTensor(y)

    def __len__(self):
        return len(self.x_0)

    def __getitem__(self, idx):
        return self.x_0[idx].float(), self.y[idx]

dataset = TwoMoonsDataset()

# Train/Test Split
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


### Model:
Now that the dataset and dataloaders are created, we can define the model.  
We will use a simple MLP with 5 layers of 64 neurons each and a final layer to output the predicted noise.
We will use the GELU activation function between each layer.  
Remember that the output of the model will be the predicted noise which has the same dimension as the input.
Complete the following class to define the model.  
Since we will be training the model on the data for different timesteps, we will need to pass the timestep as an additional input to the model.  
We will do this by concatenating the timestep to the input data.  

In [None]:
class Denoisier(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(3, 64),
            torch.nn.GELU(),
            torch.nn.Linear(64, 64),
            torch.nn.GELU(),
            torch.nn.Linear(64, 64),
            torch.nn.GELU(),
            torch.nn.Linear(64, 64),
            torch.nn.GELU(),
            torch.nn.Linear(64, 2),
        )
    def forward(self, x, t):
        x = torch.cat((x, t.reshape(-1, 1)), dim=1)
        return self.layers(x)

### Training loop:
At this point, we have all the components to train the model.  
The training loop of a denoising model is actually quite simple.  
Look at the training algorithm from the paper and implement the training loop.  
![training_loop](images/training.png)


In [None]:
from tqdm.notebook import tqdm

def train(model, train_dataloader, optimizer, alpha_bar, epochs=50, device='cpu'):
    progress_bar = tqdm(range(epochs), desc="Training")
    for epoch in range(epochs):
        total_loss = 0
        for x, y in train_dataloader:
            x = x.to(device)
            eps = torch.randn_like(x)
            t = torch.randint(T, (x.shape[0],), device=device)
            x_t = sample_x_t(x, t, alpha_bar.to(device), eps)
            eps_pred = model(x_t, t)
            loss = torch.nn.functional.mse_loss(eps_pred, eps)
            total_loss += loss.item()*x.shape[0]
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        progress_bar.set_postfix({"epoch": epoch, "loss": total_loss/x.shape[0]})
        progress_bar.update()
    progress_bar.close()

### Training:
Now instantiate the model and optimizer (Adam with a learning rate of 1e-3) and train the model for 3000 epochs.  

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Denoisier().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train(model, train_loader, optimizer, alpha_bar, epochs=3000, device=device)

### Sampling:
Here is the moment of truth!
We will now sample from the model and see if we did manage to learn the reverse process.  
Remember that the sampling algorithm is the following:  
![sampling_loop](images/sampling.png)  
Use this function to sample 1000 points and plot them.  

In [None]:
@torch.no_grad()
def sample(num_samples:int, model:torch.nn.Module, alpha:torch.Tensor, alpha_bar:torch.Tensor, T:int=200, device:str='cpu') -> List[torch.Tensor]:
    """
    Takes the model, the alphas and the number of timesteps and returns a list of the sampled data for each timestep
    args:
        model: the denoising model
        alphas: the alphas of the noise schedule
        T: the number of timesteps
    returns:
        x_series: a list of the sampled data for each timestep
    """
    alpha = alpha.to(device)
    alpha_bar = alpha_bar.to(device)
    steps = []
    xt = torch.randn((num_samples, 2)).to(device)
    for t in reversed(range(T)):
        t_batch = torch.full((num_samples,), t).to(device)
        noise_pred = model(xt, t_batch)
        mu_hat_t = (xt - (1-alpha[t,None])/(1-alpha_bar[t,None]).sqrt()*noise_pred)/(alpha[t,None]).sqrt()

        z = torch.randn_like(xt).to(device)
        sigma = (1.-alpha[t]).sqrt()
        xt = mu_hat_t + sigma*z
        steps.append(xt.clone().detach().to('cpu'))
    return steps

steps = sample(1000, model, alphas, alpha_bar, T=T, device=device)
plt.figure(figsize=[6, 6])
plt.scatter(steps[-1][:,0],steps[-1][:,1],s=15,alpha=0.5)
plt.axis('off')

OK, nothing new here. It is exactly what we did in the previous practical session.  
Here we have no onctrol on the class of data we are sampling.  
If we had to do an analogy with images, we would be sampling from a dataset of images without any control on the class of the image we are sampling.  
We would like to be able to sample from a specific class.  
This is where conditional diffusion comes into play.  
Define a conditional denoising model, that takes as input the data the timestep and the class and outputs the predicted noise.  

In [None]:
class ConditionalDenoisier(torch.nn.Module):
    def __init__(self):
        ...
    def forward(self, x, y, t):
        ...

Now implement a conditional training loop.  

In [None]:
def conditional_train(model, train_dataloader, optimizer, alpha_bar, epochs=50, device='cpu'):
    progress_bar = tqdm(range(epochs), desc="Training")
    for epoch in range(epochs):
        ...

        progress_bar.set_postfix({"epoch": epoch, "loss": total_loss/x.shape[0]})
        progress_bar.update()
    progress_bar.close()

And train your conditional denoising model.  

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ConditionalDenoisier().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
...

We also have to modify the sampling function to be able to sample from a specific class.  
Implement the following function.  and generate 1000 samples from the class 0.  
To better visualize the results, you can plot the samples in green and the original data in blue.  

In [None]:
@torch.no_grad()
def conditional_sample(num_samples:int, model:torch.nn.Module, label:int, alpha:torch.Tensor, alpha_bar:torch.Tensor, T:int=200, device:str='cpu') -> List[torch.Tensor]:
    """
    Takes the model, the alphas and the number of timesteps and returns a list of the sampled data for each timestep
    args:
        model: the denoising model
        alphas: the alphas of the noise schedule
        T: the number of timesteps
    returns:
        x_series: a list of the sampled data for each timestep
    """
    ...

steps = conditional_sample(1000, model, 0, alphas, alpha_bar, T=T, device=device)
...

Verify that your model can also sample from the other class.  

You should now have a model that can sample from a specific class.  
Even if it is not the case with our small dataset, saampling with diffusion model is a slow process.  
A naiv solution would be to skip some timesteps.  
Implement the following function to sample from a subset of timesteps (50 instead of 200 for instance) and check how it affects the quality of the samples.  

In [None]:
@torch.no_grad()
def skip_step_samples(num_samples: int, model: torch.nn.Module, label: int, alpha: torch.Tensor, alpha_bar: torch.Tensor, T: int=200, num_steps: int=50, device: str='cpu') -> List[torch.Tensor]:
    """
    Takes the model, the alphas and samples at a subset of timesteps
    args:
        model: the denoising model
        alpha: noise schedule alphas
        alpha_bar: cumulative product of alphas
        T: total timesteps in noise schedule
        num_steps: number of actual sampling steps to perform
        device: device to run on
    returns:
        steps: list of sampled data at each sampled timestep
    """
    ...

steps = skip_step_samples(1000, model, 0, alphas, alpha_bar, T=T, num_steps=50, device=device)

plt.figure(figsize=[6, 6])
plt.scatter(steps[-1][:,0], steps[-1][:,1], s=15, alpha=0.5, c='g')
plt.scatter(x_0[:,0], x_0[:,1], s=15, alpha=0.5, c=y)
plt.axis('off')

You should have observed a loss in quality when sampling from a subset of timesteps.  
We saw in class that we could improve the quality of the samples by using a different sampling scheme.  
Implement the following function to sample from a subset of timesteps using the DDIM sampling scheme from [this paper](https://arxiv.org/pdf/2010.02502) and check how it affects the quality of the samples.  
![](images/ddim.png)  
You can even try with 20 steps instead of 50.  and see the quality of the results.

In [None]:
@torch.no_grad()
def DDIM_sample_ddim(num_samples: int, model: torch.nn.Module, label: int, alpha: torch.Tensor, alpha_bar: torch.Tensor, T: int=200, num_steps: int=20, device: str='cpu') -> List[torch.Tensor]:
    """
    DDIM sampling at subset of timesteps
    args:
        model: the denoising model
        alpha: noise schedule alphas
        alpha_bar: cumulative product of alphas
        T: total timesteps in noise schedule
        num_steps: number of actual sampling steps to perform
        device: device to run on
    returns:
        steps: list of sampled data at each sampled timestep
    """
   ...
    
    return steps

steps = DDIM_sample_ddim(1000, model, 0, alphas, alpha_bar, T=T, num_steps=50, device=device)

...

Pretty fast and still accurate right!
That's it for the 2d data, in the rest of the practical session we will work on images.  