|
| 1 | +from tps_baseline_mueller import U, A, B, plot_energy_surface |
| 2 | +from flax import linen as nn |
| 3 | +from flax.training import train_state |
| 4 | +import optax |
| 5 | +import jax |
| 6 | +import jax.numpy as jnp |
| 7 | +from tqdm import trange |
| 8 | +import matplotlib.pyplot as plt |
| 9 | +import os |
| 10 | +import numpy as np |
| 11 | + |
| 12 | + |
| 13 | +class MLPq(nn.Module): |
| 14 | + @nn.compact |
| 15 | + def __call__(self, t): |
| 16 | + t = t / T |
| 17 | + h = nn.Dense(128)(t - 0.5) |
| 18 | + h = nn.swish(h) |
| 19 | + h = nn.Dense(128)(h) |
| 20 | + h = nn.swish(h) |
| 21 | + h = nn.Dense(128)(h) |
| 22 | + h = nn.swish(h) |
| 23 | + h = nn.Dense(4)(h) |
| 24 | + mu = (1 - t) * A + t * B + (1 - t) * t * h[:, :2] |
| 25 | + sigma = (1 - t) * 2.5 * 1e-2 + t * 2.5 * 1e-2 + (1 - t) * t * jnp.exp(h[:, 2:]) |
| 26 | + return mu, sigma |
| 27 | + |
| 28 | + |
| 29 | +if __name__ == '__main__': |
| 30 | + savedir = f"out/var_doobs/mueller" |
| 31 | + os.makedirs(savedir, exist_ok=True) |
| 32 | + |
| 33 | + num_paths = 1000 |
| 34 | + xi = 5 |
| 35 | + dt = 1e-4 |
| 36 | + T = 275e-4 |
| 37 | + N = int(T / dt) |
| 38 | + epochs = 2_500 |
| 39 | + |
| 40 | + q = MLPq() |
| 41 | + |
| 42 | + BS = 512 |
| 43 | + key = jax.random.PRNGKey(1) |
| 44 | + key, *init_key = jax.random.split(key, 3) |
| 45 | + params_q = q.init(init_key[0], jnp.ones([BS, 1])) |
| 46 | + |
| 47 | + optimizer_q = optax.adam(learning_rate=1e-4) |
| 48 | + state_q = train_state.TrainState.create(apply_fn=q.apply, |
| 49 | + params=params_q, |
| 50 | + tx=optimizer_q) |
| 51 | + |
| 52 | + |
| 53 | + def loss_fn(params_q, key): |
| 54 | + key = jax.random.split(key) |
| 55 | + t = T * jax.random.uniform(key[0], [BS, 1]) |
| 56 | + eps = jax.random.normal(key[1], [BS, 2]) |
| 57 | + |
| 58 | + mu_t = lambda _t: state_q.apply_fn(params_q, _t)[0] |
| 59 | + sigma_t = lambda _t: state_q.apply_fn(params_q, _t)[1] |
| 60 | + |
| 61 | + def dmudt(_t): |
| 62 | + _dmudt = jax.jacrev(lambda _t: mu_t(_t).sum(0)) |
| 63 | + return _dmudt(_t).squeeze().T |
| 64 | + |
| 65 | + def dsigmadt(_t): |
| 66 | + _dsigmadt = jax.jacrev(lambda _t: sigma_t(_t).sum(0)) |
| 67 | + return _dsigmadt(_t).squeeze().T |
| 68 | + |
| 69 | + dUdx_fn = jax.grad(lambda _x: U(_x).sum()) |
| 70 | + |
| 71 | + def v_t(_eps, _t): |
| 72 | + u_t = dmudt(_t) + dsigmadt(_t) * _eps |
| 73 | + _x = mu_t(_t) + sigma_t(_t) * _eps |
| 74 | + out = (u_t + dUdx_fn(_x)) - 0.5 * (xi ** 2) * _eps / sigma_t(t) |
| 75 | + return out |
| 76 | + |
| 77 | + loss = 0.5 * ((v_t(eps, t) / xi) ** 2).sum(1, keepdims=True) |
| 78 | + print(loss.shape, 'loss.shape', flush=True) |
| 79 | + return loss.mean() |
| 80 | + |
| 81 | + |
| 82 | + @jax.jit |
| 83 | + def train_step(state_q, key): |
| 84 | + grad_fn = jax.value_and_grad(loss_fn, argnums=0) |
| 85 | + loss, grads = grad_fn(state_q.params, key) |
| 86 | + state_q = state_q.apply_gradients(grads=grads) |
| 87 | + return state_q, loss |
| 88 | + |
| 89 | + |
| 90 | + key, loc_key = jax.random.split(key) |
| 91 | + state_q, loss = train_step(state_q, loc_key) |
| 92 | + |
| 93 | + loss_plot = [] |
| 94 | + for i in trange(epochs): |
| 95 | + key, loc_key = jax.random.split(key) |
| 96 | + state_q, loss = train_step(state_q, loc_key) |
| 97 | + loss_plot.append(loss) |
| 98 | + |
| 99 | + plt.plot(loss_plot) |
| 100 | + plt.show() |
| 101 | + |
| 102 | + t = T * jnp.linspace(0, 1, BS).reshape((-1, 1)) |
| 103 | + key, path_key = jax.random.split(key) |
| 104 | + eps = jax.random.normal(path_key, [BS, 2]) |
| 105 | + mu_t, sigma_t = state_q.apply_fn(state_q.params, t) |
| 106 | + samples = mu_t + sigma_t * eps |
| 107 | + plot_energy_surface() |
| 108 | + plt.scatter(samples[:, 0], samples[:, 1]) |
| 109 | + plt.scatter(A[0, 0], A[0, 1], color='red') |
| 110 | + plt.scatter(B[0, 0], B[0, 1], color='orange') |
| 111 | + plt.show() |
| 112 | + |
| 113 | + print("Number of potential evaluations", BS * epochs) |
| 114 | + |
| 115 | + mu_t = lambda _t: state_q.apply_fn(state_q.params, _t)[0] |
| 116 | + sigma_t = lambda _t: state_q.apply_fn(state_q.params, _t)[1] |
| 117 | + |
| 118 | + |
| 119 | + def dmudt(_t): |
| 120 | + _dmudt = jax.jacrev(lambda _t: mu_t(_t).sum(0), argnums=0) |
| 121 | + return _dmudt(_t).squeeze().T |
| 122 | + |
| 123 | + |
| 124 | + def dsigmadt(_t): |
| 125 | + _dsigmadt = jax.jacrev(lambda _t: sigma_t(_t).sum(0)) |
| 126 | + return _dsigmadt(_t).squeeze().T |
| 127 | + |
| 128 | + |
| 129 | + u_t = jax.jit(lambda _t, _x: dmudt(_t) + dsigmadt(_t) / sigma_t(_t) * (_x - mu_t(_t))) |
| 130 | + |
| 131 | + key, loc_key = jax.random.split(key) |
| 132 | + x_t = jnp.ones((BS, N + 1, 2)) * A |
| 133 | + eps = jax.random.normal(key, shape=(BS, 2)) |
| 134 | + x_t = x_t.at[:, 0, :].set(x_t[:, 0, :] + sigma_t(jnp.zeros((BS, 1))) * eps) |
| 135 | + t = jnp.zeros((BS, 1)) |
| 136 | + for i in trange(N): |
| 137 | + dx = dt * u_t(t, x_t[:, i, :]) |
| 138 | + x_t = x_t.at[:, i + 1, :].set(x_t[:, i, :] + dx) |
| 139 | + t += dt |
| 140 | + |
| 141 | + x_t_det = x_t.copy() |
| 142 | + |
| 143 | + u_t = jax.jit( |
| 144 | + lambda _t, _x: dmudt(_t) + (dsigmadt(_t) / sigma_t(_t) - 0.5 * (xi / sigma_t(_t)) ** 2) * (_x - mu_t(_t))) |
| 145 | + |
| 146 | + BS = num_paths |
| 147 | + key, loc_key = jax.random.split(key) |
| 148 | + x_t = jnp.ones((BS, N + 1, 2)) * A |
| 149 | + eps = jax.random.normal(key, shape=(BS, 2)) |
| 150 | + x_t = x_t.at[:, 0, :].set(x_t[:, 0, :] + sigma_t(jnp.zeros((BS, 1))) * eps) |
| 151 | + t = jnp.zeros((BS, 1)) |
| 152 | + for i in trange(N): |
| 153 | + key, loc_key = jax.random.split(key) |
| 154 | + eps = jax.random.normal(key, shape=(BS, 2)) |
| 155 | + dx = dt * u_t(t, x_t[:, i, :]) + jnp.sqrt(dt) * xi * eps |
| 156 | + x_t = x_t.at[:, i + 1, :].set(x_t[:, i, :] + dx) |
| 157 | + t += dt |
| 158 | + |
| 159 | + x_t_stoch = x_t.copy() |
| 160 | + |
| 161 | + np.save(f'{savedir}/paths.npy', np.array([jnp.array(p) for p in x_t_stoch], dtype=object), allow_pickle=True) |
| 162 | + |
| 163 | + plt.figure(figsize=(16, 5)) |
| 164 | + plt.subplot(121) |
| 165 | + plot_energy_surface() |
| 166 | + plt.plot(x_t_det[:10, :, 0].T, x_t_det[:10, :, 1].T) |
| 167 | + plt.scatter(A[0, 0], A[0, 1], color='red') |
| 168 | + plt.scatter(B[0, 0], B[0, 1], color='orange') |
| 169 | + |
| 170 | + plt.subplot(122) |
| 171 | + plot_energy_surface() |
| 172 | + plt.plot(x_t_stoch[:10, :, 0].T, x_t_stoch[:10, :, 1].T) |
| 173 | + plt.scatter(A[0, 0], A[0, 1], color='red') |
| 174 | + plt.scatter(B[0, 0], B[0, 1], color='orange') |
| 175 | + plt.savefig(f'{savedir}/selected_paths_det_vs_stoch.png', bbox_inches='tight') |
| 176 | + plt.show() |
| 177 | + |
| 178 | + plt.figure(figsize=(16, 5)) |
| 179 | + plt.subplot(121) |
| 180 | + plot_energy_surface(trajectories=x_t_det) |
| 181 | + |
| 182 | + plt.subplot(122) |
| 183 | + plot_energy_surface(trajectories=x_t_stoch) |
| 184 | + plt.savefig(f'{savedir}/paths_det_vs_stoch.png', bbox_inches='tight') |
| 185 | + plt.show() |
| 186 | + |
| 187 | + plot_energy_surface(trajectories=x_t_stoch) |
| 188 | + plt.savefig(f'{savedir}/mueller-variational-doobs.pdf', bbox_inches='tight') |
| 189 | + plt.show() |
0 commit comments