# Development
For experimenting with stuff

## First getting dataset
Copying this from that example notebook.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

import numpy as np

from tqdm import tqdm
from torchvision.utils import save_image, make_grid
from tqdm import tqdm
from torch.optim import Adam

import math

In [None]:
# Model Hyperparameters

dataset_path = './datasets'

cuda = True
DEVICE = "cpu" # torch.device("cuda:0" if cuda else "cpu")

dataset = 'CIFAR10'
img_size = (32, 32, 3)   if dataset == "CIFAR10" else (28, 28, 1) # (width, height, channels)

timestep_embedding_dim = 256
n_layers = 8
hidden_dim = 256
n_timesteps = 1000
beta_minmax=[1e-4, 2e-2]

train_batch_size = 1
inference_batch_size = 1
lr = 5e-5
epochs = 200

seed = 1234

hidden_dims = [hidden_dim for _ in range(n_layers)]
torch.manual_seed(seed)
np.random.seed(seed)

In [None]:
from torchvision.datasets import MNIST, CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


transform = transforms.Compose([
        transforms.ToTensor(),
])

kwargs = {'num_workers': 1, 'pin_memory': True} 

if dataset == 'CIFAR10':
    train_dataset = CIFAR10(dataset_path, transform=transform, train=True, download=True)
    test_dataset  = CIFAR10(dataset_path, transform=transform, train=False, download=True)
else:
    train_dataset = MNIST(dataset_path, transform=transform, train=True, download=True)
    test_dataset  = MNIST(dataset_path, transform=transform, train=False, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=train_batch_size, shuffle=True, **kwargs)
test_loader  = DataLoader(dataset=test_dataset,  batch_size=inference_batch_size, shuffle=False,  **kwargs)

## Copying their code to see expected behavior
After that I will replicate it to see if mine gives exactly the same.

In [None]:
class TheirCode:
    def __init__(self, model, image_resolution=[32, 32, 3], n_times=1000, beta_minmax=[1e-4, 2e-2], device='cuda'):
    
        self.n_times = n_times
        self.img_H, self.img_W, self.img_C = image_resolution

        # self.model = model
        
        # define linear variance schedule(betas)
        beta_1, beta_T = beta_minmax
        betas = torch.linspace(start=beta_1, end=beta_T, steps=n_times).to(device) # follows DDPM paper
        self.sqrt_betas = torch.sqrt(betas)
                                     
        # define alpha for forward diffusion kernel
        self.alphas = 1 - betas
        self.sqrt_alphas = torch.sqrt(self.alphas)
        alpha_bars = torch.cumprod(self.alphas, dim=0)
        self.sqrt_one_minus_alpha_bars = torch.sqrt(1-alpha_bars)
        self.sqrt_alpha_bars = torch.sqrt(alpha_bars)
        
        self.device = device
    
    def extract(self, a, t, x_shape):
        """
            from lucidrains' implementation
                https://github.com/lucidrains/denoising-diffusion-pytorch/blob/beb2f2d8dd9b4f2bd5be4719f37082fe061ee450/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L376
        """
        b, *_ = t.shape
        out = a.gather(-1, t)
        return out.reshape(b, *((1,) * (len(x_shape) - 1)))
    
    def scale_to_minus_one_to_one(self, x):
        # according to the DDPMs paper, normalization seems to be crucial to train reverse process network
        return x * 2 - 1
    
    def reverse_scale_to_zero_to_one(self, x):
        return (x + 1) * 0.5
    
    def make_noisy(self, x_zeros, t): 
        # perturb x_0 into x_t (i.e., take x_0 samples into forward diffusion kernels)
        epsilon = torch.randn_like(x_zeros).to(self.device)
        
        sqrt_alpha_bar = self.extract(self.sqrt_alpha_bars, t, x_zeros.shape)
        sqrt_one_minus_alpha_bar = self.extract(self.sqrt_one_minus_alpha_bars, t, x_zeros.shape)
        
        # Let's make noisy sample!: i.e., Forward process with fixed variance schedule
        #      i.e., sqrt(alpha_bar_t) * x_zero + sqrt(1-alpha_bar_t) * epsilon
        noisy_sample = x_zeros * sqrt_alpha_bar + epsilon * sqrt_one_minus_alpha_bar
    
        return noisy_sample.detach(), epsilon
    
    
    def forward(self, x_zeros):
        x_zeros = self.scale_to_minus_one_to_one(x_zeros)
        
        B, _, _, _ = x_zeros.shape
        
        # (1) randomly choose diffusion time-step
        t = torch.randint(low=0, high=self.n_times, size=(B,)).long().to(self.device)
        
        # (2) forward diffusion process: perturb x_zeros with fixed variance schedule
        perturbed_images, epsilon = self.make_noisy(x_zeros, t)
        
        # (3) predict epsilon(noise) given perturbed data at diffusion-timestep t.
        pred_epsilon = self.model(perturbed_images, t)
        
        return perturbed_images, epsilon, pred_epsilon
    
    
    def denoise_at_t(self, x_t, timestep, t):
        B, _, _, _ = x_t.shape
        if t > 1:
            z = torch.randn_like(x_t).to(self.device)
        else:
            z = torch.zeros_like(x_t).to(self.device)
        
        # at inference, we use predicted noise(epsilon) to restore perturbed data sample.
        epsilon_pred = self.model(x_t, timestep)
        
        alpha = self.extract(self.alphas, timestep, x_t.shape)
        sqrt_alpha = self.extract(self.sqrt_alphas, timestep, x_t.shape)
        sqrt_one_minus_alpha_bar = self.extract(self.sqrt_one_minus_alpha_bars, timestep, x_t.shape)
        sqrt_beta = self.extract(self.sqrt_betas, timestep, x_t.shape)
        
        # denoise at time t, utilizing predicted noise
        x_t_minus_1 = 1 / sqrt_alpha * (x_t - (1-alpha)/sqrt_one_minus_alpha_bar*epsilon_pred) + sqrt_beta*z
        
        return x_t_minus_1.clamp(-1., 1)
                
    def sample(self, N):
        # start from random noise vector, x_0 (for simplicity, x_T declared as x_t instead of x_T)
        x_t = torch.randn((N, self.img_C, self.img_H, self.img_W)).to(self.device)
        
        # autoregressively denoise from x_T to x_0
        #     i.e., generate image from noise, x_T
        for t in range(self.n_times-1, -1, -1):
            timestep = torch.tensor([t]).repeat_interleave(N, dim=0).long().to(self.device)
            x_t = self.denoise_at_t(x_t, timestep, t)
        
        # denormalize x_0 into 0 ~ 1 ranged values.
        x_0 = self.reverse_scale_to_zero_to_one(x_t)
        
        return x_0
    
    
