2
2
import jax .numpy as jnp
3
3
import jax
4
4
from eval .path_metrics import plot_path_energy
5
+ from systems import System
5
6
from tps .paths import decorrelated
6
- from tps_baseline_mueller import U , dUdx_fn
7
7
from scipy .optimize import minimize
8
8
import matplotlib .pyplot as plt
9
9
import os
15
15
T = 275e-4
16
16
N = int (T / dt )
17
17
18
+ system = System .from_name ('mueller_brown' , float ('inf' ))
19
+
18
20
minima_points = jnp .array ([[- 0.55828035 , 1.44169 ],
19
21
[- 0.05004308 , 0.46666032 ],
20
22
[0.62361133 , 0.02804632 ]])
@@ -27,8 +29,17 @@ def load(path):
27
29
28
30
@jax .jit
29
31
def log_path_likelihood (path ):
30
- rand = path [1 :] - path [:- 1 ] + dt * dUdx_fn (path [:- 1 ])
31
- return (- U (path [0 ]) / kbT ).sum () + jax .scipy .stats .norm .logpdf (rand , scale = jnp .sqrt (dt ) * xi ).sum ()
32
+ rand = path [1 :] - path [:- 1 ] + dt * system .dUdx (path [:- 1 ])
33
+ return (- system .U (path [0 ]) / kbT ).sum () + jax .scipy .stats .norm .logpdf (rand , scale = jnp .sqrt (dt ) * xi ).sum ()
34
+
35
+
36
+ def plot_hist (system , paths , trajectories_to_plot , seed = 1 ):
37
+ system .plot (trajectories = paths )
38
+ colors = plt .rcParams ['axes.prop_cycle' ].by_key ()['color' ]
39
+ idx = jax .random .permutation (jax .random .PRNGKey (seed ), len (paths ))[:trajectories_to_plot ]
40
+ for i , c in zip (idx , colors [1 :]):
41
+ cur_paths = jnp .array (paths [i ])
42
+ plt .plot (cur_paths [:, 0 ].T , cur_paths [:, 1 ].T , c = c )
32
43
33
44
34
45
if __name__ == '__main__' :
@@ -43,19 +54,29 @@ def log_path_likelihood(path):
43
54
('var-doobs' , './out/var_doobs/mueller/paths.npy' , 0 ),
44
55
]
45
56
46
- global_minimum_energy = U (minima_points [ 0 ] )
57
+ global_minimum_energy = min ( system . U (minima_points ) )
47
58
for point in minima_points :
48
- global_minimum_energy = min (global_minimum_energy , minimize (U , point ).fun )
59
+ global_minimum_energy = min (global_minimum_energy , minimize (system . U , point ).fun )
49
60
print ("Global minimum energy" , global_minimum_energy )
50
61
51
62
all_paths = [(name , load (path )[warmup :],) for name , path , warmup in all_paths ]
52
63
[print (name , len (path )) for name , path in all_paths ]
53
64
65
+ for name , paths in all_paths :
66
+ # for this plot we limit ourselves to 250 paths
67
+ plot_hist (system , paths [:250 ], 2 )
68
+ plt .savefig (f'{ savedir } /{ name } -histogram.pdf' , bbox_inches = 'tight' )
69
+ plt .show ()
70
+
71
+ plot_hist (system , decorrelated (paths )[:250 ], 2 )
72
+ plt .savefig (f'{ savedir } /{ name } -decorrelated-histogram.pdf' , bbox_inches = 'tight' )
73
+ plt .show ()
74
+
54
75
for name , paths in all_paths :
55
76
print (name , 'decorrelated trajectories:' , jnp .round (100 * len (decorrelated (paths )) / len (paths ), 2 ), '%' )
56
77
57
78
for name , paths in all_paths :
58
- max_energy = plot_path_energy (paths , U , add = - global_minimum_energy , label = name ) + global_minimum_energy
79
+ max_energy = plot_path_energy (paths , system . U , add = - global_minimum_energy , label = name ) + global_minimum_energy
59
80
print (name , 'max energy mean:' , jnp .round (jnp .mean (max_energy ), 2 ), 'std:' , jnp .round (jnp .std (max_energy ), 2 ))
60
81
print (name , 'min max energy: ' , jnp .round (jnp .min (max_energy ), 2 ))
61
82
@@ -65,7 +86,7 @@ def log_path_likelihood(path):
65
86
plt .show ()
66
87
67
88
for name , paths in all_paths :
68
- plot_path_energy (paths , U , add = - global_minimum_energy , reduce = jnp .median , label = name )
89
+ plot_path_energy (paths , system . U , add = - global_minimum_energy , reduce = jnp .median , label = name )
69
90
70
91
plt .legend ()
71
92
plt .ylabel ('Median energy' )
0 commit comments