In [None]:
import os
import subprocess

def git_repo_root():
    # Run the 'git rev-parse --show-toplevel' command to get the root directory of the Git repository
    try:
        root = subprocess.check_output(['git', 'rev-parse', '--show-toplevel'], universal_newlines=True).strip()
        return root
    except subprocess.CalledProcessError:
        # Handle the case where the current directory is not inside a Git repository
        return None

# Get the root directory of the Git repository
git_root = git_repo_root()

if git_root:
    # Change the working directory to the root of the Git repository
    os.chdir(git_root)
    print(f"Changed working directory to: {git_root}")
else:
    print("Not inside a Git repository.")

In [None]:
%load_ext autoreload
%autoreload 2

from diffusion import VPSDE
from data import generate_mixture_gaussians

epochs = 5000
# Make sure our diffusion process actually diffuses the data
data = generate_mixture_gaussians()
num_steps = 250

In [None]:
from training import loss_function
import torch
from diffusion import match_dim
from data import log_likelihood_mixture_gaussians_batch
from tqdm import tqdm
from matplotlib import pyplot as plt
import numpy as np
from prodigyopt import Prodigy

def train_score_network(dataloader, score_net, sde, epochs=epochs, bridge=False, optimizer_str='adam', learning_rate=1e-4):
    """
    Trains the score network

    """
    optimizer = None
    if optimizer_str == 'adam':
        optimizer = torch.optim.Adam(score_net.parameters(), lr=learning_rate)
    else:
        optimizer = Prodigy(score_net.parameters())
    avg = 0
    epoch_ll = []
    ll_mean = []
    ll_std = []
    for epoch in tqdm(range(epochs)):
        for x_batch, in dataloader:
            optimizer.zero_grad()
            loss = loss_function(score_net, x_batch, sde, bridge=bridge)
            loss.backward()
            # nn.utils.clip_grad_norm_(score_net.parameters(), 1.0)
            optimizer.step()
            avg += loss
        
        if epoch % 10 == 0:
            samples = sde.backward_diffusion(score_net, data_shape=(1000, 2)).detach()
            lls = log_likelihood_mixture_gaussians_batch(samples).numpy()
            ll_mean.append(lls.mean())
            ll_std.append(lls.std())
            epoch_ll.append(epoch)


        if (((epoch + 1) % 500 == 0 and epoch != 0) or epoch == epochs-1):
            tqdm.write(f'Epoch: {epoch} and Loss: {avg/(8*1000)}')
            avg = 0
            # samples = sde.backward_diffusion(score_net, data_shape=(1000, 2))
            # data = x_batch.detach().numpy()
            # samples_np = samples.detach().numpy()
            # lls = log_likelihood_mixture_gaussians_batch(samples)
            # tqdm.write(f'Log Likelihood mean: {lls.mean()} and std: {lls.std()}')
            # plt.scatter(data[:, 0], data[:, 1], label='Original Data')
            # plt.scatter(samples_np[:, 0], samples_np[:,1], label='Generated Samples')
            # plt.legend()
            # plt.show()
    
    # plot log-likelihood mean across epochs with std shaded on top
    return ll_mean, ll_std, epoch_ll
    


In [None]:
from torch.utils.data import DataLoader, TensorDataset
from model import MLP


data = generate_mixture_gaussians(num_samples=4000)
dataloader = DataLoader(TensorDataset(data), batch_size=500, shuffle=True)

sde = VPSDE(num_steps, 0.1, 20, logarithmic_scheduling=False)
score_net = MLP()
ll_mean_alin, ll_std_alin, epoch_ll_alin = train_score_network(dataloader, score_net, sde, optimizer_str='adam', learning_rate=1e-4)


sde = VPSDE(num_steps, 0.1, 20, logarithmic_scheduling=True)
score_net = MLP()
ll_mean_alog, ll_std_alog, epoch_ll_alog = train_score_network(dataloader, score_net, sde, optimizer_str='adam', learning_rate=1e-4)


sde = VPSDE(num_steps, 0.1, 20, logarithmic_scheduling=False)
score_net = MLP()
ll_mean_plin, ll_std_plin, epoch_ll_plin = train_score_network(dataloader, score_net, sde, optimizer_str='prodigy')


sde = VPSDE(num_steps, 0.1, 20, logarithmic_scheduling=True)
score_net = MLP()
ll_mean_plog, ll_std_plog, epoch_ll_plog  = train_score_network(dataloader, score_net, sde, optimizer_str='prodigy')

# save the results
import pickle
results = {
    'll_mean_alin': ll_mean_alin,
    'll_std_alin': ll_std_alin,
    'epoch_ll_alin': epoch_ll_alin,
    'll_mean_alog': ll_mean_alog,
    'll_std_alog': ll_std_alog,
    'epoch_ll_alog': epoch_ll_alog,
    'll_mean_plin': ll_mean_plin,
    'll_std_plin': ll_std_plin,
    'epoch_ll_plin': epoch_ll_plin,
    'll_mean_plog': ll_mean_plog,
    'll_std_plog': ll_std_plog,
    'epoch_ll_plog': epoch_ll_plog
}
pickle.dump(results, open('llvsepochs.pkl', 'wb'))


In [None]:
# Load results
import pickle
results = pickle.load(open('llvsepochs.pkl', 'rb'))
ll_mean_alin = results['ll_mean_alin']
ll_std_alin = results['ll_std_alin']
epoch_ll_alin = results['epoch_ll_alin']
ll_mean_alog = results['ll_mean_alog']
ll_std_alog = results['ll_std_alog']
epoch_ll_alog = results['epoch_ll_alog']
ll_mean_plin = results['ll_mean_plin']
ll_std_plin = results['ll_std_plin']
epoch_ll_plin = results['epoch_ll_plin']
ll_mean_plog = results['ll_mean_plog']
ll_std_plog = results['ll_std_plog']
epoch_ll_plog = results['epoch_ll_plog']

from matplotlib import pyplot as plt
import numpy as np
# Using seaborn's style
print(plt.style.available)
plt.style.use('seaborn-v0_8')
# width = 345

tex_fonts = {
    # Use LaTeX to write all text
    "text.usetex": True,
    "font.family": "serif",
}

plt.rcParams.update(tex_fonts)

def plot_ll_vs_epochs(ll_mean, ll_std, epoch_ll, label):
    plt.plot(epoch_ll, ll_mean, label=label)
    plt.fill_between(epoch_ll, np.array(ll_mean) - np.array(ll_std), np.array(ll_mean) + np.array(ll_std), alpha=0.3)

plt.figure(figsize=(8, 4))
plot_ll_vs_epochs(ll_mean_alin, ll_std_alin, epoch_ll_alin, 'Adam Linear')
plot_ll_vs_epochs(ll_mean_alog, ll_std_alog, epoch_ll_alog, 'Adam Logarithmic')
plot_ll_vs_epochs(ll_mean_plin, ll_std_plin, epoch_ll_plin, 'Prodigy Linear')
plot_ll_vs_epochs(ll_mean_plog, ll_std_plog, epoch_ll_plog, 'Prodigy Logarithmic')


plt.xlim(0, 5000)
plt.ylim(-8, 0)
plt.ylabel('Mean Log Likelihood')
plt.xlabel('Epochs')
plt.legend()
plt.tight_layout()
plt.savefig('ll_vs_epochs.pdf')
plt.show()