# KL divergence of model with respect to the empirical distribution

In [None]:
# Import packages
import torch
import numpy as np
from scipy import stats

import boltzgen as bg

from matplotlib import pyplot as plt

from tqdm import tqdm

In [None]:
# Specify checkpoint root
checkpoint_root = 'rnvp'

In [None]:
# Load config
config = bg.utils.get_config(checkpoint_root + 'config/bm.yaml')

In [None]:
# Load data
training_data = bg.utils.load_traj('data/train.h5')
test_data = bg.utils.load_traj('data/test.h5')

In [None]:
# Setup model
model = bg.BoltzmannGenerator(config)

# Move model on GPU if available
enable_cuda = False
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
model = model.to(device)
model = model.double()

In [None]:
# Plot loss
loss = np.loadtxt(checkpoint_root + 'log/loss.csv')
plt.figure(figsize=(15, 10))
plt.plot(loss, '.')
plt.ylim(-190, -160)
plt.show()

In [None]:
# Load checkpoint
model.load(checkpoint_root + 'checkpoints/model_30000.pt')

In [None]:
# Draw samples

nth = 1

model.eval()

z_np = np.zeros((0, 60))
x_np = np.zeros((0, 66))
log_p_np = np.zeros((0,))
log_q_np = np.zeros((0,))

