In [None]:
import sys
from pathlib import Path
sys.path.append(str(Path().parent))

import numpy as np
import pandas as pd

import jax
import jax.numpy as jnp
from optax import linear_onecycle_schedule


import altair as alt

import numpyro
from numpyro.contrib.hsgp.approximation import hsgp_squared_exponential
import numpyro.distributions as dist
from numpyro.infer import SVI, Predictive
from numpyro.infer.elbo import Trace_ELBO
from numpyro.infer.initialization import init_to_value
from numpyro.optim import Adam

## Define helpers

In [2]:
def run_inference_svi(
    prng_key,
    model: callable,
    guide: callable,
    num_steps: int = 5_000,
    peak_lr: float = 0.01,
    **model_kwargs,
):
  lr_scheduler = linear_onecycle_schedule(num_steps, peak_lr)
  svi = SVI(model, guide, Adam(lr_scheduler), Trace_ELBO())
  return svi.run(prng_key, num_steps, progress_bar=True, **model_kwargs)

def posterior_predictive_svi(
    prng_key,
    model: callable,
    guide: callable,
    params: dict,
    num_samples: int = 2000,
    **model_kwargs,
) -> dict[str, jax.Array]:
    predictive = Predictive(model, guide=guide, params=params, num_samples=num_samples)
    return predictive(prng_key, **model_kwargs)

In [3]:
def create_interval_index_array(intervals_series, grid_start=0, grid_end=None, grid_step=1):
    """
    Create an index array that maps intervals to points on a fine integer grid.
    
    Parameters:
    -----------
    intervals_series : pandas.Series
        A pandas Series of categorical dtype containing pd.Interval objects
        that are closed on the left and open on the right [left, right).
    grid_start : int, default 0
        The starting point of the integer grid.
    grid_end : int, optional
        The ending point of the integer grid. If None, inferred from intervals.
    grid_step : int, default 1
        The step size for the integer grid.
    
    Returns:
    --------
    numpy.ndarray
        An index array where each element corresponds to a grid point and
        contains the index of the interval that contains that grid point.
        Points not in any interval get index -1.
    
    Example:
    --------
    >>> import pandas as pd
    >>> import numpy as np
    >>> 
    >>> # Create some intervals
    >>> x = np.linspace(0, 10, 11)
    >>> intervals = pd.cut(x, bins=3, right=False)
    >>> 
    >>> # Create index array for fine grid from 0 to 10
    >>> index_array = create_interval_index_array(intervals, 0, 10)
    >>> print(index_array)
    """
    
    # Get unique intervals and their categorical codes
    unique_intervals = intervals_series.cat.categories
    interval_codes = intervals_series.cat.codes
    
    # Determine grid bounds if not provided
    if grid_end is None:
        max_right = max(interval.right for interval in unique_intervals)
        grid_end = int(np.floor(max_right))

    # Create the fine integer grid
    grid_points = np.arange(grid_start, grid_end + grid_step, grid_step)
    
    # Initialize index array with -1 (indicating no interval contains the point)
    index_array = np.full(len(grid_points), -1, dtype=int)
    
    # For each unique interval, find which grid points it contains
    for i, interval in enumerate(unique_intervals):
        # Find grid points that fall within this interval [left, right)
        mask = (grid_points >= interval.left) & (grid_points < interval.right)
        
        # Set the index array to the categorical code for this interval
        index_array[mask] = i
    
    return index_array