theircode = TheirCode(None, image_resolution=img_size, n_times=n_timesteps, 
                      beta_minmax=beta_minmax, device=DEVICE)

In [None]:
seed = 1996
torch.manual_seed(seed)
np.random.seed(seed)

batch = next(iter(train_loader))
batch[0] = theircode.scale_to_minus_one_to_one(batch[0])

for idx in [100 * x + 1 for x in range(10)] + [999]:
    img = theircode.make_noisy(batch[0], torch.tensor([idx]))[0]
    img = theircode.reverse_scale_to_zero_to_one(img).squeeze()
    plt.imshow(img.detach().permute(1, 2, 0))
    plt.show()

# My Class
Here I will create my noising class

In [None]:
class DiffusionNoiser(nn.Module):
    """A class to define the noising process. Will have methods to add noise
    in the closed and open form to images, but also inverse operations, such as
    removing (predicted) noise from an image, or sampling a new image.
    
    My convention for indexing is as follows:
    - t=0 is the original image.
    - t=steps(1000) should be pure gaussian noise
    - take image of step t and alphas/betas of step t to go to t+1
    - this means alphas and betas go from t=0 to t=steps(1000)-1

    Turned it into a nn.Module such that I don't need to worry about putting
    tensors on the correct device, etc.
    """

    def __init__(   self,
                    steps: int = 1000,
                    beta_start: float = 1e-4,
                    beta_end: float = 0.02
                 ):
        """
        Args:
            steps: number of noising steps to use
            beta_start: todo
            beta_end: todo
        """
        super().__init__()

        betas = torch.linspace(beta_start, beta_end, steps)
        alpha_bars = torch.cumprod(1 - betas, 0)

        self.register_buffer("betas", betas)
        self.register_buffer("alpha_bars", alpha_bars)
    
    def forward(self, img, noise, t):
        """Calls DiffusionNoiser.closed_form_noise()."""
        return self.closed_form_noise(img, noise, t)

    def closed_form_noise(self, img, noise, t):
        """Adds noise to an image using the DDPM closed form formula.

        Args:
            img: The image(s) to add noise to [(B x) C x H x W].
            noise: the gaussian noise to add N(0,1) [(B x) C x H x W].
            t: integer time step (single integer or shape [B,])
        """
        alpha_bar = self.alpha_bars[t]
        return torch.sqrt(alpha_bar) * img + torch.sqrt(1 - alpha_bar) * noise

    def noise_from_closed_form_noise(self, img, noised_img, t):
        """Inverse of closed_form_noise: returns the noise, given the original
        imaged and a noised version.
        
        Args:
            img: the original image.
            noised_img: a noised version of the image.
            t: the timestep used to get from img to noised_img.
        """
        alpha_bar = self.alpha_bars[t]
        return (noised_img - torch.sqrt(alpha_bar) * img) / torch.sqrt(1 - alpha_bar)
    
    def img_from_closed_form_noise(self, noised_img, noise, t):
        """Inverse of closed_form_noise: returns the original image, given the
        noise and a noised version of the image.
        
        Args:
            noised_img: a noised version of the image.
            noise: the noise that was added to get noised_img.
            t: the timestep used to get from imgage to noised_img using noise.
        """
        alpha_bar = self.alpha_bars[t]
        return (noised_img - torch.sqrt(1 - alpha_bar) * noise ) / torch.sqrt(alpha_bar)
    
    def forward_noise_step(self, img_prev, noise, t):
        """The forward noising process in a step-wise manner: computes a
        slightly noisier image from img_prev.
        
        Args:
            img_prev: image corresponding to step t.
            noise: the noise to add to get to step t+1.
            t: current time step."""
        beta = self.betas[t]
        return torch.sqrt(1 - beta) * img_prev + torch.sqrt(beta) * noise
    
    def denoising_step(self, img_next, noise_next, t, new_noise):
        """The inverse process of the forward_noise_step: predicts image at
        step t from image at step t+1.
        
        Args:
            img_next: image from step t+1.
            noise_next: the (predicted) noise in img_next.
            t: current time step.
            new_noise: new pure gausssian noise N(0, 1) to add.
        """
        beta = self.betas[t]
        alpha_bar = self.alpha_bars[t]

        mean_t = (
            img_next - noise_next * beta / torch.sqrt(1 - alpha_bar)
        ) / torch.sqrt(1 - beta)
        img_t = mean_t + beta * new_noise

        return img_t

