Skip to content

Commit e9db8e0

Browse files
committed
Add histograms to evaluation
1 parent 993860c commit e9db8e0

File tree

2 files changed

+40
-7
lines changed

2 files changed

+40
-7
lines changed

eval/evaluate_mueller.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import jax.numpy as jnp
33
import jax
44
from eval.path_metrics import plot_path_energy
5+
from systems import System
56
from tps.paths import decorrelated
6-
from tps_baseline_mueller import U, dUdx_fn
77
from scipy.optimize import minimize
88
import matplotlib.pyplot as plt
99
import os
@@ -15,6 +15,8 @@
1515
T = 275e-4
1616
N = int(T / dt)
1717

18+
system = System.from_name('mueller_brown', float('inf'))
19+
1820
minima_points = jnp.array([[-0.55828035, 1.44169],
1921
[-0.05004308, 0.46666032],
2022
[0.62361133, 0.02804632]])
@@ -27,8 +29,17 @@ def load(path):
2729

2830
@jax.jit
2931
def log_path_likelihood(path):
30-
rand = path[1:] - path[:-1] + dt * dUdx_fn(path[:-1])
31-
return (-U(path[0]) / kbT).sum() + jax.scipy.stats.norm.logpdf(rand, scale=jnp.sqrt(dt) * xi).sum()
32+
rand = path[1:] - path[:-1] + dt * system.dUdx(path[:-1])
33+
return (-system.U(path[0]) / kbT).sum() + jax.scipy.stats.norm.logpdf(rand, scale=jnp.sqrt(dt) * xi).sum()
34+
35+
36+
def plot_hist(system, paths, trajectories_to_plot, seed=1):
37+
system.plot(trajectories=paths)
38+
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
39+
idx = jax.random.permutation(jax.random.PRNGKey(seed), len(paths))[:trajectories_to_plot]
40+
for i, c in zip(idx, colors[1:]):
41+
cur_paths = jnp.array(paths[i])
42+
plt.plot(cur_paths[:, 0].T, cur_paths[:, 1].T, c=c)
3243

3344

3445
if __name__ == '__main__':
@@ -43,19 +54,29 @@ def log_path_likelihood(path):
4354
('var-doobs', './out/var_doobs/mueller/paths.npy', 0),
4455
]
4556

46-
global_minimum_energy = U(minima_points[0])
57+
global_minimum_energy = min(system.U(minima_points))
4758
for point in minima_points:
48-
global_minimum_energy = min(global_minimum_energy, minimize(U, point).fun)
59+
global_minimum_energy = min(global_minimum_energy, minimize(system.U, point).fun)
4960
print("Global minimum energy", global_minimum_energy)
5061

5162
all_paths = [(name, load(path)[warmup:],) for name, path, warmup in all_paths]
5263
[print(name, len(path)) for name, path in all_paths]
5364

65+
for name, paths in all_paths:
66+
# for this plot we limit ourselves to 250 paths
67+
plot_hist(system, paths[:250], 2)
68+
plt.savefig(f'{savedir}/{name}-histogram.pdf', bbox_inches='tight')
69+
plt.show()
70+
71+
plot_hist(system, decorrelated(paths)[:250], 2)
72+
plt.savefig(f'{savedir}/{name}-decorrelated-histogram.pdf', bbox_inches='tight')
73+
plt.show()
74+
5475
for name, paths in all_paths:
5576
print(name, 'decorrelated trajectories:', jnp.round(100 * len(decorrelated(paths)) / len(paths), 2), '%')
5677

5778
for name, paths in all_paths:
58-
max_energy = plot_path_energy(paths, U, add=-global_minimum_energy, label=name) + global_minimum_energy
79+
max_energy = plot_path_energy(paths, system.U, add=-global_minimum_energy, label=name) + global_minimum_energy
5980
print(name, 'max energy mean:', jnp.round(jnp.mean(max_energy), 2), 'std:', jnp.round(jnp.std(max_energy), 2))
6081
print(name, 'min max energy: ', jnp.round(jnp.min(max_energy), 2))
6182

@@ -65,7 +86,7 @@ def log_path_likelihood(path):
6586
plt.show()
6687

6788
for name, paths in all_paths:
68-
plot_path_energy(paths, U, add=-global_minimum_energy, reduce=jnp.median, label=name)
89+
plot_path_energy(paths, system.U, add=-global_minimum_energy, reduce=jnp.median, label=name)
6990

7091
plt.legend()
7192
plt.ylabel('Median energy')

potentials.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,17 @@
22

33

44
def U_double_well(xs, a=1.0, b=-4.0, c=0, d=1.0, beta=1.0):
5+
if xs.ndim == 1:
6+
xs = xs.reshape(1, -1)
7+
58
x, y = xs[:, 0], xs[:, 1]
69
return beta * (a * (x ** 4) + b * (x ** 2) + c * x + 0.5 * d * (y ** 2))
710

811

912
def U_double_well_hard(xs, beta=1.0):
13+
if xs.ndim == 1:
14+
xs = xs.reshape(1, -1)
15+
1016
A = jnp.array([[-3, 0]])
1117
B = jnp.array([[3, 0]])
1218
U1 = -(((xs - A) @ jnp.array([[1, 0.5], [0.5, 1.0]])) * (xs - A)).sum(1)
@@ -16,6 +22,9 @@ def U_double_well_hard(xs, beta=1.0):
1622

1723

1824
def U_double_well_dual_channel(xs, beta=1.0):
25+
if xs.ndim == 1:
26+
xs = xs.reshape(1, -1)
27+
1928
x, y = xs[:, 0], xs[:, 1]
2029
borders = x ** 6 + y ** 6
2130
e1 = +2.0 * jnp.exp(-(12.0 * (x - 0.00) ** 2 + 12.0 * (y - 0.00) ** 2))
@@ -25,6 +34,9 @@ def U_double_well_dual_channel(xs, beta=1.0):
2534

2635

2736
def U_mueller_brown(xs, beta=1.0):
37+
if xs.ndim == 1:
38+
xs = xs.reshape(1, -1)
39+
2840
x, y = xs[:, 0], xs[:, 1]
2941
e1 = -200 * jnp.exp(-(x - 1) ** 2 - 10 * y ** 2)
3042
e2 = -100 * jnp.exp(-x ** 2 - 10 * (y - 0.5) ** 2)

0 commit comments

Comments
 (0)