In [40]:
## Standard libraries
import os
import math
import numpy as np 
import time

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline 
import seaborn as sns
# Set Seaborn style
sns.set(style='darkgrid', font_scale=1.2)
from sklearn.datasets import make_moons

## Progress bar
from tqdm.notebook import tqdm

import torch
print("Using torch", torch.__version__)
#torch.manual_seed(42) # Setting the seed
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from torch.utils.tensorboard import SummaryWriter
import torch.profiler

Using torch 2.1.0


In [41]:
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim, sample_steps, sample_noise):
        self.K = sample_steps
        self.s = sample_noise
        super().__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )
    
    def forward(self, x):
        return self.layers(x)
    


class EBM(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )
    
    def forward(self, x):
        # Returns grad_z[ f_a(z) ]
        return self.layers(x)
    
    def get_prior(self, x):
        z_k = torch.randn_like(x)
        step=0
        while step < self.K:
            z_k = z_k + self.s * (self.forward(z_k)-z_k) + (torch.sqrt(2*self.s)*torch.randn_like(z_k))
            step += 1
        
        return z_k

In [42]:
class PriorSampler():
    def __init__(self, K, s, device):
        self.device = device
        self.K = torch.tensor(K, device=self.device)
        self.s = torch.tensor(s, device=self.device)
    
    def get_sample(self, x, EBM):
        z_k = torch.randn_like(x, device=self.device)
        step=0
        while step < self.K:
            z_k = z_k + self.s * (EBM(z_k)-z_k) + (torch.sqrt(2*self.s)*torch.randn_like(z_k, device=self.device))
            step += 1
        
        return z_k

class PosteriorSamples():
    def __init__(self, K, s, device):
        self.device = device
        self.K = torch.tensor(K, device=self.device)
        self.s = torch.tensor(s, device=self.device)
    
    def get_sample(self, x, z_prior, GEN):
        # Sample from prior
        z_k = z_prior
        step=0
        while step < self.K:
            g_k = GEN(z_k)
            
            # x = g(z) + eps
            x_k = g_k + torch.randn_like(x, device=self.device) 
            z_k = z_k + self.s * -2*(x_k - g_k) + (torch.sqrt(2*self.s)*torch.randn_like(z_k, device=self.device))
            step += 1
        
        return z_k

In [43]:
def loss_function(x, z, GENnet):
    log_prior = -0.5 * torch.square(z)
    log_likelihood = -0.5 * torch.square(x - GENnet(z))
    return torch.mean(log_prior + log_likelihood)

In [44]:
NUM_EPOCHS = 1000
BATCH_SIZE = 128
LR = 1e-3
K_prior = 20
K_posterior = 20

device = "cuda" if torch.cuda.is_available() else "cpu"
EBMnet = EBM(784, 784).to(device)
prior_smp = PriorSampler(K_prior, 0.1, device)
post_smp = PosteriorSamples(K_posterior, 0.1, device)
GENnet = Generator(784, 784).to(device)

# Transforms to apply to dataset. Normalising improves data convergence, numerical stability, and regularisation.
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Create instance of MNIST dataset
dataset = datasets.MNIST(root="dataset/", transform=transform, download=True)

# Create DataLoader to effectively load data from the above dataset in batches
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

loss_fn = loss_function
optimiserEMB = torch.optim.Adam(EBMnet.parameters(), lr=LR, amsgrad=True)
optimiserGEN = torch.optim.Adam(GENnet.parameters(), lr=LR, amsgrad=True)

# Calculate the number of rows and columns for subplots
sample_every = 100
writer = SummaryWriter(f"runs/VanillaEBM")

with torch.profiler.profile(
        schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
        on_trace_ready=torch.profiler.tensorboard_trace_handler('./runs/VanillaEBM/profilerlogs'),
        record_shapes=True,
        profile_memory=True,
        with_stack=True
) as prof:
    for epoch in tqdm(range(NUM_EPOCHS)):
        for batch_idx, (x, _) in enumerate(loader):  # Load real data in batches
            with prof:
                x = x.view(-1, 784).to(device)
                batch_size = x.shape[0]

                # 1. Forward Pass -- Sample from exponentially-tilted prior
                z_0 = prior_smp.get_sample(x, EBMnet)
                z_K = post_smp.get_sample(x, z_0, GENnet)

                # 2. Compute loss
                loss = loss_fn(x, z_K, GENnet)

            # 3. Backward Pass
            optimiserEMB.zero_grad()
            optimiserGEN.zero_grad()

            with prof:
                loss.backward()

            # 4. Update model
            with prof:
                optimiserEMB.step()
                optimiserGEN.step()

        if epoch % sample_every == 0 or epoch == NUM_EPOCHS:
            with torch.no_grad():
                z = prior_smp.get_sample(x, EBMnet)
                samples = GENnet(z).view(-1, 28, 28).detach().cpu()

                with prof:
                    image = torchvision.utils.make_grid(samples.unsqueeze(1), padding=2, normalize=True)
                    writer.add_image("Samples", image, epoch)


      

  0%|          | 0/1000 [00:00<?, ?it/s]