Skip to content

Commit e94d743

Browse files
committed
Update baselines, shooting, and plotting
1 parent 4d8d642 commit e94d743

File tree

3 files changed

+117
-38
lines changed

3 files changed

+117
-38
lines changed

eval/path_metrics.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import jax.numpy as jnp
2+
import matplotlib.pyplot as plt
3+
from tqdm import tqdm
4+
5+
6+
def plot_path_energy(paths, U, reduce=jnp.max, already_ln=False):
7+
reduced = jnp.array([reduce(U(path)) for path in tqdm(paths, 'Computing path metric')])
8+
9+
if already_ln:
10+
# Convert reduced to log10
11+
reduced = reduced / jnp.log(10)
12+
plt.plot(jnp.arange(0, len(reduced), 1), reduced)
13+
else:
14+
plt.semilogy(jnp.arange(0, len(reduced), 1), reduced)

tps/second_order.py

+44-27
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def one_way_shooting(system, trajectory, fixed_length, key):
2323
# pick a random direction, either forward or backward
2424
direction = jax.random.randint(key[1], (1,), 0, 2)[0]
2525

26-
# TODO: Fix correct dt in ps
27-
velocity = (trajectory[point_idx] - trajectory[point_idx - 1]) / 0.001
26+
# TODO: Fix correct dt in ps / pass previous velocities
27+
new_velocities = [(trajectory[point_idx] - trajectory[point_idx - 1]) / 0.001]
2828

2929
if direction == 0:
3030
trajectory = trajectory[:point_idx + 1]
@@ -38,27 +38,28 @@ def one_way_shooting(system, trajectory, fixed_length, key):
3838
key, iter_key = jax.random.split(key[3])
3939
while len(trajectory) < steps:
4040
key, iter_key = jax.random.split(key)
41-
point, velocity = step_function(trajectory[-1], velocity, iter_key)
41+
point, velocity = step_function(trajectory[-1], new_velocities[-1], iter_key)
4242
trajectory.append(point)
43+
new_velocities.append(velocity)
4344

4445
if jnp.isnan(point).any() or jnp.isnan(velocity).any():
45-
return False, trajectory
46+
return False, trajectory, new_velocities
4647

4748
# ensure that our trajectory does not explode
4849
if (jnp.abs(point) > MAX_ABS_VALUE).any():
49-
return False, trajectory
50+
return False, trajectory, new_velocities
5051

5152
if system.start_state(trajectory[0]) and system.target_state(trajectory[-1]):
5253
if fixed_length == 0 or len(trajectory) == fixed_length:
53-
return True, trajectory
54-
return False, trajectory
54+
return True, trajectory, new_velocities
55+
return False, trajectory, new_velocities
5556

5657
if system.target_state(trajectory[0]) and system.start_state(trajectory[-1]):
5758
if fixed_length == 0 or len(trajectory) == fixed_length:
58-
return True, trajectory[::-1]
59-
return False, trajectory
59+
return True, trajectory[::-1], new_velocities[::-1]
60+
return False, trajectory, new_velocities
6061

61-
return False, trajectory
62+
return False, trajectory, new_velocities
6263

6364

6465
def two_way_shooting(system, trajectory, fixed_length, key):
@@ -71,62 +72,74 @@ def two_way_shooting(system, trajectory, fixed_length, key):
7172

7273
steps = MAX_STEPS if fixed_length == 0 else fixed_length
7374

74-
initial_velocity = system.sample_velocity(key[1])
75-
76-
key, iter_key = jax.random.split(key[2])
7775
new_trajectory = [point]
76+
new_velocities = [system.sample_velocity(key[1])]
7877

79-
velocity = initial_velocity
78+
key, iter_key = jax.random.split(key[2])
8079
while len(new_trajectory) < steps:
8180
key, iter_key = jax.random.split(key)
82-
point, velocity = system.step_forward(new_trajectory[-1], velocity, iter_key)
81+
point, velocity = system.step_forward(new_trajectory[-1], new_velocities[-1], iter_key)
8382
new_trajectory.append(point)
83+
new_velocities.append(velocity)
8484

