## Load Libraries

In [2]:
import sys
from pathlib import Path
sys.path.append(str(Path().parent))

import numpy as np
import pandas as pd

import jax
import jax.numpy as jnp
from optax import linear_onecycle_schedule

import numpyro
from numpyro.contrib.hsgp.approximation import hsgp_squared_exponential
import numpyro.distributions as dist
from numpyro.infer import SVI, Predictive
from numpyro.infer.elbo import Trace_ELBO
from numpyro.optim import Adam
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.initialization import init_to_value

import altair as alt

## Define helpers

Helpers for inference

In [3]:
def run_inference_svi(
		prng_key,
		model: callable,
		guide: callable,
		num_steps: int = 5_000,
		peak_lr: float = 0.01,
		**model_kwargs,
):
	lr_scheduler = linear_onecycle_schedule(num_steps, peak_lr)
	svi = SVI(model, guide, Adam(lr_scheduler), Trace_ELBO())
	return svi.run(prng_key, num_steps, progress_bar=True, **model_kwargs)

def posterior_predictive_svi(
		prng_key,
		model: callable,
		guide: callable,
		params: dict,
		num_samples: int = 2000,
		**model_kwargs,
) -> dict[str, jax.Array]:
		predictive = Predictive(model, guide=guide, params=params, num_samples=num_samples)
		return predictive(prng_key, **model_kwargs)

Define custom multinomial distribution with JAX support

In [4]:
from collections.abc import Sequence
from jax._src.lax import control_flow as lax_control_flow
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax.random import binomial, split

RealArray = ArrayLike
IntegerArray = ArrayLike
DTypeLikeInt = DTypeLike
DTypeLikeUInt = DTypeLike
DTypeLikeFloat = DTypeLike
Shape = Sequence[int]

def multinomial(
    key: Array,
    n: RealArray,
    p: RealArray,
    *,
    shape: Shape | None = None,
    dtype: DTypeLikeFloat = float,
    unroll: int | bool = 1,
):
  r"""Sample from a multinomial distribution.

  The probability mass function is

  .. math::
      f(x;n,p) = \frac{n!}{x_1! \ldots x_k!} p_1^{x_1} \ldots p_k^{x_k}

  Args:
    key: PRNG key.
    n: number of trials. Should have shape broadcastable to ``p.shape[:-1]``.
    p: probability of each outcome, with outcomes along the last axis.
    shape: optional, a tuple of nonnegative integers specifying the result batch
      shape, that is, the prefix of the result shape excluding the last axis.
      Must be broadcast-compatible with ``p.shape[:-1]``. The default (None)
      produces a result shape equal to ``p.shape``.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).
    unroll: optional, unroll parameter passed to :func:`jax.lax.scan` inside the
      implementation of this function.

  Returns:
    An array of counts for each outcome with the specified dtype and with shape
      ``p.shape`` if ``shape`` is None, otherwise ``shape + (p.shape[-1],)``.
  """

  check_arraylike("multinomial", n, p)
  n, p = promote_dtypes_inexact(n, p)

  if shape is None:
    shape = p.shape
  n = jnp.broadcast_to(n, shape[:-1])
  p = jnp.broadcast_to(p, shape)

  def f(remainder, ratio_key):
    ratio, key = ratio_key
    count = binomial(key, remainder, ratio.clip(0, 1), dtype=remainder.dtype)
    return remainder - count, count

  p = jnp.moveaxis(p, -1, 0)

  remaining_probs = lax_control_flow.cumsum(p, 0, reverse=True)
  ratios = p / jnp.where(remaining_probs == 0, 1, remaining_probs)

  keys = split(key, ratios.shape[0])
  remainder, counts = lax_control_flow.scan(f, n, (ratios, keys), unroll=unroll)
  # final remainder should be zero

  return jnp.moveaxis(counts, 0, -1).astype(dtype)

Helper funtions for preprocessing the data

In [5]:
def create_interval_index_array(intervals_series, grid_start=0, grid_end=None, grid_step=1):
		"""
		Create an index array that maps intervals to points on a fine integer grid.
		
		Parameters:
		-----------
		intervals_series : pandas.Series
				A pandas Series of categorical dtype containing pd.Interval objects
				that are closed on the left and open on the right [left, right).
		grid_start : int, default 0
				The starting point of the integer grid.
		grid_end : int, optional
				The ending point of the integer grid. If None, inferred from intervals.
		grid_step : int, default 1
				The step size for the integer grid.
		
		Returns:
		--------
		numpy.ndarray
				An index array where each element corresponds to a grid point and
				contains the index of the interval that contains that grid point.
				Points not in any interval get index -1.
		"""
		
		# Get unique intervals and their categorical codes
		unique_intervals = intervals_series.cat.categories
		
		# Determine grid bounds if not provided
		if grid_end is None:
				max_right = max(interval.right for interval in unique_intervals)
				grid_end = int(np.floor(max_right))

		# Create the fine integer grid
		grid_points = np.arange(grid_start, grid_end + grid_step, grid_step)
		
		# Initialize index array with -1 (indicating no interval contains the point)
		index_array = np.full(len(grid_points), -1, dtype=int)
		
		# For each unique interval, find which grid points it contains
		for i, interval in enumerate(unique_intervals):
				# Find grid points that fall within this interval [left, right)
				mask = (grid_points >= interval.left) & (grid_points < interval.right)
				
				# Set the index array to the categorical code for this interval
				index_array[mask] = i
		
		return index_array

