# Analyse trained models

In [None]:
# Import packages
import torch
import numpy as np

import boltzgen as bg

from matplotlib import pyplot as plt

from tqdm import tqdm

In [None]:
# Specify checkpoint root
checkpoint_root = '/draco/u/vstimper/Material_Informatics/boltzmann_generators/models/resampled_09/'

In [None]:
# Load config
config = bg.utils.get_config(checkpoint_root + 'config/bm.yaml')

In [None]:
# Load data
training_data = bg.utils.load_traj('/draco/u/vstimper/Material_Informatics/boltzmann_generators/data/trajectory/aldp_without_const_100000.h5')

In [None]:
# Setup model
model = bg.BoltzmannGenerator(config)

# Move model on GPU if available
enable_cuda = False
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
model = model.to(device)
model = model.double()

In [None]:
# Plot loss
loss = np.loadtxt(checkpoint_root + 'log/loss.csv')
plt.figure(figsize=(15, 10))
plt.plot(loss, '.')
plt.ylim(-210, -150)
plt.show()

In [None]:
# Load checkpoint
model.load(checkpoint_root + 'checkpoints/model_2420000.pt')

In [None]:
# Model to load for comparison
checkpoint_root_ = '/draco/u/vstimper/Material_Informatics/boltzmann_generators/models/rnvp_01/'
config_ = bg.utils.get_config(checkpoint_root_ + 'config/bm.yaml')
model_ = bg.BoltzmannGenerator(config_)

# Move model on GPU if available
enable_cuda = False
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
model_ = model_.to(device)
model_ = model_.double()

model_.load(checkpoint_root_ + 'checkpoints/model_30000.pt')

In [None]:
model.eval()

z_np = np.zeros((0, 60))
x_np = np.zeros((0, 66))
log_p_np = np.zeros((0,))
log_q_np = np.zeros((0,))
z__np = np.zeros((0, 60))
x__np = np.zeros((0, 66))
log_p__np = np.zeros((0,))
log_q__np = np.zeros((0,))
for i in tqdm(range(100)):
    z, log_q = model.sample(1000)
    x_np = np.concatenate((x_np, z.cpu().data.numpy()))
    log_p = model.p.log_prob(z)
    z, _ = model.flows[-1].inverse(z)
    z_np_ = z.cpu().data.numpy()
    log_p_np_ = log_p.cpu().data.numpy()
    log_q_np_ = log_q.cpu().data.numpy()
    z_np = np.concatenate((z_np, z_np_))
    log_p_np = np.concatenate((log_p_np, log_p_np_))
    log_q_np = np.concatenate((log_q_np, log_q_np_))
    
    z, log_q = model_.sample(1000)
    x__np = np.concatenate((x__np, z.cpu().data.numpy()))
    log_p = model_.p.log_prob(z)
    z, _ = model_.flows[-1].inverse(z)
    z_np_ = z.cpu().data.numpy()
    log_p_np_ = log_p.cpu().data.numpy()
    log_q_np_ = log_q.cpu().data.numpy()
    z__np = np.concatenate((z__np, z_np_))
    log_p__np = np.concatenate((log_p__np, log_p_np_))
    log_q__np = np.concatenate((log_q__np, log_q_np_))


z_d = training_data[::1].double().to(device)
log_p_d = model.p.log_prob(z_d)
log_q_d = model.log_prob(z_d)
log_q__d = model_.log_prob(z_d)
z_d, _ = model.flows[-1].inverse(z_d)
z_d_np = z_d.cpu().data.numpy()

log_p_d_np = log_p_d.cpu().data.numpy()
log_q_d_np = log_q_d.cpu().data.numpy()
log_q__d_np = log_q__d.cpu().data.numpy()

In [None]:
plt.hist(log_p_d_np, bins=100, alpha=0.5, label='data', range=[-50, 0])
plt.hist(log_p_np, bins=100, alpha=0.5, label='snf', range=[-50, 0])
plt.hist(log_p__np, bins=100, alpha=0.5, label='rnvp', range=[-50, 5])
#plt.hist(log_p_mcmc_np, bins=100, alpha=0.5, label='mcmc', range=[-70, 5])
plt.legend(loc='upper right')
plt.show()

