Skip to content

Commit 8a906da

Browse files
committed
Add baselines for first order mueller
For one and two way shooting, and variational doobs
1 parent 374f50d commit 8a906da

File tree

7 files changed

+493
-52
lines changed

7 files changed

+493
-52
lines changed

eval/path_metrics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from tqdm import tqdm
44

55

6-
def plot_path_energy(paths, U, reduce=jnp.max, already_ln=False):
7-
reduced = jnp.array([reduce(U(path)) for path in tqdm(paths, 'Computing path metric')])
6+
def plot_path_energy(paths, U, reduce=jnp.max, add=0, already_ln=False, **kwargs):
7+
reduced = jnp.array([reduce(U(path)) for path in paths]) + add
88

99
if already_ln:
1010
# Convert reduced to log10
1111
reduced = reduced / jnp.log(10)
12-
plt.plot(jnp.arange(0, len(reduced), 1), reduced)
12+
plt.plot(jnp.arange(0, len(reduced), 1), reduced, **kwargs)
1313
else:
14-
plt.semilogy(jnp.arange(0, len(reduced), 1), reduced)
14+
plt.semilogy(jnp.arange(0, len(reduced), 1), reduced, **kwargs)

evaluate_mueller.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import numpy as np
2+
import jax.numpy as jnp
3+
import jax
4+
from eval.path_metrics import plot_path_energy
5+
from tps_baseline_mueller import U, dUdx_fn, minima_points
6+
from scipy.optimize import minimize
7+
import matplotlib.pyplot as plt
8+
import os
9+
10+
def load(path):
11+
return jnp.array(np.load(path, allow_pickle=True).astype(np.float32)).squeeze()
12+
13+
14+
@jax.jit
15+
def log_prob_path(path):
16+
rand = path[1:] - path[:-1] + dt * dUdx_fn(path[:-1])
17+
return U(path[0]) + jax.scipy.stats.norm.logpdf(rand, scale=jnp.sqrt(dt) * xi).sum()
18+
19+
20+
if __name__ == '__main__':
21+
savedir = './out/evaluation/mueller/'
22+
os.makedirs(savedir, exist_ok=True)
23+
24+
all_paths = [
25+
('one-way-shooting', './out/baselines/mueller/paths-one-way-shooting.npy'),
26+
('two-way-shooting', './out/baselines/mueller/paths-two-way-shooting.npy'),
27+
('var-doobs', './out/var_doobs/mueller/paths.npy'),
28+
]
29+
30+
num_paths = 1000
31+
xi = 5
32+
dt = 1e-4
33+
T = 275e-4
34+
N = int(T / dt)
35+
36+
global_minimum_energy = U(minima_points[0])
37+
for point in minima_points:
38+
global_minimum_energy = min(global_minimum_energy, minimize(U, point).fun)
39+
print("Global minimum energy", global_minimum_energy)
40+
41+
all_paths = [(name, load(path)) for name, path in all_paths]
42+
[print(name, path.shape) for name, path in all_paths]
43+
44+
for name, paths in all_paths:
45+
plot_path_energy(paths, U, add=-global_minimum_energy, label=name)
46+
47+
plt.legend()
48+
plt.ylabel('Maximum energy')
49+
plt.savefig(f'{savedir}/mueller-max-energy.pdf', bbox_inches='tight')
50+
plt.show()
51+
52+
for name, paths in all_paths:
53+
plot_path_energy(paths, U, add=-global_minimum_energy, reduce=jnp.median, label=name)
54+
55+
plt.legend()
56+
plt.ylabel('Median energy')
57+
plt.savefig(f'{savedir}/mueller-median-energy.pdf', bbox_inches='tight')
58+
plt.show()
59+
60+
for name, paths in all_paths:
61+
plot_path_energy(paths, log_prob_path, reduce=lambda x: x, label=name)
62+
print('Median energy of:', name, jnp.median(jnp.array([log_prob_path(path) for path in paths])))
63+
64+
plt.legend()
65+
plt.ylabel('log path likelihood')
66+
plt.savefig(f'{savedir}/mueller-log-path-likelihood.pdf', bbox_inches='tight')
67+
plt.show()

