Skip to content

Commit 48a3bc0

Browse files
committed
Refactor tps_baseline.py
1 parent b77138d commit 48a3bc0

File tree

6 files changed

+81
-155
lines changed

6 files changed

+81
-155
lines changed

README.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ python eval/evaluate_mueller.py
2828
and
2929

3030
```bash
31-
python tps_baseline.py
31+
python tps_baseline.py --mechanism two-way-shooting --num_paths 1000 --states phi-psi
32+
# num_steps compiles multiple MD steps into a single one, making sampling faster. But this makes startup longer. Only really worth it for long running simulations
33+
python tps_baseline.py --mechanism two-way-shooting --num_paths 1000 --fixed_length 1000 --states phi-psi --num_steps 50
34+
python tps_baseline.py --mechanism two-way-shooting --num_paths 1000 --states rmsd
3235
python eval/evaluate_tps.py
3336
```
3437

eval/path_metrics.py

+14
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,17 @@ def plot_path_energy(paths, U, reduce=jnp.max, add=0, already_ln=False, **kwargs
1414
plt.semilogy(jnp.arange(0, len(reduced), 1), reduced, **kwargs)
1515

1616
return reduced
17+
18+
19+
def plot_iterative_min_max_energy(paths, U, potential_calls):
20+
reduced = jnp.array([jnp.max(U(path)) for path in tqdm(paths)])
21+
22+
iterative_min = [reduced[0]]
23+
for c in reduced[1:]:
24+
iterative_min.append(min(c, iterative_min[-1]))
25+
26+
plt.xlabel('Number of potential calls')
27+
plt.ylabel('Minimum energy of best path so far')
28+
plt.semilogy(jnp.cumsum(jnp.array(potential_calls)), iterative_min)
29+
30+
return iterative_min

systems.py

-2
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,6 @@ def U(_x):
105105
yticklabels=[r'$-\pi$', r'$-\frac {\pi} {2}$', '0', r'$\frac {\pi} {2}$', r'$\pi$'],
106106
square=True, periodic=True,
107107
)
108-
109-
110108
else:
111109
raise ValueError(f"Unknown cv: {cv}")
112110

tps/second_order.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, dt, key, fixe
158158
trajectories = [initial_trajectory]
159159
velocities = []
160160
statistics = {
161-
'num_force_evaluations': 0,
161+
'num_force_evaluations': [],
162162
'num_tries': 0,
163163
'num_metropolis_rejected': 0,
164164
'warmup': warmup,
@@ -193,7 +193,8 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, dt, key, fixe
193193
# trajectories and velocities are one off
194194
found, new_trajectory, new_velocities = proposal(system,
195195
trajectories[traj_idx],
196-
velocities[traj_idx - 1] if len(trajectories) > 1 else None,
196+
velocities[traj_idx - 1] if len(
197+
trajectories) > 1 else None,
197198
fixed_length, dt, ikey)
198199
num_force_evaluations += len(new_trajectory) - 1
199200

@@ -206,7 +207,7 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, dt, key, fixe
206207
# only update them in the dictionary once accepted
207208
# this allows us to continue the progress
208209
statistics['num_tries'] += num_tries
209-
statistics['num_force_evaluations'] += num_force_evaluations
210+
statistics['num_force_evaluations'].append(num_force_evaluations)
210211
statistics['num_metropolis_rejected'] += num_metropolis_rejected
211212
num_tries = 0
212213
num_force_evaluations = 0

0 commit comments

Comments
 (0)