# Self-consistent, Bayesian SSN calibration and inference

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import numpyro
from numpyro.infer import MCMC, NUTS
from numpyro import distributions as dist

# Set the number of cores on your machine for parallelism:
cpu_cores = 4
numpyro.set_host_device_count(cpu_cores)

from jax import numpy as jnp, vmap, config
from jax.random import PRNGKey, split
config.update('jax_enable_x64', True)

from scipy.ndimage import gaussian_filter1d

import arviz
from corner import corner

from celerite2.jax import terms, GaussianProcess

### Generate synthetic observations

In [None]:
np.random.seed(10)
t_start = 1600
t_stop = 2022
t = np.linspace(t_start, t_stop, 1000)

# Specify the frequency of observations, which increases sharply after 1800:
data_collection_rate = 1/(np.exp((1800 - t) / 10) + 1) + 0.05
data_collection_rate_cdf = np.cumsum(data_collection_rate)
data_collection_rate_cdf /= data_collection_rate_cdf.max()

amp_mod = gaussian_filter1d(
    np.random.uniform(size=len(t)), 100
)

y = (
    20 * (amp_mod - amp_mod.min() + 0.0)/amp_mod.ptp() * 
    np.cos(2 * np.pi * t / 2 / 11)**2
)

observers = []

N_observers = 200

true_bias = np.concatenate(
    [np.random.randint(0, 15, size=N_observers)]
)

plt.figure(figsize=(20, 5))

for i in range(N_observers):
    # Randomly choose start/stop times for observer
    tbounds = np.interp(np.random.uniform(0, 1, size=2), data_collection_rate_cdf, t)

    # Randomly distribute observations within start/stop time
    N_observations_per_observer = np.random.randint(10, 20)
    x_obs = np.sort(np.random.uniform(
        tbounds[0], tbounds[1], N_observations_per_observer
    ))
    noise_scale = 1
    noise = np.random.normal(scale=noise_scale, size=len(x_obs))
    y_obs = np.interp(x_obs, t, y) + noise + true_bias[i]
    y_err = np.max([
        4 * (t_stop - x_obs) / (t_stop - t_start), 
        jnp.broadcast_to(noise_scale, x_obs.shape)
    ], axis=0)
    
    observers.append([x_obs, y_obs, y_err])

    plt.errorbar(
        x_obs, y_obs - true_bias[i], y_err, 
        fmt='.', color='k', ecolor='silver'
    )
plt.plot(t, y, zorder=10)
plt.gca().set(ylabel='SSN')

In [None]:
x_stack = [jnp.array(x) for x, y, yerr in observers]
y_stack = [jnp.array(y) for x, y, yerr in observers]
y_stack_errs = [jnp.array(yerr) for x, y, yerr in observers]

### Define a model with jax

In [None]:
from jax import lax, tree_map, tree_flatten, tree_leaves
from jax.flatten_util import ravel_pytree

In [None]:
sort = jnp.argsort(ravel_pytree(x_stack)[0])

In [None]:
Q = 100.0 

def numpyro_model():    
    # construct a prior for the bias of each non-fixed observer
    bias = numpyro.sample(
        'bias', dist.HalfNormal(10), 
        sample_shape=(N_observers,),
    )
    
    mu = numpyro.sample('mu', dist.Uniform(low=-20, high=20))
    
    model = ravel_pytree(
        tree_map(lambda yi, b: jnp.ones_like(yi) * b, y_stack, list(bias))
    )[0] - mu
    
    # the GP is parameterized by an amplitude S0
    S0 = numpyro.sample('S0', dist.HalfNormal(scale=20))
    
    # We fix the SHO period to the solar cycle period 
    kernel = terms.UnderdampedSHOTerm(
        S0=S0, w0=2*np.pi/11, Q=Q
    )
    
    # construct a GP
    gp = GaussianProcess(
        kernel, 
        t=ravel_pytree(x_stack)[0][sort], 
        diag=ravel_pytree(y_stack_errs)[0][sort]**2,
        mean=ravel_pytree(model)[0][sort], 
        check_sorted=False
    )   
    numpyro.sample(
        'obs', gp.numpyro_dist(), obs=ravel_pytree(y_stack)[0][sort]
    )   

### Run posterior sampling with numpyro

In [None]:
rng_seed = 42
rng_keys = split(
    PRNGKey(rng_seed), 
    cpu_cores
)

sampler = NUTS(
    numpyro_model, 
    dense_mass=True
)

# Monte Carlo sampling for a number of steps and parallel chains: 
mcmc = MCMC(
    sampler, 
    num_warmup=1_000, 
    num_samples=5_000, 
    num_chains=4
)

# Run the MCMC
mcmc.run(rng_keys)

result = arviz.from_numpyro(mcmc)

### Plot posteriors for the bias of each observer

In [None]:
# Plot every `skip`th bias posterior
skip = 20
corner(
    result.posterior.bias.to_numpy().reshape(
        (-1, N_observers)
    )[:, ::skip],
    truths=true_bias[::skip],
    quiet=True,
);

### Plot the posteriors for the GP hyperparameters

In [None]:
corner(
    result, var_names=['S0', 'mu'],
    quiet=True,
);

### Plot the maximum-likelihood model

In [None]:
kernel = terms.UnderdampedSHOTerm(
    S0=result.posterior.S0.to_numpy().flatten().mean(), 
    w0=2*np.pi/11, 
    Q=Q
)

bias_means = result.posterior.bias.to_numpy().reshape(
    (-1, N_observers)
).mean(0)[:, None]

model = ravel_pytree(
    tree_map(lambda yi, b: jnp.ones_like(yi) * b, y_stack, list(bias_means))
)[0] - result.posterior.mu.to_numpy().mean()

# construct a GP
gp = GaussianProcess(
    kernel, 
    t=ravel_pytree(x_stack)[0][sort], 
    diag=ravel_pytree(y_stack_errs)[0][sort]**2,
    check_sorted=False
)

# We'll predict on a new timeseries that goes beyond the input bounds
t_pred = np.linspace(t_start - 10, t_stop + 10, 1000)

plt.figure(figsize=(20, 5))

y_pred, y_pred_var = gp.predict(
    ravel_pytree(ravel_pytree(y_stack)[0]-model-result.posterior.mu.to_numpy().mean())[0][sort], t=t_pred, 
    return_var=True
)

plt.plot(t_pred, y_pred, label='GP-inferred SSN')
plt.fill_between(
    t_pred, 
    y_pred - np.sqrt(y_pred_var), 
    y_pred + np.sqrt(y_pred_var), 
    alpha=0.2, zorder=10
)
for i in range(len(x_stack)):
    plt.plot(
        x_stack[i], (y_stack[i] - bias_means[i]), 
        '.', zorder=-10, alpha=1,
    )

plt.plot(t, y, label='Input SSN')
plt.legend()
plt.gca().set(ylabel='SSN')
plt.show()

### Compare inference with truth

In [None]:
inferred_bias = result.posterior.bias.to_numpy().reshape((-1, N_observers))

plt.figure(figsize=(5, 5))
plt.errorbar(true_bias, inferred_bias.mean(0), inferred_bias.std(0), fmt='o')
plt.plot([0, inferred_bias.mean(0).max()], [0, inferred_bias.mean(0).max()], ls='--', color='silver')
plt.gca().set(
    xlabel='True bias', ylabel='Inferred bias'
)
plt.show()