8585
if jnp.isnan(point).any() or jnp.isnan(velocity).any():
86-
return False, trajectory
86+
return False, new_trajectory, new_velocities
8787

8888
# ensure that our trajectory does not explode
8989
if (jnp.abs(point) > MAX_ABS_VALUE).any():
90-
return False, trajectory
90+
return False, new_trajectory, new_velocities
9191

9292
if system.start_state(point) or system.target_state(point):
9393
break
9494

95-
velocity = initial_velocity
9695
while len(new_trajectory) < steps:
9796
key, iter_key = jax.random.split(key)
98-
point, velocity = system.step_backward(new_trajectory[0], velocity, iter_key)
97+
point, velocity = system.step_backward(new_trajectory[0], new_velocities[0], iter_key)
9998
new_trajectory.insert(0, point)
99+
new_velocities.insert(0, velocity)
100100

101101
if jnp.isnan(point).any() or jnp.isnan(velocity).any():
102-
return False, trajectory
102+
return False, new_trajectory, new_velocities
103103

104104
# ensure that our trajectory does not explode
105105
if (jnp.abs(point) > MAX_ABS_VALUE).any():
106-
return False, trajectory
106+
return False, new_trajectory, new_velocities
107107

108108
if system.start_state(point) or system.target_state(point):
109109
break
110110

111111
# throw away the trajectory if it's not the right length
112112
if fixed_length != 0 and len(new_trajectory) != fixed_length:
113-
return False, trajectory
113+
return False, new_trajectory, new_velocities
114114

115115
if system.start_state(new_trajectory[0]) and system.target_state(new_trajectory[-1]):
116-
return True, new_trajectory
116+
return True, new_trajectory, new_velocities
117117

118118
if system.target_state(new_trajectory[0]) and system.start_state(new_trajectory[-1]):
119-
return True, new_trajectory[::-1]
119+
return True, new_trajectory[::-1], new_velocities[::-1]
120120

121-
return False, trajectory
121+
return False, new_trajectory, new_velocities
122122

123123

124124
def mcmc_shooting(system, proposal, initial_trajectory, num_paths, key, fixed_length=0, warmup=50):
125125
# pick an initial trajectory
126126
trajectories = [initial_trajectory]
127+
velocities = []
128+
statistics = {
129+
'num_force_evaluations': 0,
130+
'num_tries': 0,
131+
'num_metropolis_rejected': 0,
132+
'warmup': warmup,
133+
'num_paths': num_paths,
134+
'max_steps': MAX_STEPS,
135+
'max_abs_value': MAX_ABS_VALUE,
136+
}
137+
if fixed_length > 0:
138+
statistics['fixed_length'] = fixed_length
127139

128140
with tqdm(total=num_paths + warmup, desc='warming up' if warmup > 0 else '') as pbar:
129141
while len(trajectories) <= num_paths + warmup:
142+
statistics['num_tries'] += 1
130143
if len(trajectories) > warmup:
131144
pbar.set_description('')
132145

@@ -135,7 +148,8 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, key, fixed_le
135148
# during warmup, we want an iterative scheme
136149
traj_idx = traj_idx if traj_idx < len(trajectories) else -1
137150

138-
found, new_trajectory = proposal(system, trajectories[traj_idx], fixed_length, iter_key)
151+
found, new_trajectory, new_velocities = proposal(system, trajectories[traj_idx], fixed_length, iter_key)
152+
statistics['num_force_evaluations'] += len(new_trajectory) - 1
139153

140154
if not found:
141155
continue
@@ -144,9 +158,12 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, key, fixed_le
144158
# The first trajectory might have a very unreasonable length, so we skip it
145159
if len(trajectories) == 1 or jax.random.uniform(accept_key, shape=(1,)) < ratio:
146160
trajectories.append(new_trajectory)
161+
velocities.append(new_velocities)
147162
pbar.update(1)
163+
else:
164+
statistics['num_metropolis_rejected'] += 1
148165