In [None]:
plt.hist(log_q_d_np, bins=100, alpha=0.5, label='data', range=[125, 225])
plt.hist(log_q_np, bins=100, alpha=0.5, label='snf', range=[125, 225])
plt.legend(loc='upper right')
plt.show()

In [None]:
plt.hist(log_q__d_np, bins=100, alpha=0.5, label='data', range=[145, 205])
plt.hist(log_q__np, bins=100, alpha=0.5, label='rnvp', range=[145, 205])
plt.legend(loc='upper right')
plt.show()

In [None]:
plt.hist(log_p_d_np - log_q__d_np, bins=100, alpha=0.5, label='data', range=[-215, -175])
plt.hist(log_p__np - log_q__np, bins=100, alpha=0.5, label='rnvp', range=[-215, -175])
plt.legend(loc='upper right')
plt.show()

In [None]:
log_w_d = log_p_d_np - log_q__d_np
log_w = log_p__np - log_q__np
w_d = np.exp(log_w_d - np.max(log_w_d))
w = np.exp(log_w - np.max(log_w))
plt.hist(w_d, bins=100, alpha=0.5, label='data', range=[0, 1])
plt.show()
#plt.plot(log_p__np, log_q__np, '.')

In [None]:
for i in range(60):
    print(i)
    plt.hist(z_d_np[:, i], bins=200, alpha=1, label='data', histtype='step', linewidth=2, range=[-5, 5])
    plt.hist(z_np[:, i], bins=200, alpha=1, label='snf', histtype='step', linewidth=2, range=[-5, 5])
    plt.hist(z__np[:, i], bins=200, alpha=1, label='rnvp', histtype='step', linewidth=2, range=[-5, 5])
    #plt.hist(z_mcmc_np[:, i], bins=100, alpha=0.5, label='mcmc', range=[-3.1, 3.1])
    plt.legend(loc='upper right')
    #plt.savefig('/draco/u/vstimper/Material_Informatics/boltzmann_generators/plots/rnvp03_snf12/marginal_%02i.png' % i,
    #            dpi=300)
    plt.show()

In [None]:
import mdtraj
Z_indices = np.array([[4, 6, 8, 14],
                      [11, 10, 8, 6],
                      [16, 14, 8, 6],
                      [1, 4, 6, 8],
                      [5, 4, 6, 8],
                      [7, 6, 8, 4],
                      [12, 10, 8, 4],
                      [13, 10, 8, 11],
                      [15, 14, 8, 16],
                      [18, 16, 14, 8],
                      [0, 1, 4, 6],
                      [17, 16, 14, 15], 
                      [19, 18, 16, 14],
                      [2, 1, 4, 0],
                      [3, 1, 4, 0],
                      [20, 18, 16, 19],
                      [21, 18, 16, 19]])

training_data_traj = mdtraj.load('/draco/u/vstimper/Material_Informatics/boltzmann_generators/data/trajectory/aldp_without_const_100000.h5')
torsions_train = mdtraj.compute_dihedrals(training_data_traj, Z_indices)

#ala2_pdb = mdtraj.load('../../snf_noe/data/alanine-dipeptide.pdb').topology
#training_data_traj = mdtraj.load('../../snf_noe/data/ala2_1000K_train.xtc', top=ala2_pdb)
#torsions_gen = mdtraj.compute_dihedrals(training_data_traj, Z_indices)

ala2_pdb = mdtraj.load('../../snf_noe/data/alanine-dipeptide.pdb').topology
gen_data_traj = mdtraj.Trajectory(x_np.reshape(-1, 22, 3), ala2_pdb)
torsions_gen = mdtraj.compute_dihedrals(gen_data_traj, Z_indices)

