In [1]:
import torch
import numpy as np
import mdtraj

import sys
# sys.path.insert(1, '/var/home/vs488/Documents/boltzmann/code/boltzmann-generators/')
sys.path.append("../")
import boltzgen.zmatrix as zmatrix
import boltzgen.internal as ics
import boltzgen.mixed as mixed

aldp_traj = mdtraj.load('aldp100.h5')

In [2]:
z = [
    (1, [4, 5, 6]),
    (0, [1, 4, 5]),
    (2, [1, 0, 4]),
    (3, [1, 0, 2]),
    (7, [6, 4, 5]),
    (9, [8, 6, 7]),
    (10, [8, 6, 9]),
    (11, [10, 8, 9]),
    (12, [10, 8, 11]),
    (13, [10, 11, 12]),
    (17, [16, 14, 15]),
    (19, [18, 16, 17]),
    (20, [18, 19, 16]),
    (21, [18, 19, 20])
]

backbone_indices = [4, 5, 6, 8, 14, 15, 16, 18]
# center everything
aldp_traj.center_coordinates()

# superpose on the backbone
ind = aldp_traj.top.select("backbone")

aldp_traj.superpose(aldp_traj, 0, atom_indices=ind, ref_atom_indices=ind)

# Gather the training data into a pytorch Tensor with the right shape
training_data = aldp_traj.xyz
n_atoms = training_data.shape[1]
n_dim = n_atoms * 3
training_data_npy = training_data.reshape(-1, n_dim)
training_data = torch.from_numpy(training_data_npy.astype("float32"))


mixed_transform = mixed.MixedTransform(66, backbone_indices, z, training_data)

# just for testing do the internal transform
ic_transform = ics.InternalCoordinateTransform(n_dim, z, backbone_indices, training_data)
transformed_c, jac = ic_transform.forward(training_data)

mixed_coords, jac = mixed_transform.forward(training_data)
orig_coords, invjac = mixed_transform.inverse(mixed_coords)

for i in range(20):
    print("org", training_data[0, 3*i:3*i+3])
    print("ict", transformed_c[0, 3*i:3*i+3])
    print("mxd", mixed_coords[0, 3*i:3*i+3])
    print("rcd", orig_coords[0, 3*i:3*i+3])
    print("--------")

Inside MixedTransform constructor
org tensor([-0.3444, -0.2671,  0.0676])
ict tensor([-0.0012,  0.8984, -0.4642])
mxd tensor([-1.3905, -0.4556,  2.3946])
rcd tensor([-0.3444, -0.2671,  0.0676])
--------
org tensor([-0.2896, -0.2343, -0.0207])
ict tensor([-0.7018,  0.0529, -0.0597])
mxd tensor([-0.7838,  0.7156, -1.2061])
rcd tensor([-0.2896, -0.2343, -0.0207])
--------
org tensor([-0.3459, -0.1498, -0.0604])
ict tensor([-8.1801e-04, -9.5989e-01,  6.3359e-02])
mxd tensor([-0.9657, -2.1087,  0.4958])
rcd tensor([-0.3459, -0.1498, -0.0604])
--------
org tensor([-0.2988, -0.3073, -0.1011])
ict tensor([-0.0011,  0.3681, -0.1829])
mxd tensor([-0.0911,  0.6042,  0.7680])
rcd tensor([-0.2988, -0.3073, -0.1011])
--------
org tensor([-0.1475, -0.2014,  0.0106])
ict tensor([-0.1475, -0.2014,  0.0106])
mxd tensor([ 0.1409, -1.5165, -0.9968])
rcd tensor([-0.1475, -0.2014,  0.0106])
--------
org tensor([-0.0755, -0.2853,  0.0616])
ict tensor([-0.0755, -0.2853,  0.0616])
mxd tensor([-0.1766, -1.3807,