Skip to content

Commit 3d96406

Browse files
committed
Run all baselines for mueller at once
1 parent 4d67021 commit 3d96406

File tree

1 file changed

+51
-51
lines changed

1 file changed

+51
-51
lines changed

tps_baseline_mueller.py

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -42,56 +42,56 @@ def interpolate_two_points(start, stop, steps):
4242

4343
if __name__ == '__main__':
4444
# variable or fixed length?
45-
variable = True
4645
num_paths = 1000
4746

48-
save_dir = f"out/baselines/mueller"
49-
if variable:
50-
save_dir += "-variable"
51-
52-
os.makedirs(save_dir, exist_ok=True)
53-
54-
xi = 5
55-
dt = 1e-4
56-
T = 275e-4
57-
N = 0 if variable else int(T / dt)
58-
59-
system = System.from_name('mueller_brown', float('inf'))
60-
initial_trajectory = [t.reshape(1, 2) for t in interpolate(jnp.array([system.A, system.B]), 100 if variable else N)]
61-
62-
@jax.jit
63-
def step(_x, _key):
64-
"""Perform one step of forward euler"""
65-
return _x - dt * system.dUdx(_x) + jnp.sqrt(dt) * xi * jax.random.normal(_key, _x.shape)
66-
67-
68-
tps_config = tps1.FirstOrderSystem(
69-
jax.jit(lambda s: jnp.linalg.norm(s - system.A) <= 0.1),
70-
jax.jit(lambda s: jnp.linalg.norm(s - system.B) <= 0.1),
71-
step
72-
)
73-
74-
for method, name in [
75-
(tps1.one_way_shooting, 'one-way-shooting'),
76-
(tps1.two_way_shooting, 'two-way-shooting'),
77-
]:
78-
if os.path.exists(f'{save_dir}/paths-{name}.npy') and os.path.exists(f'{save_dir}/stats-{name}.json'):
79-
print(f"Skipping {name} because the results are already present")
80-
81-
paths = np.load(f'{save_dir}/paths-{name}.npy', allow_pickle=True)
82-
paths = [jnp.array(p.astype(np.float32)) for p in paths]
83-
with open(f'{save_dir}/stats-{name}.json', 'r') as fp:
84-
statistics = json.load(fp)
85-
else:
86-
print('Generating paths for', name)
87-
paths, statistics = tps1.mcmc_shooting(tps_config, method, initial_trajectory, num_paths,
88-
jax.random.PRNGKey(1), warmup=0, fixed_length=N)
89-
90-
paths = [jnp.array(p) for p in paths]
91-
92-
np.save(f'{save_dir}/paths-{name}.npy', np.array(paths, dtype=object), allow_pickle=True)
93-
with open(f'{save_dir}/stats-{name}.json', 'w') as fp:
94-
json.dump(statistics, fp)
95-
96-
system.plot(trajectories=paths)
97-
show_or_save_fig(save_dir, f'mueller-{name}', 'pdf')
47+
for variable in [False, True]:
48+
save_dir = f"out/baselines/mueller"
49+
if variable:
50+
save_dir += "-variable"
51+
52+
os.makedirs(save_dir, exist_ok=True)
53+
54+
xi = 5
55+
dt = 1e-4
56+
T = 275e-4
57+
N = 0 if variable else int(T / dt)
58+
59+
system = System.from_name('mueller_brown', float('inf'))
60+
initial_trajectory = [t.reshape(1, 2) for t in interpolate(jnp.array([system.A, system.B]), 100 if variable else N)]
61+
62+
@jax.jit
63+
def step(_x, _key):
64+
"""Perform one step of forward euler"""
65+
return _x - dt * system.dUdx(_x) + jnp.sqrt(dt) * xi * jax.random.normal(_key, _x.shape)
66+
67+
68+
tps_config = tps1.FirstOrderSystem(
69+
jax.jit(lambda s: jnp.linalg.norm(s - system.A) <= 0.1),
70+
jax.jit(lambda s: jnp.linalg.norm(s - system.B) <= 0.1),
71+
step
72+
)
73+
74+
for method, name in [
75+
(tps1.one_way_shooting, 'one-way-shooting'),
76+
(tps1.two_way_shooting, 'two-way-shooting'),
77+
]:
78+
if os.path.exists(f'{save_dir}/paths-{name}.npy') and os.path.exists(f'{save_dir}/stats-{name}.json'):
79+
print(f"Skipping {name} because the results are already present")
80+
81+
paths = np.load(f'{save_dir}/paths-{name}.npy', allow_pickle=True)
82+
paths = [jnp.array(p.astype(np.float32)) for p in paths]
83+
with open(f'{save_dir}/stats-{name}.json', 'r') as fp:
84+
statistics = json.load(fp)
85+
else:
86+
print('Generating paths for', name)
87+
paths, statistics = tps1.mcmc_shooting(tps_config, method, initial_trajectory, num_paths,
88+
jax.random.PRNGKey(1), warmup=0, fixed_length=N)
89+
90+
paths = [jnp.array(p) for p in paths]
91+
92+
np.save(f'{save_dir}/paths-{name}.npy', np.array(paths, dtype=object), allow_pickle=True)
93+
with open(f'{save_dir}/stats-{name}.json', 'w') as fp:
94+
json.dump(statistics, fp)
95+
96+
system.plot(trajectories=paths)
97+
show_or_save_fig(save_dir, f'mueller-{name}', 'pdf')

0 commit comments

Comments
 (0)