## Generate Data

In [6]:
N = 1000

In [7]:
def generate_data(N, x_min = 0, x_max = 99):
  # Define grid
	x = np.arange(x_min, x_max + 1)
 
	# Calculate log exposure and rate
	log_exposure = np.exp(-x / 10)
	f = np.sin(x / 20) + np.cos(x / 10)
	rate = np.exp(log_exposure + f)

	# Sample indices with replacement
	idx_list = np.arange(x_max + 1)
	idx = np.random.choice(idx_list, size=N, replace=True)

	# Sample data points and corresponding rates
	s = x[idx]
	y = np.random.poisson(rate[idx])

	# Create DataFrames
	df_true = pd.DataFrame({'x': x,
												'log_exposure': log_exposure,
												'f': f,
												'rate': rate})
 
	df_data = pd.DataFrame({
		'x': s,
		'log_exposure': log_exposure[idx],
		'f': f[idx],
		'rate': rate[idx],
		'y': y
	})

	return df_true, df_data

Generate data.

In [8]:
df_true, df_data = generate_data(N)

In [9]:
rate = (
	alt.Chart(df_data).mark_line(color='#de425b').encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Rate'),
	)
)

count = alt.Chart(df_data).mark_point().encode(
	x=alt.X('x', title='x'),
	y=alt.Y('y', title='Count'),
)

(rate + count).properties(width=600, height=200)

In [10]:
df_data['range'] = pd.cut(df_data['x'], bins=np.arange(0, 101, 5), right=False)

tmp = df_data.groupby('range', observed=True).agg(
	y_agg = ('y', 'sum'),
	N = ('y', 'count')
).reset_index()

df_data = df_data.merge(tmp, on='range', how='left')
df_data['rate_agg'] = df_data['y_agg'] / df_data['N']
df_data.head(5)

Unnamed: 0,x,log_exposure,f,rate,y,range,y_agg,N,rate_agg
0,56,0.003698,1.110554,3.047288,3,"[55, 60)",179,56,3.196429
1,71,0.000825,0.287398,1.334056,2,"[70, 75)",69,50,1.38
2,67,0.001231,0.707481,2.031373,4,"[65, 70)",85,41,2.073171
3,63,0.001836,0.991451,2.700097,2,"[60, 65)",132,47,2.808511
4,81,0.000304,-1.032069,0.356377,0,"[80, 85)",16,51,0.313725


In [79]:
# Create a clean dataset without the Interval objects for plotting
df_plot = df_data.drop(columns=['range'])

rate = (
	alt.Chart(df_plot)
	.mark_line(color='red')
	.encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Rate'),
	)
)

count = alt.Chart(df_plot).mark_point(size=2).encode(
	x=alt.X('x', title='x'),
	y=alt.Y('y', title='Count'),
)

rate_agg = (
	alt.Chart(df_plot)
	.mark_line(color='orange', strokeDash=[5, 5])  # Changed color and added dashed line to distinguish
	.encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate_agg', title='Aggregated Rate'),
	)
)

(rate_agg + rate + count).properties(width=600, height=200)

In [69]:
def model(x: np.ndarray,
					N: np.ndarray,
					log_P: np.ndarray,
					int_map: dict[int, np.ndarray],
					L: float,
					M: int,
					T: np.ndarray = None):
	
	# --- Priors ---
	beta = numpyro.sample('baseline', dist.Normal(0, 1))
	sigma = numpyro.sample('sigma', dist.LogNormal(0, 1))
	lenscale = numpyro.sample('lenscale', dist.LogNormal(0, 1))

	# --- Parameterization ---
	f = hsgp_squared_exponential(
   x=x,
   alpha=sigma,
   length=lenscale,
   ell=L,
   m=M,
   non_centered=False
  )

	# --- Likelihood ---
	log_rate = log_P + (beta + f)
	rate = numpyro.deterministic('rate', jnp.exp(log_rate))
	
	# --- Data augmentation ---
	y_aug = jnp.zeros(len(x))
	for i, ind in int_map.items():
		probs = rate[ind] / jnp.sum(rate[ind])
		y_aug = y_aug.at[ind].set(multinomial(key=jax.random.PRNGKey(0), n=T[i], p=probs))

	with numpyro.plate('data', len(x)):
		numpyro.sample('y', dist.Poisson(rate * N / 5), obs=y_aug)