149-
return trajectories[warmup + 1:]
166+
return trajectories[warmup + 1:], velocities[warmup:], statistics
150167

151168

152169
def unguided_md(system, initial_point, num_paths, key, fixed_length=0):

tps_baseline.py

+59-11
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import os
2+
from functools import partial
3+
24
import jax
35
import numpy as np
46
import matplotlib.pyplot as plt
57
import jax.numpy as jnp
68
from tqdm import trange, tqdm
9+
import json
710

811
# install openmm (from conda)
912
import openmm.app as app
@@ -13,6 +16,7 @@
1316
# install mdtraj
1417
import mdtraj as md
1518

19+
from eval.path_metrics import plot_path_energy
1620
from tps import first_order as tps1
1721
from tps import second_order as tps2
1822
from tps.plot import PeriodicPathHistogram
@@ -198,6 +202,26 @@ def step_langevin_forward(_x, _v, _key):
198202

199203
return _x + dt_in_ps * new_v, new_v
200204

205+
@jax.jit
206+
def step_langevin_log_density(_x, _v, _new_x, _new_v):
207+
alpha = jnp.exp(-gamma_in_ps * dt_in_ps)
208+
f_scale = (1 - alpha) / gamma_in_ps
209+
new_v_det = alpha * _v + f_scale * -dUdx_fn_unscaled(_x) / mass
210+
new_v_rand = new_v_det - _new_v
211+
212+
return jax.scipy.stats.norm.logpdf(new_v_rand, 0, jnp.sqrt(kbT * (1 - alpha ** 2) / mass)).sum()
213+
214+
215+
def langevin_log_path_density(path_and_velocities):
216+
path, velocities = path_and_velocities
217+
218+
log_prob = (-U(path[0]) / kbT).sum()
219+
log_prob += jax.scipy.stats.norm.logpdf(velocities[0], 0, jnp.sqrt(kbT / mass)).sum()
220+
221+
for i in range(1, len(path)):
222+
log_prob += step_langevin_log_density(path[i - 1], velocities[i - 1], path[i], velocities[i])
223+
224+
return log_prob
201225

202226
@jax.jit
203227
def step_langevin_backward(_x, _v, _key):
@@ -235,8 +259,8 @@ def step_langevin_backward(_x, _v, _key):
235259

236260
plt.title(f"{human_format(steps)} steps @ {temp} K, dt = {human_format(dt)}s")
237261
ramachandran(trajectory_phi_psi)
238-
plt.scatter(phis_psis(A)[0], phis_psis(A)[1], color='red', marker='*')
239-
plt.scatter(phis_psis(B)[0], phis_psis(B)[1], color='green', marker='*')
262+
plt.scatter(phis_psis(A)[0, 0], phis_psis(A)[0, 1], color='red', marker='*')
263+
plt.scatter(phis_psis(B)[0, 0], phis_psis(B)[0, 1], color='green', marker='*')
240264
plt.show()
241265

242266
# Choose a system, either phi psi, or rmsd
@@ -255,10 +279,10 @@ def step_langevin_backward(_x, _v, _key):
255279
)
256280

