# Type 2
##### Max: 2 overlapping groups
##### e.g. [18, 32), [20, 30)

In [1]:
import numpy as np
import pandas as pd
import scipy as sp

import jax
import jax.numpy as jnp

import altair as alt
SEED = 123

In [2]:
x = np.linspace(0, 100, 100)
log_exposure = np.exp(-x / 10)
f = 0.5 * np.sin(x / 20) + 0.5 * np.cos(x / 10)
rate = np.exp(log_exposure + f)

np.random.seed(SEED) ; y = np.random.poisson(rate, size=(100,))

df_data = pd.DataFrame({
	'x': x,
	'log_exposure': log_exposure,
	'f': f,
	'rate': rate,
	'y': y
})

print(f"Sum of y: {df_data['y'].sum()}")
df_data.head()

Sum of y: 152


Unnamed: 0,x,log_exposure,f,rate,y
0,0.0,1.0,0.5,4.481689,5
1,1.010101,0.903924,0.522693,4.164587,6
2,2.020202,0.817078,0.540251,3.885801,2
3,3.030303,0.738577,0.552686,3.637378,2
4,4.040404,0.667617,0.560064,3.413307,6


In [3]:
def expand_data(df, y_col):
    result = pd.DataFrame(columns=df.columns)

    for row_id, repeat in zip(range(df.shape[0]), df[y_col]):
        row = df.loc[[row_id]]
        if int(repeat) == 0:
            result = pd.concat([result, row], ignore_index=True)
            continue
        duplicated_rows = pd.concat([row] * repeat, ignore_index=True)
        result = pd.concat([result, duplicated_rows], ignore_index=True)

    return result

In [4]:
expanded_df = expand_data(df_data, 'y')
expanded_df.tail()

  result = pd.concat([result, duplicated_rows], ignore_index=True)


Unnamed: 0,x,log_exposure,f,rate,y
174,95.959596,6.8e-05,-0.990862,0.371282,0
175,96.969697,6.1e-05,-0.976969,0.376474,1
176,97.979798,5.6e-05,-0.956903,0.384102,1
177,98.989899,5e-05,-0.930837,0.394243,0
178,100.0,4.5e-05,-0.898998,0.406996,0


In [5]:
rate = (
  alt.Chart(df_data)
  .mark_line(color='red')
  .encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Rate'),
	)
)

count = alt.Chart(df_data).mark_point().encode(
	x=alt.X('x', title='x'),
	y=alt.Y('y', title='Count'),
)

(rate + count).properties(width=600)


In [6]:
def assign_bin(x, y, bins, seed=123):
    np.random.seed(seed)
    cum_bins_id = []
    for age, count in zip(x, y):
        temp_bins = []
        for bin_id, (lower, upper) in enumerate(bins):
            if lower <= age < upper:
                temp_bins.append(bin_id)
        if not temp_bins:
            if count == 0:
                cum_bins_id.append(pd.NA) 
            else:
                cum_bins_id.extend([pd.NA] * count)
        else:
            if count == 0:
                cum_bins_id.append(np.random.choice(temp_bins))
            else:
                cum_bins_id.extend(np.random.choice(temp_bins, size=count))
    cum_bins = []
    for bin_id in cum_bins_id:
        if pd.isna(bin_id):
            cum_bins.append(pd.NA)
        else:
            cum_bins.append(bins[bin_id])
    return cum_bins

In [7]:
type_2_bins = [(0, 10), (0,50), (10, 20), (20, 30), (30, 40), (40, 50), (50, 60), (50, 100.1), (60, 70), (70, 80), (80, 90), (90, 100.1)]
expanded_df["bin"] = assign_bin(x, y, type_2_bins)
expanded_df.head(15)


Unnamed: 0,x,log_exposure,f,rate,y,bin
0,0.0,1.0,0.5,4.481689,5,"(0, 10)"
1,0.0,1.0,0.5,4.481689,5,"(0, 50)"
2,0.0,1.0,0.5,4.481689,5,"(0, 10)"
3,0.0,1.0,0.5,4.481689,5,"(0, 10)"
4,0.0,1.0,0.5,4.481689,5,"(0, 10)"
5,1.010101,0.903924,0.522693,4.164587,6,"(0, 10)"
6,1.010101,0.903924,0.522693,4.164587,6,"(0, 10)"
7,1.010101,0.903924,0.522693,4.164587,6,"(0, 50)"
8,1.010101,0.903924,0.522693,4.164587,6,"(0, 50)"
9,1.010101,0.903924,0.522693,4.164587,6,"(0, 10)"


