In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from sklearn.decomposition import PCA

from plotly.subplots import make_subplots
import plotly.figure_factory as ff
import plotly.graph_objects as go

import os
from PIL import Image


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Diffusion as a Markov Chain

$$
q(\mathbf{x}_t | \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1-\beta_t}\mathbf{x}_{t-1}, \beta_t\mathbf{I})
$$

We can compute the diffused version of an image at any arbitrary time step in the Markov Chain directly from the real image in one shot without having to iteratively perform the forward diffusion process for intermediate timesteps

$$
q(\mathbf{x}_t | \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha_t}} \mathbf{x}_0, (1 - \bar{\alpha_t}) \mathbf{I})
$$
$\text{where } \alpha_t = 1 - \beta_t \text{ and } \bar{\alpha_t} = \prod_{i=1}^{t} \alpha_i$

Derivation:

\begin{align*}
q(\mathbf{x}_t | \mathbf{x}_{t-1}) &= \mathcal{N}(\mathbf{x}_t; \sqrt{1-\beta_t}\mathbf{x}_{t-1}, \beta_t\mathbf{I}) \\
x_t &= \sqrt{1 - \beta_t}x_{t-1} + \sqrt{\beta_t}\mathcal{N}(0, I)
&= 
\end{align*}


# Create the Diffusion class

In [2]:
class ForwardDiffusion(nn.Module):
    def __init__(self, cap_T, beta_1, beta_T, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        beta_tensor = torch.linspace(beta_1, beta_T, cap_T)
        alpha_tensor = 1 - beta_tensor
        
        self.cum_prod_alpha_tensor = torch.cumprod(alpha_tensor, dim=0)

    def apply_noise_kernel(self, timestep, image, noise):
        destroy_factor = torch.sqrt(self.cum_prod_alpha_tensor[timestep])
        noise_factor = torch.sqrt(1 - self.cum_prod_alpha_tensor[timestep])

        diffused_img = (destroy_factor * image) + (noise_factor * noise)

        return diffused_img

# Check if Information flows correctly thru the Computational Graph

In [3]:
img_t = torch.randn(60000, 1, 28, 28).to(device)
noise = torch.randn_like(img_t).to(device)

diffusion = ForwardDiffusion(1000, 1e-4, 2e-2).to(device)
noisy_image_t = diffusion.apply_noise_kernel(1, img_t, noise)

noisy_image_t.shape

torch.Size([60000, 1, 28, 28])

# Load the complete MNIST data

In [4]:
root_path = '/mnt/c/Users/121js/OneDrive/Desktop/TorchImages/mnist/'

transform = transforms.Compose([transforms.ToTensor()])
mnist = datasets.MNIST(root_path, download=False, transform=transform)

loader = DataLoader(mnist, batch_size=60000, shuffle=False)
images_tensor, labels = next(iter(loader))

images_tensor.shape

torch.Size([60000, 1, 28, 28])

# Use PCA to reduce MNIST data to 1D from 784

In [5]:
def pca_transform(images_tensor):
    images_np = images_tensor.numpy().reshape(60000, -1)

    pca = PCA(n_components=1)
    images_1d = pca.fit_transform(images_np).reshape(-1,)

    return images_1d

def scale_data(data, lower_a, upper_b):
    min_val = data.min()
    max_val = data.max()

    scaled_data = lower_a + ((data - min_val) * (upper_b - lower_a)) / (max_val - min_val)

    return scaled_data

# Function to visualize at every timestep

In [6]:
def make_plot(timestep, dist_1d, which_image):
    fig = make_subplots(
        rows=1, cols=2, column_widths=[0.5, 0.3], horizontal_spacing=0.2,
        subplot_titles=("MNIST Distribution (converted to 1D using PCA)", "Sampled Image")
    )
    fig.update_layout(title=f'Snapshot from Forward Diffusion at timestep = {timestep}', width=1100, showlegend=False, template='plotly_dark')

    kde_fig = ff.create_distplot([dist_1d], ['mnist'], show_hist=False, show_rug=False)     # plot the distribution in left panel
    for trace in kde_fig['data']:
        fig.add_trace(trace, row=1, col=1)

    fig.update_xaxes(range=[-1, 1], row=1, col=1)    # Keep the boundaries same to have smooth visualizations
    fig.update_yaxes(range=[0, 2], row=1, col=1)

    fig.add_trace(go.Heatmap(z=which_image, colorscale='gray', showscale=False), row=1, col=2)     # add the image to the right
    fig.update_xaxes(zeroline=False, showticklabels=False, row=1, col=2)
    fig.update_yaxes(zeroline=False, showticklabels=False, row=1, col=2)

    return fig

# Experiment

In [7]:
cap_T = 200
beta_1 = 1e-4
beta_T = 2e-2
diffusion = ForwardDiffusion(cap_T, beta_1, beta_T).to(device)

images_2_gpu = images_tensor.to(device)

image_index = 7

for timestep in range(0, cap_T+1):
    if timestep == 0:       
        # No noise added at timestep=0, just inspect starting conditions
        
        dist_1d = pca_transform(images_2_gpu.cpu())
        scaled_dist_1d = scale_data(dist_1d, -1, 1)

        sampled_image = images_2_gpu[image_index].squeeze(0)           # This goes as input to the plotly Heatmap, which expects 2 dim, so thats why squeeze the channel(=1) dim

        fig = make_plot(0, scaled_dist_1d, sampled_image.cpu())

    else:
        transition_kernel = torch.randn_like(images_2_gpu).to(device)

        noisy_images = diffusion.apply_noise_kernel(timestep-1, images_2_gpu, transition_kernel)       # timestep=1 corresponds to 0th element in the alpha tensor, so reduce 1

        dist_1d = pca_transform(noisy_images.cpu())
        scaled_dist_1d = scale_data(dist_1d, -1, 1)

        sampled_image = noisy_images[image_index].squeeze(0)

        fig = make_plot(timestep, scaled_dist_1d, sampled_image.cpu())

    fig.write_image(f'forward_exp_images/{timestep}.png')

# Compile into GIF

In [8]:
def make_gif():
    image_folder = 'forward_exp_images'
    images = []
    
    for timestep in range(1, cap_T+1):
        filename = f'{timestep}.png'
        image_path = os.path.join(image_folder, filename)
        img = Image.open(image_path)
        images.append(img)

    images[0].save('forward_exp.gif', save_all=True, append_images=images[1:], duration=200, loop=0)

make_gif()