## Load Libraries

In [1]:
import numpy as np
import pandas as pd

import jax
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.initialization import init_to_value

from aggpp.sim import generate_data, assign_overlapping_bins
from aggpp.models import DisjointAggPP, OverlapAggPP

# Visualization
import altair as alt

## Generate Data

Generate data.

In [2]:
N = 1000
df_true, df_data = generate_data(N, M = 2, seed=42)

In [3]:
rate = (
	alt.Chart(df_data).mark_line(color='#de425b').encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Rate'),
	)
)

count = alt.Chart(df_data).mark_point(size=3).encode(
	x=alt.X('x', title='x'),
	y=alt.Y('y', title='Count'),
)

(rate + count).properties(width=600, height=200)

In [4]:

x = df_data['x'].values
left_limits = [0] + list(range(5, 95, 5))
right_limits = list(range(10, 101, 5))

print(left_limits)
print(right_limits)

sampling_effort = [1] * len(left_limits)
df_data['range'] = assign_overlapping_bins(x, left_limits, right_limits, sampling_effort)

tmp = df_data.groupby('range', observed=True).agg(
	y_agg = ('y', 'sum'),
	N = ('y', 'count')
).reset_index()

df_data = df_data.merge(tmp, on='range', how='left')
df_data['rate_agg'] = df_data['y_agg'] / df_data['N']
df_data.head(5)

[0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90]
[10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100]


Unnamed: 0,x,log_exposure,f,rate,y,range,y_agg,N,rate_agg
0,82,-1.025,-0.836819,0.15539,0,"[80, 90)",16,51,0.313725
1,86,-1.075,-0.056807,0.32245,1,"[80, 90)",16,51,0.313725
2,74,-0.925,-0.315744,0.289169,0,"[70, 80)",12,37,0.324324
3,74,-0.925,-0.315744,0.289169,0,"[65, 75)",76,47,1.617021
4,87,-1.0875,0.165866,0.397869,0,"[85, 95)",32,62,0.516129


In [20]:
codes = df_data['range'].cat.codes.sort_values().unique()
categories = df_data['range'].cat.categories

In [21]:
print(codes)
print(categories)

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18]
IntervalIndex([  [0, 10),   [5, 15),  [10, 20),  [15, 25),  [20, 30),
                [25, 35),  [30, 40),  [35, 45),  [40, 50),  [45, 55),
                [50, 60),  [55, 65),  [60, 70),  [65, 75),  [70, 80),
                [75, 85),  [80, 90),  [85, 95), [90, 100)],
              dtype='interval[int64, left]')


In [None]:
def create_interval_dict(x: pd.Series):
	"""
	Create a dictionary mapping interval codes to their corresponding intervals.
	
	Parameters
	----------
	x : pd.Series
		A pandas Series containing interval codes.
	
	Returns
	-------
	dict: A dictionary where the keys are the interval codes and
				the values are the corresponding intervals.
	"""
	codes = x.cat.codes.sort_values().unique()
	intervals = x.cat.categories
	return {code: interval for code, interval in zip(codes, intervals)}

def find_overlapping_intervals(interval_dict: dict):
	"""
	Find overlapping intervals in a sorted list of intervals.

	Returns
	-------
	dict: A dictionary where the keys are the interval codes and
				the values are lists of interval codes that overlap with the key interval.
	"""
	overlap_record = {}
	for i, interval in interval_dict.items():
		for j, other_interval in interval_dict.items():
			if i != j and interval.overlaps(other_interval):
				overlap_record.setdefault(i, []).append(j)
	 
	return overlap_record

def create_overlap_weights(interval_dict: dict,
							 overlap_dict: dict,
							 sampling_effort_dict: dict):
	"""
	Create overlap weights for intervals based on overlapping regions and sampling effort.
	
	Parameters
	----------
	interval_dict : dict
		Dictionary mapping interval codes to their corresponding intervals.
	overlap_dict : dict
		Dictionary mapping interval codes to lists of overlapping interval codes.
	sampling_effort_dict : dict
		Dictionary mapping interval codes to their sampling effort values.
	
	Returns
	-------
	dict
		Dictionary mapping interval codes to their overlap weight arrays.
	"""
	overlap_weights = {}
	
	for code, interval in interval_dict.items():
		left, right = interval.left, interval.right
		sampling_effort = sampling_effort_dict[code]
		
		# Initialize total effort array for this interval
		total_effort = np.full(right - left, sampling_effort, dtype=float)
		
		# Add effort from overlapping intervals
		if code in overlap_dict:
			for overlap_code in overlap_dict[code]:
				other_interval = interval_dict[overlap_code]
				overlap_start = max(left, other_interval.left) - left
				overlap_end = min(right, other_interval.right) - left
				
				if overlap_start < overlap_end:
					total_effort[overlap_start:overlap_end] += sampling_effort_dict[overlap_code]
		
		# Calculate weights as sampling effort divided by total effort
		overlap_weights[code] = sampling_effort / total_effort
	
	return overlap_weights