noiser = DiffusionNoiser()

In [None]:
seed = 1996
torch.manual_seed(seed)
np.random.seed(seed)

batch = next(iter(train_loader))
batch[0] = theircode.scale_to_minus_one_to_one(batch[0])

noisy_img = torch.randn_like(batch[0])
for idx in reversed(range(1000)):
    t = torch.tensor([idx])

    # Pretending actual noise is predicted noise from model:
    predicted_noise = noiser.noise_from_closed_form_noise(batch[0], noisy_img, t)

    new_noise = torch.randn_like(batch[0])
    noisy_img = noiser.denoising_step(noisy_img, predicted_noise, t, new_noise)

noisy_img = theircode.reverse_scale_to_zero_to_one(noisy_img).squeeze()
plt.imshow(noisy_img.detach().permute(1,2,0))

# img2 = theircode.reverse_scale_to_zero_to_one(img2).squeeze()
# plt.imshow(img2.detach().permute(1, 2, 0))
# plt.show()

In [None]:
torch.all((1 - noiser.betas) == theircode.alphas)

In [None]:
seed = 1996
torch.manual_seed(seed)
np.random.seed(seed)

batch = next(iter(train_loader))
batch[0] = theircode.scale_to_minus_one_to_one(batch[0])

for idx in [100 * x + 1 for x in range(10)] + [999]:
    img = batch[0]
    t = torch.tensor([idx])
    noised_img, noise = theircode.make_noisy(img, t)

    # noised_img2 = noiser(img, noise, t)
    # print(torch.all(noised_img == noised_img2))

    # noise2 = noiser.noise_from_closed_form_noise(img, noised_img, t)
    # print(torch.all(noise == noise2))
    # print(torch.max(noise - noise2))
    # print(torch.all(torch.isclose(noise, noise2)))

    img2 = noiser.img_from_closed_form_noise(noised_img, noise, t)
    print(torch.all(torch.isclose(img, img2)))
    print(torch.max(img - img2))
    
    # img2 = theircode.reverse_scale_to_zero_to_one(img2).squeeze()
    # plt.imshow(img2.detach().permute(1, 2, 0))
    # plt.show()


    # img = theircode.make_noisy(batch[0], torch.tensor([idx]))[0]
    # img = theircode.reverse_scale_to_zero_to_one(img).squeeze()
    # plt.imshow(img.detach().permute(1, 2, 0))
    # plt.show()