In [8]:
# Keeping unique rows
expanded_df = expanded_df.drop_duplicates(subset=['x', 'log_exposure', 'f', 'rate', 'y', 'bin']).reset_index(drop=True)
expanded_df['y'] = np.asarray(expanded_df['y'].values, dtype=np.int64)

print(expanded_df.shape)
expanded_df.head(15)

(130, 6)


Unnamed: 0,x,log_exposure,f,rate,y,bin
0,0.0,1.0,0.5,4.481689,5,"(0, 10)"
1,0.0,1.0,0.5,4.481689,5,"(0, 50)"
2,1.010101,0.903924,0.522693,4.164587,6,"(0, 10)"
3,1.010101,0.903924,0.522693,4.164587,6,"(0, 50)"
4,2.020202,0.817078,0.540251,3.885801,2,"(0, 50)"
5,2.020202,0.817078,0.540251,3.885801,2,"(0, 10)"
6,3.030303,0.738577,0.552686,3.637378,2,"(0, 50)"
7,3.030303,0.738577,0.552686,3.637378,2,"(0, 10)"
8,4.040404,0.667617,0.560064,3.413307,6,"(0, 50)"
9,4.040404,0.667617,0.560064,3.413307,6,"(0, 10)"


In [9]:
# y_agg: the number of counts in each age range
# N: the number of non-missing values in each age range
tmp = expanded_df.groupby('bin', observed=True).agg(
	y_agg = ('y', 'sum'),
	N = ('y', 'count')
).reset_index()

y_agg = np.asarray(tmp['y_agg'].values, dtype=np.int64)
tmp['y_agg'] = y_agg
tmp

Unnamed: 0,bin,y_agg,N
0,"(0, 10)",34,9
1,"(0, 50)",90,39
2,"(10, 20)",20,8
3,"(20, 30)",10,7
4,"(30, 40)",3,3
5,"(40, 50)",15,7
6,"(50, 60)",12,6
7,"(50, 100.1)",36,29
8,"(60, 70)",13,7
9,"(70, 80)",8,5


In [10]:
# rate_agg: the avg number of counts per non-missing age associated with each age range
expanded_df = expanded_df.merge(tmp, on='bin', how='left')
expanded_df['rate_agg'] = expanded_df['y_agg'] / expanded_df['N']
expanded_df['rate_agg'] = np.asarray(expanded_df['rate_agg'].values, dtype=np.float64)
expanded_df.head(10)

Unnamed: 0,x,log_exposure,f,rate,y,bin,y_agg,N,rate_agg
0,0.0,1.0,0.5,4.481689,5,"(0, 10)",34,9,3.777778
1,0.0,1.0,0.5,4.481689,5,"(0, 50)",90,39,2.307692
2,1.010101,0.903924,0.522693,4.164587,6,"(0, 10)",34,9,3.777778
3,1.010101,0.903924,0.522693,4.164587,6,"(0, 50)",90,39,2.307692
4,2.020202,0.817078,0.540251,3.885801,2,"(0, 50)",90,39,2.307692
5,2.020202,0.817078,0.540251,3.885801,2,"(0, 10)",34,9,3.777778
6,3.030303,0.738577,0.552686,3.637378,2,"(0, 50)",90,39,2.307692
7,3.030303,0.738577,0.552686,3.637378,2,"(0, 10)",34,9,3.777778
8,4.040404,0.667617,0.560064,3.413307,6,"(0, 50)",90,39,2.307692
9,4.040404,0.667617,0.560064,3.413307,6,"(0, 10)",34,9,3.777778


In [11]:
# Create a clean dataset without the Interval objects for plotting
df_plot = expanded_df.drop(columns=['bin'])

rate = (
  alt.Chart(df_plot)
  .mark_line(color='red')
  .encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Rate'),
	)
)

count = alt.Chart(df_plot).mark_point().encode(
	x=alt.X('x', title='x'),
	y=alt.Y('y', title='Count'),
)