In [None]:
sampling_effort_dict = {i: 1 for i in codes}

{0: 1,
 1: 1,
 2: 1,
 3: 1,
 4: 1,
 5: 1,
 6: 1,
 7: 1,
 8: 1,
 9: 1,
 10: 1,
 11: 1,
 12: 1,
 13: 1,
 14: 1,
 15: 1,
 16: 1,
 17: 1,
 18: 1}

In [31]:
interval_dict = create_interval_dict(df_data['range'])
overlap_dict = find_overlapping_intervals(interval_dict)
overlap_weights = create_overlap_weights(interval_dict, overlap_dict, sampling_effort_dict)

In [5]:
overlap_model = OverlapAggPP(df_data, df_true, L = 7)

The following model should fail to accurately infer the rate.

In [7]:
prng_key = jax.random.PRNGKey(42)
init_values = {
	'baseline': -df_true['log_exposure'].values.mean()
}
guide = AutoNormal(overlap_model.model, init_loc_fn=init_to_value(values=init_values))
overlap_model.run_inference_svi(prng_key, guide)

TypeError: mul got incompatible shapes for broadcasting: (100,), (190,).

In [74]:
post_pred = posterior_predictive_svi(prng_key, model, guide, svi.params, **model_data)
post_rate_sum = np.quantile(post_pred['rate'], q = (0.025, 0.5, 0.975), axis=0)

In [75]:
df_true['M'] = post_rate_sum[1][df_true['x'].values]
df_true['CL'] = post_rate_sum[0][df_true['x'].values]
df_true['CU'] = post_rate_sum[2][df_true['x'].values]

In [76]:
svi.params

{'baseline_auto_loc': Array(0.01706906, dtype=float32),
 'baseline_auto_scale': Array(0.02619283, dtype=float32),
 'beta_auto_loc': Array([ 1.2101327 ,  0.43130377, -0.661234  , -0.33806458, -1.6233759 ,
         1.7236325 ,  0.5317304 ,  0.24385883,  1.2212498 ,  0.17222165,
        -0.45071432, -0.8072674 , -0.36877945, -0.68363196, -0.9435648 ,
         0.8054031 ,  0.99164313, -0.5903625 , -0.60134375,  0.19065456,
        -0.7937727 ,  0.57289404, -0.7106124 ,  1.4395381 , -1.3907038 ,
        -1.2235905 , -0.4283401 , -0.6508319 ,  0.5131671 ,  0.65669364],      dtype=float32),
 'beta_auto_scale': Array([0.07351285, 0.24138463, 0.08228104, 0.13210855, 0.0922631 ,
        0.09694765, 0.10730568, 0.09074406, 0.10360788, 0.10751261,
        0.09498505, 0.12378865, 0.08576813, 0.12131248, 0.09455027,
        0.10784619, 0.10438012, 0.09633049, 0.11039106, 0.09383176,
        0.11243469, 0.09989084, 0.10551156, 0.10110452, 0.0978139 ,
        0.10103423, 0.10059325, 0.10224608, 0.1079

In [None]:
rate = (
	alt.Chart(df_plot)
	.mark_line(color='red')
	.encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Intensity', scale=alt.Scale(domain=(0, 12))),
	)
)

rate_est = (
	alt.Chart(df_true)
	.mark_line(color='green')  # Changed
	.encode(
		x=alt.X('x', title='x'),
		y=alt.Y('M', title='Intensity'),
	)
)

bands = alt.Chart(df_true).mark_errorband(color='green').encode(
	x=alt.X('x', title='x'),
	y=alt.Y('CL', title='Intensity'),
	y2=alt.Y2('CU', title='Intensity'),
)

(rate + rate_est + bands).properties(width=600, height=200)