Skip to content

Commit 6c307a2

Browse files
committed
Add evaluation scripts
1 parent 4512926 commit 6c307a2

File tree

2 files changed

+242
-0
lines changed

2 files changed

+242
-0
lines changed

eval/evaluate_mueller.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import numpy as np
2+
import jax.numpy as jnp
3+
import jax
4+
from eval.path_metrics import plot_path_energy
5+
from tps.paths import decorrelated
6+
from tps_baseline_mueller import U, dUdx_fn
7+
from scipy.optimize import minimize
8+
import matplotlib.pyplot as plt
9+
import os
10+
11+
num_paths = 1000
12+
xi = 5
13+
kbT = xi ** 2 / 2
14+
dt = 1e-4
15+
T = 275e-4
16+
N = int(T / dt)
17+
18+
minima_points = jnp.array([[-0.55828035, 1.44169],
19+
[-0.05004308, 0.46666032],
20+
[0.62361133, 0.02804632]])
21+
22+
23+
def load(path):
24+
loaded = np.load(path, allow_pickle=True)
25+
return [p.astype(np.float32).reshape(-1, 2) for p in loaded]
26+
27+
28+
@jax.jit
29+
def log_path_likelihood(path):
30+
rand = path[1:] - path[:-1] + dt * dUdx_fn(path[:-1])
31+
return (-U(path[0]) / kbT).sum() + jax.scipy.stats.norm.logpdf(rand, scale=jnp.sqrt(dt) * xi).sum()
32+
33+
34+
if __name__ == '__main__':
35+
savedir = './out/evaluation/mueller/'
36+
os.makedirs(savedir, exist_ok=True)
37+
38+
all_paths = [
39+
('one-way-shooting', './out/baselines/mueller/paths-one-way-shooting.npy', 50),
40+
('variable-one-way-shooting', './out/baselines/mueller-variable/paths-one-way-shooting.npy', 50),
41+
('two-way-shooting', './out/baselines/mueller/paths-two-way-shooting.npy', 0),
42+
('variable-two-way-shooting', './out/baselines/mueller-variable/paths-two-way-shooting.npy', 0),
43+
('var-doobs', './out/var_doobs/mueller/paths.npy', 0),
44+
]
45+
46+
global_minimum_energy = U(minima_points[0])
47+
for point in minima_points:
48+
global_minimum_energy = min(global_minimum_energy, minimize(U, point).fun)
49+
print("Global minimum energy", global_minimum_energy)
50+
51+
all_paths = [(name, load(path)[warmup:],) for name, path, warmup in all_paths]
52+
[print(name, len(path)) for name, path in all_paths]
53+
54+
for name, paths in all_paths:
55+
print(name, 'decorrelated trajectories:', jnp.round(100 * len(decorrelated(paths)) / len(paths), 2), '%')
56+
57+
for name, paths in all_paths:
58+
max_energy = plot_path_energy(paths, U, add=-global_minimum_energy, label=name) + global_minimum_energy
59+
print(name, 'max energy mean:', jnp.round(jnp.mean(max_energy), 2), 'std:', jnp.round(jnp.std(max_energy), 2))
60+
print(name, 'min max energy: ', jnp.round(jnp.min(max_energy), 2))
61+
62+
plt.legend()
63+
plt.ylabel('Maximum energy')
64+
plt.savefig(f'{savedir}/mueller-max-energy.pdf', bbox_inches='tight')
65+
plt.show()
66+
67+
for name, paths in all_paths:
68+
plot_path_energy(paths, U, add=-global_minimum_energy, reduce=jnp.median, label=name)
69+
70+
plt.legend()
71+
plt.ylabel('Median energy')
72+
plt.savefig(f'{savedir}/mueller-median-energy.pdf', bbox_inches='tight')
73+
plt.show()
74+
75+
for name, paths in all_paths:
76+
likelihood = plot_path_energy(paths, log_path_likelihood, reduce=lambda x: x, label=name)
77+
print(name, 'mean log-likelihood:', jnp.round(jnp.mean(likelihood), 2), 'std:',
78+
jnp.round(jnp.std(likelihood), 2))
79+
print(name, 'maximum log-likelihood:', jnp.round(jnp.max(likelihood), 2))
80+
81+
plt.legend()
82+
plt.ylabel('log path likelihood')
83+
plt.savefig(f'{savedir}/mueller-log-path-likelihood.pdf', bbox_inches='tight')
84+
plt.show()

