|
| 1 | +import jax |
| 2 | +import jax.numpy as jnp |
| 3 | +from flax import linen as nn |
| 4 | + |
| 5 | +interp = jax.vmap(jnp.interp, in_axes=(None, None, 1)) |
| 6 | + |
| 7 | +class Interpolant(nn.Module): |
| 8 | + T: float |
| 9 | + A: float |
| 10 | + B: float |
| 11 | + ndim: int |
| 12 | + n_points: int = 100 |
| 13 | + @nn.compact |
| 14 | + def __call__(self, t): |
| 15 | + t = t/self.T |
| 16 | + ndim = self.ndim |
| 17 | + t_grid = jnp.linspace(0,1,self.n_points) |
| 18 | + S_0 = jnp.log(1e-2)*jnp.eye(ndim) |
| 19 | + S_0_vec = S_0[jnp.tril_indices(ndim)] |
| 20 | + mu_params = self.param('mu_params', lambda rng: jnp.linspace(A[0], B[0], self.n_points)[1:-1]) |
| 21 | + S_params = self.param('S_params', lambda rng: jnp.linspace(S_0_vec, S_0_vec, self.n_points)[1:-1]) |
| 22 | + y_grid = jnp.concatenate([self.A, mu_params, self.B]) |
| 23 | + S_grid = jnp.concatenate([S_0_vec[None,:], S_params, S_0_vec[None,:]]) |
| 24 | + |
| 25 | + @jax.vmap |
| 26 | + def get_tril(v): |
| 27 | + a = jnp.zeros((ndim,ndim)) |
| 28 | + a = a.at[jnp.tril_indices(ndim)].set(v) |
| 29 | + return a |
| 30 | + |
| 31 | + mu = interp(t.flatten(), t_grid, y_grid).T |
| 32 | + S = interp(t.flatten(), t_grid, S_grid).T |
| 33 | + S = get_tril(S) |
| 34 | + S = jnp.tril(2*jax.nn.sigmoid(S) - 1.0, k=-1) + jnp.eye(ndim)[None,...]*jnp.exp(S) |
| 35 | + return mu, S |
| 36 | + |
| 37 | +class MLPfull(nn.Module): |
| 38 | + T: float |
| 39 | + A: float |
| 40 | + B: float |
| 41 | + ndim: int |
| 42 | + xi_0: float = 1e-2 |
| 43 | + @nn.compact |
| 44 | + def __call__(self, t): |
| 45 | + t = t/self.T |
| 46 | + ndim = self.ndim |
| 47 | + h_mu = (1-t)*self.A + t*self.B |
| 48 | + S_0 = self.xi_0*jnp.eye(ndim) |
| 49 | + S_0 = S_0[None,...] |
| 50 | + h_S = (1-2*t*(1-t))[...,None]*S_0 |
| 51 | + h = jnp.hstack([h_mu, h_S.reshape(-1,ndim*ndim), t]) |
| 52 | + h = nn.Dense(256)(h) |
| 53 | + h = nn.swish(h) |
| 54 | + h = nn.Dense(256)(h) |
| 55 | + h = nn.swish(h) |
| 56 | + h = nn.Dense(256)(h) |
| 57 | + h = nn.swish(h) |
| 58 | + h = nn.Dense(ndim + ndim*(ndim+1)//2)(h) |
| 59 | + mu = h_mu + (1-t)*t*h[:,:ndim] |
| 60 | + |
| 61 | + @jax.vmap |
| 62 | + def get_tril(v): |
| 63 | + a = jnp.zeros((ndim,ndim)) |
| 64 | + a = a.at[jnp.tril_indices(ndim)].set(v) |
| 65 | + return a |
| 66 | + # S = h[:,ndim:].reshape(-1,ndim,ndim) |
| 67 | + S = get_tril(h[:,ndim:]) |
| 68 | + S = jnp.tril(2*jax.nn.sigmoid(S) - 1.0, k=-1) + jnp.eye(ndim)[None,...]*jnp.exp(S) |
| 69 | + S = h_S + 2*((1-t)*t)[...,None]*S |
| 70 | + return mu, S |
| 71 | + |
| 72 | +class MLPdiag(nn.Module): |
| 73 | + T: float |
| 74 | + A: float |
| 75 | + B: float |
| 76 | + ndim: int |
| 77 | + xi_0: float = 1e-4 |
| 78 | + @nn.compact |
| 79 | + def __call__(self, t): |
| 80 | + t = t/self.T |
| 81 | + ndim = self.ndim |
| 82 | + h_mu = (1-t)*self.A + t*self.B |
| 83 | + h = jnp.hstack([h_mu, t]) |
| 84 | + h = nn.Dense(256)(h) |
| 85 | + h = nn.swish(h) |
| 86 | + h = nn.Dense(256)(h) |
| 87 | + h = nn.swish(h) |
| 88 | + h = nn.Dense(256)(h) |
| 89 | + h = nn.swish(h) |
| 90 | + h = nn.Dense(2*ndim)(h) |
| 91 | + mu = h_mu + (1-t)*t*h[:,:ndim] |
| 92 | + sigma = (1-t)*self.xi_0 + t*self.xi_0 + (1-t)*t*jnp.exp(h[:,ndim:]) |
| 93 | + return mu, sigma |
0 commit comments