In [1]:
import numpy as np
import pandas as pd
import scipy as sp

import jax
import jax.numpy as jnp

import altair as alt

In [2]:
x = np.linspace(0, 100, 100)
log_exposure = np.exp(-x / 10)
f = 0.5 * np.sin(x / 20) + 0.5 * np.cos(x / 10)
rate = np.exp(log_exposure + f)

y = np.random.poisson(rate, size=(100,))

df_data = pd.DataFrame({
	'x': x,
	'log_exposure': log_exposure,
	'f': f,
	'rate': rate,
	'y': y
})

In [3]:
rate = (
  alt.Chart(df_data)
  .mark_line(color='red')
  .encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Rate'),
	)
)

count = alt.Chart(df_data).mark_point().encode(
	x=alt.X('x', title='x'),
	y=alt.Y('y', title='Count'),
)

(rate + count).properties(width=600)

In [4]:
df_data['range'] = pd.cut(df_data['x'], bins=20, right=False)

tmp = df_data.groupby('range', observed=True).agg(
	y_agg = ('y', 'sum'),
	N = ('y', 'count')
).reset_index()

df_data = df_data.merge(tmp, on='range', how='left')
df_data['rate_agg'] = df_data['y_agg'] / df_data['N']
df_data.tail(50)

Unnamed: 0,x,log_exposure,f,rate,y,range,y_agg,N,rate_agg
50,50.505051,0.006406,0.454882,1.586115,0,"[50.0, 55.0)",6,5,1.2
51,51.515152,0.005791,0.480635,1.626492,1,"[50.0, 55.0)",6,5,1.2
52,52.525253,0.005234,0.503537,1.663247,0,"[50.0, 55.0)",6,5,1.2
53,53.535354,0.004731,0.52319,1.695405,3,"[50.0, 55.0)",6,5,1.2
54,54.545455,0.004277,0.539223,1.722023,2,"[50.0, 55.0)",6,5,1.2
55,55.555556,0.003866,0.551297,1.742226,1,"[55.0, 60.0)",10,5,2.0
56,56.565657,0.003494,0.559112,1.755241,2,"[55.0, 60.0)",10,5,2.0
57,57.575758,0.003159,0.562404,1.760438,3,"[55.0, 60.0)",10,5,2.0
58,58.585859,0.002855,0.560956,1.757358,2,"[55.0, 60.0)",10,5,2.0
59,59.59596,0.002581,0.554595,1.745735,2,"[55.0, 60.0)",10,5,2.0


In [5]:
# Create a clean dataset without the Interval objects for plotting
df_plot = df_data.drop(columns=['range'])

rate = (
  alt.Chart(df_plot)
  .mark_line(color='red')
  .encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Rate'),
	)
)

count = alt.Chart(df_plot).mark_point().encode(
	x=alt.X('x', title='x'),
	y=alt.Y('y', title='Count'),
)

rate_agg = (
  alt.Chart(df_plot)
  .mark_line(color='blue', strokeDash=[5, 5])  # Changed color and added dashed line to distinguish
  .encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate_agg', title='Aggregated Rate'),
	)
)

(rate_agg + rate + count).properties(width=600)

In [6]:
import numpyro
from numpyro.contrib.hsgp.approximation import hsgp_squared_exponential
from numpyro.contrib.hsgp.laplacian import eigenfunctions
from numpyro.contrib.hsgp.spectral_densities import (
	diag_spectral_density_squared_exponential
)
import numpyro.distributions as dist
from numpyro.infer import SVI, MCMC, NUTS, Predictive

