@@ -23,8 +23,8 @@ def one_way_shooting(system, trajectory, fixed_length, key):
23
23
# pick a random direction, either forward or backward
24
24
direction = jax .random .randint (key [1 ], (1 ,), 0 , 2 )[0 ]
25
25
26
- # TODO: Fix correct dt in ps
27
- velocity = (trajectory [point_idx ] - trajectory [point_idx - 1 ]) / 0.001
26
+ # TODO: Fix correct dt in ps / pass previous velocities
27
+ new_velocities = [ (trajectory [point_idx ] - trajectory [point_idx - 1 ]) / 0.001 ]
28
28
29
29
if direction == 0 :
30
30
trajectory = trajectory [:point_idx + 1 ]
@@ -38,27 +38,28 @@ def one_way_shooting(system, trajectory, fixed_length, key):
38
38
key , iter_key = jax .random .split (key [3 ])
39
39
while len (trajectory ) < steps :
40
40
key , iter_key = jax .random .split (key )
41
- point , velocity = step_function (trajectory [- 1 ], velocity , iter_key )
41
+ point , velocity = step_function (trajectory [- 1 ], new_velocities [ - 1 ] , iter_key )
42
42
trajectory .append (point )
43
+ new_velocities .append (velocity )
43
44
44
45
if jnp .isnan (point ).any () or jnp .isnan (velocity ).any ():
45
- return False , trajectory
46
+ return False , trajectory , new_velocities
46
47
47
48
# ensure that our trajectory does not explode
48
49
if (jnp .abs (point ) > MAX_ABS_VALUE ).any ():
49
- return False , trajectory
50
+ return False , trajectory , new_velocities
50
51
51
52
if system .start_state (trajectory [0 ]) and system .target_state (trajectory [- 1 ]):
52
53
if fixed_length == 0 or len (trajectory ) == fixed_length :
53
- return True , trajectory
54
- return False , trajectory
54
+ return True , trajectory , new_velocities
55
+ return False , trajectory , new_velocities
55
56
56
57
if system .target_state (trajectory [0 ]) and system .start_state (trajectory [- 1 ]):
57
58
if fixed_length == 0 or len (trajectory ) == fixed_length :
58
- return True , trajectory [::- 1 ]
59
- return False , trajectory
59
+ return True , trajectory [::- 1 ], new_velocities [:: - 1 ]
60
+ return False , trajectory , new_velocities
60
61
61
- return False , trajectory
62
+ return False , trajectory , new_velocities
62
63
63
64
64
65
def two_way_shooting (system , trajectory , fixed_length , key ):
@@ -71,62 +72,74 @@ def two_way_shooting(system, trajectory, fixed_length, key):
71
72
72
73
steps = MAX_STEPS if fixed_length == 0 else fixed_length
73
74
74
- initial_velocity = system .sample_velocity (key [1 ])
75
-
76
- key , iter_key = jax .random .split (key [2 ])
77
75
new_trajectory = [point ]
76
+ new_velocities = [system .sample_velocity (key [1 ])]
78
77
79
- velocity = initial_velocity
78
+ key , iter_key = jax . random . split ( key [ 2 ])
80
79
while len (new_trajectory ) < steps :
81
80
key , iter_key = jax .random .split (key )
82
- point , velocity = system .step_forward (new_trajectory [- 1 ], velocity , iter_key )
81
+ point , velocity = system .step_forward (new_trajectory [- 1 ], new_velocities [ - 1 ] , iter_key )
83
82
new_trajectory .append (point )
83
+ new_velocities .append (velocity )
84
84
85
85
if jnp .isnan (point ).any () or jnp .isnan (velocity ).any ():
86
- return False , trajectory
86
+ return False , new_trajectory , new_velocities
87
87
88
88
# ensure that our trajectory does not explode
89
89
if (jnp .abs (point ) > MAX_ABS_VALUE ).any ():
90
- return False , trajectory
90
+ return False , new_trajectory , new_velocities
91
91
92
92
if system .start_state (point ) or system .target_state (point ):
93
93
break
94
94
95
- velocity = initial_velocity
96
95
while len (new_trajectory ) < steps :
97
96
key , iter_key = jax .random .split (key )
98
- point , velocity = system .step_backward (new_trajectory [0 ], velocity , iter_key )
97
+ point , velocity = system .step_backward (new_trajectory [0 ], new_velocities [ 0 ] , iter_key )
99
98
new_trajectory .insert (0 , point )
99
+ new_velocities .insert (0 , velocity )
100
100
101
101
if jnp .isnan (point ).any () or jnp .isnan (velocity ).any ():
102
- return False , trajectory
102
+ return False , new_trajectory , new_velocities
103
103
104
104
# ensure that our trajectory does not explode
105
105
if (jnp .abs (point ) > MAX_ABS_VALUE ).any ():
106
- return False , trajectory
106
+ return False , new_trajectory , new_velocities
107
107
108
108
if system .start_state (point ) or system .target_state (point ):
109
109
break
110
110
111
111
# throw away the trajectory if it's not the right length
112
112
if fixed_length != 0 and len (new_trajectory ) != fixed_length :
113
- return False , trajectory
113
+ return False , new_trajectory , new_velocities
114
114
115
115
if system .start_state (new_trajectory [0 ]) and system .target_state (new_trajectory [- 1 ]):
116
- return True , new_trajectory
116
+ return True , new_trajectory , new_velocities
117
117
118
118
if system .target_state (new_trajectory [0 ]) and system .start_state (new_trajectory [- 1 ]):
119
- return True , new_trajectory [::- 1 ]
119
+ return True , new_trajectory [::- 1 ], new_velocities [:: - 1 ]
120
120
121
- return False , trajectory
121
+ return False , new_trajectory , new_velocities
122
122
123
123
124
124
def mcmc_shooting (system , proposal , initial_trajectory , num_paths , key , fixed_length = 0 , warmup = 50 ):
125
125
# pick an initial trajectory
126
126
trajectories = [initial_trajectory ]
127
+ velocities = []
128
+ statistics = {
129
+ 'num_force_evaluations' : 0 ,
130
+ 'num_tries' : 0 ,
131
+ 'num_metropolis_rejected' : 0 ,
132
+ 'warmup' : warmup ,
133
+ 'num_paths' : num_paths ,
134
+ 'max_steps' : MAX_STEPS ,
135
+ 'max_abs_value' : MAX_ABS_VALUE ,
136
+ }
137
+ if fixed_length > 0 :
138
+ statistics ['fixed_length' ] = fixed_length
127
139
128
140
with tqdm (total = num_paths + warmup , desc = 'warming up' if warmup > 0 else '' ) as pbar :
129
141
while len (trajectories ) <= num_paths + warmup :
142
+ statistics ['num_tries' ] += 1
130
143
if len (trajectories ) > warmup :
131
144
pbar .set_description ('' )
132
145
@@ -135,7 +148,8 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, key, fixed_le
135
148
# during warmup, we want an iterative scheme
136
149
traj_idx = traj_idx if traj_idx < len (trajectories ) else - 1
137
150
138
- found , new_trajectory = proposal (system , trajectories [traj_idx ], fixed_length , iter_key )
151
+ found , new_trajectory , new_velocities = proposal (system , trajectories [traj_idx ], fixed_length , iter_key )
152
+ statistics ['num_force_evaluations' ] += len (new_trajectory ) - 1
139
153
140
154
if not found :
141
155
continue
@@ -144,9 +158,12 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, key, fixed_le
144
158
# The first trajectory might have a very unreasonable length, so we skip it
145
159
if len (trajectories ) == 1 or jax .random .uniform (accept_key , shape = (1 ,)) < ratio :
146
160
trajectories .append (new_trajectory )
161
+ velocities .append (new_velocities )
147
162
pbar .update (1 )
163
+ else :
164
+ statistics ['num_metropolis_rejected' ] += 1
148
165
149
- return trajectories [warmup + 1 :]
166
+ return trajectories [warmup + 1 :], velocities [ warmup :], statistics
150
167
151
168
152
169
def unguided_md (system , initial_point , num_paths , key , fixed_length = 0 ):
0 commit comments