In [70]:
int_codes = df_data['range'].cat.codes.sort_values().unique()
int_ind_arr = create_interval_index_array(df_data['range'],
																					grid_start=0,
																					grid_end=99,
																					grid_step=1)
ind_arr = np.arange(len(int_ind_arr))

# Maps the interval codes to their corresponding indices
int_map = {
	int(code): ind_arr[int_ind_arr == code]
for code in int_codes}

# Standardize x
x = df_true['x'].values
x_std = (x - np.mean(x)) / np.std(x)

# Poisson thinning
N = df_data['N'].values

df_obs = df_data.groupby('range', observed=True).agg(
	y_agg=('y', 'sum'),
	N=('y', 'count')
).reset_index()

T = df_obs['y_agg'].values
# N = df_obs['N'].values

In [71]:
df_N = df_data[['x', 'N']].drop_duplicates().sort_values(by='x').reset_index(drop=True)
N = df_N['N'].values

In [72]:
model_data = {
	'x': x_std,
	'N': N,
	'log_P': df_true['log_exposure'].values,
	'int_map': int_map,
	'L': 10.0, 
	'M': 30,  # Number of inducing points
	'T': T,
}

The following model should fail to accurately infer the rate.

In [73]:
prng_key = jax.random.PRNGKey(42)
init_values = {
  'baseline': -df_true['log_exposure'].values.mean()
}
guide = AutoNormal(model, init_loc_fn=init_to_value(values=init_values))
svi = run_inference_svi(prng_key, model, guide, **model_data)

100%|██████████| 5000/5000 [00:33<00:00, 151.13it/s, init loss: 3797.7534, avg. loss [4751-5000]: 350.7061]


In [74]:
post_pred = posterior_predictive_svi(prng_key, model, guide, svi.params, **model_data)
post_rate_sum = np.quantile(post_pred['rate'], q = (0.025, 0.5, 0.975), axis=0)

In [75]:
df_true['M'] = post_rate_sum[1][df_true['x'].values]
df_true['CL'] = post_rate_sum[0][df_true['x'].values]
df_true['CU'] = post_rate_sum[2][df_true['x'].values]

In [76]:
svi.params

{'baseline_auto_loc': Array(0.01706906, dtype=float32),
 'baseline_auto_scale': Array(0.02619283, dtype=float32),
 'beta_auto_loc': Array([ 1.2101327 ,  0.43130377, -0.661234  , -0.33806458, -1.6233759 ,
         1.7236325 ,  0.5317304 ,  0.24385883,  1.2212498 ,  0.17222165,
        -0.45071432, -0.8072674 , -0.36877945, -0.68363196, -0.9435648 ,
         0.8054031 ,  0.99164313, -0.5903625 , -0.60134375,  0.19065456,
        -0.7937727 ,  0.57289404, -0.7106124 ,  1.4395381 , -1.3907038 ,
        -1.2235905 , -0.4283401 , -0.6508319 ,  0.5131671 ,  0.65669364],      dtype=float32),
 'beta_auto_scale': Array([0.07351285, 0.24138463, 0.08228104, 0.13210855, 0.0922631 ,
        0.09694765, 0.10730568, 0.09074406, 0.10360788, 0.10751261,
        0.09498505, 0.12378865, 0.08576813, 0.12131248, 0.09455027,
        0.10784619, 0.10438012, 0.09633049, 0.11039106, 0.09383176,
        0.11243469, 0.09989084, 0.10551156, 0.10110452, 0.0978139 ,
        0.10103423, 0.10059325, 0.10224608, 0.1079

In [77]:
rate = (
	alt.Chart(df_plot)
	.mark_line(color='red')
	.encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Intensity', scale=alt.Scale(domain=(0, 12))),
	)
)

rate_est = (
	alt.Chart(df_true)
	.mark_line(color='green')  # Changed
	.encode(
		x=alt.X('x', title='x'),
		y=alt.Y('M', title='Intensity'),
	)
)

bands = alt.Chart(df_true).mark_errorband(color='green').encode(
	x=alt.X('x', title='x'),
	y=alt.Y('CL', title='Intensity'),
	y2=alt.Y2('CU', title='Intensity'),
)

(rate + rate_est + bands).properties(width=600, height=200)