## Load Libraries

In [1]:
import numpy as np

import jax
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.initialization import init_to_value

from aggpp.sim import generate_data, assign_overlapping_bins
from aggpp.models import OverlapAggPP

# Visualization
import altair as alt

## Generate Data

Generate data.

In [33]:
N = 1000
df_true, df_data = generate_data(N, M = 2, seed=3)

In [34]:
x = df_data['x'].values
left_limits = [0, 0, 5, 8, 10, 15, 18, 20, 25, 30, 35, 40, 50, 60, 70, 80, 90]
right_limits = [5, 12, 15, 25, 20, 30, 35, 40, 45, 50, 55, 60, 70, 80, 90, 95, 100]

In [35]:
sampling_effort = [1] * len(left_limits)
df_data['range'] = assign_overlapping_bins(x, left_limits, right_limits, sampling_effort)

tmp = df_data.groupby('range', observed=True).agg(
	y_agg = ('y', 'sum'),
	N = ('y', 'count')
).reset_index()

df_data = df_data.merge(tmp, on='range', how='left')
df_data['rate_agg'] = df_data['y_agg'] / df_data['N']
df_data.head(5)

Unnamed: 0,x,log_exposure,f,rate,y,range,y_agg,N,rate_agg
0,19,-0.2375,2.868263,13.884363,9,"[18, 35)",196,53,3.698113
1,74,-0.925,3.140558,9.166525,9,"[60, 80)",393,80,4.9125
2,41,-0.5125,-0.476412,0.371981,0,"[40, 60)",24,74,0.324324
3,10,-0.125,2.729621,13.526093,13,"[10, 20)",1040,41,25.365854
4,21,-0.2625,2.028391,5.846777,6,"[8, 25)",685,48,14.270833


In [36]:
df_agg = (
  df_data[['range', 'rate_agg']]
  .drop_duplicates()
  .sort_values('range')
  .reset_index(drop=True)
)

df_agg['x_left'] = df_agg['range'].apply(lambda x: x.left)
df_agg['x_right'] = df_agg['range'].apply(lambda x: x.right)

df_agg['x_left'] = df_agg['x_left'].astype(float)
df_agg['x_right'] = df_agg['x_right'].astype(float)

In [37]:
df_plot = df_data.copy()
df_plot.drop(columns=['range', 'y_agg', 'N'], inplace=True)
df_agg_plot = df_agg.drop(columns='range')

rate = (
	alt.Chart(df_plot).mark_line(color='#de425b').encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Rate'),
	)
)

count = alt.Chart(df_plot).mark_point(size=3).encode(
	x=alt.X('x', title='x'),
	y=alt.Y('y', title='Count'),
)

segment = alt.Chart(df_agg_plot).mark_rule(
  color='#ff6361',
  size=1.5
).encode(	
	x=alt.X('x_left', title='x'),  # Use same title and scale
	x2=alt.X2('x_right'),
	y=alt.Y('rate_agg', title='Rate'),
)

# Use resolve_scale to ensure all plots share the same x-axis
(rate + count + segment).resolve_scale(x='shared').properties(width=600, height=200)

In [38]:
overlap_model = OverlapAggPP(df_data, df_true, L = 7)

The following model should fail to accurately infer the rate.

In [39]:
prng_key = jax.random.PRNGKey(42)
init_values = {
	'baseline': -df_true['log_exposure'].values.mean()
}
guide = AutoNormal(overlap_model.model, init_loc_fn=init_to_value(values=init_values))
overlap_model.run_inference_svi(prng_key, guide)

100%|██████████| 5000/5000 [00:33<00:00, 151.00it/s, init loss: 12159.3789, avg. loss [4751-5000]: 662.7852] 


In [40]:
post_pred = overlap_model.posterior_predictive_svi(prng_key, guide)
post_rate_sum = np.quantile(post_pred['rate'], q = (0.025, 0.5, 0.975), axis=0)

In [41]:
df_true['M'] = post_rate_sum[1][df_true['x'].values]
df_true['CL'] = post_rate_sum[0][df_true['x'].values]
df_true['CU'] = post_rate_sum[2][df_true['x'].values]

In [42]:
rate = (
	alt.Chart(df_true)
	.mark_line(color='red')
	.encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Intensity'),
	)
)

rate_est = (
	alt.Chart(df_true)
	.mark_line(color='green')  # Changed
	.encode(
		x=alt.X('x', title='x'),
		y=alt.Y('M', title='Intensity'),
	)
)

bands = alt.Chart(df_true).mark_errorband(color='green').encode(
	x=alt.X('x', title='x'),
	y=alt.Y('CL', title='Intensity'),
	y2=alt.Y2('CU', title='Intensity'),
)

(rate + rate_est + bands).properties(width=600, height=200)