In [None]:
import torch
import numpy as np
import mdtraj

import sys
sys.path.insert(1, '/var/home/vs488/Documents/boltzmann/code/boltzmann-generators/')
import boltzgen.zmatrix as zmatrix
import boltzgen.internal as ics
import boltzgen.mixed as mixed

import normflow as nf
from boltzgen.flows import CoordinateTransform
from boltzgen.distributions import Boltzmann, BoltzmannParallel, TransformedBoltzmann, TransformedBoltzmannParallel

from tqdm import tqdm
from matplotlib import pyplot as plt

from autograd import grad
from autograd import numpy as np
from simtk import openmm as mm
from simtk import unit
from simtk.openmm import app
from openmmtools.testsystems import AlanineDipeptideVacuum

# Load the alanine dipeptide trajectory
aldp_traj = mdtraj.load('/scratch2/vs488/flow/alanine_dipeptide/trajectory/aldp100000.h5')

In [None]:
# Set up coordinate transformation

z_matrix = [
    (1, [4, 5, 6]),
    (0, [1, 4, 5]),
    (2, [1, 0, 4]),
    (3, [1, 0, 2]),
    (7, [6, 4, 5]),
    (9, [8, 6, 7]),
    (10, [8, 6, 9]),
    (11, [10, 8, 9]),
    (12, [10, 8, 11]),
    (13, [10, 11, 12]),
    (17, [16, 14, 15]),
    (19, [18, 16, 17]),
    (20, [18, 19, 16]),
    (21, [18, 19, 20])
]

backbone_indices = [4, 5, 6, 8, 14, 15, 16, 18]

aldp_traj.center_coordinates()

# superpose on the backbone
ind = aldp_traj.top.select("backbone")

aldp_traj.superpose(aldp_traj, 0, atom_indices=ind, ref_atom_indices=ind)

# Gather the training data into a pytorch Tensor with the right shape
training_data = aldp_traj.xyz
n_atoms = training_data.shape[1]
n_dim = n_atoms * 3
training_data_npy = training_data.reshape(-1, n_dim)
training_data = torch.from_numpy(training_data_npy.astype("float64"))

In [None]:
# Set up simulation object for energy computation

temperature = 1000

testsystem = AlanineDipeptideVacuum()
implicit_sim = app.Simulation(testsystem.topology,
                              testsystem.system,
                              mm.LangevinIntegrator(temperature * unit.kelvin , 1.0 / unit.picosecond, 1.0 * unit.femtosecond),
                              mm.Platform.getPlatformByName('Reference')#,
                              #{'Precision': 'double'}
                              )
implicit_sim.context.setPositions(testsystem.positions)

In [None]:
# Set up model

# Define flows
K = 5
torch.manual_seed(0)

# Set prior and q0
p = Boltzmann(implicit_sim.context, temperature, energy_cut=1e2, energy_max=1e20)
transform = CoordinateTransform(training_data, 66, z_matrix, backbone_indices)
p_ = TransformedBoltzmannParallel(testsystem, temperature, energy_cut=1e2,
                                  energy_max=1e20, transform=transform)
p__ = BoltzmannParallel(testsystem, temperature, energy_cut=1e2, energy_max=1e20)

latent_size = 60
hidden_units = 128
q0 = nf.distributions.DiagGaussian(latent_size, trainable=False)

b = torch.Tensor([1 if i % 2 == 0 else 0 for i in range(latent_size)])
flows = []
for i in range(K):
    # Add two alternating Real NVP layers, and ActNorm layer, and a MCMC layer
    # Real NVP layers
    s = nf.nets.MLP([latent_size, hidden_units, hidden_units, hidden_units, latent_size])
    t = nf.nets.MLP([latent_size, hidden_units, hidden_units, hidden_units, latent_size])
    flows += [nf.flows.MaskedAffineFlow(b, s, t)]
    s = nf.nets.MLP([latent_size, hidden_units, hidden_units, hidden_units, latent_size])
    t = nf.nets.MLP([latent_size, hidden_units, hidden_units, hidden_units, latent_size])
    flows += [nf.flows.MaskedAffineFlow(1 - b, s, t)]
    # ActNorm
    flows += [nf.flows.ActNorm(latent_size)]
    # MCMC layer
    dist = nf.distributions.LinearInterpolation(p_, q0, (i + 1) / K)
    proposal = nf.distributions.DiagGaussianProposal((latent_size,),
                                                     0.1 * np.ones(latent_size))
    flows += [nf.flows.MetropolisHastings(dist, proposal, 20)]
flows += [transform]

# Construct flow model
nfm = nf.NormalizingFlow(q0=q0, flows=flows, p=p)

# Move model on GPU if available
enable_cuda = False
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
nfm = nfm.to(device)
nfm = nfm.double()

# Initialize ActNorm
ind = torch.randint(len(training_data), (256, ))
x = training_data[ind, :].double().to(device)
from time import time
t = time()
kld = nfm.forward_kld(x)
print(time() - t)

In [None]:
print(kld)

In [None]:
# Train model
batch_size = 256
num_samples = 32
max_iter = 10
trans_iter = 8000
n_data = len(training_data)
eval_rkld = 10


loss_hist = np.array([])
fkld_hist = np.array([])
rkld_hist = np.array([])

optimizer = torch.optim.AdamW(nfm.parameters(), lr=1e-4, weight_decay=1e-5)
for it in tqdm(range(max_iter)):
    #nfm.p.alpha = np.max([0., 1 - it / trans_iter])
    optimizer.zero_grad()
    ind = torch.randint(n_data, (batch_size, ))
    x = training_data[ind, :].double().to(device)
    fkld = nfm.forward_kld(x)
    #rkld = nfm.reverse_kld(num_samples=num_samples)
    loss = fkld #+ rkld
    if not torch.isnan(loss) and loss < 0:
        loss.backward()
        #torch.nn.utils.clip_grad_value_(nfm.parameters(), .01)
        #gradient_norm = torch.nn.utils.clip_grad.clip_grad_norm_(nfm.parameters(), 100.)
        optimizer.step()
    
    loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())
    fkld_hist = np.append(fkld_hist, fkld.to('cpu').data.numpy())
    #rkld_hist = np.append(rkld_hist, rkld.to('cpu').data.numpy())

In [None]:
loss_hist[loss_hist > 0] = np.nan
plt.plot(loss_hist)
plt.show()
#plt.plot(fkld_hist)
#plt.show()
#plt.plot(rkld_hist)
#plt.show()

In [None]:
nfm.eval()
z, _ = nfm.sample(10000)
z, _ = nfm.flows[-1].inverse(z)
z_d, _ = nfm.flows[-1].inverse(training_data[::50].double().to(device))
z_np = z.cpu().data.numpy()
z_d_np = z_d.cpu().data.numpy()

In [None]:
for i in range(60):
    print(i)
    plt.hist(z_d_np[:, i], bins=100, alpha=0.5, label='data', range=[-3.1, 3.1])
    plt.hist(z_np[:, i], bins=100, alpha=0.5, label='samples', range=[-3.1, 3.1])
    plt.legend(loc='upper right')
    plt.show()

In [None]:
for i in range(60):
    print(i)
    plt.hist(z_d_np[:, i], bins=100, alpha=0.5, label='data', range=[-3.1, 3.1])
    plt.hist(z_np[:, i], bins=100, alpha=0.5, label='samples', range=[-3.1, 3.1])
    plt.legend(loc='upper right')
    plt.show()