def create_mapping_matrix(intervals_series, grid_start=0, grid_end=None, grid_step=1):
    """
    Create a mapping matrix that aggregates values from a fine grid to intervals.
    
    Parameters:
    -----------
    intervals_series : pandas.Series
        A pandas Series of categorical dtype containing pd.Interval objects.
    grid_start : int, default 0
        The starting point of the integer grid.
    grid_end : int, optional
        The ending point of the integer grid. If None, inferred from intervals.
    grid_step : int, default 1
        The step size for the integer grid.
    
    Returns:
    --------
    numpy.ndarray
        A matrix of shape (n_intervals, n_grid_points) where each row corresponds
        to an interval and contains 1s for grid points within that interval.
        This can be used for aggregation: aggregated_values = matrix @ grid_values
    """
    
    # Get the index array
    index_array = create_interval_index_array(intervals_series, grid_start, grid_end, grid_step)
    
    # Get number of unique intervals
    n_intervals = len(intervals_series.cat.categories)
    n_grid_points = len(index_array)
    
    # Create mapping matrix
    mapping_matrix = np.zeros((n_intervals, n_grid_points), dtype=int)
    
    # Fill the mapping matrix
    for grid_idx, interval_idx in enumerate(index_array):
        if interval_idx >= 0:  # Valid interval index
            mapping_matrix[interval_idx, grid_idx] = 1
    
    return mapping_matrix

## Generate Data

In [4]:
N = 1000

Generate data.

In [5]:
x = np.arange(0, 100)
log_exposure = np.exp(-x / 10)
f = np.sin(x / 20) + np.cos(x / 10)
rate = np.exp(log_exposure + f)

df_true = pd.DataFrame({
	'x': x,
	'log_exposure': log_exposure,
	'f': f,
	'rate': rate
})

idx_list = np.arange(100)
idx = np.random.choice(idx_list, size=N, replace=True)

s = x[idx]
y = np.random.poisson(rate[idx])

df_data = pd.DataFrame({
	'x': s,
	'log_exposure': log_exposure[idx],
	'f': f[idx],
	'rate': rate[idx],
	'y': y
})

df_data.head(10)

Unnamed: 0,x,log_exposure,f,rate,y
0,35,0.030197,0.047529,1.080827,1
1,94,8.3e-05,-1.999616,0.135398,0
2,39,0.020242,0.203027,1.250157,0
3,15,0.22313,0.752376,2.652509,2
4,10,0.367879,1.019728,4.005255,3
5,0,1.0,1.0,7.389056,10
6,44,0.012277,0.501164,1.671031,2
7,69,0.001008,0.512184,1.670614,1
8,33,0.036883,0.009385,1.047356,2
9,86,0.000184,-1.594886,0.202969,0


In [6]:
rate = (
  alt.Chart(df_data).mark_line(color='#de425b').encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Rate'),
	)
)

count = alt.Chart(df_data).mark_point().encode(
	x=alt.X('x', title='x'),
	y=alt.Y('y', title='Count'),
)

(rate + count).properties(width=600, height=200)

In [7]:
df_data['range'] = pd.cut(df_data['x'], bins=20, right=False)

tmp = df_data.groupby('range', observed=True).agg(
	y_agg = ('y', 'sum'),
	N = ('y', 'count')
).reset_index()

df_data = df_data.merge(tmp, on='range', how='left')
df_data['rate_agg'] = df_data['y_agg'] / df_data['N']
df_data.head(5)

Unnamed: 0,x,log_exposure,f,rate,y,range,y_agg,N,rate_agg
0,35,0.030197,0.047529,1.080827,1,"[34.65, 39.6)",43,37,1.162162
1,94,8.3e-05,-1.999616,0.135398,0,"[89.1, 94.05)",9,51,0.176471
2,39,0.020242,0.203027,1.250157,0,"[34.65, 39.6)",43,37,1.162162
3,15,0.22313,0.752376,2.652509,2,"[14.85, 19.8)",145,61,2.377049
4,10,0.367879,1.019728,4.005255,3,"[9.9, 14.85)",180,54,3.333333


In [8]:
# Create a clean dataset without the Interval objects for plotting
df_plot = df_data.drop(columns=['range'])

rate = (
  alt.Chart(df_plot)
  .mark_line(color='red')
  .encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Rate'),
	)
)