for i in tqdm(range(1000 // nth)):
    z, log_q = model.sample(1000)
    x_np = np.concatenate((x_np, z.cpu().data.numpy()))
    log_p = model.p.log_prob(z)
    z, _ = model.flows[-1].inverse(z)
    z_np_ = z.cpu().data.numpy()
    log_p_np_ = log_p.cpu().data.numpy()
    log_q_np_ = log_q.cpu().data.numpy()
    z_np = np.concatenate((z_np, z_np_))
    log_p_np = np.concatenate((log_p_np, log_p_np_))
    log_q_np = np.concatenate((log_q_np, log_q_np_))


#z_d = training_data[::nth].double().to(device)
z_d = test_data[::nth].double().to(device)
log_p_d = model.p.log_prob(z_d)
z_d, _ = model.flows[-1].inverse(z_d)
z_d_np = z_d.cpu().data.numpy()

log_p_d_np = log_p_d.cpu().data.numpy()

## Use histogram to compute KLD

In [None]:
# Estimate density
nbins = 200
hist_range = [-5, 5]
ndims = z_np.shape[1]

hists_train = np.zeros((nbins, ndims))
hists_gen = np.zeros((nbins, ndims))

for i in range(ndims):
    htrain, _ = np.histogram(z_d_np[:, i], nbins, range=hist_range, density=True);
    hgen, _ = np.histogram(z_np[:, i], nbins, range=hist_range, density=True);
    
    hists_train[:, i] = htrain
    hists_gen[:, i] = hgen

In [None]:
for i in range(ndims):
    print(i)
    plt.plot(np.linspace(-5, 5, nbins), hists_train[:, i])
    plt.plot(np.linspace(-5, 5, nbins), hists_gen[:, i])
    plt.show()

In [None]:
# Compute KLD
kld = np.zeros(ndims)
eps = 1e-10
kld_unscaled = np.sum(hists_train * np.log((hists_train + eps) / (hists_gen + eps)), axis=0)
kld = kld_unscaled * (hist_range[1] - hist_range[0]) / nbins

In [None]:
np.median(kld)

In [None]:
# Split KLD into groups
ncarts = model.flows[-1].mixed_transform.len_cart_inds
permute_inv = model.flows[-1].mixed_transform.permute_inv
bond_ind = model.flows[-1].mixed_transform.ic_transform.bond_indices
angle_ind = model.flows[-1].mixed_transform.ic_transform.angle_indices
dih_ind = model.flows[-1].mixed_transform.ic_transform.dih_indices

kld_cart = kld[:(3 * ncarts - 6)]
kld_ = np.concatenate([kld[:(3 * ncarts - 6)], np.zeros(6), kld[(3 * ncarts - 6):]])
kld_ = kld_[permute_inv]
kld_bond = kld_[bond_ind]
kld_angle = kld_[angle_ind]
kld_dih = kld_[dih_ind]

In [None]:
# Print resulting KLDs
print('Cartesian coorinates')
print(np.sort(kld_cart))
print('Mean:', np.mean(kld_cart))
print('Median:', np.median(kld_cart))

print('\n\nBond lengths')
print(np.sort(kld_bond))
print('Mean:', np.mean(kld_bond))
print('Median:', np.median(kld_bond))

print('\n\nBond angles')
print(np.sort(kld_angle))
print('Mean:', np.mean(kld_angle))
print('Median:', np.median(kld_angle))

print('\n\nDihedral angles')
print(np.sort(kld_dih))
print('Mean:', np.mean(kld_dih))
print('Median:', np.median(kld_dih))

In [None]:
# Histograms of the groups
hists_train_cart = hists_train[:, :(3 * ncarts - 6)]
hists_train_ = np.concatenate([hists_train[:, :(3 * ncarts - 6)], np.zeros((nbins, 6)),
                               hists_train[:, (3 * ncarts - 6):]], axis=1)
hists_train_ = hists_train_[:, permute_inv]
hists_train_bond = hists_train_[:, bond_ind]
hists_train_angle = hists_train_[:, angle_ind]
hists_train_dih = hists_train_[:, dih_ind]

for hists in [hists_train_cart, hists_train_bond, hists_train_angle, hists_train_dih]:
    for i in range(hists.shape[1]):
        print(i)
        plt.plot(np.linspace(-5, 5, nbins), hists[:, i])
        plt.show()

## Use Gaussian KDE to compute KLD

In [None]:
# Estimate density
ndims = z_np.shape[1]

kde_train = []
kde_gen = []

for i in range(ndims):
    kernel_train = stats.gaussian_kde(z_d_np[:, i])
    kernel_gen = stats.gaussian_kde(z_np[:, i])
    
    kde_train.append(kernel_train)
    kde_gen.append(kernel_gen)

In [None]:
x = np.linspace(-5, 5, 200)
for i in range(ndims):
    print(i)
    plt.plot(x, kde_train[i].pdf(x))
    plt.plot(x, kde_gen[i].pdf(x))
    plt.show()

In [None]:
# Compute KLD
eps = 1e-10
int_range = [-5, 5]
npoints = 1000

kld = np.zeros(ndims)
x = np.linspace(int_range[0], int_range[1], npoints)

for i in tqdm(range(ndims)):
    kld_unscaled = np.sum(kde_train[i].pdf(x) * np.log((kde_train[i].pdf(x) + eps) / (kde_gen[i].pdf(x) + eps)))
    kld[i] = kld_unscaled * (int_range[1] - int_range[0]) / npoints

In [None]:
# Split KLD into groups
ncarts = model.flows[-1].mixed_transform.len_cart_inds
permute_inv = model.flows[-1].mixed_transform.permute_inv
bond_ind = model.flows[-1].mixed_transform.ic_transform.bond_indices
angle_ind = model.flows[-1].mixed_transform.ic_transform.angle_indices
dih_ind = model.flows[-1].mixed_transform.ic_transform.dih_indices

kld_cart = kld[:(3 * ncarts - 6)]
kld_ = np.concatenate([kld[:(3 * ncarts - 6)], np.zeros(6), kld[(3 * ncarts - 6):]])
kld_ = kld_[permute_inv]
kld_bond = kld_[bond_ind]
kld_angle = kld_[angle_ind]
kld_dih = kld_[dih_ind]

In [None]:
# Print resulting KLDs
print('Cartesian coorinates')
print(np.sort(kld_cart))
print('Mean:', np.mean(kld_cart))
print('Median:', np.median(kld_cart))

print('\n\nBond lengths')
print(np.sort(kld_bond))
print('Mean:', np.mean(kld_bond))
print('Median:', np.median(kld_bond))

print('\n\nBond angles')
print(np.sort(kld_angle))
print('Mean:', np.mean(kld_angle))
print('Median:', np.median(kld_angle))

print('\n\nDihedral angles')
print(np.sort(kld_dih))
print('Mean:', np.mean(kld_dih))
print('Median:', np.median(kld_dih))