In [7]:
def create_interval_index_array(intervals_series, grid_start=0, grid_end=None, grid_step=1):
    """
    Create an index array that maps intervals to points on a fine integer grid.
    
    Parameters:
    -----------
    intervals_series : pandas.Series
        A pandas Series of categorical dtype containing pd.Interval objects
        that are closed on the left and open on the right [left, right).
    grid_start : int, default 0
        The starting point of the integer grid.
    grid_end : int, optional
        The ending point of the integer grid. If None, inferred from intervals.
    grid_step : int, default 1
        The step size for the integer grid.
    
    Returns:
    --------
    numpy.ndarray
        An index array where each element corresponds to a grid point and
        contains the index of the interval that contains that grid point.
        Points not in any interval get index -1.
    
    Example:
    --------
    >>> import pandas as pd
    >>> import numpy as np
    >>> 
    >>> # Create some intervals
    >>> x = np.linspace(0, 10, 11)
    >>> intervals = pd.cut(x, bins=3, right=False)
    >>> 
    >>> # Create index array for fine grid from 0 to 10
    >>> index_array = create_interval_index_array(intervals, 0, 10)
    >>> print(index_array)
    """
    
    # Get unique intervals and their categorical codes
    unique_intervals = intervals_series.cat.categories
    interval_codes = intervals_series.cat.codes
    
    # Determine grid bounds if not provided
    if grid_end is None:
        max_right = max(interval.right for interval in unique_intervals)
        grid_end = int(np.floor(max_right))

    # Create the fine integer grid
    grid_points = np.arange(grid_start, grid_end + grid_step, grid_step)
    
    # Initialize index array with -1 (indicating no interval contains the point)
    index_array = np.full(len(grid_points), -1, dtype=int)
    
    # For each unique interval, find which grid points it contains
    for i, interval in enumerate(unique_intervals):
        # Find grid points that fall within this interval [left, right)
        mask = (grid_points >= interval.left) & (grid_points < interval.right)
        
        # Set the index array to the categorical code for this interval
        index_array[mask] = i
    
    return index_array


def create_mapping_matrix(intervals_series, grid_start=0, grid_end=None, grid_step=1):
    """
    Create a mapping matrix that aggregates values from a fine grid to intervals.
    
    Parameters:
    -----------
    intervals_series : pandas.Series
        A pandas Series of categorical dtype containing pd.Interval objects.
    grid_start : int, default 0
        The starting point of the integer grid.
    grid_end : int, optional
        The ending point of the integer grid. If None, inferred from intervals.
    grid_step : int, default 1
        The step size for the integer grid.
    
    Returns:
    --------
    numpy.ndarray
        A matrix of shape (n_intervals, n_grid_points) where each row corresponds
        to an interval and contains 1s for grid points within that interval.
        This can be used for aggregation: aggregated_values = matrix @ grid_values
    """
    
    # Get the index array
    index_array = create_interval_index_array(intervals_series, grid_start, grid_end, grid_step)
    
    # Get number of unique intervals
    n_intervals = len(intervals_series.cat.categories)
    n_grid_points = len(index_array)
    
    # Create mapping matrix
    mapping_matrix = np.zeros((n_intervals, n_grid_points), dtype=int)
    
    # Fill the mapping matrix
    for grid_idx, interval_idx in enumerate(index_array):
        if interval_idx >= 0:  # Valid interval index
            mapping_matrix[interval_idx, grid_idx] = 1
    
    return mapping_matrix

In [8]:
# Standardize x
x_std = (x - np.mean(x)) / np.std(x)

mapping_matrix = create_mapping_matrix(
  df_data['range'],
  grid_start=0,
  grid_end=99,
  grid_step=1
)

In [9]:
def model(x, log_exposure, N, map_matrix, L, M, non_centered, y=None):
  # --- Priors ---
  beta = numpyro.sample('baseline', dist.Normal(0, 10))
  sigma = numpyro.sample('sigma', dist.InverseGamma(5, 5))
  lenscale = numpyro.sample('lenscale', dist.InverseGamma(5, 5))
  
  # --- Parameterization ---
  f = hsgp_squared_exponential(
		x=x,
		alpha=sigma,
		length=lenscale,
		ell=L,
		m=M,
		non_centered=non_centered
	)
  
  # --- Likelihood ---
  log_rate = log_exposure + (beta + f)
  rate = numpyro.deterministic('rate', jnp.exp(log_rate))
  mu_agg = (map_matrix @ rate)
  numpyro.sample('y', dist.Poisson(mu_agg), obs=y)

