# Self-consistent, Bayesian SSN calibration and inference

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import numpyro
from numpyro.infer import MCMC, NUTS
from numpyro import distributions as dist

# Set the number of cores on your machine for parallelism:
cpu_cores = 4
numpyro.set_host_device_count(cpu_cores)

from jax import (
    numpy as jnp, vmap, config, lax, 
    tree_map, tree_flatten, tree_leaves
)
from jax.random import PRNGKey, split
from jax.flatten_util import ravel_pytree

config.update('jax_enable_x64', True)

from scipy.ndimage import gaussian_filter1d

import arviz
from corner import corner

from celerite2.jax import terms, GaussianProcess

### Generate synthetic observations

In [None]:
np.random.seed(10)
t_start = 1600
t_stop = 2022
t = np.linspace(t_start, t_stop, 1000)

# Specify the frequency of observations, which increases sharply after 1800:
data_collection_rate = 1 / (np.exp((1800 - t) / 10) + 1) + 0.5
data_collection_rate_cdf = np.cumsum(data_collection_rate)
data_collection_rate_cdf /= data_collection_rate_cdf.max()

amp_mod = gaussian_filter1d(
    np.random.uniform(size=len(t)), 100
)

y = (
    20 * (amp_mod - amp_mod.min() + 0)/amp_mod.ptp() * 
    np.cos(2 * np.pi * t / 2 / 11)**2
)

observers = []

N_observers = 100

# true_weight = np.concatenate([[1], np.random.uniform(0.2, 1, size=N_observers-1)])
true_weight = np.concatenate([[1], 1-np.abs(np.random.normal(0, 0.3, size=N_observers-1))])

plt.figure(figsize=(20, 5))

for i in range(N_observers):
    # Randomly choose start/stop times for observer

    if i == 0:
        tbounds = [1930, 1990]
    else:
        tbounds = np.interp(
            np.random.uniform(0, 1, size=2), 
            data_collection_rate_cdf, 
            t
        )

    # Randomly distribute observations within start/stop time
    
    N_observations_per_observer = (
        max(int(np.random.lognormal(np.log(30), 0.5)), 50) 
        if i > 0 else 600
    )
    x_obs = np.sort(np.random.uniform(
        tbounds[0], 
        tbounds[1], 
        N_observations_per_observer
    ))
    noise_scale = 0.2
    noise = np.random.normal(scale=noise_scale, size=len(x_obs))
    y_obs = np.interp(x_obs, t, y) * true_weight[i] + noise
    y_err = jnp.broadcast_to(1, x_obs.shape) 
    #* 10 / N_observations_per_observer**0.5 # np.max([
    #     4 * (t_stop - x_obs) / (t_stop - t_start), 
    #     jnp.broadcast_to(noise_scale, x_obs.shape)
    # ], axis=0)
    
    observers.append([x_obs, y_obs, y_err])

    plt.errorbar(
        x_obs, y_obs / true_weight[i], y_err, 
        fmt='.', color='k', ecolor='silver'
    )
plt.plot(t, y, zorder=10)
plt.gca().set(ylabel='SSN');

In [None]:
x_stack = [jnp.array(x) for x, y, yerr in observers]
y_stack = [jnp.array(y) for x, y, yerr in observers]
y_stack_errs = [jnp.array(yerr) for x, y, yerr in observers]

This "first" time series is taken as the ground-truth against which all others are calibrated:

In [None]:
npoints = [yi.shape[0] for yi in y_stack]
mask = [yi.shape[0] == max(npoints) for yi in y_stack]
plt.errorbar(x_stack[i], y_stack[i], y_stack_errs[i], fmt='.', ecolor='silver')

### Define a model with jax

In [None]:
sort = jnp.argsort(ravel_pytree(x_stack)[0])

In [None]:
Q = 100.0 

