In [None]:
import torch
import numpy as np
import mdtraj

import sys
sys.path.insert(1, '/var/home/vs488/Documents/boltzmann/code/boltzmann-generators/')
import boltzgen.zmatrix as zmatrix
import boltzgen.internal as ics
import boltzgen.mixed as mixed

import normflow as nf
from boltzgen.flows import CoordinateTransform
from boltzgen.distributions import Boltzmann

from tqdm import tqdm
from matplotlib import pyplot as plt

from autograd import grad
from autograd import numpy as np
from openmmtools.constants import kB
from simtk import openmm as mm
from simtk import unit
from simtk.openmm import app
from openmmtools.testsystems import AlanineDipeptideImplicit

# Load the alanine dipeptide trajectory
aldp_traj = mdtraj.load('/scratch2/vs488/flow/alanine_dipeptide/trajectory/aldp100000.h5')

In [None]:
# Set up coordinate transformation

z_matrix = [
    (1, [4, 5, 6]),
    (0, [1, 4, 5]),
    (2, [1, 0, 4]),
    (3, [1, 0, 2]),
    (7, [6, 4, 5]),
    (9, [8, 6, 7]),
    (10, [8, 6, 9]),
    (11, [10, 8, 9]),
    (12, [10, 8, 11]),
    (13, [10, 11, 12]),
    (17, [16, 14, 15]),
    (19, [18, 16, 17]),
    (20, [18, 19, 16]),
    (21, [18, 19, 20])
]

backbone_indices = [4, 5, 6, 8, 14, 15, 16, 18]

aldp_traj.center_coordinates()

# superpose on the backbone
ind = aldp_traj.top.select("backbone")

aldp_traj.superpose(aldp_traj, 0, atom_indices=ind, ref_atom_indices=ind)

# Gather the training data into a pytorch Tensor with the right shape
training_data = aldp_traj.xyz
n_atoms = training_data.shape[1]
n_dim = n_atoms * 3
training_data_npy = training_data.reshape(-1, n_dim)
training_data = torch.from_numpy(training_data_npy.astype("float64"))

In [None]:
# Set up simulation object for energy computation

temperature = 1000
kT = kB * temperature

testsystem = AlanineDipeptideImplicit()
implicit_sim = app.Simulation(testsystem.topology,
                              testsystem.system,
                              mm.LangevinIntegrator(temperature * unit.kelvin , 1.0 / unit.picosecond, 1.0 * unit.femtosecond),
                              platform=mm.Platform.getPlatformByName('CPU')
                              )
implicit_sim.context.setPositions(testsystem.positions)

In [None]:
# Set up flow model

# Define flows
K = 8
torch.manual_seed(0)

latent_size = 60
b = torch.Tensor([1 if i % 2 == 0 else 0 for i in range(latent_size)])
flows = []
for i in range(K):
    s = nf.nets.MLP([latent_size, 4 * latent_size, 4 * latent_size, latent_size], output_fn='tanh', output_scale=3.)
    t = nf.nets.MLP([latent_size, 4 * latent_size, 4 * latent_size, latent_size], output_fn='tanh', output_scale=3.)
    if i % 2 == 0:
        flows += [nf.flows.MaskedAffineFlow(b, s, t)]
    else:
        flows += [nf.flows.MaskedAffineFlow(1 - b, s, t)]
    #flows += [nf.flows.Planar(latent_size)]
    flows += [nf.flows.ActNorm(latent_size)]
flows += [CoordinateTransform(training_data, 66, z_matrix, backbone_indices)]

# Set prior and q0
p = Boltzmann(implicit_sim.context, temperature, energy_cut=1e10, energy_max=1e20)
q0 = nf.distributions.DiagGaussian(latent_size)

# Construct flow model
nfm = nf.NormalizingFlow(q0=q0, flows=flows, p=p)

# Move model on GPU if available
enable_cuda = False
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
nfm = nfm.to(device)
nfm = nfm.double()

ind = torch.randint(len(training_data), (128, ))
x = training_data[ind, :].double()
kld = nfm.forward_kld(x)
#kld = nfm.reverse_kld(128)

In [None]:
print(kld)

In [None]:
# Train model
batch_size = 128
num_samples = 64
max_iter = 500
trans_iter = 8000
n_data = len(training_data)
eval_rkld = 10


loss_hist = np.array([])
fkld_hist = np.array([])
rkld_hist = np.array([])

