# Visualize selected marginals for different models

In [None]:
# Import packages
import torch
import numpy as np
from scipy import stats

import boltzgen as bg
import mdtraj as md

import matplotlib as mpl
from matplotlib import pyplot as plt

from tqdm import tqdm

In [None]:
# Create model for transform

# Specify checkpoint root
checkpoint_root = 'models/rnvp_01/'
# Load config
config = bg.utils.get_config(checkpoint_root + 'config/bm.yaml')
# Setup model
model = bg.BoltzmannGenerator(config)
# Move model on GPU if available
enable_cuda = False
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
model = model.to(device)
model = model.double()

In [None]:
# Load checkpoint
model.load(checkpoint_root + 'checkpoints/model_30000.pt')

In [None]:
# Get test data
test_data = bg.utils.load_traj('data/trajectory/aldp_test.h5')

In [None]:
# Load model samples
"""
prefix = ['data/samples/20200903/alpha_1_no_scale_01/samples_batch_num_0_processID_',
          'data/samples/20200825/alpha_1_scale_02/samples_batch_num_0_processID_',
          'data/samples/20200906_baselines/grid_search/alpha_1/alpha_1_grid_search_samples_batch_num_0_processID_',
          'data/samples/20200906_baselines/train_acc_prob/alpha_1/alpha_1_train_acc_prob_samples_batch_num_0_processID_']
prefix = ['data/samples/20200903/alpha_0_no_scale_01/samples_batch_num_0_processID_',
          'data/samples/20200825/alpha_0_scale_02/samples_batch_num_0_processID_',
          'data/samples/20200922/alpha_0_grid_search_samples_batch_num_0_processID_',
          'data/samples/20200922/alpha_0_train_acc_prob_samples_batch_num_0_processID_']
prefix = ['data/samples/20200903/md_no_scale_01/samples_batch_num_0_processID_',
          'data/samples/20200825/md_scale_02/samples_batch_num_0_processID_',
          'data/samples/20200906_baselines/grid_search/md/md_grid_search_samples_batch_num_0_processID_',
          'data/samples/20200906_baselines/train_acc_prob/md/md_train_acc_prob_samples_batch_num_0_processID_']
prefix = ['data/samples/20210115/alpha_0_no_scale_02/0250/samples_batch_num_0_processID_',
          'data/samples/20210115/alpha_0_scale_03/0250/samples_batch_num_0_processID_',
          'data/samples/20200922/alpha_0_grid_search_samples_batch_num_0_processID_',
          'data/samples/20200922/alpha_0_train_acc_prob_samples_batch_num_0_processID_']    
"""
prefix = ['data/samples/20210115/alpha_1_no_scale_02/samples_batch_num_0_processID_',
          'data/samples/20210115/alpha_1_scale_03/samples_batch_num_0_processID_',
          'data/samples/20200906_baselines/grid_search/alpha_1/alpha_1_grid_search_samples_batch_num_0_processID_',
          'data/samples/20200906_baselines/train_acc_prob/alpha_1/alpha_1_train_acc_prob_samples_batch_num_0_processID_']

x_np = np.zeros((len(prefix) + 1, 1024 * 1024, 66))
z_np = np.zeros((len(prefix) + 1, 1024 * 1024, 60))
for j in range(len(prefix)):
    for i in tqdm(range(1024)):
        x_np_ = np.load(prefix[j] + str(i) + '.npy')
        x_np[j + 1, (i * 1024):((i + 1) * 1024), :] = x_np_
        x = torch.tensor(x_np_)
        z, _ = model.flows[-1].inverse(x)
        z_np_ = z.numpy()
        z_np[j + 1, (i * 1024):((i + 1) * 1024), :] = z_np_
x_np = x_np[:, :1000000, :]
z_np = z_np[:, :1000000, :]
x_np[0, :, :] = test_data.numpy()
z, _ = model.flows[-1].inverse(test_data)
z_np[0, :, :] = z.numpy()

In [None]:
# Get marginals via KDE
int_range = [-np.pi, np.pi]
npoints = 150
x = np.linspace(int_range[0], int_range[1], npoints)
kde_marg = np.zeros((len(z_np), npoints, 60))
for i in range(len(z_np)):
    for j in tqdm(range(60)):
        kde = stats.gaussian_kde(z_np[i, np.logical_not(np.isnan(z_np[i, :, j])), j])
        kde_marg[i, :, j] = kde.pdf(x)

In [None]:
ind_marg = np.array([[22, 43, 58], [9, 33, 45], [32, 53, 11], [1, 2, 7]])
ylabel = ['Bond angles', 'Bond lengths', 'Dihedral angles', 'Cartesian coordinates']
f, ax = plt.subplots(4, 3, figsize=(15, 20), sharex=True)
lines = [None] * len(kde_marg)
for i in range(ind_marg.shape[0]):
    for j in range(ind_marg.shape[1]):
        for k in range(len(kde_marg)):
            lines[k], = ax[i, j].plot(x, kde_marg[k, :, ind_marg[i, j]])
        ax[i, j].set_yticks([])
        ax[i, j].tick_params(axis='x', which='both', labelsize=18)
        if j == 0:
            ax[i, j].set_ylabel(ylabel[i], fontsize=22)