eval/evaluate_tps.py

+158
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import json
2+
from functools import partial
3+
4+
import numpy as np
5+
import jax.numpy as jnp
6+
import jax
7+
from eval.path_metrics import plot_path_energy
8+
import matplotlib.pyplot as plt
9+
import os
10+
import openmm.app as app
11+
import openmm.unit as unit
12+
from dmff import Hamiltonian, NeighborList
13+
from tqdm import tqdm
14+
15+
from tps.paths import decorrelated
16+
17+
dt_as_unit = unit.Quantity(value=1, unit=unit.femtosecond)
18+
dt_in_ps = dt_as_unit.value_in_unit(unit.picosecond)
19+
dt = dt_as_unit.value_in_unit(unit.second)
20+
21+
gamma_as_unit = 1.0 / unit.picosecond
22+
# actually gamma is 1/s, but we are working without units and just need the correct scaling
23+
# TODO: try to get rid of this duplicate definition
24+
gamma = 1.0 * unit.picosecond
25+
gamma_in_ps = gamma.value_in_unit(unit.picosecond)
26+
gamma = gamma.value_in_unit(unit.second)
27+
28+
temp = 300
29+
kbT = 1.380649 * 6.02214076 * 1e-3 * temp
30+
31+
init_pdb = app.PDBFile('./files/AD_A.pdb')
32+
# Construct the mass matrix
33+
mass = [a.element.mass.value_in_unit(unit.dalton) for a in init_pdb.topology.atoms()]
34+
new_mass = []
35+
for mass_ in mass:
36+
for _ in range(3):
37+
new_mass.append(mass_)
38+
mass = jnp.array(new_mass)
39+
# Obtain xi, gamma is by default 1
40+
xi = jnp.sqrt(2 * kbT / mass / gamma)
41+
42+
# Initialize the potential energy with amber forcefields
43+
ff = Hamiltonian('amber14/protein.ff14SB.xml', 'amber14/tip3p.xml')
44+
potentials = ff.createPotential(init_pdb.topology,
45+
nonbondedMethod=app.NoCutoff,
46+
nonbondedCutoff=1.0 * unit.nanometers,
47+
constraints=None,
48+
ewaldErrorTolerance=0.0005)
49+
# Create a box used when calling
50+
box = np.array([[50.0, 0.0, 0.0], [0.0, 50.0, 0.0], [0.0, 0.0, 50.0]])
51+
nbList = NeighborList(box, 4.0, potentials.meta["cov_map"])
52+
nbList.allocate(init_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
53+
pairs = nbList.pairs
54+
55+
56+
@jax.jit
57+
@jax.vmap
58+
def U_native(_x):
59+
"""
60+
Calling U by U(x, box, pairs, ff.paramset.parameters), x is [22, 3] and output the energy, if it is batched, use vmap
61+
"""
62+
_U = potentials.getPotentialFunc()
63+
64+
return _U(_x.reshape(22, 3), box, pairs, ff.paramset.parameters).sum()
65+
66+
67+
def U_padded(x):
68+
x = x.reshape(-1, 66)
69+
orig_length = x.shape[0]
70+
padded_length = orig_length // 100 * 100 + 100
71+
x_empty = jnp.zeros((padded_length, 66))
72+
x = x_empty.at[:x.shape[0], :].set(x.reshape(-1, 66))
73+
return U_native(x)[:orig_length]
74+
75+
76+
@jax.jit
77+
@jax.vmap
78+
def dUdx_fn(_x):
79+
def U(_x):
80+
"""
81+
Calling U by U(x, box, pairs, ff.paramset.parameters), x is [22, 3] and output the energy, if it is batched, use vmap
82+
"""
83+
_U = potentials.getPotentialFunc()
84+
85+
return _U(_x.reshape(22, 3), box, pairs, ff.paramset.parameters)
86+
87+
return jax.grad(lambda _x: U(_x).sum())(_x) / mass / gamma_in_ps
88+
89+
90+
@jax.jit
91+
def step_langevin_log_prob(_x, _v, _new_x, _new_v):
92+
alpha = jnp.exp(-gamma_in_ps * dt_in_ps)
93+
f_scale = (1 - alpha) / gamma_in_ps
94+
new_v_det = alpha * _v + f_scale * -dUdx_fn(_x.reshape(1, -1))
95+
new_v_rand = new_v_det - _new_v
96+
97+
return jax.scipy.stats.norm.logpdf(new_v_rand, 0, jnp.sqrt(kbT * (1 - alpha ** 2) / mass)).sum()
98+
99+
100+
def langevin_log_path_likelihood(path, velocities):
101+
assert len(path) == len(
102+
velocities), f'path and velocities must have the same length, but got {len(path)} and {len(velocities)}'
103+
log_prob = (-U_native(path[0].reshape(1, -1)) / kbT).sum()
104+
log_prob += jax.scipy.stats.norm.logpdf(velocities[0], 0, jnp.sqrt(kbT / mass)).sum()
105+
106+
for i in range(1, len(path)):
107+
log_prob += step_langevin_log_prob(path[i - 1], velocities[i - 1], path[i], velocities[i])
108+
109+
# log_prob += step_langevin_log_prob(path[:-1], velocities[:-1], path[1:], velocities[1:]).sum()
110+
111+
return log_prob
112+
113+
114+
def load(path):
115+
loaded = np.load(path, allow_pickle=True)
116+
return [p.astype(np.float32).reshape(-1, 66) for p in loaded]
117+
118+
119+
if __name__ == '__main__':
120+
savedir = './out/evaluation/alanine/'
121+
os.makedirs(savedir, exist_ok=True)
122+
123+
all_paths = [
124+
('one-way-shooting-var-length-cv', './out/baselines/alanine-one-way-shooting', 50),
125+
('one-way-shooting-var-length-rmsd', './out/baselines/alanine-one-way-shooting-rmsd', 50),
126+
('one-way-shooting-fixed-length-cv', './out/baselines/alanine-one-way-shooting-1000steps', 50),
127+
('one-way-shooting-fixed-length-rmsd', './out/baselines/alanine-one-way-shooting-1000steps-rmsd', 50),
128+
('two-way-shooting-var-length-cv', './out/baselines/alanine-two-way-shooting', 0),
129+
('two-way-shooting-var-length-rmsd', './out/baselines/alanine-two-way-shooting-rmsd', 0),
130+
('two-way-shooting-fixed-length-cv', './out/baselines/alanine-two-way-shooting-1000steps', 0),
131+
]
132+
133+
# print relevant statistics:
134+
for name, file_path, _warmup in all_paths:
135+
with open(f'{file_path}/stats.json', 'r') as fp:
136+
statistics = json.load(fp)
137+
print(name, statistics)
138+
139+
all_paths = [(name, load(f'{path}/paths.npy')[warmup:], load(f'{path}/velocities.npy')[warmup:]) for
140+
name, path, warmup in tqdm(all_paths, desc='loading paths')]
141+
[print(name, len(path), len(velocities)) for name, path, velocities in all_paths]
142+
143+
for name, paths, _velocities in all_paths:
144+
print(name, 'decorrelated trajectories:', jnp.round(100 * len(decorrelated(paths)) / len(paths), 2), '%')
145+
146+
for name, paths, _velocities in all_paths:
147+
max_energy = jnp.array([jnp.max(U_padded(path)) for path in tqdm(paths)])
148+
print(name, 'max energy mean:', jnp.round(jnp.mean(max_energy), 2), 'std:', jnp.round(jnp.std(max_energy), 2))
149+
print(name, 'min max energy:', jnp.round(jnp.min(max_energy), 2))
150+
151+
for name, paths, velocities in all_paths:
152+
log_likelihood = jnp.array(
153+
[langevin_log_path_likelihood(path, current_velocities) for path, current_velocities in
154+
tqdm(zip(paths, velocities), total=len(paths))])
155+
156+
print(name, 'max log likelihood:', jnp.round(jnp.max(log_likelihood), 2))
157+
print(name, 'mean log likelihood:', jnp.round(jnp.mean(log_likelihood), 2), 'std:',
158+
jnp.round(jnp.std(log_likelihood), 2))

0 commit comments

Comments
 (0)