Skip to content

Commit 5fcc35d

Browse files
committed
Fix shooting and TPS baseliens
1 parent 6dbbfdf commit 5fcc35d

File tree

5 files changed

+64
-93
lines changed

5 files changed

+64
-93
lines changed

eval/path_metrics.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@ def plot_path_energy(paths, U, reduce=jnp.max, add=0, already_ln=False, **kwargs
1212
plt.plot(jnp.arange(0, len(reduced), 1), reduced, **kwargs)
1313
else:
1414
plt.semilogy(jnp.arange(0, len(reduced), 1), reduced, **kwargs)
15+
16+
return reduced

evaluate_mueller.py

-69
This file was deleted.

tps/paths.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import jax.numpy as jnp
2+
from tqdm import tqdm
3+
4+
5+
def decorrelated(paths):
6+
prev = paths[0]
7+
decorrelated = [prev]
8+
9+
for x in tqdm(paths[1:]):
10+
# check if the two arrays share a common value
11+
if not jnp.in1d(prev, x).any():
12+
decorrelated.append(x)
13+
prev = x
14+
15+
return decorrelated

tps/second_order.py

+30-10
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,25 @@ def __init__(self, start_state, target_state, step_forward, step_backward, sampl
1515
self.sample_velocity = sample_velocity
1616

1717

18-
def one_way_shooting(system, trajectory, fixed_length, dt, key):
18+
def one_way_shooting(system, trajectory, previous_velocities, fixed_length, dt, key):
1919
key = jax.random.split(key)
2020

21+
if previous_velocities is None:
22+
previous_velocities = [(trajectory[i] - trajectory[i - 1]) / dt for i in range(1, len(trajectory))]
23+
previous_velocities.insert(0, system.sample_velocity(key[0]))
24+
2125
# pick a random point along the trajectory
22-
point_idx = jax.random.randint(key[0], (1,), 1, len(trajectory) - 1)[0]
26+
point_idx = jax.random.randint(key[1], (1,), 1, len(trajectory) - 1)[0]
2327
# pick a random direction, either forward or backward
24-
direction = jax.random.randint(key[1], (1,), 0, 2)[0]
25-
26-
new_velocities = [(trajectory[point_idx] - trajectory[point_idx - 1]) / dt]
28+
direction = jax.random.randint(key[2], (1,), 0, 2)[0]
2729

2830
if direction == 0:
2931
trajectory = trajectory[:point_idx + 1]
32+
new_velocities = previous_velocities[:point_idx + 1]
3033
step_function = system.step_forward
3134
else: # direction == 1:
3235
trajectory = trajectory[point_idx:][::-1]
36+
new_velocities = previous_velocities[point_idx:][::-1]
3337
step_function = system.step_backward
3438

3539
steps = MAX_STEPS if fixed_length == 0 else fixed_length
@@ -73,7 +77,7 @@ def one_way_shooting(system, trajectory, fixed_length, dt, key):
7377
return False, trajectory, new_velocities
7478

7579

76-
def two_way_shooting(system, trajectory, fixed_length, _dt, key):
80+
def two_way_shooting(system, trajectory, _previous_velocities, fixed_length, _dt, key):
7781
key = jax.random.split(key)
7882

7983
# pick a random point along the trajectory
@@ -170,11 +174,14 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, dt, key, fixe
170174
velocities = stored['velocities']
171175
statistics = stored['statistics']
172176

177+
num_tries = 0
178+
num_force_evaluations = 0
179+
num_metropolis_rejected = 0
173180
try:
174181
with tqdm(total=num_paths + warmup, initial=len(trajectories) - 1,
175182
desc='warming up' if warmup > 0 else '') as pbar:
176183
while len(trajectories) <= num_paths + warmup:
177-
statistics['num_tries'] += 1
184+
num_tries += 1
178185
if len(trajectories) > warmup:
179186
pbar.set_description('')
180187

@@ -183,20 +190,33 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, dt, key, fixe
183190
# during warmup, we want an iterative scheme
184191
traj_idx = traj_idx if traj_idx < len(trajectories) else -1
185192

186-
found, new_trajectory, new_velocities = proposal(system, trajectories[traj_idx], fixed_length, dt, ikey)
187-
statistics['num_force_evaluations'] += len(new_trajectory) - 1
193+
# trajectories and velocities are one off
194+
found, new_trajectory, new_velocities = proposal(system,
195+
trajectories[traj_idx],
196+
velocities[traj_idx - 1] if len(trajectories) > 1 else None,
197+
fixed_length, dt, ikey)
198+
num_force_evaluations += len(new_trajectory) - 1
188199

189200
if not found:
190201
continue
191202

192203
ratio = len(trajectories[-1]) / len(new_trajectory)
193204
# The first trajectory might have a very unreasonable length, so we skip it
194205
if len(trajectories) == 1 or jax.random.uniform(accept_key, shape=(1,)) < ratio:
206+
# only update them in the dictionary once accepted
207+
# this allows us to continue the progress
208+
statistics['num_tries'] += num_tries
209+
statistics['num_force_evaluations'] += num_force_evaluations
210+
statistics['num_metropolis_rejected'] += num_metropolis_rejected
211+
num_tries = 0
212+
num_force_evaluations = 0
213+
num_metropolis_rejected = 0
214+
195215
trajectories.append(new_trajectory)
196216
velocities.append(new_velocities)
197217
pbar.update(1)
198218
else:
199-
statistics['num_metropolis_rejected'] += 1
219+
num_metropolis_rejected += 1
200220
except KeyboardInterrupt:
201221
print('SIGINT received, stopping early')
202222
# Fix in case we stop when adding a trajectory

tps_baseline.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
from functools import partial
32
import traceback
43
import jax
54
import numpy as np
@@ -33,6 +32,7 @@
3332
parser.add_argument('--mechanism', type=str, choices=['one-way-shooting', 'two-way-shooting'], required=True)
3433
parser.add_argument('--states', type=str, default='phi-psi', choices=['phi-psi', 'rmsd'])
3534
parser.add_argument('--fixed_length', type=int, default=0)
35+
parser.add_argument('--warmup', type=int, default=0)
3636
parser.add_argument('--num_paths', type=int, required=True)
3737
parser.add_argument('--num_steps', type=int, default=10,
3838
help='The number of MD steps taken at once. More takes longer to compile but runs faster in the end.')
@@ -213,6 +213,18 @@ def step_langevin_forward(_x, _v, _key):
213213
return _x + dt_in_ps * new_v, new_v
214214

215215

216+
@jax.jit
217+
def step_langevin_backward(_x, _v, _key):
218+
"""Perform one step of backward langevin"""
219+
alpha = jnp.exp(-gamma_in_ps * dt_in_ps)
220+
f_scale = (1 - alpha) / gamma_in_ps
221+
prev_x = _x - dt_in_ps * _v
222+
prev_v = 1 / alpha * (_v + f_scale * dUdx_fn(prev_x) - jnp.sqrt(
223+
kbT * (1 - alpha ** 2) / mass) * jax.random.normal(_key, _x.shape))
224+
225+
return prev_x, prev_v
226+
227+
216228
@jax.jit
217229
def step_langevin_log_prob(_x, _v, _new_x, _new_v):
218230
alpha = jnp.exp(-gamma_in_ps * dt_in_ps)
@@ -225,6 +237,8 @@ def step_langevin_log_prob(_x, _v, _new_x, _new_v):
225237

226238
def langevin_log_path_likelihood(path_and_velocities):
227239
path, velocities = path_and_velocities
240+
assert len(path) == len(velocities), \
241+
f'path and velocities must have the same length, but got {len(path)} and {len(velocities)}'
228242

229243
log_prob = (-U(path[0]) / kbT).sum()
230244
log_prob += jax.scipy.stats.norm.logpdf(velocities[0], 0, jnp.sqrt(kbT / mass)).sum()
@@ -235,18 +249,6 @@ def langevin_log_path_likelihood(path_and_velocities):
235249
return log_prob
236250

237251

238-
@jax.jit
239-
def step_langevin_backward(_x, _v, _key):
240-
"""Perform one step of backward langevin"""
241-
alpha = jnp.exp(-gamma_in_ps * dt_in_ps)
242-
f_scale = (1 - alpha) / gamma_in_ps
243-
prev_x = _x - dt_in_ps * _v
244-
prev_v = 1 / alpha * (_v + f_scale * dUdx_fn(prev_x) - jnp.sqrt(
245-
kbT * (1 - alpha ** 2) / mass) * jax.random.normal(_key, _x.shape))
246-
247-
return prev_x, prev_v
248-
249-
250252
# Choose a system, either phi psi, or rmsd
251253
# system = tps1.System(
252254
# jax.jit(jax.vmap(lambda s: kabsch_rmsd(A.reshape(22, 3), s.reshape(22, 3)) < 0.1)),
@@ -324,7 +326,8 @@ def step_langevin_backward(_x, _v, _key):
324326

325327
try:
326328
paths, velocities, statistics = tps2.mcmc_shooting(system, mechanism, initial_trajectory,
327-
args.num_paths, dt_in_ps, jax.random.PRNGKey(1), warmup=0,
329+
args.num_paths, dt_in_ps, jax.random.PRNGKey(1),
330+
warmup=args.warmup,
328331
fixed_length=args.fixed_length,
329332
stored=stored)
330333
# paths = tps2.unguided_md(system, B, 1, key)

0 commit comments

Comments
 (0)