@@ -15,21 +15,25 @@ def __init__(self, start_state, target_state, step_forward, step_backward, sampl
15
15
self .sample_velocity = sample_velocity
16
16
17
17
18
- def one_way_shooting (system , trajectory , fixed_length , dt , key ):
18
+ def one_way_shooting (system , trajectory , previous_velocities , fixed_length , dt , key ):
19
19
key = jax .random .split (key )
20
20
21
+ if previous_velocities is None :
22
+ previous_velocities = [(trajectory [i ] - trajectory [i - 1 ]) / dt for i in range (1 , len (trajectory ))]
23
+ previous_velocities .insert (0 , system .sample_velocity (key [0 ]))
24
+
21
25
# pick a random point along the trajectory
22
- point_idx = jax .random .randint (key [0 ], (1 ,), 1 , len (trajectory ) - 1 )[0 ]
26
+ point_idx = jax .random .randint (key [1 ], (1 ,), 1 , len (trajectory ) - 1 )[0 ]
23
27
# pick a random direction, either forward or backward
24
- direction = jax .random .randint (key [1 ], (1 ,), 0 , 2 )[0 ]
25
-
26
- new_velocities = [(trajectory [point_idx ] - trajectory [point_idx - 1 ]) / dt ]
28
+ direction = jax .random .randint (key [2 ], (1 ,), 0 , 2 )[0 ]
27
29
28
30
if direction == 0 :
29
31
trajectory = trajectory [:point_idx + 1 ]
32
+ new_velocities = previous_velocities [:point_idx + 1 ]
30
33
step_function = system .step_forward
31
34
else : # direction == 1:
32
35
trajectory = trajectory [point_idx :][::- 1 ]
36
+ new_velocities = previous_velocities [point_idx :][::- 1 ]
33
37
step_function = system .step_backward
34
38
35
39
steps = MAX_STEPS if fixed_length == 0 else fixed_length
@@ -73,7 +77,7 @@ def one_way_shooting(system, trajectory, fixed_length, dt, key):
73
77
return False , trajectory , new_velocities
74
78
75
79
76
- def two_way_shooting (system , trajectory , fixed_length , _dt , key ):
80
+ def two_way_shooting (system , trajectory , _previous_velocities , fixed_length , _dt , key ):
77
81
key = jax .random .split (key )
78
82
79
83
# pick a random point along the trajectory
@@ -170,11 +174,14 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, dt, key, fixe
170
174
velocities = stored ['velocities' ]
171
175
statistics = stored ['statistics' ]
172
176
177
+ num_tries = 0
178
+ num_force_evaluations = 0
179
+ num_metropolis_rejected = 0
173
180
try :
174
181
with tqdm (total = num_paths + warmup , initial = len (trajectories ) - 1 ,
175
182
desc = 'warming up' if warmup > 0 else '' ) as pbar :
176
183
while len (trajectories ) <= num_paths + warmup :
177
- statistics [ ' num_tries' ] += 1
184
+ num_tries += 1
178
185
if len (trajectories ) > warmup :
179
186
pbar .set_description ('' )
180
187
@@ -183,20 +190,33 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, dt, key, fixe
183
190
# during warmup, we want an iterative scheme
184
191
traj_idx = traj_idx if traj_idx < len (trajectories ) else - 1
185
192
186
- found , new_trajectory , new_velocities = proposal (system , trajectories [traj_idx ], fixed_length , dt , ikey )
187
- statistics ['num_force_evaluations' ] += len (new_trajectory ) - 1
193
+ # trajectories and velocities are one off
194
+ found , new_trajectory , new_velocities = proposal (system ,
195
+ trajectories [traj_idx ],
196
+ velocities [traj_idx - 1 ] if len (trajectories ) > 1 else None ,
197
+ fixed_length , dt , ikey )
198
+ num_force_evaluations += len (new_trajectory ) - 1
188
199
189
200
if not found :
190
201
continue
191
202
192
203
ratio = len (trajectories [- 1 ]) / len (new_trajectory )
193
204
# The first trajectory might have a very unreasonable length, so we skip it
194
205
if len (trajectories ) == 1 or jax .random .uniform (accept_key , shape = (1 ,)) < ratio :
206
+ # only update them in the dictionary once accepted
207
+ # this allows us to continue the progress
208
+ statistics ['num_tries' ] += num_tries
209
+ statistics ['num_force_evaluations' ] += num_force_evaluations
210
+ statistics ['num_metropolis_rejected' ] += num_metropolis_rejected
211
+ num_tries = 0
212
+ num_force_evaluations = 0
213
+ num_metropolis_rejected = 0
214
+
195
215
trajectories .append (new_trajectory )
196
216
velocities .append (new_velocities )
197
217
pbar .update (1 )
198
218
else :
199
- statistics [ ' num_metropolis_rejected' ] += 1
219
+ num_metropolis_rejected += 1
200
220
except KeyboardInterrupt :
201
221
print ('SIGINT received, stopping early' )
202
222
# Fix in case we stop when adding a trajectory
0 commit comments