In [1]:
import numpy as np
import pandas as pd
import scipy as sp

import jax
import jax.numpy as jnp

import altair as alt

In [34]:
x = np.linspace(0, 100, 100)
log_exposure = np.exp(-x / 10)
f = 0.5 * np.sin(x / 20) + 0.5 * np.cos(x / 10)
rate = np.exp(log_exposure + f)

y = np.random.poisson(rate, size=(100,))

df_data = pd.DataFrame({
	'x': x,
	'log_exposure': log_exposure,
	'f': f,
	'rate': rate,
	'y': y
})

In [35]:
rate = (
  alt.Chart(df_data)
  .mark_line(color='red')
  .encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Rate'),
	)
)

count = alt.Chart(df_data).mark_point().encode(
	x=alt.X('x', title='x'),
	y=alt.Y('y', title='Count'),
)

(rate + count).properties(width=600)

In [36]:
df_data['range'] = pd.cut(df_data['x'], bins=20, right=False)

tmp = df_data.groupby('range', observed=True).agg(
	y_agg = ('y', 'sum'),
	N = ('y', 'count')
).reset_index()

df_data = df_data.merge(tmp, on='range', how='left')
df_data['rate_agg'] = df_data['y_agg'] / df_data['N']
df_data.head(10)

Unnamed: 0,x,log_exposure,f,rate,y,range,y_agg,N,rate_agg
0,0.0,1.0,0.5,4.481689,5,"[0.0, 5.0)",19,5,3.8
1,1.010101,0.903924,0.522693,4.164587,1,"[0.0, 5.0)",19,5,3.8
2,2.020202,0.817078,0.540251,3.885801,4,"[0.0, 5.0)",19,5,3.8
3,3.030303,0.738577,0.552686,3.637378,4,"[0.0, 5.0)",19,5,3.8
4,4.040404,0.667617,0.560064,3.413307,5,"[0.0, 5.0)",19,5,3.8
5,5.050505,0.603475,0.5625,3.20905,2,"[5.0, 10.0)",19,5,3.8
6,6.060606,0.545496,0.560156,3.021193,4,"[5.0, 10.0)",19,5,3.8
7,7.070707,0.493086,0.553242,2.847179,6,"[5.0, 10.0)",19,5,3.8
8,8.080808,0.445713,0.542012,2.685118,4,"[5.0, 10.0)",19,5,3.8
9,9.090909,0.40289,0.526759,2.53362,3,"[5.0, 10.0)",19,5,3.8


In [37]:
# Create a clean dataset without the Interval objects for plotting
df_plot = df_data.drop(columns=['range'])

rate = (
  alt.Chart(df_plot)
  .mark_line(color='red')
  .encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Rate'),
	)
)

count = alt.Chart(df_plot).mark_point().encode(
	x=alt.X('x', title='x'),
	y=alt.Y('y', title='Count'),
)

rate_agg = (
  alt.Chart(df_plot)
  .mark_line(color='blue', strokeDash=[5, 5])  # Changed color and added dashed line to distinguish
  .encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate_agg', title='Aggregated Rate'),
	)
)

(rate_agg + rate + count).properties(width=600)

In [38]:
import numpyro
from numpyro.contrib.hsgp.approximation import hsgp_squared_exponential
from numpyro.contrib.hsgp.laplacian import eigenfunctions
from numpyro.contrib.hsgp.spectral_densities import (
	diag_spectral_density_squared_exponential
)
import numpyro.distributions as dist
from numpyro.infer import SVI, MCMC, NUTS, Predictive

In [39]:
def create_interval_index_array(intervals_series, grid_start=0, grid_end=None, grid_step=1):
    """
    Create an index array that maps intervals to points on a fine integer grid.
    
    Parameters:
    -----------
    intervals_series : pandas.Series
        A pandas Series of categorical dtype containing pd.Interval objects
        that are closed on the left and open on the right [left, right).
    grid_start : int, default 0
        The starting point of the integer grid.
    grid_end : int, optional
        The ending point of the integer grid. If None, inferred from intervals.
    grid_step : int, default 1
        The step size for the integer grid.
    
    Returns:
    --------
    numpy.ndarray
        An index array where each element corresponds to a grid point and
        contains the index of the interval that contains that grid point.
        Points not in any interval get index -1.
    
    Example:
    --------
    >>> import pandas as pd
    >>> import numpy as np
    >>> 
    >>> # Create some intervals
    >>> x = np.linspace(0, 10, 11)
    >>> intervals = pd.cut(x, bins=3, right=False)
    >>> 
    >>> # Create index array for fine grid from 0 to 10
    >>> index_array = create_interval_index_array(intervals, 0, 10)
    >>> print(index_array)
    """
    
    # Get unique intervals and their categorical codes
    unique_intervals = intervals_series.cat.categories
    interval_codes = intervals_series.cat.codes
    
    # Determine grid bounds if not provided
    if grid_end is None:
        max_right = max(interval.right for interval in unique_intervals)
        grid_end = int(np.floor(max_right))

    # Create the fine integer grid
    grid_points = np.arange(grid_start, grid_end + grid_step, grid_step)
    
    # Initialize index array with -1 (indicating no interval contains the point)
    index_array = np.full(len(grid_points), -1, dtype=int)
    
    # For each unique interval, find which grid points it contains
    for i, interval in enumerate(unique_intervals):
        # Find grid points that fall within this interval [left, right)
        mask = (grid_points >= interval.left) & (grid_points < interval.right)
        
        # Set the index array to the categorical code for this interval
        index_array[mask] = i
    
    return index_array