mueller.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
from tps_baseline_mueller import U, A, B, plot_energy_surface
2+
from flax import linen as nn
3+
from flax.training import train_state
4+
import optax
5+
import jax
6+
import jax.numpy as jnp
7+
from tqdm import trange
8+
import matplotlib.pyplot as plt
9+
import os
10+
import numpy as np
11+
12+
13+
class MLPq(nn.Module):
14+
@nn.compact
15+
def __call__(self, t):
16+
t = t / T
17+
h = nn.Dense(128)(t - 0.5)
18+
h = nn.swish(h)
19+
h = nn.Dense(128)(h)
20+
h = nn.swish(h)
21+
h = nn.Dense(128)(h)
22+
h = nn.swish(h)
23+
h = nn.Dense(4)(h)
24+
mu = (1 - t) * A + t * B + (1 - t) * t * h[:, :2]
25+
sigma = (1 - t) * 2.5 * 1e-2 + t * 2.5 * 1e-2 + (1 - t) * t * jnp.exp(h[:, 2:])
26+
return mu, sigma
27+
28+
29+
if __name__ == '__main__':
30+
savedir = f"out/var_doobs/mueller"
31+
os.makedirs(savedir, exist_ok=True)
32+
33+
num_paths = 1000
34+
xi = 5
35+
dt = 1e-4
36+
T = 275e-4
37+
N = int(T / dt)
38+
epochs = 2_500
39+
40+
q = MLPq()
41+
42+
BS = 512
43+
key = jax.random.PRNGKey(1)
44+
key, *init_key = jax.random.split(key, 3)
45+
params_q = q.init(init_key[0], jnp.ones([BS, 1]))
46+
47+
optimizer_q = optax.adam(learning_rate=1e-4)
48+
state_q = train_state.TrainState.create(apply_fn=q.apply,
49+
params=params_q,
50+
tx=optimizer_q)
51+
52+
53+
def loss_fn(params_q, key):
54+
key = jax.random.split(key)
55+
t = T * jax.random.uniform(key[0], [BS, 1])
56+
eps = jax.random.normal(key[1], [BS, 2])
57+
58+
mu_t = lambda _t: state_q.apply_fn(params_q, _t)[0]
59+
sigma_t = lambda _t: state_q.apply_fn(params_q, _t)[1]
60+
61+
def dmudt(_t):
62+
_dmudt = jax.jacrev(lambda _t: mu_t(_t).sum(0))
63+
return _dmudt(_t).squeeze().T
64+
65+
def dsigmadt(_t):
66+
_dsigmadt = jax.jacrev(lambda _t: sigma_t(_t).sum(0))
67+
return _dsigmadt(_t).squeeze().T
68+
69+
dUdx_fn = jax.grad(lambda _x: U(_x).sum())
70+
71+
def v_t(_eps, _t):
72+
u_t = dmudt(_t) + dsigmadt(_t) * _eps
73+
_x = mu_t(_t) + sigma_t(_t) * _eps
74+
out = (u_t + dUdx_fn(_x)) - 0.5 * (xi ** 2) * _eps / sigma_t(t)
75+
return out
76+
77+
loss = 0.5 * ((v_t(eps, t) / xi) ** 2).sum(1, keepdims=True)
78+
print(loss.shape, 'loss.shape', flush=True)
79+
return loss.mean()
80+
81+
82+
@jax.jit
83+
def train_step(state_q, key):
84+
grad_fn = jax.value_and_grad(loss_fn, argnums=0)
85+
loss, grads = grad_fn(state_q.params, key)
86+
state_q = state_q.apply_gradients(grads=grads)
87+
return state_q, loss
88+
89+
90+
key, loc_key = jax.random.split(key)
91+
state_q, loss = train_step(state_q, loc_key)
92+
93+
loss_plot = []
94+
for i in trange(epochs):
95+
key, loc_key = jax.random.split(key)
96+
state_q, loss = train_step(state_q, loc_key)
97+
loss_plot.append(loss)
98+
99+
plt.plot(loss_plot)
100+
plt.show()
101+
102+
t = T * jnp.linspace(0, 1, BS).reshape((-1, 1))
103+
key, path_key = jax.random.split(key)
104+
eps = jax.random.normal(path_key, [BS, 2])
105+
mu_t, sigma_t = state_q.apply_fn(state_q.params, t)
106+
samples = mu_t + sigma_t * eps
107+
plot_energy_surface()
108+
plt.scatter(samples[:, 0], samples[:, 1])
109+
plt.scatter(A[0, 0], A[0, 1], color='red')
110+
plt.scatter(B[0, 0], B[0, 1], color='orange')
111+
plt.show()
112+
113+
print("Number of potential evaluations", BS * epochs)
114+
115+
mu_t = lambda _t: state_q.apply_fn(state_q.params, _t)[0]
116+
sigma_t = lambda _t: state_q.apply_fn(state_q.params, _t)[1]
117+
118+
119+
def dmudt(_t):
120+
_dmudt = jax.jacrev(lambda _t: mu_t(_t).sum(0), argnums=0)
121+
return _dmudt(_t).squeeze().T
122+
123+
124+
def dsigmadt(_t):
125+
_dsigmadt = jax.jacrev(lambda _t: sigma_t(_t).sum(0))
126+
return _dsigmadt(_t).squeeze().T
127+
128+
129+
u_t = jax.jit(lambda _t, _x: dmudt(_t) + dsigmadt(_t) / sigma_t(_t) * (_x - mu_t(_t)))
130+
131+
key, loc_key = jax.random.split(key)
132+
x_t = jnp.ones((BS, N + 1, 2)) * A
133+
eps = jax.random.normal(key, shape=(BS, 2))
134+
x_t = x_t.at[:, 0, :].set(x_t[:, 0, :] + sigma_t(jnp.zeros((BS, 1))) * eps)
135+
t = jnp.zeros((BS, 1))
136+
for i in trange(N):
137+
dx = dt * u_t(t, x_t[:, i, :])
138+
x_t = x_t.at[:, i + 1, :].set(x_t[:, i, :] + dx)
139+
t += dt
140+
141+
x_t_det = x_t.copy()
142+
143+
u_t = jax.jit(
144+
lambda _t, _x: dmudt(_t) + (dsigmadt(_t) / sigma_t(_t) - 0.5 * (xi / sigma_t(_t)) ** 2) * (_x - mu_t(_t)))
145+
146+
BS = num_paths
147+
key, loc_key = jax.random.split(key)
148+
x_t = jnp.ones((BS, N + 1, 2)) * A
149+
eps = jax.random.normal(key, shape=(BS, 2))
150+
x_t = x_t.at[:, 0, :].set(x_t[:, 0, :] + sigma_t(jnp.zeros((BS, 1))) * eps)
151+
t = jnp.zeros((BS, 1))
152+
for i in trange(N):
153+
key, loc_key = jax.random.split(key)
154+
eps = jax.random.normal(key, shape=(BS, 2))
155+
dx = dt * u_t(t, x_t[:, i, :]) + jnp.sqrt(dt) * xi * eps
156+
x_t = x_t.at[:, i + 1, :].set(x_t[:, i, :] + dx)
157+
t += dt
158+
159+
x_t_stoch = x_t.copy()
160+
161+
np.save(f'{savedir}/paths.npy', np.array([jnp.array(p) for p in x_t_stoch], dtype=object), allow_pickle=True)
162+
163+
plt.figure(figsize=(16, 5))
164+
plt.subplot(121)
165+
plot_energy_surface()
166+
plt.plot(x_t_det[:10, :, 0].T, x_t_det[:10, :, 1].T)
167+
plt.scatter(A[0, 0], A[0, 1], color='red')
168+
plt.scatter(B[0, 0], B[0, 1], color='orange')
169+
170+
plt.subplot(122)
171+
plot_energy_surface()
172+
plt.plot(x_t_stoch[:10, :, 0].T, x_t_stoch[:10, :, 1].T)
173+
plt.scatter(A[0, 0], A[0, 1], color='red')
174+
plt.scatter(B[0, 0], B[0, 1], color='orange')
175+
plt.savefig(f'{savedir}/selected_paths_det_vs_stoch.png', bbox_inches='tight')
176+
plt.show()
177+
178+
plt.figure(figsize=(16, 5))
179+
plt.subplot(121)
180+
plot_energy_surface(trajectories=x_t_det)
181+
182+
plt.subplot(122)
183+
plot_energy_surface(trajectories=x_t_stoch)
184+
plt.savefig(f'{savedir}/paths_det_vs_stoch.png', bbox_inches='tight')
185+
plt.show()
186+
187+
plot_energy_surface(trajectories=x_t_stoch)
188+
plt.savefig(f'{savedir}/mueller-variational-doobs.pdf', bbox_inches='tight')
189+
plt.show()