def numpyro_model(x_stack, y_stack, y_stack_errs):    
    weight = numpyro.sample(
        'weight', dist.TwoSidedTruncatedDistribution(
            dist.Normal(loc=1, scale=0.1), low=0, high=1),
        sample_shape=(N_observers - 1,),
    )
    # weight = numpyro.sample(
    #     'weight', dist.Uniform(low=0, high=1),
    #     sample_shape=(N_observers - 1,),
    # )
    # weight = numpyro.sample(
    #     'weight', dist.TwoSidedTruncatedDistribution(
    #         dist.Normal(
    #             loc=jnp.array(true_weight),
    #             scale=0.1 * jnp.ones_like(true_weight)
    #     ), low=0, high=1
    #     )#sample_shape=(N_observers,),
    # )

    y_tree = tree_map(lambda yi, w: yi / w, y_stack, [1] + list(weight))
    y_weighted_sorted = ravel_pytree(y_tree)[0][sort]

    # the GP is parameterized by an amplitude S0
    S0 = numpyro.sample('S0', dist.HalfNormal(scale=20))
    mu = numpyro.sample('mu', dist.Uniform(low=-20, high=20))
    # We fix the SHO period to the solar cycle period 
    kernel = terms.UnderdampedSHOTerm(
        S0=S0, w0=2*np.pi/11, Q=Q
    )
    diag_scale = numpyro.sample('diag_scale', dist.Uniform(low=0.01, high=100))
    # construct a GP
    gp = GaussianProcess(
        kernel, 
        t=ravel_pytree(x_stack)[0][sort], 
        diag=diag_scale * ravel_pytree(y_stack_errs)[0][sort]**2,
        mean=mu, 
        check_sorted=False
    )   
    numpyro.sample(
        'obs', gp.numpyro_dist(), obs=y_weighted_sorted
    )   

### Run posterior sampling with numpyro

In [None]:
rng_seed = 42
rng_keys = split(
    PRNGKey(rng_seed), 
    cpu_cores
)

sampler = NUTS(
    numpyro_model, 
    dense_mass=True
)

# Monte Carlo sampling for a number of steps and parallel chains: 
mcmc = MCMC(
    sampler, 
    num_warmup=100, 
    num_samples=500, 
    num_chains=4
)

# Run the MCMC
mcmc.run(rng_keys, x_stack, y_stack, y_stack_errs)

result = arviz.from_numpyro(mcmc)

### Plot posteriors for the bias of each observer

In [None]:
# Plot every `skip`th weight posterior
skip = 10
corner(
    result.posterior.weight.to_numpy().reshape(
        (-1, N_observers - 1)
    )[:, ::skip],
    truths=true_weight[1:][::skip],
    quiet=True, plot_datapoints=False
);

### Plot the posteriors for the GP hyperparameters

In [None]:
corner(
    result, var_names=['S0', 'mu', 'diag_scale'],
    quiet=True,
);

### Plot the maximum-likelihood model

In [None]:
kernel = terms.UnderdampedSHOTerm(
    S0=result.posterior.S0.to_numpy().flatten().mean(), 
    w0=2*np.pi/11, 
    Q=Q
)

weight_means = np.concatenate([[1], result.posterior.weight.to_numpy().reshape(
    (-1, N_observers - 1)
).mean(0)])
model = ravel_pytree(
    tree_map(lambda yi, w: yi / w, y_stack, list(weight_means))
)[0][sort]

# construct a GP
gp = GaussianProcess(
    kernel,
    t=ravel_pytree(x_stack)[0][sort], 
    diag=ravel_pytree(y_stack_errs)[0][sort]**2,
    check_sorted=False, 
    mean=result.posterior.mu.to_numpy().flatten().mean()
)

# We'll predict on a new timeseries that goes beyond the input bounds
t_pred = np.linspace(t_start - 10, t_stop + 10, 1000)

plt.figure(figsize=(20, 5))

y_pred, y_pred_var = gp.predict(
    model, t=t_pred, 
    return_var=True
)

plt.plot(t_pred, y_pred, label='GP-inferred SSN')
plt.fill_between(
    t_pred, 
    y_pred - np.sqrt(y_pred_var), 
    y_pred + np.sqrt(y_pred_var), 
    alpha=0.2, zorder=10
)
for i in range(len(x_stack)):
    plt.plot(
        x_stack[i], y_stack[i] / weight_means[i], 
        '.', zorder=-10, alpha=1,
    )

plt.plot(t, y, label='Input SSN')
plt.legend()
plt.gca().set(ylabel='SSN', ylim=[-5, 1.1*y.max()])
plt.show()

### Compare inference with truth

In [None]:
# inferred_weight.mean(0).shape, true_weight.shape

In [None]:
inferred_weight = result.posterior.weight.to_numpy().reshape((-1, N_observers - 1))

plt.figure(figsize=(5, 5))

plt.scatter(true_weight[1:], inferred_weight.mean(0), marker='.', c=npoints[1:])
plt.errorbar(
    true_weight[1:], inferred_weight.mean(0), inferred_weight.std(0), 
    fmt=',', zorder=-2, ecolor='silver'
)
plt.colorbar()
mm = [inferred_weight.mean(0).min(), inferred_weight.mean(0).max()]
plt.plot(mm, mm, ls='--', color='silver')
plt.gca().set(
    xlabel='True weight', ylabel='Inferred weight'
)
plt.show()