f.legend(lines, ['Ground truth', 'maxELT', 'maxELT & SKSD', 'Grid search', '$\overline{p}_a=0.65$'], 
         bbox_to_anchor=(0.905, 0.885), fontsize=16)
plt.savefig('plots/marginals/alpha1.eps')
plt.show()

In [None]:
# Get indices of the groups
ncarts = model.flows[-1].mixed_transform.len_cart_inds
permute_inv = model.flows[-1].mixed_transform.permute_inv
bond_ind = model.flows[-1].mixed_transform.ic_transform.bond_indices
angle_ind = model.flows[-1].mixed_transform.ic_transform.angle_indices
dih_ind = model.flows[-1].mixed_transform.ic_transform.dih_indices

ind_perm = np.concatenate([np.arange(3 * ncarts - 6), np.arange(60, 66), np.arange(3 * ncarts - 6, 60)])
ind = ind_perm[permute_inv]

print(ind[bond_ind])
print(ind[angle_ind])
print(ind[dih_ind])

## Ramachandran plot

In [None]:
ala2_top = md.load('code/snf_noe/data/alanine-dipeptide.pdb').topology
traj = [md.Trajectory(x_np_.reshape(-1, 22, 3), ala2_top) for x_np_ in x_np] 

In [None]:
psi = np.array([md.compute_psi(traj_)[1].reshape(-1) for traj_ in traj])
psi[np.isnan(psi)] = 0
phi = np.array([md.compute_phi(traj_)[1].reshape(-1) for traj_ in traj])
phi[np.isnan(phi)] = 0

In [None]:
pref = 'plots/ramachandran/md/'
file_name = ['md', 'init'] #['md', 'maxelt', 'maxelt_sksd', 'grid_search', 'acc_prob']

for i in range(len(phi)):
    plt.figure(figsize=(10, 10))
    plt.hist2d(phi[i, :], psi[i, :], bins=64, norm=mpl.colors.LogNorm())
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.xlabel('$\phi$', fontsize=24)
    plt.ylabel('$\psi$', fontsize=24)
    plt.savefig(pref + file_name[i] + '.png')
    plt.show()

In [None]:
nbins = 64
eps = 1e-10

hist = []
for i in range(len(phi)):
    hist.append(np.histogram2d(phi[i, :], psi[i, :], nbins, density=True)[0])

kld = []
for i in range(1, len(phi)):
    kld.append(np.sum(hist[0] * np.log((hist[0] + eps) / (hist[i] + eps))) * 4 * np.pi ** 2 / nbins ** 2)

In [None]:
kld

In [None]:
a = np.loadtxt('results/ramachandran/kld.csv',
               skiprows=1, usecols=(1, 2, 3, 4, 5), delimiter=',')

In [None]:
plt.plot(a[0, :4], '.')
plt.plot(a[1, :4], '.')
plt.plot(a[2, :4], '.')
plt.show()

In [None]:
m2

In [None]:
t1 = np.loadtxt('models/alpha_0_scale_04/checkpoints/trainprog_hmc_ei_sksd_ckpt_00700.txt')
t2 = np.loadtxt('models/alpha_0_scale_03/checkpoints/trainprog_hmc_ei_sksd_ckpt_02000.txt')

In [None]:
plt.plot(t2[:, 0])

In [None]:
for i in range(len(psi)):
    plt.hist(psi[i, :], bins=200)
    #plt.savefig('psi_' + file_name[i] + '.png')
    plt.show()

In [None]:
for i in range(len(phi)):
    plt.hist(phi[i, :], bins=200)
    #plt.savefig('phi_' + file_name[i] + '.png')
    plt.show()

In [None]:
# Draw samples

nth = 1

model.eval()

z_np_ = np.zeros((0, 60))
x_np_ = np.zeros((0, 66))
log_p_np = np.zeros((0,))
log_q_np = np.zeros((0,))

for i in tqdm(range(1000 // nth)):
    z, log_q = model.sample(1000)
    x_np_ = np.concatenate((x_np_, z.cpu().data.numpy()))
    log_p = model.p.log_prob(z)
    z, _ = model.flows[-1].inverse(z)
    z_np__ = z.cpu().data.numpy()
    log_p_np_ = log_p.cpu().data.numpy()
    log_q_np_ = log_q.cpu().data.numpy()
    z_np_ = np.concatenate((z_np_, z_np__))
    log_p_np = np.concatenate((log_p_np, log_p_np_))
    log_q_np = np.concatenate((log_q_np, log_q_np_))

In [None]:
x_np = np.concatenate((x_np, x_np_[None, ...]))

In [None]:
#ala2_top = md.load('code/snf_noe/data/alanine-dipeptide.pdb')
traj = md.Trajectory(x_np.reshape(-1, 22, 3), ala2_top)
psi = np.array(md.compute_psi(traj)[1].reshape(-1))
psi[np.isnan(psi)] = 0
phi = np.array(md.compute_phi(traj)[1].reshape(-1))
phi[np.isnan(phi)] = 0

In [None]:
plt.hist(phi, bins=200)
plt.savefig('phi_init.png')
plt.show()