rate_agg = (
  alt.Chart(df_plot)
  .mark_line(color='blue', strokeDash=[5, 5])  # Changed color and added dashed line to distinguish
  .encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate_agg', title='Aggregated Rate'),
	)
)

(rate_agg + rate + count).properties(width=600)

In [12]:
import numpyro
from numpyro.contrib.hsgp.approximation import hsgp_squared_exponential
from numpyro.contrib.hsgp.laplacian import eigenfunctions
from numpyro.contrib.hsgp.spectral_densities import (
	diag_spectral_density_squared_exponential
)
import numpyro.distributions as dist
from numpyro.infer import SVI, MCMC, NUTS, Predictive

In [13]:
def create_interval_index_array(intervals_series, grid_start=0, grid_end=None, grid_step=1):
    """
    Create an index array that maps intervals to points on a fine integer grid.
    
    Parameters:
    -----------
    intervals_series : pandas.Series
        A pandas Series of categorical dtype containing pd.Interval objects
        that are closed on the left and open on the right [left, right).
    grid_start : int, default 0
        The starting point of the integer grid.
    grid_end : int, optional
        The ending point of the integer grid. If None, inferred from intervals.
    grid_step : int, default 1
        The step size for the integer grid.
    
    Returns:
    --------
    numpy.ndarray
        An index array where each element corresponds to a grid point and
        contains the index of the interval that contains that grid point.
        Points not in any interval get index -1.
    
    Example:
    --------
    >>> import pandas as pd
    >>> import numpy as np
    >>> 
    >>> # Create some intervals
    >>> x = np.linspace(0, 10, 11)
    >>> intervals = pd.cut(x, bins=3, right=False)
    >>> 
    >>> # Create index array for fine grid from 0 to 10
    >>> index_array = create_interval_index_array(intervals, 0, 10)
    >>> print(index_array)
    """
    
    # Get unique intervals and their categorical codes
    unique_intervals = intervals_series.cat.categories
    interval_codes = intervals_series.cat.codes
    
    # Determine grid bounds if not provided
    if grid_end is None:
        max_right = max(interval.right for interval in unique_intervals)
        grid_end = int(np.floor(max_right))

    # Create the fine integer grid
    grid_points = np.arange(grid_start, grid_end + grid_step, grid_step)
    
    # Initialize index array with -1 (indicating no interval contains the point)
    index_array = np.full(len(grid_points), -1, dtype=int)
    
    # For each unique interval, find which grid points it contains
    for i, interval in enumerate(unique_intervals):
        # Find grid points that fall within this interval [left, right)
        mask = (grid_points >= interval.left) & (grid_points < interval.right)
        
        # Set the index array to the categorical code for this interval
        index_array[mask] = i
    
    return index_array


def create_mapping_matrix(intervals_series, grid_start=0, grid_end=None, grid_step=1):
    """
    Create a mapping matrix that aggregates values from a fine grid to intervals.
    
    Parameters:
    -----------
    intervals_series : pandas.Series
        A pandas Series of categorical dtype containing pd.Interval objects.
    grid_start : int, default 0
        The starting point of the integer grid.
    grid_end : int, optional
        The ending point of the integer grid. If None, inferred from intervals.
    grid_step : int, default 1
        The step size for the integer grid.
    
    Returns:
    --------
    numpy.ndarray
        A matrix of shape (n_intervals, n_grid_points) where each row corresponds
        to an interval and contains 1s for grid points within that interval.
        This can be used for aggregation: aggregated_values = matrix @ grid_values
    """
    
    # Get the index array
    index_array = create_interval_index_array(intervals_series, grid_start, grid_end, grid_step)
    
    # Get number of unique intervals
    n_intervals = len(intervals_series.cat.categories)
    n_grid_points = len(index_array)
    
    # Create mapping matrix
    mapping_matrix = np.zeros((n_intervals, n_grid_points), dtype=int)
    
    # Fill the mapping matrix
    for grid_idx, interval_idx in enumerate(index_array):
        if interval_idx >= 0:  # Valid interval index
            mapping_matrix[interval_idx, grid_idx] = 1
    
    return mapping_matrix

In [14]:
expanded_df['bin'] = pd.IntervalIndex.from_arrays(
    [interval[0] for interval in expanded_df['bin']],
    [interval[1] for interval in expanded_df['bin']],
    closed='left'
)

