In [None]:
!pip install einops

In [None]:
import einops as ein
from einops.layers.torch import Rearrange

import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision.utils import save_image
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD # -ELBO

In [None]:
class Denoiser (nn.Module):
    def __init__(self, T):
        super (Denoiser, self).__init__()
        self.T = T
        # self.layer1 = nn.Sequential (
        #     nn.Conv2d (1, 32, 3, padding=1),
        #     nn.ReLU (),
        #     nn.Conv2d (32, 64, 3, padding=1),
        #     nn.ReLU ()
        # )
        # self.layer2 = nn.Sequential (
        #     nn.Conv2d (64, 128, 3, padding=1),
        #     nn.ReLU (),
        #     nn.Conv2d (128, 1, 3, padding=1),
        #     nn.Sigmoid ()
        # )
        self.layer1 = nn.Sequential (
            nn.Conv2d (1, 1, 3, padding=1),
            nn.ReLU (),
            nn.Conv2d (1, 1, 3, padding=1),
            nn.ReLU ()
        )
        self.layer2 = nn.Sequential (
            nn.Conv2d (1, 1, 3, padding=1),
            nn.ReLU (),
            nn.Conv2d (1, 1, 3, padding=1),
            nn.Sigmoid ()
        )

        self.time_embed = nn.Linear (self.T, 1)

    def forward (self, x, t):
        phi = torch.stack ([torch.tensor([ torch.sin ((i * np.pi * t_ / self.T)) for i in range (2, 2*self.T+1, 2)]) for t_ in t])
        time_embed = self.time_embed(phi)
        x = self.layer1 (x) + time_embed[:,:,None,None]
        return self.layer2 (x)
#        return self.layers (x) + self.time_embed (phi)

class Diffusion (nn.Module):
    def __init__(self, noise_steps=100, device="cpu", latent_dim=2):
        super (Diffusion, self).__init__()
        #self.beta = ein.repeat (torch.tensor ([0.05]), "() -> a", a=noise_steps)
        self.beta = torch.linspace (1e-4, 0.02, noise_steps, device=device)
        self.alpha = 1 - self.beta
        self.alpha_bar = torch.cumprod (self.alpha, dim=0)
        self.T = noise_steps
        self.eps_model = Denoiser (noise_steps)

    def q (self, x0, t):
        temp = ein.rearrange (self.alpha_bar[t], "b -> b () () ()")
        mean = torch.sqrt (temp) * x0
        var = 1 - temp
        return mean, var
        
    def sample_q (self, x0, t, epsilon=None):
        if epsilon is None:
            epsilon = torch.randn_like (x0)

        mean, var = self.q (x0, t)
        return mean + torch.sqrt (var) * epsilon

    def p_sample (self, xt, t):
        eps_theta = self.eps_model (xt, t)
        alpha_bar = ein.rearrange (self.alpha_bar[t], "b -> b () () ()")
        alpha = ein.rearrange (self.alpha[t], "b -> b () () ()") 
        eps_coef = (1 - alpha) / torch.sqrt ((1 - alpha_bar))
        mean = (1 / torch.sqrt (alpha)) * (xt - eps_coef * eps_theta) 
        var = self.beta[t]
        epsilon = torch.randn (xt.shape, device=xt.device)
        return mean + torch.sqrt (var) * epsilon

    def loss (self, x0, noise=None):
        t = torch.randint (0, self.T, (x0.size(0),), device=x0.device, dtype=torch.long)
        if noise is None:
            noise = torch.randn_like (x0)
        xt = self.sample_q (x0, t, noise)
        eps_theta = self.eps_model (xt, t)
        return F.mse_loss (noise, eps_theta)


    

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Diffusion (device=device).to(device)
optimizer = optim.Adam(model.parameters(), lr=5e-4)

batch_size = 256
# Get train and test data
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)


In [None]:
def train (epoch):
    model.train ()
    train_loss = 0
    for batch_idx, (data, _) in enumerate (train_loader):
        data = data.to (device)
        optimizer.zero_grad ()
        loss = model.loss (data)
        loss.backward ()
        train_loss += loss.item ()
        optimizer.step ()
    print ('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format (
        epoch, batch_idx * len (data), len (train_loader.dataset),
        100. * batch_idx / len (train_loader),
        loss.item () / len (data)))

def test (epoch):
    model.train ()
    test_loss = 0
    with torch.no_grad ():
        for i, (data, _) in enumerate (test_loader):
            data = data.to (device)
            loss = model.loss (data)
            test_loss += loss.item ()

    test_loss /= len (test_loader.dataset)
    print ('====> Test set loss: {:.4f}'.format (test_loss))

In [None]:
epochs = 10
for epoch in range(1, epochs + 1):
    train(epoch)
    test(epoch)

In [None]:
torch.save (model, "diffusion.pth")