optimizer = torch.optim.AdamW(nfm.parameters(), lr=1e-4, weight_decay=1e-3)
for it in tqdm(range(max_iter)):
    #nfm.p.alpha = np.max([0., 1 - it / trans_iter])
    optimizer.zero_grad()
    ind = torch.randint(n_data, (batch_size, ))
    x = training_data[ind, :].double()
    fkld = nfm.forward_kld(x)
    rkld = nfm.reverse_kld(num_samples=num_samples)
    loss = fkld + rkld
    if not torch.isnan(loss):
        loss.backward()
        torch.nn.utils.clip_grad_value_(nfm.parameters(), .01)
        gradient_norm = torch.nn.utils.clip_grad.clip_grad_norm_(nfm.parameters(), 100.)
        optimizer.step()
    
    loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())
    fkld_hist = np.append(fkld_hist, fkld.to('cpu').data.numpy())
    rkld_hist = np.append(rkld_hist, rkld.to('cpu').data.numpy())

In [None]:
# Train model with complex schedule
batch_size = [32, 64] + 8 * [128]
num_iter = 3 * [1500] + 6 * [30] + [300]
num_samples = 3 * [64] + 7 * [512]
lr = 3 * [1e-3] + 7 * [1e-4]
E_cut = 4 * [1e-10] + [1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-5]
w_kl = 3 * [0] + [1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-2]

n_data = len(training_data)

loss_hist = np.array([])
fkld_hist = np.array([])
rkld_hist = np.array([])

optimizer = torch.optim.AdamW(nfm.parameters(), lr=lr[0], weight_decay=1e-4)

for i in range(len(num_iter)):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr[i]
    nfm.p = Boltzmann(implicit_sim.context, temperature, energy_cut=E_cut[i], energy_max=1e20)
    for it in tqdm(range(num_iter[i])):
        optimizer.zero_grad()
        ind = torch.randint(n_data, (batch_size[i], ))
        x = training_data[ind, :]
        #with torch.autograd.detect_anomaly():
        fkld = nfm.forward_kld(x)
        rkld = nfm.reverse_kld(num_samples=num_samples[i])
        if w_kl[i] > 0:
            loss = fkld + w_kl[i] * rkld
        else:
            loss = fkld
        if not torch.isnan(loss):
            loss.backward()
            torch.nn.utils.clip_grad_value_(nfm.parameters(), .01)
            gradient_norm = torch.nn.utils.clip_grad.clip_grad_norm_(nfm.parameters(), 100.)
            #print(gradient_norm)
            optimizer.step()
            #print(nfm.flows[-2].s[0, nfm.flows[-1].mixed_transform.ic_transform.rev_z_indices[:, 0]])
            #print(nfm.flows[-2].t[0, nfm.flows[-1].mixed_transform.ic_transform.rev_z_indices[:, 0]])
        #print(nfm.flows[0].s.net[0].weight._grad)
        #print(nfm.flows[0].s.net[0].weight)

        loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())
        fkld_hist = np.append(fkld_hist, fkld.to('cpu').data.numpy())
        rkld_hist = np.append(rkld_hist, rkld.to('cpu').data.numpy())

In [None]:
#loss_hist[loss_hist > 0] = np.nan
plt.plot(loss_hist)
plt.show()
plt.plot(fkld_hist)
plt.show()
plt.plot(rkld_hist)
plt.show()

In [None]:
x, _ = nfm.sample(32)
print(nfm.p.log_prob(x))
x = training_data[ind, :].double()
print(nfm.p.log_prob(x))

In [None]:
z, log_q = nfm.q0(8)
for flow in nfm.flows:
    z, log_det = flow(z)
    print(torch.norm(z, dim=1))
    print(log_det)
    log_q -= log_det

In [None]:
nfm.flows[-1].mixed_transform.ic_transform.bond_indices

In [None]:
z, _ = nfm.sample(100000)
z, _ = nfm.flows[-1].inverse(z)
z_d, _ = nfm.flows[-1].inverse(training_data.double())
z_np = z.data.numpy()
z_d_np = z_d.data.numpy()

In [None]:
for i in range(60):
    print(i)
    plt.hist(z_d_np[:, i], bins=100, alpha=0.5, label='data', range=[-3.1, 3.1])
    plt.hist(z_np[:, i], bins=100, alpha=0.5, label='samples', range=[-3.1, 3.1])
    plt.legend(loc='upper right')
    plt.show()