ala2_pdb = mdtraj.load('../../snf_noe/data/alanine-dipeptide.pdb').topology
gen_data_traj_ = mdtraj.Trajectory(x__np.reshape(-1, 22, 3), ala2_pdb)
torsions_gen_ = mdtraj.compute_dihedrals(gen_data_traj_, Z_indices)

In [None]:
def periodic_convolution(x, kernel):
    x_padded = np.concatenate([x, x, x])
    y_padded = np.convolve(x_padded, kernel, mode='same')
    return y_padded[x.size:-x.size]

In [None]:
torsion_hists_train = []
torsion_hists_gen = []
torsion_hists_gen_ = []
xticks = None

for i in range(torsions_train.shape[1]):
    htrain, e = np.histogram(torsions_train[:, i], 50, range=(-np.pi, np.pi), density=True);
    xticks = 0.5 * (e[1:] + e[:-1])
    hgen, _ = np.histogram(torsions_gen[:, i], 50, range=(-np.pi, np.pi), density=True);
    hgen_, _ = np.histogram(torsions_gen_[:, i], 50, range=(-np.pi, np.pi), density=True);
    
    htrain = periodic_convolution(htrain, [0.25, 0.5, 1.0, 0.5, 0.25])
    hgen = periodic_convolution(hgen, [0.25, 0.5, 1.0, 0.5, 0.25])
    hgen_ = periodic_convolution(hgen_, [0.25, 0.5, 1.0, 0.5, 0.25])
    
    torsion_hists_train.append(htrain)
    torsion_hists_gen.append(hgen)
    torsion_hists_gen_.append(hgen_)

In [None]:
torsions_simple = [4, 5, 6, 7, 8]
torsions_complex = [0, 1, 2, 10, 12]

In [None]:
#fig, axes = plt.subplots(nrows=2, ncols=5, sharex=True, figsize=(15, 5))
fig, axes = plt.subplots(nrows=1, ncols=5, sharey=True, sharex=True, figsize=(15, 3))
axes = axes.reshape((1, 5))
fig.subplots_adjust(hspace=0.05, wspace=0.15)
#for row, torsion_index in zip([0, 1], [torsions_simple, torsions_complex]):
for row, torsion_index in zip([0], [torsions_complex]):
    for i, ax in enumerate(axes[row]):
        ax.plot(xticks, torsion_hists_train[torsion_index[i]], color='grey', linewidth=5)
        ax.plot(xticks, torsion_hists_gen[torsion_index[i]], color='red', linewidth=3)
        ax.plot(xticks, torsion_hists_gen_[torsion_index[i]], color='blue', linewidth=3)
        ax.set_yticks([])
        ax.set_xlim(-np.pi, np.pi)
        if row == 0:
            ax.set_ylim(0, 5)
        if row == 1:
            ax.set_ylim(0, 1.5)
        if row == 1:
            ax.set_xticks((-np.pi,0, np.pi))
            ax.set_xticklabels(('$-\pi$', 0,'$\pi$'))
    axes[-1,-1].set_yticks([])
axes[0, 0].set_ylim(0, 1.5)
axes[0, 0].text(0, 1.35, '$\phi$')
axes[0, 1].text(0, 1.35, '$\gamma_1$')
axes[0, 2].text(0, 1.35, '$\psi$')
axes[0, 3].text(0, 1.35, '$\gamma_2$')
axes[0, 4].text(0, 1.35, '$\gamma_3$')
axes[0, 0].set_ylabel('density')
#axes[1, 0].set_ylabel('density')
axes[0, -1].text(-np.pi+1, 1.1, 'Target', color='grey')
axes[0, -1].text(-np.pi+1, 0.95, 'RNVP', color='blue')
axes[0, -1].text(-np.pi+1, 0.8, 'RNVP + MCMC', color='red')
plt.savefig('torsion_angles.png', dpi=300)

In [None]:
eps = 1e-5
for i in torsions_complex:
    KL_NF = np.sum(torsion_hists_train[i] * np.log((torsion_hists_train[i]+eps) / (torsion_hists_gen[i]+eps)))
    print("{:1.2f}".format(KL_NF))