tps/first_order.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,26 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, key, fixed_le
101101
# pick an initial trajectory
102102
trajectories = [initial_trajectory]
103103

104+
statistics = {
105+
'num_force_evaluations': 0,
106+
'num_tries': 0,
107+
'num_metropolis_rejected': 0,
108+
'warmup': warmup,
109+
'num_paths': num_paths,
110+
}
111+
if fixed_length > 0:
112+
statistics['fixed_length'] = fixed_length
113+
else:
114+
statistics['max_steps'] = MAX_STEPS
115+
104116
with tqdm(total=num_paths) as pbar:
105117
while len(trajectories) <= num_paths + warmup:
118+
statistics['num_tries'] += 1
119+
106120
key, iter_key, accept_key = jax.random.split(key, 3)
107121
found, new_trajectory = proposal(system, trajectories[-1], fixed_length, iter_key)
122+
statistics['num_force_evaluations'] += len(new_trajectory) - 1
123+
108124
if not found:
109125
continue
110126

@@ -115,19 +131,30 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, key, fixed_le
115131

116132
if len(trajectories) > warmup:
117133
pbar.update(1)
134+
else:
135+
statistics['num_metropolis_rejected'] += 1
118136

119-
return trajectories[warmup + 1:]
137+
return trajectories[warmup + 1:], statistics
120138

121139

122140
def unguided_md(system, initial_point, num_paths, key, fixed_length=0):
123141
trajectories = []
124142
current_frame = initial_point.clone()
125143
current_trajectory = []
126144

145+
statistics = {
146+
'num_force_evaluations': 0,
147+
'num_paths': num_paths,
148+
'max_steps': MAX_STEPS,
149+
}
150+
if fixed_length > 0:
151+
statistics['fixed_length'] = fixed_length
152+
127153
with tqdm(total=num_paths) as pbar:
128154
while len(trajectories) < num_paths:
129155
key, iter_key = jax.random.split(key)
130156
next_frame = system.step(current_frame, iter_key)
157+
statistics['num_force_evaluations'] += 1
131158

132159
is_transition = not (system.start_state(next_frame) or system.target_state(next_frame))
133160
if is_transition:
@@ -153,4 +180,4 @@ def unguided_md(system, initial_point, num_paths, key, fixed_length=0):
153180

154181
current_frame = next_frame
155182

156-
return trajectories
183+
return trajectories, statistics

0 commit comments

Comments
 (0)