In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
import torch
import torch.nn as nn
from torch.nn import Linear, ReLU, Softplus
import numpy as np
import matplotlib.pyplot as plt
from torch.optim import Adam
import tqdm.notebook as tqdm
from huggingface_hub import Unet

In [None]:


from tqdm.notebook import tqdm
from time import sleep

import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (x*2)-1)])
                                 #transforms.Normalize(mean=0.5, std=0.5)])

BATCH_SIZE = 256

trainset = torchvision.datasets.MNIST(root='./MNIST', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True)

testset = torchvision.datasets.MNIST(root='./MNIST', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False)


In [None]:
class MLP(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.model = nn.Sequential(
            Linear(input_dim, 512),
            ReLU(),
            Linear(512, 512),
            ReLU(),
            Linear(512, output_dim)
        )

    def forward(self, x):
        return self.model(x)

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

class Diffusion(nn.Module):
    def __init__(self, data_dim, T=300, device='cuda'):
        super().__init__()
        self.device = device
        self.model = MLP(data_dim + 1, data_dim)
        self.data_dim = data_dim
        self.T = T
        self.beta = linear_beta_schedule(T).to(self.device)
        # constants for sampling
        self._init_scalars()

    def _init_scalars(self):
        self.alpha = 1 - self.beta
        self.cumprod_alpha = torch.cumprod(self.alpha, dim=0).to(self.device)
        alphas_cumprod_prev = torch.nn.functional.pad(self.cumprod_alpha[:-1], (1, 0), value=1)
        self.posterior_variance = self.beta * alphas_cumprod_prev / (1 - self.cumprod_alpha)

    def forward(self, x, t):
        input = torch.cat([x, t], dim=1)
        return self.model(input)

    def sample_q_t(self, x_0, t):
        " forward noising step "
        noise = torch.randn_like(x_0).to(self.device)
        # scalars for sampling
        sqrt_alpha_cumprod = torch.sqrt(self.cumprod_alpha[t])
        sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - self.cumprod_alpha[t])
        # sample
        x_t = sqrt_alpha_cumprod * x_0 +  sqrt_one_minus_alpha_cumprod * noise
        return x_t
    
    @torch.no_grad()
    def sample_p_t(self, x_t_prev, t):
        " backward denoising step "
        # scalars for sampling
        sqrt_recip_alpha_cumprod = torch.sqrt(1 / self.cumprod_alpha[t])
        one_minus_alpha_cumprod = 1 - self.cumprod_alpha[t]
        scaling = self.beta[t] / torch.sqrt(one_minus_alpha_cumprod)
        # sample
        pred = self.forward(x_t_prev, t)

        print(x_t_prev.mean(), x_t_prev.std())
        print((x_t_prev - scaling * pred).mean())
        print()
        mean = sqrt_recip_alpha_cumprod * (x_t_prev - scaling * pred)
        std = torch.sqrt(self.posterior_variance[t])
        
        noise = torch.randn_like(x_t_prev) if t > 0 else torch.zeros_like(x_t_prev)
        x_t = mean + std * noise.to(self.device)
        return x_t
    
    @torch.no_grad()
    def sample(self):
        " sample from the model "
        x_0 = torch.randn(size=(1, self.data_dim)).to(self.device)
        x_t = x_0
        for t in reversed(range(self.T)):
            t = torch.tensor([t]).unsqueeze(-1).to(self.device)
            x_t = self.sample_p_t(x_t, t)
        return x_t

    
class DiffusionTrainer:
    def __init__(self, diffusion, optimizer, device='cuda'):
        self.diffusion = diffusion
        self.optimizer = optimizer
        self.device = device

    def train_epoch(self, batch):
        # sample random t for every batch element
        t = torch.randint(
            0, 
            self.diffusion.T, 
            (batch.shape[0],)
        ).unsqueeze(-1).to(self.device)
        # sample x_0 ~ q(x_0)
        x_0 = torch.randn_like(batch).to(self.device)
        # sample x_t ~ q(x_t|x_0)
        x_t = self.diffusion.sample_q_t(x_0, t)
        # predict noise
        noise = self.diffusion(x_t, t)
        # compute loss
        loss = torch.nn.functional.mse_loss(noise, batch)
        # optimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def train(self, trainloader, epochs=10):
        self.diffusion.train()
        for _ in range(epochs):
            for x, _ in tqdm(trainloader):
                x = x.reshape(x.shape[0], -1).to(self.device)
                self.train_epoch(x)

In [None]:
# test forward process
diffusion = Diffusion(784).cuda()
image = next(iter(trainloader))[0][0].reshape((1, 784)).to('cuda')
for t in range(0, 300, 25):
    x_t = diffusion.sample_q_t(image, t)
    plt.figure()
    plt.imshow(x_t.view(28,28).detach().cpu(), cmap='gray')

In [None]:
# train
diffusion = Diffusion(784).cuda()
# optimizer = Adam(diffusion.parameters(), lr=1e-3)
# trainer = DiffusionTrainer(diffusion, optimizer)
# trainer.train(trainloader, epochs=1)

In [None]:
sample = diffusion.sample()
plt.imshow(sample.view(28,28).detach().cpu(), cmap='gray')