# DDPM with Pytorch for Adult.

This is our first attempt at implementing a DDGM using Python and Pytorch. Our basis is the code from [here](https://github.com/jmtomczak/intro_dgm/blob/main/ddgms/ddgm_example.ipynb).

In [14]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn import metrics # plot_roc_curve.
from sklearn.model_selection import train_test_split # Train/test/validation split of data.
import sklearn.preprocessing as preprocessing
import random 

# Pytorch imports
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader 

# Configure the device 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using '{device}' device.")

# Print working directory (for control)
import os
print(f"The working directory is {os.getcwd()}")

# Set seeds for reproducibility. 
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

Using 'cuda' device.
The working directory is /home/ajo/gitRepos/master_thesis


In [2]:
# Load Adult data. 

In [16]:
class DDPM(nn.Module):
    def __init__(self, p_dnns, decoder_net, beta, T, D):
        super(DDPM, self).__init__()
        self.p_dnns = p_dnns # A list of sequentials: A single Sequential defines a DNN to 
                             # parameterize a distribution p(x_i|x_{i+1}) (reverse process).
        self.decoder_net = decoder_net # The last MLP for p(x_0|x_1).
        # I would prefer to define it here later. We can make this change after implementing the rest. 
        
        self.D = D # Dimensionality of inputs (necessary for sampling).
        self.T = T # Number of steps (or latent variables).
        self.beta = torch.FloatTensor([beta]) # Betas for forward process ("encoding"). 
                                              # Is essentially the fixed variance for diffusion.
        
    @staticmethod
    def reparameterization(mu, log_var):
        """Reparameterization for Gaussian distribution."""
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return mu + std * eps
    
    def reparameterization_gaussian_diffusion(self, x, i):
        """Reparameterization for Gaussian forward diffusion.
        
        This is reparameterization following the iterative forward equation:
        q(x_t|x_{t-1}) = N(sqrt(1-beta) * x_{t-1}, beta * I)
    
        """
        # i-argumentet brukes ikke her derimot, det må jo mangle!
        # Mulig indeksen skal inn i self.beta tenker jeg, for å kun få en skalar som retur-verdi!
        # Sånn må det nok være for at forward diffusjonen i forward-funksjonen skal bli rett!
        return torch.sqrt(1.0 - self.beta)*x + torch.sqrt(self.beta)*torch.randn_like(x)
    
    def forward(self, x, reduction = "avg"):
        """Forward process in the neural net."""
        
        ######## Forward diffusion process.
        # The code's original author notes that we "just wander around in the space using Gaussian random walk".
        
        # Save the latent variables x_1, \ldots, x_T in a list. 
        latents = [self.reparameterization_gaussian_diffusion(x, 0)]
        for i in range(1, self.T):
            latents.append(self.reparameterization_gaussian_diffusion(latents[-1], i))
        # This should obviously be extended with the closed formula for any time step
        # using alpha_bar etc, since that will give a lot less of a computational burden.
            
        ######## Backward diffusion process.
        mus = []
        log_vars = []
        
        for i in range(len(self.p_dnns) -1 , -1, -1):
            h = self.p_dnns[i](latents[i+1]) # predict p(x_i|x_{i+1}).
            mu_i, log_var_i = torch.chunk(h, 2, dim = 1) # Get mu and log_var from the prediction.
                            # This model predicts the variances as well, which is fixed in DDPM paper.
            mus.append(mu_i) # Save the mu_i.
            log_vars.append(log_var_i) # Save the log_var_i.
            
        # The last step, i.e. p(x_0|x_1):
        # We assume the last distribution is Normal(x | tanh(NN(x_1)), 1).
        # This assumptions is apparent from the decoder net that will be defined later. 
        mu_x = self.decoder_net(latents[0])
        
        
        ######## ELBO.
        # Reconstruction error. Equal to -MSE(x,mu_x) + constant.
        RE = log_standard_normal(x - mu_x).sum(-1)
        
        # KL divergence. We need to calculate this for all levels of latents. 
        KL = (log_normal_diag(latents[-1], torch.sqrt(1.0 - self.beta) \
                * latents[-1], torch.log(self.beta)) - log_standard_normal(latents[-1])).sum(-1)
        
        for i in range(len(mus)):
            KL_i = (log_normal_diag(latents[i], torch.sqrt (1.0 - self.beta) \
                    * latents[i], torch.log(self.beta)) - log_normal_diag(latents[i], \
                                                mus[i], log_vars[i])).sum(-1)
        
        # Final ELBO.
        if reduction == "sum":
            loss = -(RE - KL).sum()
        else:
            loss = -(RE - KL).mean()
        return loss
        
        
    def sample(self, batch_size = 64):
        """Sample from the model. Follow backward diffusion model to the beginning."""
        z = torch.randn([batch_size, self.D])
        for i in range(len(self.p_dnns)-1, -1, -1):
            h = self.p_dnns[i](z)
            mu_i, log_var_i = torch.chunk(h, 2, dim = 1)
            z = self.reparameterization(torch.tanh(mu_i), log_var_i)
            
        mu_x = self.decoder_net(z)
        return mu_x
    
    
    def sample_diffusion(self, x):
        """Sample from last latent after forward diffusion ('sanity check').
        
        This should resemble white noise, since we are returning x_T.
        """
        latents = [self.reparameterization_gaussian_diffusion(x, 0)]

        for i in range(1, self.T):
            latents.append(self.reparameterization_gaussian_diffusion(latents[-1], i))

        return latents[-1]

## This implementation simply goes through the entire foward diffusion chain and backward chain (all latents) in each training iteration.

This makes it extremely simplistic and useless in practice. Does not use any time embeddings because of this for example, since it trains a different MLP for each latent x_1, \ldots, x_T. 


However, take some inspiration from the training loop the authors has built, I though it was very smart! He has incorporated some validation losses, as well as model saving on best and early stopping after some patience parameter number of epochs.