|
| 1 | +import json |
| 2 | +from functools import partial |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import jax.numpy as jnp |
| 6 | +import jax |
| 7 | +from eval.path_metrics import plot_path_energy |
| 8 | +import matplotlib.pyplot as plt |
| 9 | +import os |
| 10 | +import openmm.app as app |
| 11 | +import openmm.unit as unit |
| 12 | +from dmff import Hamiltonian, NeighborList |
| 13 | +from tqdm import tqdm |
| 14 | + |
| 15 | +from tps.paths import decorrelated |
| 16 | + |
| 17 | +dt_as_unit = unit.Quantity(value=1, unit=unit.femtosecond) |
| 18 | +dt_in_ps = dt_as_unit.value_in_unit(unit.picosecond) |
| 19 | +dt = dt_as_unit.value_in_unit(unit.second) |
| 20 | + |
| 21 | +gamma_as_unit = 1.0 / unit.picosecond |
| 22 | +# actually gamma is 1/s, but we are working without units and just need the correct scaling |
| 23 | +# TODO: try to get rid of this duplicate definition |
| 24 | +gamma = 1.0 * unit.picosecond |
| 25 | +gamma_in_ps = gamma.value_in_unit(unit.picosecond) |
| 26 | +gamma = gamma.value_in_unit(unit.second) |
| 27 | + |
| 28 | +temp = 300 |
| 29 | +kbT = 1.380649 * 6.02214076 * 1e-3 * temp |
| 30 | + |
| 31 | +init_pdb = app.PDBFile('./files/AD_A.pdb') |
| 32 | +# Construct the mass matrix |
| 33 | +mass = [a.element.mass.value_in_unit(unit.dalton) for a in init_pdb.topology.atoms()] |
| 34 | +new_mass = [] |
| 35 | +for mass_ in mass: |
| 36 | + for _ in range(3): |
| 37 | + new_mass.append(mass_) |
| 38 | +mass = jnp.array(new_mass) |
| 39 | +# Obtain xi, gamma is by default 1 |
| 40 | +xi = jnp.sqrt(2 * kbT / mass / gamma) |
| 41 | + |
| 42 | +# Initialize the potential energy with amber forcefields |
| 43 | +ff = Hamiltonian('amber14/protein.ff14SB.xml', 'amber14/tip3p.xml') |
| 44 | +potentials = ff.createPotential(init_pdb.topology, |
| 45 | + nonbondedMethod=app.NoCutoff, |
| 46 | + nonbondedCutoff=1.0 * unit.nanometers, |
| 47 | + constraints=None, |
| 48 | + ewaldErrorTolerance=0.0005) |
| 49 | +# Create a box used when calling |
| 50 | +box = np.array([[50.0, 0.0, 0.0], [0.0, 50.0, 0.0], [0.0, 0.0, 50.0]]) |
| 51 | +nbList = NeighborList(box, 4.0, potentials.meta["cov_map"]) |
| 52 | +nbList.allocate(init_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) |
| 53 | +pairs = nbList.pairs |
| 54 | + |
| 55 | + |
| 56 | +@jax.jit |
| 57 | +@jax.vmap |
| 58 | +def U_native(_x): |
| 59 | + """ |
| 60 | + Calling U by U(x, box, pairs, ff.paramset.parameters), x is [22, 3] and output the energy, if it is batched, use vmap |
| 61 | + """ |
| 62 | + _U = potentials.getPotentialFunc() |
| 63 | + |
| 64 | + return _U(_x.reshape(22, 3), box, pairs, ff.paramset.parameters).sum() |
| 65 | + |
| 66 | + |
| 67 | +def U_padded(x): |
| 68 | + x = x.reshape(-1, 66) |
| 69 | + orig_length = x.shape[0] |
| 70 | + padded_length = orig_length // 100 * 100 + 100 |
| 71 | + x_empty = jnp.zeros((padded_length, 66)) |
| 72 | + x = x_empty.at[:x.shape[0], :].set(x.reshape(-1, 66)) |
| 73 | + return U_native(x)[:orig_length] |
| 74 | + |
| 75 | + |
| 76 | +@jax.jit |
| 77 | +@jax.vmap |
| 78 | +def dUdx_fn(_x): |
| 79 | + def U(_x): |
| 80 | + """ |
| 81 | + Calling U by U(x, box, pairs, ff.paramset.parameters), x is [22, 3] and output the energy, if it is batched, use vmap |
| 82 | + """ |
| 83 | + _U = potentials.getPotentialFunc() |
| 84 | + |
| 85 | + return _U(_x.reshape(22, 3), box, pairs, ff.paramset.parameters) |
| 86 | + |
| 87 | + return jax.grad(lambda _x: U(_x).sum())(_x) / mass / gamma_in_ps |
| 88 | + |
| 89 | + |
| 90 | +@jax.jit |
| 91 | +def step_langevin_log_prob(_x, _v, _new_x, _new_v): |
| 92 | + alpha = jnp.exp(-gamma_in_ps * dt_in_ps) |
| 93 | + f_scale = (1 - alpha) / gamma_in_ps |
| 94 | + new_v_det = alpha * _v + f_scale * -dUdx_fn(_x.reshape(1, -1)) |
| 95 | + new_v_rand = new_v_det - _new_v |
| 96 | + |
| 97 | + return jax.scipy.stats.norm.logpdf(new_v_rand, 0, jnp.sqrt(kbT * (1 - alpha ** 2) / mass)).sum() |
| 98 | + |
| 99 | + |
| 100 | +def langevin_log_path_likelihood(path, velocities): |
| 101 | + assert len(path) == len( |
| 102 | + velocities), f'path and velocities must have the same length, but got {len(path)} and {len(velocities)}' |
| 103 | + log_prob = (-U_native(path[0].reshape(1, -1)) / kbT).sum() |
| 104 | + log_prob += jax.scipy.stats.norm.logpdf(velocities[0], 0, jnp.sqrt(kbT / mass)).sum() |
| 105 | + |
| 106 | + for i in range(1, len(path)): |
| 107 | + log_prob += step_langevin_log_prob(path[i - 1], velocities[i - 1], path[i], velocities[i]) |
| 108 | + |
| 109 | + # log_prob += step_langevin_log_prob(path[:-1], velocities[:-1], path[1:], velocities[1:]).sum() |
| 110 | + |
| 111 | + return log_prob |
| 112 | + |
| 113 | + |
| 114 | +def load(path): |
| 115 | + loaded = np.load(path, allow_pickle=True) |
| 116 | + return [p.astype(np.float32).reshape(-1, 66) for p in loaded] |
| 117 | + |
| 118 | + |
| 119 | +if __name__ == '__main__': |
| 120 | + savedir = './out/evaluation/alanine/' |
| 121 | + os.makedirs(savedir, exist_ok=True) |
| 122 | + |
| 123 | + all_paths = [ |
| 124 | + ('one-way-shooting-var-length-cv', './out/baselines/alanine-one-way-shooting', 50), |
| 125 | + ('one-way-shooting-var-length-rmsd', './out/baselines/alanine-one-way-shooting-rmsd', 50), |
| 126 | + ('one-way-shooting-fixed-length-cv', './out/baselines/alanine-one-way-shooting-1000steps', 50), |
| 127 | + ('one-way-shooting-fixed-length-rmsd', './out/baselines/alanine-one-way-shooting-1000steps-rmsd', 50), |
| 128 | + ('two-way-shooting-var-length-cv', './out/baselines/alanine-two-way-shooting', 0), |
| 129 | + ('two-way-shooting-var-length-rmsd', './out/baselines/alanine-two-way-shooting-rmsd', 0), |
| 130 | + ('two-way-shooting-fixed-length-cv', './out/baselines/alanine-two-way-shooting-1000steps', 0), |
| 131 | + ] |
| 132 | + |
| 133 | + # print relevant statistics: |
| 134 | + for name, file_path, _warmup in all_paths: |
| 135 | + with open(f'{file_path}/stats.json', 'r') as fp: |
| 136 | + statistics = json.load(fp) |
| 137 | + print(name, statistics) |
| 138 | + |
| 139 | + all_paths = [(name, load(f'{path}/paths.npy')[warmup:], load(f'{path}/velocities.npy')[warmup:]) for |
| 140 | + name, path, warmup in tqdm(all_paths, desc='loading paths')] |
| 141 | + [print(name, len(path), len(velocities)) for name, path, velocities in all_paths] |
| 142 | + |
| 143 | + for name, paths, _velocities in all_paths: |
| 144 | + print(name, 'decorrelated trajectories:', jnp.round(100 * len(decorrelated(paths)) / len(paths), 2), '%') |
| 145 | + |
| 146 | + for name, paths, _velocities in all_paths: |
| 147 | + max_energy = jnp.array([jnp.max(U_padded(path)) for path in tqdm(paths)]) |
| 148 | + print(name, 'max energy mean:', jnp.round(jnp.mean(max_energy), 2), 'std:', jnp.round(jnp.std(max_energy), 2)) |
| 149 | + print(name, 'min max energy:', jnp.round(jnp.min(max_energy), 2)) |
| 150 | + |
| 151 | + for name, paths, velocities in all_paths: |
| 152 | + log_likelihood = jnp.array( |
| 153 | + [langevin_log_path_likelihood(path, current_velocities) for path, current_velocities in |
| 154 | + tqdm(zip(paths, velocities), total=len(paths))]) |
| 155 | + |
| 156 | + print(name, 'max log likelihood:', jnp.round(jnp.max(log_likelihood), 2)) |
| 157 | + print(name, 'mean log likelihood:', jnp.round(jnp.mean(log_likelihood), 2), 'std:', |
| 158 | + jnp.round(jnp.std(log_likelihood), 2)) |
0 commit comments