A problem I have run into frequently in astronomical data analysis is the need to infer parameters of a density model where some aspects of the model are allowed to be flexible and other components are held more rigid. In these contexts, we are also sometimes interested in learning a flexible representation for the density of sources itself. This post demonstrates how to implement models with flexibility controlled by spline interpolation of function values in [JAX](https://jax.readthedocs.io/en/latest/).

One example of the need for flexibility in density modeling is the classic Galactic astronomy problem of measuring the vertical stellar density profile (and midplane density) of the Galactic disk: In this problem, we start with observations of stellar positions ($x, y, z$) (probably observed under some selection function) and we want to infer the midplane density value and a model for the density profile away from the midplane. Historically, simple, parametric density profiles have been used (e.g., [Bovy et al. 2017](https://ui.adsabs.harvard.edu/abs/2017MNRAS.470.1360B/abstract)), but we now know that there are significant asymmetries in the density of stars (e.g., [Bennett et al. 2019](https://ui.adsabs.harvard.edu/abs/2019MNRAS.482.1417B/abstract)), and so we might now want to fit a parametric density profile plus a model component to handle this asymmetry.

Another problem where the need to fit models with parametric and flexible components arises is in modeling the phase-space density of stellar streams (e.g., [Koposov et al. 2019](https://ui.adsabs.harvard.edu/abs/2019MNRAS.485.4726K/abstract), [Tavangar et al. 2022](https://ui.adsabs.harvard.edu/abs/2022ApJ...925..118T/abstract)). In the case of stellar streams, we generally want to simultaneously fit the "track" or ridgeline of the stream in position and velocity components, the width of the stream, the density along the stream, and a flexible model for the background stellar density in these components.

There are many possible options for adding flexibility to models (see: Machine Learning). One particularly useful tool that is used heavily in time series analysis are [Gaussian processes](https://en.wikipedia.org/wiki/Gaussian_process) (GPs). GPs allow adding controlled flexibility in probabilistic models (i.e. weakly parametric, through specification of a kernel function) and have gained popularity in astronomy recently thanks to advances in computational efficiency in computing GP likelihoods (e.g., [celerite](https://github.com/exoplanet-dev/celerite2) or [tinygp](https://github.com/dfm/tinygp)). I won't go over GPs in this post, but there are many resources available online and on GitHub that give great introductions to GPs (e.g., [Dan Foreman-Mackey's slides](https://speakerdeck.com/dfm/an-astronomers-introduction-to-gaussian-processes-v2) or [Rodrigo Luger's tutorial](https://github.com/LSSTC-DSFP/LSSTC-DSFP-Sessions/blob/main/Sessions/Session13/Day2/answers/01-Introduction-to-GPs.ipynb)).

In this post, we will use another frequently-used tool for specifying flexible models: [cubic splines](https://en.wikipedia.org/wiki/Spline_(mathematics)).

In [None]:
# Some global imports we will need throughout this post:
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline
import numpy as np

# Spline models

A spline function is fully determined by the degree of the polynomial used, the location of $M$ "knots" $x_m$, and the function value at the knots $f_m$. A common choice for the polynomial degree is 3, or cubic splines. 

In [None]:
from scipy.interpolate import InterpolatedUnivariateSpline

In [None]:
rng = np.random.default_rng(seed=42)

M = 8
x_m = np.linspace(0, 10, M)
f_m = rng.uniform(-1, 1, M)
spl = InterpolatedUnivariateSpline(x_m, f_m, k=3)  # k = the polynomial degree

In [None]:
plt.scatter(x_m, f_m)

grid = np.linspace(-1, 11, 1024)
plt.plot(grid, spl(grid), marker="", linestyle="-", color="tab:blue", zorder=-10)

plt.annotate(
    "knots",
    xy=(x_m[0], f_m[0]),
    xytext=(2, 2),
    arrowprops=dict(color="#666", shrinkB=4, arrowstyle="->"),
    ha="center",
)
for m in range(1, 3):
    plt.annotate(
        "     ",
        xy=(x_m[m], f_m[m]),
        xytext=(2, 2),
        arrowprops=dict(color="#666", shrinkB=4, arrowstyle="->"),
        ha="center",
    )

plt.xlabel("$x$")
plt.ylabel("$f$")

The task of finding a spline representation of a function given samples or points is sometimes called "spline regression." The problem is straightforward if we pick and fix locations for the knots of the spline function we want to fit and then add into our model the values of the function at the locations of the knots. This type of model has the advantage that the (spatial) scale of flexibility or "degrees of freedom" is controllable by setting the number of knots. However, unlike in GPs where kernel functions can be used to parametrize the amplitude or spatial scales of your problem, these things are not explicitly controlled in a spline model. One other disadvantage of a spline model is that the number of parameters in your model grows as you increase the number of knots (i.e. the degrees of freedom) of the model -- this can make spline models intractable in some simple optimization routines (e.g., using `scipy.minimize` without gradient information) or in some Markov Chain Monte Carlo (MCMC) methods that do not use gradient information (e.g., Metropolis-Hastings or [`emcee`](https://emcee.readthedocs.io/en/stable/)).

Fortunately, it is possible to use spline models with [JAX](https://jax.readthedocs.io/), which automatically gives us access to functional gradients and therefore opens up the possibility of using optimization and sampling methods that perform well with large numbers of parameters. Below are two examples that demonstrate how to implement spline components in density models using JAX, to optimize the parameters of the models with [`jaxopt`](https://jaxopt.github.io/), and to generate posterior samples using Hamiltonian Monte Carlo with [`blackjax`](https://blackjax-devs.github.io/blackjax/).

# Example: The (vertical) density profile in the Solar Neighborhood


In [None]:
from scipy.stats import binned_statistic

import jax
import jax.numpy as jnp
from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline
from jaxopt import ScipyMinimize
import blackjax

Make some fake data:

In [None]:
N = 1024
window = (-20, 20)
rng = np.random.default_rng(42)
# phi1_data = rng.uniform(0, 10, size=1024)
phi1_data = rng.normal(5, 4, size=1024)
assert np.all((phi1_data < window[1]) & (phi1_data > window[0]))

phi2_data = rng.normal(1.5 * np.cos(2 * np.pi * phi1_data / 5), 0.8)
plt.figure(figsize=(10, 3))
plt.scatter(phi1_data, phi2_data)

In [None]:
@jax.jit
def jln_normal(x, mu, var):
    return -0.5 * (jnp.log(2 * np.pi * var) + (x - mu) ** 2 / var)


@jax.jit
def phi2_ln_likelihood(
    phi2_mean_knots,
    phi2_mean_vals,
    phi2_std_knots,
    phi2_std_vals,
    phi1_eval,
    phi2_eval,
):

    phi2_mean = InterpolatedUnivariateSpline(phi2_mean_knots, phi2_mean_vals, k=3)

    phi2_std = InterpolatedUnivariateSpline(phi2_std_knots, phi2_std_vals, k=3)

    phi2_mean_model = phi2_mean(phi1_eval)
    phi2_var_model = phi2_std(phi1_eval) ** 2

    # phi2_mean_model = jnp.interp(phi1_eval, phi2_mean_knots, phi2_mean_vals)
    # phi2_var_model = jnp.interp(phi1_eval, phi2_std_knots, phi2_std_vals) ** 2

    return jln_normal(phi2_eval, phi2_mean_model, phi2_var_model)

In [None]:
phi2_knots = np.linspace(phi1_data.min(), phi1_data.max(), 25)

dp2 = phi2_knots[1] - phi2_knots[0]
_bins = np.linspace(
    phi2_knots[0] - dp2 / 2, phi2_knots[-1] + dp2 / 2, len(phi2_knots) + 1
)
stat = binned_statistic(phi1_data, phi2_data, bins=_bins, statistic=np.nanmean)
phi2_vals = stat.statistic
phi2_vals[np.isnan(phi2_vals)] = 0.0

In [None]:
plt.figure(figsize=(10, 3))
plt.scatter(phi1_data, phi2_data)
plt.scatter(phi2_knots, phi2_vals, color="tab:red")

In [None]:
phi2_ln_likelihood(
    phi2_knots,
    phi2_vals,
    phi2_knots,
    np.full_like(phi2_knots, 0.4),
    phi1_data,
    phi2_data,
)

In [None]:
xx, yy = np.meshgrid(
    np.linspace(phi1_data.min(), phi1_data.max(), 256), np.linspace(-4, 4, 128)
)
zz = phi2_ln_likelihood(
    phi2_knots, phi2_vals, phi2_knots, np.full_like(phi2_knots, 0.4), xx, yy
)
plt.figure(figsize=(10, 3))
plt.pcolormesh(xx, yy, np.exp(zz), shading="auto")
plt.scatter(phi1_data, phi2_data, color="tab:blue")

In [None]:
@jax.jit
def ln_prob(pars, data, phi2_knots):
    n_phi2 = len(phi2_knots)
    phi2_means = pars[:n_phi2]
    ln_phi2_stds = pars[n_phi2 : 2 * n_phi2]
    phi2_stds = jnp.exp(ln_phi2_stds)

    ll = phi2_ln_likelihood(
        phi2_knots, phi2_means, phi2_knots, phi2_stds, data["phi1"], data["phi2"]
    ).sum()

    lp = jln_normal(phi2_means, 0, 2.0).sum()
    lp += jln_normal(ln_phi2_stds, -1, 1.0).sum()

    return ll + lp


@jax.jit
def objective(pars, data, phi2_knots):
    return -ln_prob(pars, data, phi2_knots)

In [None]:
init_params = np.concatenate((phi2_vals, np.log(np.full_like(phi2_knots, 0.4))))
data = {"phi1": phi1_data, "phi2": phi2_data}

In [None]:
solver = ScipyMinimize(method="l-bfgs-b", fun=objective)
res = solver.run(init_params, data=data, phi2_knots=phi2_knots)
res.state.iter_num

In [None]:
res.params

In [None]:
plt.figure(figsize=(10, 3))
plt.plot(phi2_knots, res.params[: len(phi2_knots)])
plt.plot(phi2_knots, res.params[len(phi2_knots) :], color="tab:red")
plt.scatter(phi1_data, phi2_data, color="tab:blue", s=2)

In [None]:
xx, yy = np.meshgrid(
    np.linspace(phi1_data.min(), phi1_data.max(), 256), np.linspace(-4, 4, 128)
)
zz = np.exp(
    phi2_ln_likelihood(
        phi2_knots,
        res.params[: len(phi2_knots)],
        phi2_knots,
        np.exp(res.params[len(phi2_knots) :]),
        xx,
        yy,
    )
)
plt.figure(figsize=(10, 3))
plt.pcolormesh(
    xx,
    yy,
    zz,
    shading="auto",
    vmin=np.median(zz[(xx > 0) & (xx < 5)]),
    vmax=np.max(zz[(xx > 0) & (xx < 5)]),
)
plt.scatter(phi1_data, phi2_data, color="tab:blue", s=2)

Copied from the `blackjax` getting started:

In [None]:
def inference_loop(rng_key, kernel, initial_state, num_samples):
    @jax.jit
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

In [None]:
rng_key = jax.random.PRNGKey(42)

In [None]:
fn = jax.tree_util.Partial(ln_prob, data=data, phi2_knots=phi2_knots)
warmup = blackjax.window_adaptation(
    blackjax.nuts,
    fn,
    1000,
)

state, kernel, _ = warmup.run(
    rng_key,
    res.params,
)

In [None]:
states = inference_loop(rng_key, kernel, state, 1_000)

In [None]:
states.position.shape

In [None]:
plt.figure(figsize=(10, 3))
for i in np.random.choice(states.position.shape[0], size=32):
    params = states.position[i]
    plt.plot(phi2_knots, params[: len(phi2_knots)])

In [None]:
plt.figure(figsize=(10, 3))
for i in np.random.choice(states.position.shape[0], size=32):
    params = states.position[i]
    plt.plot(phi2_knots, params[len(phi2_knots) :])

---

### Infer log-density as well:

Inhomogeneous poisson point process: http://people.ee.duke.edu/~lcarin/PoissonProcess.pdf

In [None]:
def ln_simpson(ln_y, x, dtype=None):
    """Evaluates definite integral using Simpson's 1/3 rule"""

    dx = jnp.diff(x)[0]
    num_points = len(x)
    if num_points // 2 == num_points / 2:
        raise ValueError("oopsies")

    weights_first = jnp.asarray([1.0], dtype=dtype)
    weights_mid = jnp.tile(
        jnp.asarray([4.0, 2.0], dtype=dtype), [(num_points - 3) // 2]
    )
    weights_last = jnp.asarray([4.0, 1.0], dtype=dtype)
    weights = jnp.concatenate([weights_first, weights_mid, weights_last], axis=0)

    return jax.scipy.special.logsumexp(ln_y + jnp.log(weights), axis=-1) + jnp.log(
        dx / 3
    )

In [None]:
@jax.jit
def phi1_ln_likelihood(phi1_knots, ln_phi1_rate, phi1_eval):
    ln_phi1_rate_spl = InterpolatedUnivariateSpline(phi1_knots, ln_phi1_rate, k=3)

    # V = phi1_rate_spl.integral(*window)
    _grid = jnp.linspace(*window, 1025)
    lnV = ln_simpson(ln_phi1_rate_spl(_grid), _grid)

    return -jnp.exp(lnV) / len(phi1_eval) + ln_phi1_rate_spl(phi1_eval)

In [None]:
@jax.jit
def ln_prob2(pars, data, phi2_knots):
    n_phi2 = len(phi2_knots)
    phi2_means = pars[:n_phi2]
    ln_phi2_stds = pars[n_phi2 : 2 * n_phi2]
    phi2_stds = jnp.exp(ln_phi2_stds)

    phi1_rate = pars[2 * n_phi2 : 3 * n_phi2]

    ll = phi2_ln_likelihood(
        phi2_knots, phi2_means, phi2_knots, phi2_stds, data["phi1"], data["phi2"]
    ).sum()

    ll2 = phi1_ln_likelihood(
        phi2_knots,
        phi1_rate,
        data["phi1"],
    ).sum()

    lp = jln_normal(phi2_means, 0, 2.0).sum()
    lp += jln_normal(ln_phi2_stds, -1, 1.0).sum()

    return ll + ll2 + lp


@jax.jit
def objective2(pars, data, phi2_knots):
    return -ln_prob2(pars, data, phi2_knots)

In [None]:
H, xe = np.histogram(data["phi1"], bins=np.linspace(*window, 32))
xc = 0.5 * (xe[:-1] + xe[1:])
H = np.log((H + 1e-4) / (window[1] - window[0]))
init_rate = InterpolatedUnivariateSpline(xc, H, k=3)(phi2_knots)

In [None]:
init_params = np.concatenate(
    (phi2_vals, np.log(np.full_like(phi2_knots, 0.4)), init_rate)
)
data = {"phi1": phi1_data, "phi2": phi2_data}

In [None]:
solver = ScipyMinimize(method="l-bfgs-b", fun=objective2)
res = solver.run(init_params, data=data, phi2_knots=phi2_knots)
res.state.iter_num, res.params

In [None]:
plt.figure(figsize=(10, 3))
plt.plot(phi2_knots, res.params[: len(phi2_knots)])
plt.plot(phi2_knots, res.params[len(phi2_knots) : 2 * len(phi2_knots)], color="tab:red")
plt.plot(phi2_knots, res.params[2 * len(phi2_knots) :], color="tab:green")
plt.scatter(phi1_data, phi2_data, color="tab:blue", s=2)

In [None]:
xx, yy = np.meshgrid(
    np.linspace(phi1_data.min(), phi1_data.max(), 256), np.linspace(-4, 4, 128)
)
zz = np.exp(
    phi2_ln_likelihood(
        phi2_knots,
        res.params[: len(phi2_knots)],
        phi2_knots,
        np.exp(res.params[len(phi2_knots) : 2 * len(phi2_knots)]),
        xx,
        yy,
    )
    + phi1_ln_likelihood(
        phi2_knots, res.params[2 * len(phi2_knots) : 3 * len(phi2_knots)], xx
    )
)

plt.figure(figsize=(10, 3))
plt.pcolormesh(
    xx,
    yy,
    zz,
    shading="auto",
    vmin=np.median(zz[(xx > 0) & (xx < 5)]),
    vmax=np.max(zz[(xx > 0) & (xx < 5)]),
)
# plt.scatter(phi1_data, phi2_data, color='tab:blue', s=2)

In [None]:
rng_key = jax.random.PRNGKey(42)

In [None]:
fn = jax.tree_util.Partial(ln_prob2, data=data, phi2_knots=phi2_knots)
warmup = blackjax.window_adaptation(
    blackjax.nuts,
    fn,
    1000,
)

state, kernel, _ = warmup.run(
    rng_key,
    res.params,
)

In [None]:
states = inference_loop(rng_key, kernel, state, 1_000)

In [None]:
plt.figure(figsize=(10, 3))
for i in np.random.choice(states.position.shape[0], size=32):
    params = states.position[i]
    plt.plot(phi2_knots, params[: len(phi2_knots)])

In [None]:
plt.figure(figsize=(10, 3))
for i in np.random.choice(states.position.shape[0], size=32):
    params = states.position[i]
    plt.plot(phi2_knots, params[len(phi2_knots) : 2 * len(phi2_knots)])

In [None]:
plt.figure(figsize=(10, 3))
for i in np.random.choice(states.position.shape[0], size=32):
    params = states.position[i]
    plt.plot(phi2_knots, params[2 * len(phi2_knots) : 3 * len(phi2_knots)])

In [None]:
xx, yy = np.meshgrid(
    np.linspace(phi1_data.min(), phi1_data.max(), 256), np.linspace(-4, 4, 128)
)

for i in np.random.choice(states.position.shape[0], size=8):
    params = states.position[i]
    zz = np.exp(
        phi2_ln_likelihood(
            phi2_knots,
            params[: len(phi2_knots)],
            phi2_knots,
            np.exp(params[len(phi2_knots) : 2 * len(phi2_knots)]),
            xx,
            yy,
        )
        + phi1_ln_likelihood(
            phi2_knots, params[2 * len(phi2_knots) : 3 * len(phi2_knots)], xx
        )
    )

    plt.figure(figsize=(10, 3))
    plt.pcolormesh(
        xx,
        yy,
        zz,
        shading="auto",
        vmin=np.median(zz[(xx > 0) & (xx < 5)]),
        vmax=np.max(zz[(xx > 0) & (xx < 5)]),
    )