In [10]:
df_train = df_data.groupby('range').agg(
	y_agg=('y', 'sum'),
	N=('y', 'count')
).reset_index()

y_agg = df_train['y_agg'].values
N = df_train['N'].values

  df_train = df_data.groupby('range').agg(


In [11]:
print(f"Type: {type(x_std)}")

Type: <class 'numpy.ndarray'>


In [12]:
sampler = NUTS(model)
mcmc = MCMC(sampler, num_warmup=500, num_samples=1_000, num_chains=4)

rng_key, rng_subkey = jax.random.split(jax.random.PRNGKey(42))

L = 2.0
M = 30
non_centered = True

mcmc.run(rng_subkey, 
	x=x_std, 
	log_exposure=df_data['log_exposure'].values, 
	N=N, 
	map_matrix=mapping_matrix, 
	L=L, 
	M=M, 
	non_centered=non_centered, 
	y=y_agg
)

  mcmc = MCMC(sampler, num_warmup=500, num_samples=1_000, num_chains=4)
sample: 100%|██████████| 1500/1500 [00:04<00:00, 357.92it/s, 10 steps of size 1.02e-01. acc. prob=0.84]
sample: 100%|██████████| 1500/1500 [00:01<00:00, 800.24it/s, 13 steps of size 1.39e-01. acc. prob=0.77]
sample: 100%|██████████| 1500/1500 [00:01<00:00, 759.16it/s, 5 steps of size 1.57e-01. acc. prob=0.75] 
sample: 100%|██████████| 1500/1500 [00:02<00:00, 696.98it/s, 10 steps of size 7.36e-02. acc. prob=0.86]


In [13]:
import arviz as az
idata = az.from_numpyro(mcmc)

In [14]:
df_sum = az.summary(
	data=idata,
	var_names=['rate']
)
df_sum

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
rate[0],3.149,0.715,1.944,4.508,0.026,0.015,684.0,963.0,1.00
rate[1],2.895,0.625,1.864,4.149,0.021,0.013,825.0,943.0,1.00
rate[2],2.678,0.554,1.746,3.800,0.017,0.011,958.0,974.0,1.01
rate[3],2.490,0.498,1.574,3.432,0.015,0.010,1044.0,947.0,1.01
rate[4],2.326,0.452,1.506,3.197,0.014,0.010,1068.0,1073.0,1.00
...,...,...,...,...,...,...,...,...,...
rate[95],0.368,0.148,0.127,0.657,0.007,0.004,471.0,799.0,1.00
rate[96],0.390,0.157,0.134,0.689,0.007,0.005,502.0,1081.0,1.00
rate[97],0.416,0.167,0.141,0.731,0.007,0.005,530.0,1151.0,1.00
rate[98],0.445,0.178,0.148,0.771,0.008,0.005,543.0,1094.0,1.00


In [15]:
df_plot['rate_est'] = df_sum['mean'].values

rate = (
  alt.Chart(df_plot)
  .mark_line(color='red')
  .encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Rate'),
	)
)

rate_est = (
  alt.Chart(df_plot)
	.mark_line(color='green', strokeDash=[5, 5])  # Changed
	.encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate_est', title='Estimated Rate'),
	)
)

count = alt.Chart(df_plot).mark_point().encode(
	x=alt.X('x', title='x'),
	y=alt.Y('y', title='Count'),
)

rate_agg = (
  alt.Chart(df_plot)
  .mark_line(color='blue', strokeDash=[5, 5])  # Changed color and added dashed line to distinguish
  .encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate_agg', title='Aggregated Rate'),
	)
)

(rate_agg + rate + count + rate_est).properties(width=600)