TODO NEW PLAN:
- First example just do toy / fake data - sample from gaussian, learn parameters of gaussian, and spline version of density
- Second example do GD-1 sky positions - coarse filter CMD and proper motions and fit 2D sky positions

A problem I have run into frequently in astronomical data analysis is the need to infer parameters of a density model where some aspects of the model are allowed to be flexible and other components are held more rigid. In these contexts, we are also sometimes interested in learning a flexible representation for the density of sources itself. This post demonstrates how to implement models with flexibility controlled by spline interpolation of function values in [JAX](https://jax.readthedocs.io/en/latest/).

One example of the need for flexibility in density modeling is the classic Galactic astronomy problem of measuring the vertical stellar density profile (and midplane density) of the Galactic disk: In this problem, we start with observations of stellar positions ($x, y, z$) (probably observed under some selection function) and we want to infer the midplane density value and a model for the density profile away from the midplane. Historically, simple, parametric density profiles have been used (e.g., [Bovy et al. 2017](https://ui.adsabs.harvard.edu/abs/2017MNRAS.470.1360B/abstract)), but we now know that there are significant asymmetries in the density of stars (e.g., [Bennett et al. 2019](https://ui.adsabs.harvard.edu/abs/2019MNRAS.482.1417B/abstract)), and so we might now want to fit a parametric density profile plus a model component to handle this asymmetry.

Another problem where the need to fit models with parametric and flexible components arises is in modeling the phase-space density of stellar streams (e.g., [Koposov et al. 2019](https://ui.adsabs.harvard.edu/abs/2019MNRAS.485.4726K/abstract), [Tavangar et al. 2022](https://ui.adsabs.harvard.edu/abs/2022ApJ...925..118T/abstract)). In the case of stellar streams, we generally want to simultaneously fit the "track" or ridgeline of the stream in position and velocity components, the width of the stream, the density along the stream, and a flexible model for the background stellar density in these components.

There are many possible options for adding flexibility to models (see: Machine Learning). One particularly useful tool that is used heavily in time series analysis are [Gaussian processes](https://en.wikipedia.org/wiki/Gaussian_process) (GPs). GPs allow adding controlled flexibility in probabilistic models (i.e. weakly parametric, through specification of a kernel function) and have gained popularity in astronomy recently thanks to advances in computational efficiency in computing GP likelihoods (e.g., [celerite](https://github.com/exoplanet-dev/celerite2) or [tinygp](https://github.com/dfm/tinygp)). I won't go over GPs in this post, but there are many resources available online and on GitHub that give great introductions to GPs (e.g., [Dan Foreman-Mackey's slides](https://speakerdeck.com/dfm/an-astronomers-introduction-to-gaussian-processes-v2) or [Rodrigo Luger's tutorial](https://github.com/LSSTC-DSFP/LSSTC-DSFP-Sessions/blob/main/Sessions/Session13/Day2/answers/01-Introduction-to-GPs.ipynb)).

In this post, we will use another frequently-used tool for specifying flexible models: [cubic splines](https://en.wikipedia.org/wiki/Spline_(mathematics)).

In [None]:
# Some global imports we will need throughout this post:
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt

from jax.config import config

config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp

%matplotlib inline
import numpy as np

# Spline models

A spline function is fully determined by the degree of the polynomial used, the location of $M$ "knots" $x_m$, and the function value at the knots $f_m$. A common choice for the polynomial degree is 3, or cubic splines. 

In [None]:
import scipy.interpolate as sci

In [None]:
rng = np.random.default_rng(seed=42)

M = 8
x_m = np.linspace(0, 10, M)
f_m = rng.uniform(-1, 1, M)
spl = sci.InterpolatedUnivariateSpline(x_m, f_m, k=3)  # k = the polynomial degree

In [None]:
plt.scatter(x_m, f_m)

grid = np.linspace(-1, 11, 1024)
plt.plot(grid, spl(grid), marker="", linestyle="-", color="tab:blue", zorder=-10)

plt.annotate(
    "knots",
    xy=(x_m[0], f_m[0]),
    xytext=(2, 2),
    arrowprops=dict(color="#666", shrinkB=4, arrowstyle="->"),
    ha="center",
)
for m in range(1, 3):
    plt.annotate(
        "     ",
        xy=(x_m[m], f_m[m]),
        xytext=(2, 2),
        arrowprops=dict(color="#666", shrinkB=4, arrowstyle="->"),
        ha="center",
    )

plt.xlabel("$x$")
plt.ylabel("$f$")

The task of finding a spline representation of a function given samples or points is sometimes called "spline regression." The problem is straightforward if we pick and fix locations for the knots of the spline function we want to fit and then add into our model the values of the function at the locations of the knots. This type of model has the advantage that the (spatial) scale of flexibility or "degrees of freedom" is controllable by setting the number of knots. However, unlike in GPs where kernel functions can be used to parametrize the amplitude or spatial scales of your problem, these things are not explicitly controlled in a spline model. One other disadvantage of a spline model is that the number of parameters in your model grows as you increase the number of knots (i.e. the degrees of freedom) of the model -- this can make spline models intractable in some simple optimization routines (e.g., using `scipy.minimize` without gradient information) or in some Markov Chain Monte Carlo (MCMC) methods that do not use gradient information (e.g., Metropolis-Hastings or [`emcee`](https://emcee.readthedocs.io/en/stable/)).

Fortunately, it is possible to use spline models with [JAX](https://jax.readthedocs.io/), which automatically gives us access to functional gradients and therefore opens up the possibility of using optimization and sampling methods that perform well with large numbers of parameters. Below are two examples that demonstrate how to implement spline components in density models using JAX, to optimize the parameters of the models with [`jaxopt`](https://jaxopt.github.io/), and to generate posterior samples using Hamiltonian Monte Carlo with [`blackjax`](https://blackjax-devs.github.io/blackjax/).

# Example: Fitting a 1D density profile with splines

As a first demonstration of the idea, we are going to use simulated data to mock up a simpler version of the vertical density problem mentioned above. We will generate simulated data from a Gaussian, and then show how to fit the density distribution by modeling the points as an [inhomogeneous Poisson process](https://en.wikipedia.org/wiki/Poisson_point_process#Inhomogeneous_Poisson_point_process) with either (1) a Gaussian or (2) a cubic spline density function. In either case, given a density function $n(z)$ (Gaussian or spline), our likelihood and log-likelihood are given by the Poisson process likelihood, given all $N$ of our $z_n$ data points:
$$
\begin{align}
p(\left\{z_n\right\}_N \,|\, n(z)) &=
    \exp{\left[-\int {\rm d}z \, n(z)\right]} \, \prod_n^N n(z_n)
\end{align}
$$

## Case 1: Gaussian model

For our first demo, we will use a Gaussian to fit the data (which were generated by a Gaussian, so this is truly a toy example). In this case:
$$
\begin{align}
n(z \,|\, N_0, \mu, \sigma) &= N_0 \, \mathcal{N}(z \,|\, \mu, \sigma)\\
\mathcal{N}(x \,|\, \mu, \sigma) &= \frac{1}{\sqrt{2\pi\,\sigma^2}} \, e^{-\,\frac{(x - \mu)^2}{2\,\sigma^2}}\\
\end{align}
$$
where $\mathcal{N}$ represents the normal distribution, $N_0$ is the total number of sources, and the mean $\mu$ and standard deviation $\sigma$ are the usual Gaussian parameters. 

The integral that appears in the first term of the Poisson process likelihood above is therefore just the total number $N_0$, as the integral over the normal distribution $\mathcal{N}$ is 1:
$$
\begin{align}
    p(\left\{z_n\right\}_N \,|\, N_0, \mu, \sigma) &= \exp{\left[-N_0 \, \int {\rm d}z \, \mathcal{N}(z)\right]}  \, \prod_n^N n(z_n)\\
    &= e^{-n_0}  \, N_0^N \, \prod_n^N \mathcal{N}(z_n \,|\, \mu, \sigma)\\
\end{align}
$$

The log-likelihood is therefore (where $N$ is the number of data points, and $N_0$ is a parameter):
$$
\begin{align}
\ln p(\left\{z_n\right\}_N \,|\, n_0, \mu, \sigma) &=
    -N_0 + N\,\ln N_0 + \sum_n^N \ln \mathcal{N}(z_n \,|\, \mu, \sigma)
\end{align}
$$

To start with, we will generate some random, normal distributed points with arbitrarily chosen mean and variance:

In [None]:
rng = np.random.default_rng(seed=42)

N = 100_000
z = rng.normal(0.03, 0.31, size=N)

# Pack the data into a dictionary so later we can store other metadata. For
# reasons that will be clear later, we also store the number of data points
# in this dictionary data structure:
data = {"N": N, "z": z}

Let's start by making a histogram of the "data" to visualize it:

In [None]:
z_bins = np.linspace(-2, 2, 128)
plt.hist(data["z"], bins=z_bins)
plt.yscale("log")
plt.xlabel("$z$")
plt.ylabel("number of sources");

To visualize an estimate of the density function, we can use the `numpy.histogram` function instead to compute the number counts per bin and divide by the size of each bin:

In [None]:
H, xe = np.histogram(data["z"], bins=z_bins)
xc = 0.5 * (xe[:-1] + xe[1:])
dens = H / (xe[1] - xe[0])

plt.plot(xc, dens, drawstyle="steps-mid", marker="")
plt.yscale("log")

plt.xlabel("$z$")
plt.ylabel("density $n(z)$");

In what follows, we are going to be defining several different density models and objective functions for our different density models. But ultimately, with all of these choices (e.g., Gaussian density model vs. cubic spline), we will need to be able to compute the log-likelihood given a choice of parameters. I like to use object-oriented programming (OOP) to structure my code when I am in situations like this because it helps to reduce duplicated code, enables encapsulation and namespacing, and, frankly, because I think the benefits of Python shine when using OOP. However, JAX is really designed to be used within a [*functional programming*](https://en.wikipedia.org/wiki/Functional_programming) context because of the way [Just-in-time](https://en.wikipedia.org/wiki/Just-in-time_compilation) (JIT) compilation works. You can read a bit more about this in the [JAX Gotchas](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) page, but the bottom line is that all JIT-compiled functions must be *pure functions* (functions that return the same values given the same input arguments). 

There are some advanced ways of implementing more OOP-like code with JAX, but here I'm going to (ab)use Python classes as a simple way of creating namespaces for the functions we will need with a light form of inheritance that still obeys the *pure function* requirement of JAX. These classes don't look like true OOP because we use `@classmethod`'s instead of regular instance methods, but some other OOP ideas still translate. We will start by defining a base `Model` class that implements some common methods we will need for any of the density models we implement:

In [None]:
# We will need to wrap JAX's jit function with a partial function call to get
# it to work with our classmethod's below. We will use it to tell JAX to treat
# the 0'th input (i.e. the class itself in a classmethod) as a compile-time
# constant-valued object:
from functools import partial


class Model:
    # This will store the parameter names and expected sizes of the parameters
    # (to allow for array-valued parameters) for the density models we
    # implement later on:
    param_names = {}

    @classmethod
    @partial(jax.jit, static_argnums=(0,))
    def unpack_pars(cls, p_arr):
        """
        This function takes a parameter array and unpacks it into a dictionary
        with the parameter names as keys.
        """
        p_dict = {}
        j = 0
        for name, size in cls.param_names.items():
            p_dict[name] = jnp.squeeze(p_arr[j : j + size])
            j += size
        return p_dict

    @classmethod
    @partial(jax.jit, static_argnums=(0,))
    def pack_pars(cls, p_dict):
        """
        This function takes a parameter dictionary and packs it into a JAX array
        where the order is set by the parameter name list defined on the class.
        """
        p_arrs = []
        for name in cls.param_names.keys():
            p_arrs.append(jnp.atleast_1d(p_dict[name]))
        return jnp.concatenate(p_arrs)

    @classmethod
    @partial(jax.jit, static_argnums=(0,))
    def ln_posterior(cls, pars, data, *args):
        return cls.ln_likelihood(pars, data, *args) + cls.ln_prior(pars)

    @classmethod
    @partial(jax.jit, static_argnums=(0,))
    def objective(cls, pars_arr, N, data, *args):
        """
        This function computes an objective function to be *minimized*: In our
        case, we will be doing Bayesian statistics, so this is generally the
        negative log-posterior-probability value such that if we minimize the
        objective function, we obtain the maximum a posteriori (MAP) parameter
        values. Here we also normalize the value by the number of data points so
        that scipy's minimizers don't run into overflow issues with the
        gradients.
        """
        pars = cls.unpack_pars(pars_arr)
        return -cls.ln_posterior(pars, data, *args) / N

With our base `Model` class defined, we can now implement a subclass for the first model we are going to fit to our simulated data: a Gaussian! Using the true density model to fit the simulated data we should recover the input parameters that we used to generate the data:

In [None]:
def ln_normal(x, mu, var):
    """Evaluate the log-normal probability"""
    return -0.5 * (jnp.log(2 * np.pi * var) + (x - mu) ** 2 / var)

In [None]:
class GaussianModel(Model):
    param_names = {
        "ln_N0": 1,  # the log number density
        "mean": 1,  # the mean of the Gaussian
        "ln_std": 1,  # the log standard deviation
    }

    @staticmethod
    @jax.jit
    def ln_density(x, ln_N0, mean, ln_std):
        """
        This function implements the log-density of our model. Here, this is the
        log-Gaussian.
        """
        var = jnp.exp(2 * ln_std)
        return ln_N0 + ln_normal(x, mean, jnp.exp(2 * ln_std))

    @classmethod
    @partial(jax.jit, static_argnums=(0,))
    def ln_likelihood(cls, pars, data):
        """
        Implementation of the log-likelihood for an inhomogeneous Poisson
        process with underlying density (rate) function given by a Gaussian.
        Here the integral over our density function is has a simple closed form
        solution (see the math above).
        """
        dens = cls.ln_density(data["z"], **pars)
        return -jnp.exp(pars["ln_N0"]) + dens.sum()

    @classmethod
    @partial(jax.jit, static_argnums=(0,))
    def ln_prior(cls, pars):
        """
        A very light prior on the parameters. We again use Normal's for priors,
        but for most parameters we use relatively wide (large variance) values
        so that the prior does not have much of an influence.
        """
        lp = 0.0

        # A very wide, basically unconstrained Gaussian
        lp += ln_normal(pars["ln_N0"], 0, 100)

        # We expect the mean to be close to 0
        lp += ln_normal(pars["mean"], 0, 1)

        # We expect the standard deviation to be small:
        lp += ln_normal(pars["ln_std"], -2, 3)

        return lp

Let's pick some initial values for our parameters and plot the density function corresponding to our parameter choices:

In [None]:
plt.plot(xc, dens, drawstyle="steps-mid", marker="")

init_pars = {"ln_N0": np.log(N) + 0.5, "mean": 1e-1, "ln_std": np.log(0.3)}
init_p = GaussianModel.pack_pars(init_pars)

z_grid = np.linspace(z_bins.min(), z_bins.max(), 1024)
plt.plot(z_grid, np.exp(GaussianModel.ln_density(z_grid, **init_pars)), marker="")

plt.yscale("log")

plt.xlabel("$z$")
plt.ylabel("density $n(z)$");

Those initial parameter values don't look like a very good match to the observed density, but it's probably close enough that an optimizer will be able to find a better solution from there. For the optimizer, we will use Scipy's [L-BFGS-B](https://docs.scipy.org/doc/scipy/reference/optimize.minimize-lbfgsb.html) implementation, which is available through the general-purpose `scipy.optimize.minimize()` function. Here we use JAX's `value_and_grad()` to get a function handle based on our objective function that returns both the objective value and the gradient with respect to the input parameters. This is where the utility of JAX comes to light: it uses auto-differentiation to compute the gradients for us. We have to set `jac=True` in `minimize()` to tell Scipy to expect the gradient along with the objective function value:

In [None]:
import scipy.optimize as sco

In [None]:
res = sco.minimize(
    jax.value_and_grad(GaussianModel.objective),
    GaussianModel.pack_pars(init_pars),
    args=(len(data["z"]), data),
    jac=True,
    method="l-bfgs-b",
    options=dict(maxiter=1000),
    bounds=[(5, 20), (-2, 2), (-5, 5)],
)
res

It looks like that optimization completed successfully, and after only 10 function evaluations! Let's look at the density function implied by the optimized parameters:

In [None]:
opt_pars = GaussianModel.unpack_pars(res.x)

plt.plot(xc, dens, drawstyle="steps-mid", marker="")

z_grid = np.linspace(z_bins.min(), z_bins.max(), 1024)
plt.plot(
    z_grid,
    np.exp(GaussianModel.ln_density(z_grid, **opt_pars)),
    marker="",
    color="tab:green",
)

plt.yscale("log")

plt.xlabel("$z$")
plt.ylabel("density $n(z)$");

That looks like a pretty good fit! Let's move on to a more flexible example.

## Case 2: Spline model

We will now replace our density model $n(z)$ with a cubic spline representation of the function. We will fix the location of the spline knots by using a hard-set, uniform grid of points in $z$, but the parameters of the model will then be the value of the (log-)density at the locations of the knots. Though there is no jax-ified cubic interpolation built-in to JAX itself (as far as I can tell, it currently only supports linear interpolation), we will use another package — `jax_cosmo` — which provides a jax-aware version of Scipy's `InterpolatedUnivariateSpline`:

In [None]:
from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline

Awesome, we now have the main tool we need to enable implementing the spline model, however we have one more mathematical / numerical hurdle to solve: we need to be able to compute the integral of our density model to compute the first term in the Poisson process likelihood
$$
\exp{\left[-\int {\rm d}z \, n(z)\right]}
$$

For generic cubic splines, this integral over all $z$'s is not finite. We therefore have to pick a domain over which to do this integral, and this then slightly changes the meaning of our parameter `ln_N0` to be the number of sources *in the domain we choose*. In practice, if we pick a domain that is large enough and the density function falls off quickly (as it does here), there won't be any practical difference. (But note: if you have a rigid selection region, or if you pick a domain that truncates the data, you have to be more careful than me!) Since our data end around $z\sim \pm 1.5$, we will pick a window of $(-3, 3)$. 

We now need a way of computing the integral of our spline model over this domain. If our parameters were the value of the *density* $N_0$ at the locations of the knots, we could use the `InterpolatedUnivariateSpline.integral()` method directly to compute the integral. However, we use the value of the log-density as parameters, so the integral is not as straightforward. Here, I've implemented a version of [Simpson's rule](https://en.wikipedia.org/wiki/Simpson%27s_rule) that takes in the log-function values and returns the log-integral, which is more stable than using other integration tools that would require first exponentiating the density and then taking the log of the estimated integral value on the outside:

In [None]:
def ln_simpson(ln_y, x):
    """
    Evaluate the log of the definite integral of a function evaluated on a
    grid using Simpson's rule
    """

    dx = jnp.diff(x)[0]
    num_points = len(x)
    if num_points // 2 == num_points / 2:
        raise ValueError("Because of laziness, the input size must be odd")

    weights_first = jnp.asarray([1.0])
    weights_mid = jnp.tile(jnp.asarray([4.0, 2.0]), [(num_points - 3) // 2])
    weights_last = jnp.asarray([4.0, 1.0])
    weights = jnp.concatenate([weights_first, weights_mid, weights_last], axis=0)

    return jax.scipy.special.logsumexp(ln_y + jnp.log(weights), axis=-1) + jnp.log(
        dx / 3
    )

With a decision about our integration window and a jax-ified function to compute the value of the log-integral over our spline density function, we can now set up a spline model to fit our toy data:

In [None]:
class GaussianSplineModel(Model):
    knots = jnp.linspace(-3, 3, 11)  # locations of the spline knots
    param_names = {
        "ln_n0": 11,  # the value of the log-density at the knots
    }
    window = (-3, 3)  # integration window for numerical integral of density
    n_integral_pts = 1025  # the number of integration grid points to use

    @staticmethod
    @jax.jit
    def ln_density(x, ln_n0, knots):
        """
        The log-density is just an evaluation of the spline at the input
        """
        ln_dens_spl = InterpolatedUnivariateSpline(knots, ln_n0, k=3)
        return ln_dens_spl(x)

    @classmethod
    @partial(jax.jit, static_argnums=(0,))
    def ln_likelihood(cls, pars, data):
        """
        Implementation of the log-likelihood for an inhomogeneous Poisson
        process with underlying density (rate) function given by a spline
        """
        ln_dens = cls.ln_density(data["z"], pars["ln_n0"], cls.knots)

        # As mentioned above, to compute the integral over the density, we do
        # the integral numerically using Simpson's rule. For my implementation,
        # we must pass in a grid of points and the log of the function to
        # integrate evaluated at these grid points. The number of grid points is
        # hard-set here, but this should be tuned to meet some accuracy criteria
        V_grid = jnp.linspace(*cls.window, cls.n_integral_pts)
        ln_V = ln_simpson(cls.ln_density(V_grid, pars["ln_n0"], cls.knots), V_grid)
        return -jnp.exp(ln_V) + ln_dens.sum()

    @classmethod
    @partial(jax.jit, static_argnums=(0,))
    def ln_prior(cls, pars):
        lp = 0.0
        for name, p in pars.items():
            lp += ln_normal(p, 0, 100).sum()
        return lp

With our spline model defined, we now 

In [None]:
knots = np.linspace(xc.min(), xc.max(), 11)
knots_ln_dens = sci.InterpolatedUnivariateSpline(xc, np.log(dens + 1e-8), k=3)(
    GaussianSplineModel.knots
)

plt.plot(xc, dens, drawstyle="steps-mid", marker="")
plt.scatter(GaussianSplineModel.knots, np.exp(knots_ln_dens))
plt.yscale("log")
plt.ylim(1e0, 3e5)

In [None]:
init_pars = {"ln_n0": knots_ln_dens + np.random.uniform(0, 0.5, size=len(knots))}
res = sco.minimize(
    jax.value_and_grad(GaussianSplineModel.objective),
    GaussianSplineModel.pack_pars(init_pars),
    args=(len(data["z"]), data),
    jac=True,
    method="l-bfgs-b",
    options=dict(maxiter=1000, maxls=1000),
)
res

In [None]:
plt.plot(xc, dens, drawstyle="steps-mid", marker="")

opt_pars = GaussianSplineModel.unpack_pars(res.x)

_grid = np.linspace(-2, 2, 1024)
plt.plot(
    _grid,
    np.exp(
        GaussianSplineModel.ln_density(
            _grid, opt_pars["ln_n0"], GaussianSplineModel.knots
        )
    ),
    marker="",
)
plt.yscale("log")
plt.ylim(1e0, 3e5)

TODO: sampling??

In [None]:
from jaxopt import ScipyMinimize

In [None]:
solver = ScipyMinimize(
    method="bfgs", fun=gaussian_objective, options=dict(maxiter=1000, disp=True)
)
res = solver.run(model.pack_pars(init_pars), data=data)

In [None]:
plt.plot(xc, dens, drawstyle="steps-mid", marker="")

_grid = np.linspace(-2, 2, 128)
init_params = {"ln_n0": res.params[0], "x0": res.params[1], "ln_std": res.params[2]}
plt.plot(_grid, np.exp(ln_gaussian_density(init_params, _grid)), marker="")

plt.yscale("log")

In [None]:
@jax.jit
def ln_gaussian_density(params, x):
    var = jnp.exp(2 * params["ln_std"])
    return params["ln_n0"] - 0.5 * (
        jnp.log(2 * np.pi) + 2 * params["ln_std"] + (x - params["x0"]) ** 2 / var
    )


@jax.jit
def ln_gaussian_likelihood(params, x):
    dens = ln_gaussian_density(params, x)
    V = jnp.exp(params["ln_n0"])
    return -V + dens.sum()


@jax.jit
def gaussian_objective(p, data):
    params = {"ln_n0": p[0], "x0": p[1], "ln_std": p[2]}
    return -ln_gaussian_likelihood(params, data["z"]) / len(data["z"])

In [None]:
plt.plot(xc, dens, drawstyle="steps-mid", marker="")

_grid = np.linspace(-2, 2, 128)
init_params = {"ln_n0": np.log(N) + 0.5, "x0": 1e-1, "ln_std": np.log(0.2)}
plt.plot(_grid, np.exp(ln_gaussian_density(init_params, _grid)), marker="")

plt.yscale("log")

In [None]:
from jaxopt import ScipyMinimize

In [None]:
solver = ScipyMinimize(
    method="bfgs", fun=gaussian_objective, options=dict(maxiter=1000, disp=True)
)
init_p = list(init_params.values())
res = solver.run(init_p, data=data)

In [None]:
plt.plot(xc, dens, drawstyle="steps-mid", marker="")

_grid = np.linspace(-2, 2, 128)
init_params = {"ln_n0": res.params[0], "x0": res.params[1], "ln_std": res.params[2]}
plt.plot(_grid, np.exp(ln_gaussian_density(init_params, _grid)), marker="")

plt.yscale("log")

In [None]:
res

In [None]:
from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline

In [None]:
def ln_simpson(ln_y, x, dtype=None):
    """Evaluates definite integral using Simpson's 1/3 rule"""

    dx = jnp.diff(x)[0]
    num_points = len(x)
    if num_points // 2 == num_points / 2:
        raise ValueError("oopsies")

    weights_first = jnp.asarray([1.0], dtype=dtype)
    weights_mid = jnp.tile(
        jnp.asarray([4.0, 2.0], dtype=dtype), [(num_points - 3) // 2]
    )
    weights_last = jnp.asarray([4.0, 1.0], dtype=dtype)
    weights = jnp.concatenate([weights_first, weights_mid, weights_last], axis=0)

    return jax.scipy.special.logsumexp(ln_y + jnp.log(weights), axis=-1) + jnp.log(
        dx / 3
    )

In [None]:
@jax.jit
def ln_flexible_gaussian_likelihood(p, x, knots, window):
    ln_dens_spl = InterpolatedUnivariateSpline(knots, p, k=3)
    ln_dens = ln_dens_spl(x)

    V_grid = jnp.linspace(*window, 1025)  # TODO: MAGIC NUMBER
    ln_V = ln_simpson(ln_dens_spl(V_grid), V_grid)

    return -jnp.exp(ln_V) + ln_dens.sum()


@jax.jit
def flexible_gaussian_objective(p, data, *args, **kw):
    return -ln_flexible_gaussian_likelihood(p, data["z"], *args, **kw) / len(data["z"])

In [None]:
from scipy.interpolate import interp1d

In [None]:
knots = np.linspace(xc.min(), xc.max(), 11)
knots_ln_dens = interp1d(xc, np.log(dens))(knots)
knots_ln_dens[~np.isfinite(knots_ln_dens)] = np.log(dens[dens != 0].min())

plt.plot(xc, dens, drawstyle="steps-mid", marker="")
plt.scatter(knots, np.exp(knots_ln_dens))
plt.yscale("log")
plt.ylim(1e0, 3e5)

In [None]:
solver = ScipyMinimize(
    method="bfgs",
    fun=flexible_gaussian_objective,
    options=dict(maxiter=1000, disp=True),
)
init_p = knots_ln_dens
res = solver.run(init_p, data=data, knots=knots, window=(-2, 2))

In [None]:
plt.plot(xc, dens, drawstyle="steps-mid", marker="")

_grid = np.linspace(-2, 2, 128)
ln_dens_spl = InterpolatedUnivariateSpline(knots, res.params, k=3)
plt.plot(_grid, np.exp(ln_dens_spl(_grid)), marker="")

plt.yscale("log")

TODO: show sampling as well

TODO: try on real data??

In [None]:
import healpy as hp

In [None]:
import astropy.coordinates as coord
from pyia import GaiaData

g = GaiaData("/Users/apricewhelan/data/GaiaDR3/dr3-rv-plx0.1.fits")

MG = g.phot_g_mean_mag - g.distmod
g = g[
    (g.bp_rp.value > 0.5)
    & (g.bp_rp.value < 2)
    & (MG.value > -1)
    & (MG.value < 1)
    & (g.parallax_over_error > 4)
]
len(g)

In [None]:
c = g.get_skycoord()

In [None]:
galcen_frame = coord.Galactocentric(galcen_distance=8.275 * u.kpc)
galcen = c.transform_to(galcen_frame)
R_xy = np.sqrt((galcen.x + galcen_frame.galcen_distance) ** 2 + galcen.y**2)
R = np.sqrt(
    (galcen.x + galcen_frame.galcen_distance) ** 2 + galcen.y**2 + galcen.z**2
)

R_lim = 0.5 * u.kpc
R_mask = R < R_lim
R_xy_mask = R_xy < R_lim

In [None]:
nside = 32
npix = hp.nside2npix(nside)
indices = hp.ang2pix(
    nside, c.galactic.l.degree[R_mask], c.galactic.b.degree[R_mask], lonlat=True
)

hpxmap = np.zeros(npix, dtype=int)
np.add.at(hpxmap, indices, np.ones_like(indices))

In [None]:
hp.mollview(hpxmap, norm=mpl.colors.LogNorm(), min=0.1, max=1e3)

In [None]:
z = galcen.z[R_xy_mask]
data = dict(z=z.to_value(u.kpc))

In [None]:
z_bins = np.linspace(-5, 5, 201)
plt.hist(z.to_value(u.kpc), bins=z_bins)
plt.yscale("log")

In [None]:
bins = np.linspace(-1, 1, 81)
plt.hist(z.to_value(u.kpc), bins=z_bins)
plt.hist(-z.to_value(u.kpc), bins=z_bins, histtype="step")
plt.yscale("log")

$$
n(z) = n_0 \, \left[
    \alpha \, {\rm sech}^2\left(\frac{|z|}{2\,h_{z,1}}\right)
    + (1-\alpha)\,{\rm sech}^2\left(\frac{|z|}{2\,h_{z,2}}\right)
\right]\\
\int {\rm d}z\,n(z) = 4\,n_0 \, (\alpha\,h_{z,1} + (1-\alpha)\,h_{z,1})
$$

In [None]:
from jax.config import config

config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp

In [None]:
@jax.jit
def ln_two_sech2(z, ln_n0, alpha, ln_hz1, ln_hz2):
    hz1 = jnp.exp(ln_hz1)
    hz2 = jnp.exp(ln_hz2)
    lnterm1 = jnp.log(alpha) - 2 * jnp.log(jnp.cosh(jnp.abs(z) / (2 * hz1)))
    lnterm2 = jnp.log(1 - alpha) - 2 * jnp.log(jnp.cosh(jnp.abs(z) / (2 * hz2)))
    return ln_n0 + jnp.logaddexp(lnterm1, lnterm2)

In [None]:
H, xe = np.histogram(data["z"], bins=z_bins)
xc = 0.5 * (xe[:-1] + xe[1:])
dens = H / (xe[1] - xe[0])
plt.plot(xc, dens, drawstyle="steps-mid", marker="")

z_grid = np.linspace(-5, 5, 1024)
close_p = np.array([np.log(len(data["z"])), 0.95, np.log(0.2), np.log(0.7)])
plt.plot(z_grid, np.exp(ln_two_sech2(z_grid, *close_p)), marker="")

plt.yscale("log")

In [None]:
@jax.jit
def ln_likelihood_two_sech2(params, data):
    model_ln_dens = ln_two_sech2(
        data["z"], params["ln_n0"], params["alpha"], params["ln_hz1"], params["ln_hz2"]
    )

    ln_V = (
        params["ln_n0"]
        + jnp.log(4)
        + jnp.logaddexp(
            jnp.log(params["alpha"]) + params["ln_hz1"],
            jnp.log(1 - params["alpha"]) + params["ln_hz2"],
        )
    )
    return -jnp.exp(ln_V) + model_ln_dens.sum()


@jax.jit
def two_sech2_objective(params, data):
    p = {
        "ln_n0": params[0],
        "alpha": params[1],
        "ln_hz1": params[2],
        "ln_hz2": params[3],
    }
    return -ln_likelihood_two_sech2(p, data) / len(data["z"])

In [None]:
print(close_p)
for i in range(len(close_p)):
    pp = np.array(close_p, copy=True)
    par_grid = close_p[i] * np.linspace(0.5, 2, 128)
    objs = np.zeros_like(par_grid)
    for j, val in enumerate(par_grid):
        pp[i] = val
        objs[j] = -two_sech2_objective(pp, data)

    plt.figure()
    plt.plot(par_grid, np.exp(objs - np.nanmax(objs)))
    plt.axvline(close_p[i], marker="")

In [None]:
from jaxopt import ScipyBoundedMinimize, ScipyMinimize

In [None]:
print(two_sech2_objective(close_p, data))

In [None]:
init_params

In [None]:
close_p

In [None]:
solver = ScipyBoundedMinimize(
    method="l-bfgs-b",
    fun=two_sech2_objective,
    #     options={'gtol': 1e-5 * len(data['z'])}
)
# solver = ScipyMinimize(
#     method='l-bfgs-b',
#     fun=two_sech2_objective,
#     # tol=1e-8,
#     # options={'ftol': 1e-12, 'gtol': 1e-10}
# #     options=dict(maxfun=150000, maxiter=150000, maxls=1000)
# )

# init_params = np.array([
#     np.log(4e3),
#     0.9,
#     np.log(0.22),
#     np.log(0.7)
# ])
init_params = close_p + 1e-2

res = solver.run(init_params, bounds=([0, 0, -5, -5], [12, 1, 5, 5]), data=data)
# res = solver.run(init_params, data=data)
res.state.iter_num, res.state.success, res.state.status

In [None]:
print(init_params)
print("init", two_sech2_objective(init_params, data))
print("close_p", two_sech2_objective(close_p, data))
print("opt", two_sech2_objective(res.params, data))

In [None]:
plt.plot(xc, dens, drawstyle="steps-mid", marker="")

z_grid = np.linspace(-5, 5, 1024)
close_p = np.array([np.log(len(data["z"])), 0.95, np.log(0.2), np.log(0.7)])
plt.plot(z_grid, np.exp(ln_two_sech2(z_grid, *res.params)), marker="")

plt.yscale("log")

In [None]:
from scipy.optimize import minimize, fmin_l_bfgs_b

In [None]:
res = fmin_l_bfgs_b(
    two_sech2_objective,
    fprime=jax.grad(two_sech2_objective),
    approx_grad=False,
    x0=init_params,
    args=(data,),
    bounds=[(0, 12), (0, 1), (-5, 5), (-5, 5)],
    # options=dict(ftol=1e-14)
    factr=0,
    maxls=100000,
    maxfun=1500000,
    maxiter=1500000,
    m=128,
    pgtol=1e-10,
)
res

In [None]:
res = minimize(
    two_sech2_objective,
    #     method='bfgs',
    jac=jax.grad(two_sech2_objective),
    x0=init_params,
    args=(data,),
    bounds=[(0, 12), (0, 1), (-5, 5), (-5, 5)],
    options=dict(ftol=1e-14),
)
res

In [None]:
res.x

In [None]:
(init_params, data)

In [None]:
res.params

In [None]:
from scipy.stats import binned_statistic

import jax
import jax.numpy as jnp
from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline
from jaxopt import ScipyMinimize
import blackjax

Make some fake data:

In [None]:
N = 1024
window = (-20, 20)
rng = np.random.default_rng(42)
# phi1_data = rng.uniform(0, 10, size=1024)
phi1_data = rng.normal(5, 4, size=1024)
assert np.all((phi1_data < window[1]) & (phi1_data > window[0]))

phi2_data = rng.normal(1.5 * np.cos(2 * np.pi * phi1_data / 5), 0.8)
plt.figure(figsize=(10, 3))
plt.scatter(phi1_data, phi2_data)

In [None]:
@jax.jit
def jln_normal(x, mu, var):
    return -0.5 * (jnp.log(2 * np.pi * var) + (x - mu) ** 2 / var)


@jax.jit
def phi2_ln_likelihood(
    phi2_mean_knots,
    phi2_mean_vals,
    phi2_std_knots,
    phi2_std_vals,
    phi1_eval,
    phi2_eval,
):

    phi2_mean = InterpolatedUnivariateSpline(phi2_mean_knots, phi2_mean_vals, k=3)

    phi2_std = InterpolatedUnivariateSpline(phi2_std_knots, phi2_std_vals, k=3)

    phi2_mean_model = phi2_mean(phi1_eval)
    phi2_var_model = phi2_std(phi1_eval) ** 2

    # phi2_mean_model = jnp.interp(phi1_eval, phi2_mean_knots, phi2_mean_vals)
    # phi2_var_model = jnp.interp(phi1_eval, phi2_std_knots, phi2_std_vals) ** 2

    return jln_normal(phi2_eval, phi2_mean_model, phi2_var_model)

In [None]:
phi2_knots = np.linspace(phi1_data.min(), phi1_data.max(), 25)

dp2 = phi2_knots[1] - phi2_knots[0]
_bins = np.linspace(
    phi2_knots[0] - dp2 / 2, phi2_knots[-1] + dp2 / 2, len(phi2_knots) + 1
)
stat = binned_statistic(phi1_data, phi2_data, bins=_bins, statistic=np.nanmean)
phi2_vals = stat.statistic
phi2_vals[np.isnan(phi2_vals)] = 0.0

In [None]:
plt.figure(figsize=(10, 3))
plt.scatter(phi1_data, phi2_data)
plt.scatter(phi2_knots, phi2_vals, color="tab:red")

In [None]:
phi2_ln_likelihood(
    phi2_knots,
    phi2_vals,
    phi2_knots,
    np.full_like(phi2_knots, 0.4),
    phi1_data,
    phi2_data,
)

In [None]:
xx, yy = np.meshgrid(
    np.linspace(phi1_data.min(), phi1_data.max(), 256), np.linspace(-4, 4, 128)
)
zz = phi2_ln_likelihood(
    phi2_knots, phi2_vals, phi2_knots, np.full_like(phi2_knots, 0.4), xx, yy
)
plt.figure(figsize=(10, 3))
plt.pcolormesh(xx, yy, np.exp(zz), shading="auto")
plt.scatter(phi1_data, phi2_data, color="tab:blue")

In [None]:
@jax.jit
def ln_prob(pars, data, phi2_knots):
    n_phi2 = len(phi2_knots)
    phi2_means = pars[:n_phi2]
    ln_phi2_stds = pars[n_phi2 : 2 * n_phi2]
    phi2_stds = jnp.exp(ln_phi2_stds)

    ll = phi2_ln_likelihood(
        phi2_knots, phi2_means, phi2_knots, phi2_stds, data["phi1"], data["phi2"]
    ).sum()

    lp = jln_normal(phi2_means, 0, 2.0).sum()
    lp += jln_normal(ln_phi2_stds, -1, 1.0).sum()

    return ll + lp


@jax.jit
def objective(pars, data, phi2_knots):
    return -ln_prob(pars, data, phi2_knots)

In [None]:
init_params = np.concatenate((phi2_vals, np.log(np.full_like(phi2_knots, 0.4))))
data = {"phi1": phi1_data, "phi2": phi2_data}

In [None]:
solver = ScipyMinimize(method="l-bfgs-b", fun=objective)
res = solver.run(init_params, data=data, phi2_knots=phi2_knots)
res.state.iter_num

In [None]:
res.params

In [None]:
plt.figure(figsize=(10, 3))
plt.plot(phi2_knots, res.params[: len(phi2_knots)])
plt.plot(phi2_knots, res.params[len(phi2_knots) :], color="tab:red")
plt.scatter(phi1_data, phi2_data, color="tab:blue", s=2)

In [None]:
xx, yy = np.meshgrid(
    np.linspace(phi1_data.min(), phi1_data.max(), 256), np.linspace(-4, 4, 128)
)
zz = np.exp(
    phi2_ln_likelihood(
        phi2_knots,
        res.params[: len(phi2_knots)],
        phi2_knots,
        np.exp(res.params[len(phi2_knots) :]),
        xx,
        yy,
    )
)
plt.figure(figsize=(10, 3))
plt.pcolormesh(
    xx,
    yy,
    zz,
    shading="auto",
    vmin=np.median(zz[(xx > 0) & (xx < 5)]),
    vmax=np.max(zz[(xx > 0) & (xx < 5)]),
)
plt.scatter(phi1_data, phi2_data, color="tab:blue", s=2)

Copied from the `blackjax` getting started:

In [None]:
def inference_loop(rng_key, kernel, initial_state, num_samples):
    @jax.jit
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

In [None]:
rng_key = jax.random.PRNGKey(42)

In [None]:
fn = jax.tree_util.Partial(ln_prob, data=data, phi2_knots=phi2_knots)
warmup = blackjax.window_adaptation(
    blackjax.nuts,
    fn,
    1000,
)

state, kernel, _ = warmup.run(
    rng_key,
    res.params,
)

In [None]:
states = inference_loop(rng_key, kernel, state, 1_000)

In [None]:
states.position.shape

In [None]:
plt.figure(figsize=(10, 3))
for i in np.random.choice(states.position.shape[0], size=32):
    params = states.position[i]
    plt.plot(phi2_knots, params[: len(phi2_knots)])

In [None]:
plt.figure(figsize=(10, 3))
for i in np.random.choice(states.position.shape[0], size=32):
    params = states.position[i]
    plt.plot(phi2_knots, params[len(phi2_knots) :])

---

### Infer log-density as well:

Inhomogeneous poisson point process: http://people.ee.duke.edu/~lcarin/PoissonProcess.pdf

In [None]:
def ln_simpson(ln_y, x, dtype=None):
    """Evaluates definite integral using Simpson's 1/3 rule"""

    dx = jnp.diff(x)[0]
    num_points = len(x)
    if num_points // 2 == num_points / 2:
        raise ValueError("oopsies")

    weights_first = jnp.asarray([1.0], dtype=dtype)
    weights_mid = jnp.tile(
        jnp.asarray([4.0, 2.0], dtype=dtype), [(num_points - 3) // 2]
    )
    weights_last = jnp.asarray([4.0, 1.0], dtype=dtype)
    weights = jnp.concatenate([weights_first, weights_mid, weights_last], axis=0)

    return jax.scipy.special.logsumexp(ln_y + jnp.log(weights), axis=-1) + jnp.log(
        dx / 3
    )

In [None]:
@jax.jit
def phi1_ln_likelihood(phi1_knots, ln_phi1_rate, phi1_eval):
    ln_phi1_rate_spl = InterpolatedUnivariateSpline(phi1_knots, ln_phi1_rate, k=3)

    # V = phi1_rate_spl.integral(*window)
    _grid = jnp.linspace(*window, 1025)
    lnV = ln_simpson(ln_phi1_rate_spl(_grid), _grid)

    return -jnp.exp(lnV) / len(phi1_eval) + ln_phi1_rate_spl(phi1_eval)

In [None]:
@jax.jit
def ln_prob2(pars, data, phi2_knots):
    n_phi2 = len(phi2_knots)
    phi2_means = pars[:n_phi2]
    ln_phi2_stds = pars[n_phi2 : 2 * n_phi2]
    phi2_stds = jnp.exp(ln_phi2_stds)

    phi1_rate = pars[2 * n_phi2 : 3 * n_phi2]

    ll = phi2_ln_likelihood(
        phi2_knots, phi2_means, phi2_knots, phi2_stds, data["phi1"], data["phi2"]
    ).sum()

    ll2 = phi1_ln_likelihood(
        phi2_knots,
        phi1_rate,
        data["phi1"],
    ).sum()

    lp = jln_normal(phi2_means, 0, 2.0).sum()
    lp += jln_normal(ln_phi2_stds, -1, 1.0).sum()

    return ll + ll2 + lp


@jax.jit
def objective2(pars, data, phi2_knots):
    return -ln_prob2(pars, data, phi2_knots)

In [None]:
H, xe = np.histogram(data["phi1"], bins=np.linspace(*window, 32))
xc = 0.5 * (xe[:-1] + xe[1:])
H = np.log((H + 1e-4) / (window[1] - window[0]))
init_rate = InterpolatedUnivariateSpline(xc, H, k=3)(phi2_knots)

In [None]:
init_params = np.concatenate(
    (phi2_vals, np.log(np.full_like(phi2_knots, 0.4)), init_rate)
)
data = {"phi1": phi1_data, "phi2": phi2_data}

In [None]:
solver = ScipyMinimize(method="l-bfgs-b", fun=objective2)
res = solver.run(init_params, data=data, phi2_knots=phi2_knots)
res.state.iter_num, res.params

In [None]:
plt.figure(figsize=(10, 3))
plt.plot(phi2_knots, res.params[: len(phi2_knots)])
plt.plot(phi2_knots, res.params[len(phi2_knots) : 2 * len(phi2_knots)], color="tab:red")
plt.plot(phi2_knots, res.params[2 * len(phi2_knots) :], color="tab:green")
plt.scatter(phi1_data, phi2_data, color="tab:blue", s=2)

In [None]:
xx, yy = np.meshgrid(
    np.linspace(phi1_data.min(), phi1_data.max(), 256), np.linspace(-4, 4, 128)
)
zz = np.exp(
    phi2_ln_likelihood(
        phi2_knots,
        res.params[: len(phi2_knots)],
        phi2_knots,
        np.exp(res.params[len(phi2_knots) : 2 * len(phi2_knots)]),
        xx,
        yy,
    )
    + phi1_ln_likelihood(
        phi2_knots, res.params[2 * len(phi2_knots) : 3 * len(phi2_knots)], xx
    )
)

plt.figure(figsize=(10, 3))
plt.pcolormesh(
    xx,
    yy,
    zz,
    shading="auto",
    vmin=np.median(zz[(xx > 0) & (xx < 5)]),
    vmax=np.max(zz[(xx > 0) & (xx < 5)]),
)
# plt.scatter(phi1_data, phi2_data, color='tab:blue', s=2)

In [None]:
rng_key = jax.random.PRNGKey(42)

In [None]:
fn = jax.tree_util.Partial(ln_prob2, data=data, phi2_knots=phi2_knots)
warmup = blackjax.window_adaptation(
    blackjax.nuts,
    fn,
    1000,
)

state, kernel, _ = warmup.run(
    rng_key,
    res.params,
)

In [None]:
states = inference_loop(rng_key, kernel, state, 1_000)

In [None]:
plt.figure(figsize=(10, 3))
for i in np.random.choice(states.position.shape[0], size=32):
    params = states.position[i]
    plt.plot(phi2_knots, params[: len(phi2_knots)])

In [None]:
plt.figure(figsize=(10, 3))
for i in np.random.choice(states.position.shape[0], size=32):
    params = states.position[i]
    plt.plot(phi2_knots, params[len(phi2_knots) : 2 * len(phi2_knots)])

In [None]:
plt.figure(figsize=(10, 3))
for i in np.random.choice(states.position.shape[0], size=32):
    params = states.position[i]
    plt.plot(phi2_knots, params[2 * len(phi2_knots) : 3 * len(phi2_knots)])

In [None]:
xx, yy = np.meshgrid(
    np.linspace(phi1_data.min(), phi1_data.max(), 256), np.linspace(-4, 4, 128)
)

for i in np.random.choice(states.position.shape[0], size=8):
    params = states.position[i]
    zz = np.exp(
        phi2_ln_likelihood(
            phi2_knots,
            params[: len(phi2_knots)],
            phi2_knots,
            np.exp(params[len(phi2_knots) : 2 * len(phi2_knots)]),
            xx,
            yy,
        )
        + phi1_ln_likelihood(
            phi2_knots, params[2 * len(phi2_knots) : 3 * len(phi2_knots)], xx
        )
    )

    plt.figure(figsize=(10, 3))
    plt.pcolormesh(
        xx,
        yy,
        zz,
        shading="auto",
        vmin=np.median(zz[(xx > 0) & (xx < 5)]),
        vmax=np.max(zz[(xx > 0) & (xx < 5)]),
    )