257281
system = tps2.SecondOrderSystem(
258-
jax.jit(lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(A), radius)),
259-
jax.jit(lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(B), radius)),
260-
# jax.jit(jax.vmap(lambda s: kabsch_rmsd(A.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
261-
# jax.jit(jax.vmap(lambda s: kabsch_rmsd(B.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
282+
# jax.jit(lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(A), radius)),
283+
# jax.jit(lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(B), radius)),
284+
jax.jit(jax.vmap(lambda s: kabsch_rmsd(A.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
285+
jax.jit(jax.vmap(lambda s: kabsch_rmsd(B.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
262286
step_langevin_forward,
263287
step_langevin_backward,
264288
jax.jit(lambda key: jnp.sqrt(kbT / mass) * jax.random.normal(key, (1, 66)))
@@ -282,34 +306,58 @@ def step_langevin_backward(_x, _v, _key):
282306
initial_trajectory = [p for p in initial_trajectory]
283307
save_trajectory(mdtraj_topology, jnp.array(initial_trajectory), f'{savedir}/initial_trajectory.pdb')
284308

285-
load = False
309+
load = True
286310
if load:
287311
paths = np.load(f'{savedir}/paths.npy', allow_pickle=True)
312+
velocities = np.load(f'{savedir}/velocities.npy', allow_pickle=True)
313+
with open(f'{savedir}/stats.json', 'r') as fp:
314+
statistics = json.load(fp)
288315
else:
289-
paths = tps2.mcmc_shooting(system, tps2.two_way_shooting, initial_trajectory, 100, jax.random.PRNGKey(1),
290-
warmup=10)
316+
paths, velocities, statistics = tps2.mcmc_shooting(system, tps2.two_way_shooting, initial_trajectory,
317+
100, jax.random.PRNGKey(1), warmup=10)
291318
# paths = tps2.unguided_md(system, B, 1, key)
292319
paths = [jnp.array(p) for p in paths]
320+
velocities = [jnp.array(p) for p in velocities]
293321
# store paths
294322
np.save(f'{savedir}/paths.npy', np.array(paths, dtype=object), allow_pickle=True)
323+
np.save(f'{savedir}/velocities.npy', np.array(velocities, dtype=object), allow_pickle=True)
324+
# save statistics, which is a dictionary
325+
with open(f'{savedir}/stats.json', 'w') as fp:
326+
json.dump(statistics, fp)
295327

328+
print(statistics)
296329
print([len(p) for p in paths])
297330
plt.hist([len(p) for p in paths], bins=jnp.sqrt(len(paths)).astype(int).item())
298331
plt.show()
299332

300333
path_hist = PeriodicPathHistogram()
301-
for i, path in tqdm(enumerate(paths)):
334+
for i, path in tqdm(enumerate(paths), desc='Adding paths to histogram', total=len(paths)):
302335
path_hist.add_path(np.array(phis_psis(path)))
303336

304337
plt.title(f"{human_format(len(paths))} paths @ {temp} K, dt = {human_format(dt)}s")
305-
path_hist.plot(cmin=0.001)
338+
path_hist.plot(cmin=0.01)
306339
ramachandran(None, states=[
307340
{'name': 'A', 'center': phis_psis(A).squeeze(), 'radius': radius},
308341
{'name': 'B', 'center': phis_psis(B).squeeze(), 'radius': radius},
309342
], alpha=0.7)
310343
plt.savefig(f'{savedir}/paths.png', bbox_inches='tight')
311344
plt.show()
312345

346+
plot_path_energy(paths, jax.vmap(U))
347+
plt.ylabel('Maximum energy')
348+
plt.savefig(f'{savedir}/max_energy.png', bbox_inches='tight')
349+
plt.show()
350+
351+
plot_path_energy(paths, jax.vmap(U), reduce=jnp.median)
352+
plt.ylabel('Median energy')
353+
plt.savefig(f'{savedir}/median_energy.png', bbox_inches='tight')
354+
plt.show()
355+
356+
plot_path_energy(list(zip(paths, velocities)), langevin_log_path_density, reduce=lambda x: x, already_ln=True)
357+
plt.ylabel('Path Density')
358+
plt.savefig(f'{savedir}/path_density.png', bbox_inches='tight')
359+
plt.show()
360+
313361
for i, path in tqdm(enumerate(paths)):
314362
save_trajectory(mdtraj_topology, jnp.array([kabsch_align(p.reshape(-1, 3), B.reshape(-1, 3))[0] for p in path]),
315363
f'{savedir}/trajectory_{i}.pdb')

0 commit comments

Comments
 (0)