In [1]:
## Standard libraries
import os
import math
import numpy as np 
import time

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline 
import seaborn as sns
# Set Seaborn style
sns.set(style='darkgrid', font_scale=1.2)
from sklearn.datasets import make_moons

## Progress bar
from tqdm.notebook import tqdm

import torch
print("Using torch", torch.__version__)
#torch.manual_seed(42) # Setting the seed
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from torch.utils.tensorboard import SummaryWriter
import torch.profiler

Using torch 2.1.0


In [2]:
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )
    
    def forward(self, x):
        return self.layers(x)

class EBM(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )
    
    def forward(self, x):
        # Returns grad_z[ f_a(z) ]
        return self.layers(x)

In [3]:
class PriorSampler():
    def __init__(self, K, s, device):
        self.device = device
        self.K = torch.tensor(K, device=self.device)
        self.s = torch.tensor(s, device=self.device)
    
    def get_sample(self, x, EBM):
        z_k = torch.randn_like(x, device=self.device)
        step=0
        while step < self.K:
            z_k = z_k + self.s * (EBM(z_k)-z_k) + (torch.sqrt(2*self.s)*torch.randn_like(z_k, device=self.device))
            step += 1
        
        return z_k

class PosteriorSamples():
    def __init__(self, K, s, device):
        self.device = device
        self.K = torch.tensor(K, device=self.device)
        self.s = torch.tensor(s, device=self.device)
    
    def get_sample(self, x, z_prior, GEN):
        # Sample from prior
        z_k = z_prior
        step=0
        while step < self.K:
            g_k = GEN(z_k)
            
            # x = g(z) + eps
            x_k = g_k + torch.randn_like(x, device=self.device) 
            z_k = z_k + self.s * -2*(x_k - g_k) + (torch.sqrt(2*self.s)*torch.randn_like(z_k, device=self.device))
            step += 1
        
        return z_k

In [4]:
def loss_function(x, z, GENnet):
    log_prior = -0.5 * torch.square(z)
    log_likelihood = -0.5 * torch.square(x - GENnet(z))
    return torch.mean(log_prior + log_likelihood)

In [9]:
NUM_EPOCHS = 3000
BATCH_SIZE = 128
LR = 1e-3
K_prior = 20
K_posterior = 20
SAMPLES=1000

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the moons dataset
X, _ = make_moons(n_samples=SAMPLES, noise=0.05, random_state=42)
X = torch.tensor(X).float().to(device)


EBMnet = EBM(input_dim=X.shape[1], output_dim=X.shape[1]).to(device)
prior_smp = PriorSampler(K_prior, 0.1, device)
post_smp = PosteriorSamples(K_posterior, 0.1, device)
GENnet = Generator(input_dim=X.shape[1], output_dim=X.shape[1]).to(device)

# Create DataLoader to effectively load data from the above dataset in batches
loader = DataLoader(X, batch_size=BATCH_SIZE, shuffle=True)

loss_fn = loss_function
optimiserEMB = torch.optim.Adam(EBMnet.parameters(), lr=LR, amsgrad=True)
optimiserGEN = torch.optim.Adam(GENnet.parameters(), lr=LR, amsgrad=True)

# Determine the number of batches
num_batches = (len(X) - 1) // BATCH_SIZE + 1

# Write to tensorboard 10 times
sample_every = NUM_EPOCHS//10
writer = SummaryWriter(f"runs/VanillaEBM")
num_plots = (NUM_EPOCHS // sample_every) - 1
num_cols = min(5, num_plots)  # Maximum of 2 columns
num_rows = (num_plots - 1) // num_cols + 1

# Create a figure with subplots
fig, axs = plt.subplots(num_rows, num_cols, figsize=(15, 5 * num_rows))
fig.suptitle("Generated Samples")

with torch.profiler.profile(
        schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
        on_trace_ready=torch.profiler.tensorboard_trace_handler('./runs/VanillaEBM/profilerlogs'),
        record_shapes=True,
        profile_memory=True,
        with_stack=True
) as prof:
    for epoch in tqdm(range(NUM_EPOCHS)):
        for batch_idx, x in enumerate(loader):  # Load real data in batches
            with prof:
                batch_size = x.shape[0]

                # 1. Forward Pass -- Sample from exponentially-tilted prior
                z_0 = prior_smp.get_sample(x, EBMnet)
                z_K = post_smp.get_sample(x, z_0, GENnet)

                # 2. Compute loss
                loss = loss_fn(x, z_K, GENnet)

            # 3. Backward Pass
            optimiserEMB.zero_grad()
            optimiserGEN.zero_grad()

            with prof:
                loss.backward()

            # 4. Update model
            with prof:
                optimiserEMB.step()
                optimiserGEN.step()

        if epoch % sample_every == 0 or epoch == NUM_EPOCHS:
            with torch.no_grad():
                z = prior_smp.get_sample(x, EBMnet)
                samples = GENnet(z).detach().cpu()

                with prof:
                        plot_num = (epoch // sample_every) - 1
                        row = plot_num // num_cols
                        col = (plot_num % num_cols)

                        # Plot x0 on the specified subplot
                        sns.scatterplot(x=samples[:, 0].cpu().numpy(), y=samples[:, 1].cpu().numpy(),
                        color='red', marker='o', ax=axs[row, col])

                        axs[row, col].set_title(f'Epoch: {epoch}')

                        # Convert the Matplotlib figure to a NumPy array
                        fig.canvas.draw()
                        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
                        image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))

                        # Write the image to TensorBoard
                        writer.add_image("Vanilla EBM -- Make Moons", image, global_step=epoch, dataformats='HWC')


      

  0%|          | 0/3000 [00:00<?, ?it/s]

  if pd.api.types.is_categorical_dtype(vector):
  if pd.api.types.is_categorical_dtype(vector):
  image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
