In [None]:
import torch
from openmmtools.constants import kB
from simtk import openmm as mm
from simtk import unit
from simtk.openmm import app
from openmmtools.testsystems import AlanineDipeptideVacuum
import numpy as np
import mdtraj

import boltzgen.zmatrix as zmatrix
import boltzgen.internal as ics

In [None]:
# Generate a trajectory for use as training data
temperature = 298
kT = kB * temperature
testsystem = AlanineDipeptideVacuum()
vacuum_sim = app.Simulation(testsystem.topology,
                            testsystem.system,
                            mm.LangevinIntegrator(temperature * unit.kelvin , 1.0 / unit.picosecond, 1.0 * unit.femtosecond),
                            platform=mm.Platform.getPlatformByName('CPU')
                            )
vacuum_sim.context.setPositions(testsystem.positions)
vacuum_sim.minimizeEnergy()
vacuum_sim.reporters.append(mdtraj.reporters.HDF5Reporter('aldp100.h5', 1000))
vacuum_sim.step(100000)
del(vacuum_sim)

In [None]:
aldp_traj = mdtraj.load('aldp100.h5')

In [None]:
z = [
    (1, [4, 5, 6]),
    (0, [1, 4, 5]),
    (2, [1, 0, 4]),
    (3, [1, 0, 2]),
    (7, [6, 4, 5]),
    (9, [8, 6, 7]),
    (10, [8, 6, 9]),
    (11, [10, 8, 9]),
    (12, [10, 8, 11]),
    (13, [10, 11, 12]),
    (17, [16, 14, 15]),
    (19, [18, 16, 17]),
    (20, [18, 19, 16]),
    (21, [18, 19, 20])
]

backbone_indices = [4, 5, 6, 8, 14, 15, 16, 18]
# center everything
aldp_traj.center_coordinates()

# superpose on the backbone
ind = aldp_traj.top.select("backbone")

aldp_traj.superpose(aldp_traj, 0, atom_indices=ind, ref_atom_indices=ind)

# Gather the training data into a pytorch Tensor with the right shape
training_data = aldp_traj.xyz
n_atoms = training_data.shape[1]
n_dim = n_atoms * 3
training_data_npy = training_data.reshape(-1, n_dim)
training_data = torch.from_numpy(training_data_npy.astype("float32"))

ic_transform = ics.InternalCoordinateTransform(n_dim, z, backbone_indices, training_data)

transformed_c, jac = ic_transform.forward(training_data)

back_trans, back_jac = ic_transform.inverse(transformed_c)

In [None]:
print(jac)
print(back_jac)