Skip to content

Commit 26b2989

Browse files
committed
polished notebook
1 parent 9444fd0 commit 26b2989

File tree

4 files changed

+369
-782
lines changed

4 files changed

+369
-782
lines changed

notebooks/models.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import jax
2+
import jax.numpy as jnp
3+
from flax import linen as nn
4+
5+
interp = jax.vmap(jnp.interp, in_axes=(None, None, 1))
6+
7+
class Interpolant(nn.Module):
8+
T: float
9+
A: float
10+
B: float
11+
ndim: int
12+
n_points: int = 100
13+
@nn.compact
14+
def __call__(self, t):
15+
t = t/self.T
16+
ndim = self.ndim
17+
t_grid = jnp.linspace(0,1,self.n_points)
18+
S_0 = jnp.log(1e-2)*jnp.eye(ndim)
19+
S_0_vec = S_0[jnp.tril_indices(ndim)]
20+
mu_params = self.param('mu_params', lambda rng: jnp.linspace(A[0], B[0], self.n_points)[1:-1])
21+
S_params = self.param('S_params', lambda rng: jnp.linspace(S_0_vec, S_0_vec, self.n_points)[1:-1])
22+
y_grid = jnp.concatenate([self.A, mu_params, self.B])
23+
S_grid = jnp.concatenate([S_0_vec[None,:], S_params, S_0_vec[None,:]])
24+
25+
@jax.vmap
26+
def get_tril(v):
27+
a = jnp.zeros((ndim,ndim))
28+
a = a.at[jnp.tril_indices(ndim)].set(v)
29+
return a
30+
31+
mu = interp(t.flatten(), t_grid, y_grid).T
32+
S = interp(t.flatten(), t_grid, S_grid).T
33+
S = get_tril(S)
34+
S = jnp.tril(2*jax.nn.sigmoid(S) - 1.0, k=-1) + jnp.eye(ndim)[None,...]*jnp.exp(S)
35+
return mu, S
36+
37+
class MLPfull(nn.Module):
38+
T: float
39+
A: float
40+
B: float
41+
ndim: int
42+
xi_0: float = 1e-2
43+
@nn.compact
44+
def __call__(self, t):
45+
t = t/self.T
46+
ndim = self.ndim
47+
h_mu = (1-t)*self.A + t*self.B
48+
S_0 = self.xi_0*jnp.eye(ndim)
49+
S_0 = S_0[None,...]
50+
h_S = (1-2*t*(1-t))[...,None]*S_0
51+
h = jnp.hstack([h_mu, h_S.reshape(-1,ndim*ndim), t])
52+
h = nn.Dense(256)(h)
53+
h = nn.swish(h)
54+
h = nn.Dense(256)(h)
55+
h = nn.swish(h)
56+
h = nn.Dense(256)(h)
57+
h = nn.swish(h)
58+
h = nn.Dense(ndim + ndim*(ndim+1)//2)(h)
59+
mu = h_mu + (1-t)*t*h[:,:ndim]
60+
61+
@jax.vmap
62+
def get_tril(v):
63+
a = jnp.zeros((ndim,ndim))
64+
a = a.at[jnp.tril_indices(ndim)].set(v)
65+
return a
66+
# S = h[:,ndim:].reshape(-1,ndim,ndim)
67+
S = get_tril(h[:,ndim:])
68+
S = jnp.tril(2*jax.nn.sigmoid(S) - 1.0, k=-1) + jnp.eye(ndim)[None,...]*jnp.exp(S)
69+
S = h_S + 2*((1-t)*t)[...,None]*S
70+
return mu, S
71+
72+
class MLPdiag(nn.Module):
73+
T: float
74+
A: float
75+
B: float
76+
ndim: int
77+
xi_0: float = 1e-4
78+
@nn.compact
79+
def __call__(self, t):
80+
t = t/self.T
81+
ndim = self.ndim
82+
h_mu = (1-t)*self.A + t*self.B
83+
h = jnp.hstack([h_mu, t])
84+
h = nn.Dense(256)(h)
85+
h = nn.swish(h)
86+
h = nn.Dense(256)(h)
87+
h = nn.swish(h)
88+
h = nn.Dense(256)(h)
89+
h = nn.swish(h)
90+
h = nn.Dense(2*ndim)(h)
91+
mu = h_mu + (1-t)*t*h[:,:ndim]
92+
sigma = (1-t)*self.xi_0 + t*self.xi_0 + (1-t)*t*jnp.exp(h[:,ndim:])
93+
return mu, sigma

notebooks/tps_gaussian.ipynb

Lines changed: 244 additions & 155 deletions
Large diffs are not rendered by default.

notebooks/tps_gaussian_2nd.ipynb

Lines changed: 0 additions & 627 deletions
This file was deleted.

notebooks/vf.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
def get_parameterization_fn(params, state):
5+
gauss_params = lambda _t: state.apply_fn(params, _t)
6+
def dgauss_paramsdt(_t):
7+
_gauss_params = lambda _t: jax.tree.map(lambda _a: _a.sum(0), gauss_params(_t))
8+
return jax.tree.map(lambda a: a.squeeze().T, jax.jacrev(_gauss_params)(_t))
9+
return gauss_params, dgauss_paramsdt
10+
11+
def v_t_diag(_eps, _t, params, state):
12+
gauss_params, dgauss_paramsdt = get_parameterization_fn(params, state)
13+
_eps = _eps.squeeze()
14+
mu_t_val, s_val = gauss_params(_t)
15+
dmudt_val, dsdt_val = dgauss_paramsdt(_t)
16+
_x = mu_t_val + jnp.sqrt(s_val)*_eps
17+
dlogdx = -_eps/jnp.sqrt(s_val)
18+
u_t = dmudt_val - 0.5*dlogdx*dsdt_val
19+
out = (u_t - drift(_x)) + 0.5*(xi**2)*dlogdx.squeeze()
20+
return out
21+
22+
def v_t_full(_eps, _t, params, state):
23+
gauss_params, dgauss_paramsdt = get_parameterization_fn(params, state)
24+
mu_t_val, S_t_val = gauss_params(_t)
25+
dmudt_val, dSdt_val = dgauss_paramsdt(_t)
26+
_x = mu_t_val + jax.lax.batch_matmul(S_t_val, _eps).squeeze()
27+
dlogdx = -jax.scipy.linalg.solve_triangular(jnp.transpose(S_t_val, (0,2,1)), _eps)
28+
dSigmadt = jax.lax.batch_matmul(dSdt_val, jnp.transpose(S_t_val, (0,2,1)))
29+
dSigmadt += jax.lax.batch_matmul(S_t_val, jnp.transpose(dSdt_val, (0,2,1)))
30+
u_t = dmudt_val - 0.5*jax.lax.batch_matmul(dSigmadt, dlogdx).squeeze()
31+
out = (u_t - drift(_x)) + 0.5*(xi**2)*dlogdx.squeeze()
32+
return out

0 commit comments

Comments
 (0)