def create_mapping_matrix(intervals_series, grid_start=0, grid_end=None, grid_step=1):
    """
    Create a mapping matrix that aggregates values from a fine grid to intervals.
    
    Parameters:
    -----------
    intervals_series : pandas.Series
        A pandas Series of categorical dtype containing pd.Interval objects.
    grid_start : int, default 0
        The starting point of the integer grid.
    grid_end : int, optional
        The ending point of the integer grid. If None, inferred from intervals.
    grid_step : int, default 1
        The step size for the integer grid.
    
    Returns:
    --------
    numpy.ndarray
        A matrix of shape (n_intervals, n_grid_points) where each row corresponds
        to an interval and contains 1s for grid points within that interval.
        This can be used for aggregation: aggregated_values = matrix @ grid_values
    """
    
    # Get the index array
    index_array = create_interval_index_array(intervals_series, grid_start, grid_end, grid_step)
    
    # Get number of unique intervals
    n_intervals = len(intervals_series.cat.categories)
    n_grid_points = len(index_array)
    
    # Create mapping matrix
    mapping_matrix = np.zeros((n_intervals, n_grid_points), dtype=int)
    
    # Fill the mapping matrix
    for grid_idx, interval_idx in enumerate(index_array):
        if interval_idx >= 0:  # Valid interval index
            mapping_matrix[interval_idx, grid_idx] = 1
    
    return mapping_matrix

In [40]:
# Standardize x
x_std = (x - np.mean(x)) / np.std(x)

mapping_matrix = create_mapping_matrix(
  df_data['range'],
  grid_start=0,
  grid_end=99,
  grid_step=1
)

In [None]:
def model(x, log_exposure, map_matrix, L, M, non_centered, y=None):
  # --- Priors ---
  beta = numpyro.sample('baseline', dist.Normal(0, 10))
  sigma = numpyro.sample('sigma', dist.InverseGamma(5, 5))
  lenscale = numpyro.sample('lenscale', dist.InverseGamma(5, 5))
  
  # --- Parameterization ---
  f = hsgp_squared_exponential(x=x, alpha=sigma, length=lenscale, ell=L, m=M, non_centered=non_centered)
  
  # --- Likelihood ---
  log_rate = log_exposure + (beta + f)
  rate = numpyro.deterministic('rate', jnp.exp(log_rate))
  mu_agg = (map_matrix @ rate)
  numpyro.sample('y', dist.Poisson(mu_agg), obs=y)

In [74]:
df_train = df_data.groupby('range').agg(
	y_agg=('y', 'sum'),
	N=('y', 'count')
).reset_index()

y_agg = df_train['y_agg'].values
N = df_train['N'].values

  df_train = df_data.groupby('range').agg(


In [75]:
sampler = NUTS(model)
mcmc = MCMC(sampler, num_warmup=500, num_samples=1_000, num_chains=4)

rng_key, rng_subkey = jax.random.split(jax.random.PRNGKey(42))

L = 2.0
M = 30
non_centered = True

mcmc.run(rng_subkey, 
	x=x_std, 
	log_exposure=df_data['log_exposure'].values, 
	N=N, 
	map_matrix=mapping_matrix, 
	L=L, 
	M=M, 
	non_centered=non_centered, 
	y=y_agg
)

  mcmc = MCMC(sampler, num_warmup=500, num_samples=1_000, num_chains=4)
sample: 100%|██████████| 1500/1500 [00:03<00:00, 390.46it/s, 12 steps of size 1.07e-01. acc. prob=0.84]
sample: 100%|██████████| 1500/1500 [00:02<00:00, 697.54it/s, 31 steps of size 1.20e-01. acc. prob=0.86]
sample: 100%|██████████| 1500/1500 [00:01<00:00, 852.76it/s, 7 steps of size 1.42e-01. acc. prob=0.72] 
sample: 100%|██████████| 1500/1500 [00:02<00:00, 531.55it/s, 25 steps of size 4.23e-02. acc. prob=0.89]


In [76]:
import arviz as az
idata = az.from_numpyro(mcmc)

In [80]:
df_sum = az.summary(
	data=idata,
	var_names=['rate']
)
df_sum

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
rate[0],4.669,0.944,2.999,6.454,0.046,0.024,420.0,922.0,1.01
rate[1],4.385,0.831,2.853,5.921,0.036,0.019,517.0,1150.0,1.00
rate[2],4.134,0.743,2.765,5.502,0.030,0.016,620.0,1265.0,1.00
rate[3],3.907,0.672,2.732,5.208,0.025,0.014,704.0,1302.0,1.01
rate[4],3.697,0.614,2.592,4.838,0.022,0.013,749.0,1322.0,1.01
...,...,...,...,...,...,...,...,...,...
rate[95],0.621,0.211,0.235,0.989,0.010,0.006,451.0,600.0,1.01
rate[96],0.646,0.222,0.254,1.042,0.010,0.006,454.0,595.0,1.01
rate[97],0.675,0.233,0.256,1.095,0.011,0.007,453.0,631.0,1.01
rate[98],0.705,0.245,0.249,1.136,0.011,0.007,446.0,662.0,1.01


In [81]:
df_plot['rate_est'] = df_sum['mean'].values

rate = (
  alt.Chart(df_plot)
  .mark_line(color='red')
  .encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Rate'),
	)
)

rate_est = (
  alt.Chart(df_plot)
	.mark_line(color='green', strokeDash=[5, 5])  # Changed
	.encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate_est', title='Estimated Rate'),
	)
)

count = alt.Chart(df_plot).mark_point().encode(
	x=alt.X('x', title='x'),
	y=alt.Y('y', title='Count'),
)

rate_agg = (
  alt.Chart(df_plot)
  .mark_line(color='blue', strokeDash=[5, 5])  # Changed color and added dashed line to distinguish
  .encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate_agg', title='Aggregated Rate'),
	)
)

(rate_agg + rate + count + rate_est).properties(width=600)