expanded_df['bin'] = pd.Categorical(expanded_df['bin'])

In [15]:
# Standardize x
x_std = (x - np.mean(x)) / np.std(x)

mapping_matrix = create_mapping_matrix(
  expanded_df['bin'],
  grid_start=0,
  grid_end=99,
  grid_step=1
)

In [16]:
def model(x, log_exposure, N, map_matrix, L, M, non_centered, y=None):
  # --- Priors ---
  beta = numpyro.sample('baseline', dist.Normal(0, 10))
  sigma = numpyro.sample('sigma', dist.InverseGamma(5, 5))
  lenscale = numpyro.sample('lenscale', dist.InverseGamma(5, 5))
  
  # --- Parameterization ---
  f = hsgp_squared_exponential(
		x=x,
		alpha=sigma,
		length=lenscale,
		ell=L,
		m=M,
		non_centered=non_centered
	)
  
  # --- Likelihood ---
  log_rate = log_exposure + (beta + f)
  rate = numpyro.deterministic('rate', jnp.exp(log_rate))
  mu_agg = (map_matrix @ rate)
  numpyro.sample('y', dist.Poisson(mu_agg), obs=y)

In [17]:
df_train = expanded_df.groupby('bin').agg(
	y_agg=('y', 'sum'),
	N=('y', 'count')
).reset_index()

y_agg = df_train['y_agg'].values
N = df_train['N'].values

  df_train = expanded_df.groupby('bin').agg(


In [18]:
sampler = NUTS(model)
mcmc = MCMC(sampler, num_warmup=500, num_samples=1_000, num_chains=4)

rng_key, rng_subkey = jax.random.split(jax.random.PRNGKey(42))

L = 2.0
M = 30
non_centered = True

mcmc.run(rng_subkey, 
	x=x_std, 
	log_exposure=df_data['log_exposure'].values, 
	N=N, 
	map_matrix=mapping_matrix, 
	L=L, 
	M=M, 
	non_centered=non_centered, 
	y=y_agg
)

  mcmc = MCMC(sampler, num_warmup=500, num_samples=1_000, num_chains=4)


ValueError: Poisson distribution got invalid rate parameter.

In [None]:
import arviz as az
idata = az.from_numpyro(mcmc)

In [None]:
df_sum = az.summary(
	data=idata,
	var_names=['rate']
)
df_sum

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
rate[0],4.572,0.928,2.895,6.277,0.036,0.023,667.0,1363.0,1.00
rate[1],4.268,0.802,2.822,5.744,0.025,0.018,1046.0,1749.0,1.00
rate[2],4.011,0.709,2.704,5.319,0.020,0.014,1304.0,1706.0,1.00
rate[3],3.788,0.640,2.668,5.056,0.017,0.012,1456.0,1612.0,1.00
rate[4],3.592,0.587,2.533,4.714,0.016,0.011,1434.0,1604.0,1.00
...,...,...,...,...,...,...,...,...,...
rate[95],0.383,0.158,0.119,0.674,0.007,0.004,570.0,1177.0,1.01
rate[96],0.410,0.172,0.129,0.721,0.007,0.005,553.0,1144.0,1.01
rate[97],0.443,0.187,0.146,0.788,0.008,0.006,534.0,1066.0,1.01
rate[98],0.481,0.204,0.146,0.847,0.009,0.006,510.0,1029.0,1.00


In [None]:
df_plot['rate_est'] = df_sum['mean'].values

rate = (
  alt.Chart(df_plot)
  .mark_line(color='red')
  .encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Rate'),
	)
)

rate_est = (
  alt.Chart(df_plot)
	.mark_line(color='green', strokeDash=[5, 5])  # Changed
	.encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate_est', title='Estimated Rate'),
	)
)

count = alt.Chart(df_plot).mark_point().encode(
	x=alt.X('x', title='x'),
	y=alt.Y('y', title='Count'),
)

rate_agg = (
  alt.Chart(df_plot)
  .mark_line(color='blue', strokeDash=[5, 5])  # Changed color and added dashed line to distinguish
  .encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate_agg', title='Aggregated Rate'),
	)
)

(rate_agg + rate + count + rate_est).properties(width=600)

ValueError: Length of values (100) does not match length of index (116)