count = alt.Chart(df_plot).mark_point().encode(
	x=alt.X('x', title='x'),
	y=alt.Y('y', title='Count'),
)

rate_agg = (
  alt.Chart(df_plot)
  .mark_line(color='blue', strokeDash=[5, 5])  # Changed color and added dashed line to distinguish
  .encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate_agg', title='Aggregated Rate'),
	)
)

(rate_agg + rate + count).properties(width=600, height=200)

In [9]:
# Standardize x
x_std = (x - np.mean(x)) / np.std(x)

mapping_matrix = create_mapping_matrix(
  df_data['range'],
  grid_start=0,
  grid_end=99,
  grid_step=1
)

rid = df_data['range'].cat.codes.values

In [10]:
def model(x: np.ndarray,
          N: np.ndarray,
          log_exposure: np.ndarray,
          map_matrix: np.ndarray,
          L: float,
          M: int,
          non_centered: bool,
          y: np.ndarray = None):
  
  # --- Priors ---
  beta = numpyro.sample('baseline', dist.Normal(0, 1))
  sigma = numpyro.sample('sigma', dist.InverseGamma(5, 5))
  lenscale = numpyro.sample('lenscale', dist.InverseGamma(5, 5))
  
  # --- Parameterization ---
  f = hsgp_squared_exponential(x=x, alpha=sigma, length=lenscale, ell=L, m=M, non_centered=non_centered)
  
  # --- Likelihood ---
  log_rate = log_exposure + (beta + f)
  rate = numpyro.deterministic('rate', jnp.exp(log_rate))
  mu_agg = (map_matrix @ rate) * N

  numpyro.sample('y', dist.Poisson(mu_agg), obs=y)

In [11]:
df_obs = df_data.groupby('range').agg(
  y_agg=('y', 'sum'),
  N=('y', 'count')
).reset_index()

y_agg = df_obs['y_agg'].values
N = df_obs['N'].values

  df_obs = df_data.groupby('range').agg(


In [12]:
model_data = {
	'x': x_std,
	'N': N,
	'log_exposure': log_exposure,
	'map_matrix': mapping_matrix,
	'L': 2.0,  # Length scale for the GP kernel
	'M': 30,  # Number of inducing points
	'non_centered': False,  # Use non-centered parameterization
	'y': y_agg,
}

The following model should fail to accurately infer the rate.

In [13]:
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.initialization import init_to_value

prng_key = jax.random.PRNGKey(42)
init_values = {'baseline': -log_exposure.mean()}
guide = AutoNormal(model, init_loc_fn=init_to_value(values=init_values))
svi = run_inference_svi(prng_key, model, guide, **model_data)

100%|██████████| 5000/5000 [00:03<00:00, 1468.58it/s, init loss: 296270.6562, avg. loss [4751-5000]: 491.0443] 


In [14]:
post_pred = posterior_predictive_svi(prng_key, model, guide, svi.params, **model_data)
post_rate_sum = np.quantile(post_pred['rate'], q = (0.025, 0.5, 0.975), axis=0)

In [None]:
df_true['M'] = post_rate_sum[1][df_true['x'].values]
df_true['CL'] = post_rate_sum[0][df_true['x'].values]
df_true['CU'] = post_rate_sum[2][df_true['x'].values]

In [30]:
rate = (
  alt.Chart(df_plot)
  .mark_line(color='red')
  .encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Intensity', scale=alt.Scale(domain=(0, 12))),
	)
)

rate_est = (
  alt.Chart(df_true)
	.mark_line(color='green')  # Changed
	.encode(
		x=alt.X('x', title='x'),
		y=alt.Y('M', title='Intensity'),
	)
)

bands = alt.Chart(df_true).mark_errorband(color='green').encode(
	x=alt.X('x', title='x'),
	y=alt.Y('CL', title='Intensity'),
	y2=alt.Y2('CU', title='Intensity'),
)

(rate + rate_est + bands).properties(width=600, height=200)