Skip to content

Commit 8785e85

Browse files
committed
Add cubic splines
1 parent d8a70f9 commit 8785e85

File tree

3 files changed

+57
-3
lines changed

3 files changed

+57
-3
lines changed

main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@
4949
parser.add_argument('--model', type=str, choices=['mlp', 'spline'], default='mlp',
5050
help="The model that will be used. Note that spline will not work with all configurations.")
5151

52+
# Spline arguments
5253
parser.add_argument('--num_points', type=int, default=100, help="Number of points in the spline model.")
54+
parser.add_argument('--spline_mode', type=str, choices=['linear', 'cubic'], default='linear')
5355

5456
# MLP arguments
5557
parser.add_argument('--hidden_layers', nargs='+', type=int, help='The dimensions of the hidden layer of the MLP.',

training/qsetup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def construct(system: System, model: Optional[nn.module], xi: float, A: ArrayLik
9292
elif args.parameterization == 'low_rank':
9393
if args.model == 'spline':
9494
model = lowrank.LowRankSpline(
95-
args.num_points, args.T, transform, A, B, args.num_gaussians, args.trainable_weights, args.base_sigma
95+
args.num_points, args.spline_mode, args.T, transform, A, B, args.num_gaussians, args.trainable_weights, args.base_sigma
9696
)
9797
else:
9898
model = lowrank.LowRankWrapper(

training/setups/lowrank.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,56 @@
99
from training.setups.drift import DriftedSetup
1010
from training.utils import forward_and_derivatives
1111

12+
13+
def compute_spline_coefficients(x_knots, y_knots):
14+
n = len(x_knots) - 1
15+
h = jnp.diff(x_knots)
16+
b = (jnp.diff(y_knots, axis=0).T / h).T
17+
v = jnp.zeros((n + 1,) + y_knots.shape[1:], dtype=jnp.float32)
18+
u = jnp.zeros((n + 1,), dtype=jnp.float32)
19+
20+
u = u.at[1:n].set(2 * (h[:-1] + h[1:]))
21+
v = v.at[1:n].set(6 * (b[1:] - b[:-1]))
22+
23+
u = u.at[0].set(1)
24+
u = u.at[n].set(1)
25+
26+
for i in range(1, n):
27+
u = u.at[i].set(u[i] - (h[i - 1] ** 2) / u[i - 1])
28+
v = v.at[i].set(v[i] - (h[i - 1] * v[i - 1]) / u[i - 1])
29+
30+
m = jnp.zeros_like(v)
31+
for i in range(n - 1, 0, -1):
32+
m = m.at[i].set((v[i] - h[i] * m[i + 1]) / u[i])
33+
34+
return m
35+
36+
37+
def evaluate_cubic_spline(x, x_knots, y_knots, m):
38+
i = jnp.searchsorted(x_knots, x) - 1
39+
i = jnp.clip(i, 0, len(x_knots) - 2) # Ensure i is within bounds
40+
h = x_knots[i + 1] - x_knots[i]
41+
A = (x_knots[i + 1] - x) / h
42+
B = (x - x_knots[i]) / h
43+
C = (1 / 6) * (A ** 3 - A) * h ** 2
44+
D = (1 / 6) * (B ** 3 - B) * h ** 2
45+
y = A * y_knots[i] + B * y_knots[i + 1] + C * m[i] + D * m[i + 1]
46+
return y
47+
48+
49+
def compute_cubic_spline(t, x_knots, y_knots):
50+
m = compute_spline_coefficients(x_knots, y_knots)
51+
return evaluate_cubic_spline(t, x_knots, y_knots, m)
52+
53+
54+
vectorized_cubic_spline = jax.vmap(compute_cubic_spline, in_axes=(0, None, None))
55+
1256
interp = jax.vmap(jnp.interp, in_axes=(None, None, 1))
1357

1458

1559
class LowRankSpline(nn.Module):
1660
n_points: int
61+
interpolation: str
1762
T: float
1863
transform: Optional[Callable[[Any], Any]]
1964
A: ArrayLike
@@ -43,8 +88,15 @@ def get_tril(v):
4388
a = a.at[jnp.tril_indices(ndim)].set(v)
4489
return a
4590

46-
mu = interp(t.flatten(), t_grid, y_grid).T
47-
S = interp(t.flatten(), t_grid, S_grid).T
91+
if self.interpolation == 'cubic':
92+
mu = vectorized_cubic_spline(t.flatten(), t_grid, y_grid)
93+
S = vectorized_cubic_spline(t.flatten(), t_grid, S_grid)
94+
elif self.interpolation == 'linear':
95+
mu = interp(t.flatten(), t_grid, y_grid).T
96+
S = interp(t.flatten(), t_grid, S_grid).T
97+
else:
98+
raise ValueError(f"Interpolation method {self.interpolation} not recognized.")
99+
48100
S = get_tril(S)
49101
S = jnp.tril(2 * jax.nn.sigmoid(S) - 1.0, k=-1) + jnp.eye(ndim, dtype=jnp.float32)[None, ...] * jnp.exp(S)
50102

0 commit comments

Comments
 (0)