# Visualize selected marginals for different models

In [None]:
# Import packages
import torch
import numpy as np
from scipy import stats

import boltzgen as bg

from matplotlib import pyplot as plt

from tqdm import tqdm

In [None]:
# Create model for transform

# Specify checkpoint root
checkpoint_root = 'rnvp/'
# Load config
config = bg.utils.get_config(checkpoint_root + 'config/bm.yaml')
# Setup model
model = bg.BoltzmannGenerator(config)
# Move model on GPU if available
enable_cuda = False
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
model = model.to(device)
model = model.double()

In [None]:
# Get test data
test_data = bg.utils.load_traj('data/test.h5')

In [None]:
# Load model samples
prefix = ['samples/alpha_0_no_scale_01/samples_batch_num_0_processID_',
          'samples/alpha_0_scale_02/samples_batch_num_0_processID_',
          'samples/alpha_0_grid_search_samples_batch_num_0_processID_',
          'samples/alpha_0_train_acc_prob_samples_batch_num_0_processID_']

z_np = np.zeros((len(prefix) + 1, 1024 * 1024, 60))
for j in range(len(prefix)):
    for i in tqdm(range(1024)):
        x_np = np.load(prefix[j] + str(i) + '.npy')
        x = torch.tensor(x_np)
        z, _ = model.flows[-1].inverse(x)
        z_np_ = z.numpy()
        z_np[j + 1, (i * 1024):((i + 1) * 1024), :] = z_np_
z_np = z_np[:, :1000000, :]
z, _ = model.flows[-1].inverse(test_data)
z_np[0, :, :] = z.numpy()

In [None]:
# Get marginals via KDE
int_range = [-np.pi, np.pi]
npoints = 150
x = np.linspace(int_range[0], int_range[1], npoints)
kde_marg = np.zeros((len(z_np), npoints, 60))
for i in range(len(z_np)):
    for j in tqdm(range(60)):
        kde = stats.gaussian_kde(z_np[i, np.logical_not(np.isnan(z_np[i, :, j])), j])
        kde_marg[i, :, j] = kde.pdf(x)

In [None]:
ind_marg = np.array([[22, 43, 58], [9, 33, 45], [32, 53, 11], [1, 2, 7]])
ylabel = ['Bond angles', 'Bond lengths', 'Dihedral angles', 'Cartesian coordinates']
f, ax = plt.subplots(4, 3, figsize=(15, 20), sharex=True)
lines = [None] * len(kde_marg)
for i in range(ind_marg.shape[0]):
    for j in range(ind_marg.shape[1]):
        for k in range(len(kde_marg)):
            lines[k], = ax[i, j].plot(x, kde_marg[k, :, ind_marg[i, j]])
        ax[i, j].set_yticks([])
        ax[i, j].tick_params(axis='x', which='both', labelsize=18)
        if j == 0:
            ax[i, j].set_ylabel(ylabel[i], fontsize=22)
f.legend(lines, ['Ground truth', 'maxELT', 'maxELT & SKSD', 'Grid search', '$\overline{p}_a=0.65$'], 
         bbox_to_anchor=(0.905, 0.885), fontsize=16)
plt.show()

In [None]:
# Get indices of the groups
ncarts = model.flows[-1].mixed_transform.len_cart_inds
permute_inv = model.flows[-1].mixed_transform.permute_inv
bond_ind = model.flows[-1].mixed_transform.ic_transform.bond_indices
angle_ind = model.flows[-1].mixed_transform.ic_transform.angle_indices
dih_ind = model.flows[-1].mixed_transform.ic_transform.dih_indices

ind_perm = np.concatenate([np.arange(3 * ncarts - 6), np.arange(60, 66), np.arange(3 * ncarts - 6, 60)])
ind = ind_perm[permute_inv]

print(ind[bond_ind])
print(ind[angle_ind])
print(ind[dih_ind])
