@@ -42,56 +42,56 @@ def interpolate_two_points(start, stop, steps):
42
42
43
43
if __name__ == '__main__' :
44
44
# variable or fixed length?
45
- variable = True
46
45
num_paths = 1000
47
46
48
- save_dir = f"out/baselines/mueller"
49
- if variable :
50
- save_dir += "-variable"
51
-
52
- os .makedirs (save_dir , exist_ok = True )
53
-
54
- xi = 5
55
- dt = 1e-4
56
- T = 275e-4
57
- N = 0 if variable else int (T / dt )
58
-
59
- system = System .from_name ('mueller_brown' , float ('inf' ))
60
- initial_trajectory = [t .reshape (1 , 2 ) for t in interpolate (jnp .array ([system .A , system .B ]), 100 if variable else N )]
61
-
62
- @jax .jit
63
- def step (_x , _key ):
64
- """Perform one step of forward euler"""
65
- return _x - dt * system .dUdx (_x ) + jnp .sqrt (dt ) * xi * jax .random .normal (_key , _x .shape )
66
-
67
-
68
- tps_config = tps1 .FirstOrderSystem (
69
- jax .jit (lambda s : jnp .linalg .norm (s - system .A ) <= 0.1 ),
70
- jax .jit (lambda s : jnp .linalg .norm (s - system .B ) <= 0.1 ),
71
- step
72
- )
73
-
74
- for method , name in [
75
- (tps1 .one_way_shooting , 'one-way-shooting' ),
76
- (tps1 .two_way_shooting , 'two-way-shooting' ),
77
- ]:
78
- if os .path .exists (f'{ save_dir } /paths-{ name } .npy' ) and os .path .exists (f'{ save_dir } /stats-{ name } .json' ):
79
- print (f"Skipping { name } because the results are already present" )
80
-
81
- paths = np .load (f'{ save_dir } /paths-{ name } .npy' , allow_pickle = True )
82
- paths = [jnp .array (p .astype (np .float32 )) for p in paths ]
83
- with open (f'{ save_dir } /stats-{ name } .json' , 'r' ) as fp :
84
- statistics = json .load (fp )
85
- else :
86
- print ('Generating paths for' , name )
87
- paths , statistics = tps1 .mcmc_shooting (tps_config , method , initial_trajectory , num_paths ,
88
- jax .random .PRNGKey (1 ), warmup = 0 , fixed_length = N )
89
-
90
- paths = [jnp .array (p ) for p in paths ]
91
-
92
- np .save (f'{ save_dir } /paths-{ name } .npy' , np .array (paths , dtype = object ), allow_pickle = True )
93
- with open (f'{ save_dir } /stats-{ name } .json' , 'w' ) as fp :
94
- json .dump (statistics , fp )
95
-
96
- system .plot (trajectories = paths )
97
- show_or_save_fig (save_dir , f'mueller-{ name } ' , 'pdf' )
47
+ for variable in [False , True ]:
48
+ save_dir = f"out/baselines/mueller"
49
+ if variable :
50
+ save_dir += "-variable"
51
+
52
+ os .makedirs (save_dir , exist_ok = True )
53
+
54
+ xi = 5
55
+ dt = 1e-4
56
+ T = 275e-4
57
+ N = 0 if variable else int (T / dt )
58
+
59
+ system = System .from_name ('mueller_brown' , float ('inf' ))
60
+ initial_trajectory = [t .reshape (1 , 2 ) for t in interpolate (jnp .array ([system .A , system .B ]), 100 if variable else N )]
61
+
62
+ @jax .jit
63
+ def step (_x , _key ):
64
+ """Perform one step of forward euler"""
65
+ return _x - dt * system .dUdx (_x ) + jnp .sqrt (dt ) * xi * jax .random .normal (_key , _x .shape )
66
+
67
+
68
+ tps_config = tps1 .FirstOrderSystem (
69
+ jax .jit (lambda s : jnp .linalg .norm (s - system .A ) <= 0.1 ),
70
+ jax .jit (lambda s : jnp .linalg .norm (s - system .B ) <= 0.1 ),
71
+ step
72
+ )
73
+
74
+ for method , name in [
75
+ (tps1 .one_way_shooting , 'one-way-shooting' ),
76
+ (tps1 .two_way_shooting , 'two-way-shooting' ),
77
+ ]:
78
+ if os .path .exists (f'{ save_dir } /paths-{ name } .npy' ) and os .path .exists (f'{ save_dir } /stats-{ name } .json' ):
79
+ print (f"Skipping { name } because the results are already present" )
80
+
81
+ paths = np .load (f'{ save_dir } /paths-{ name } .npy' , allow_pickle = True )
82
+ paths = [jnp .array (p .astype (np .float32 )) for p in paths ]
83
+ with open (f'{ save_dir } /stats-{ name } .json' , 'r' ) as fp :
84
+ statistics = json .load (fp )
85
+ else :
86
+ print ('Generating paths for' , name )
87
+ paths , statistics = tps1 .mcmc_shooting (tps_config , method , initial_trajectory , num_paths ,
88
+ jax .random .PRNGKey (1 ), warmup = 0 , fixed_length = N )
89
+
90
+ paths = [jnp .array (p ) for p in paths ]
91
+
92
+ np .save (f'{ save_dir } /paths-{ name } .npy' , np .array (paths , dtype = object ), allow_pickle = True )
93
+ with open (f'{ save_dir } /stats-{ name } .json' , 'w' ) as fp :
94
+ json .dump (statistics , fp )
95
+
96
+ system .plot (trajectories = paths )
97
+ show_or_save_fig (save_dir , f'mueller-{ name } ' , 'pdf' )
0 commit comments