|
9 | 9 | from training.setups.drift import DriftedSetup
|
10 | 10 | from training.utils import forward_and_derivatives
|
11 | 11 |
|
| 12 | + |
| 13 | +def compute_spline_coefficients(x_knots, y_knots): |
| 14 | + n = len(x_knots) - 1 |
| 15 | + h = jnp.diff(x_knots) |
| 16 | + b = (jnp.diff(y_knots, axis=0).T / h).T |
| 17 | + v = jnp.zeros((n + 1,) + y_knots.shape[1:], dtype=jnp.float32) |
| 18 | + u = jnp.zeros((n + 1,), dtype=jnp.float32) |
| 19 | + |
| 20 | + u = u.at[1:n].set(2 * (h[:-1] + h[1:])) |
| 21 | + v = v.at[1:n].set(6 * (b[1:] - b[:-1])) |
| 22 | + |
| 23 | + u = u.at[0].set(1) |
| 24 | + u = u.at[n].set(1) |
| 25 | + |
| 26 | + for i in range(1, n): |
| 27 | + u = u.at[i].set(u[i] - (h[i - 1] ** 2) / u[i - 1]) |
| 28 | + v = v.at[i].set(v[i] - (h[i - 1] * v[i - 1]) / u[i - 1]) |
| 29 | + |
| 30 | + m = jnp.zeros_like(v) |
| 31 | + for i in range(n - 1, 0, -1): |
| 32 | + m = m.at[i].set((v[i] - h[i] * m[i + 1]) / u[i]) |
| 33 | + |
| 34 | + return m |
| 35 | + |
| 36 | + |
| 37 | +def evaluate_cubic_spline(x, x_knots, y_knots, m): |
| 38 | + i = jnp.searchsorted(x_knots, x) - 1 |
| 39 | + i = jnp.clip(i, 0, len(x_knots) - 2) # Ensure i is within bounds |
| 40 | + h = x_knots[i + 1] - x_knots[i] |
| 41 | + A = (x_knots[i + 1] - x) / h |
| 42 | + B = (x - x_knots[i]) / h |
| 43 | + C = (1 / 6) * (A ** 3 - A) * h ** 2 |
| 44 | + D = (1 / 6) * (B ** 3 - B) * h ** 2 |
| 45 | + y = A * y_knots[i] + B * y_knots[i + 1] + C * m[i] + D * m[i + 1] |
| 46 | + return y |
| 47 | + |
| 48 | + |
| 49 | +def compute_cubic_spline(t, x_knots, y_knots): |
| 50 | + m = compute_spline_coefficients(x_knots, y_knots) |
| 51 | + return evaluate_cubic_spline(t, x_knots, y_knots, m) |
| 52 | + |
| 53 | + |
| 54 | +vectorized_cubic_spline = jax.vmap(compute_cubic_spline, in_axes=(0, None, None)) |
| 55 | + |
12 | 56 | interp = jax.vmap(jnp.interp, in_axes=(None, None, 1))
|
13 | 57 |
|
14 | 58 |
|
15 | 59 | class LowRankSpline(nn.Module):
|
16 | 60 | n_points: int
|
| 61 | + interpolation: str |
17 | 62 | T: float
|
18 | 63 | transform: Optional[Callable[[Any], Any]]
|
19 | 64 | A: ArrayLike
|
@@ -43,8 +88,15 @@ def get_tril(v):
|
43 | 88 | a = a.at[jnp.tril_indices(ndim)].set(v)
|
44 | 89 | return a
|
45 | 90 |
|
46 |
| - mu = interp(t.flatten(), t_grid, y_grid).T |
47 |
| - S = interp(t.flatten(), t_grid, S_grid).T |
| 91 | + if self.interpolation == 'cubic': |
| 92 | + mu = vectorized_cubic_spline(t.flatten(), t_grid, y_grid) |
| 93 | + S = vectorized_cubic_spline(t.flatten(), t_grid, S_grid) |
| 94 | + elif self.interpolation == 'linear': |
| 95 | + mu = interp(t.flatten(), t_grid, y_grid).T |
| 96 | + S = interp(t.flatten(), t_grid, S_grid).T |
| 97 | + else: |
| 98 | + raise ValueError(f"Interpolation method {self.interpolation} not recognized.") |
| 99 | + |
48 | 100 | S = get_tril(S)
|
49 | 101 | S = jnp.tril(2 * jax.nn.sigmoid(S) - 1.0, k=-1) + jnp.eye(ndim, dtype=jnp.float32)[None, ...] * jnp.exp(S)
|
50 | 102 |
|
|
0 commit comments