## Load Libraries

In [1]:
import numpy as np
import pandas as pd

import jax

from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.initialization import init_to_value

from aggpp.sim import generate_data, assign_disjoint_bins
from aggpp.models import DisjointAggPP

# Visualization libraries
import altair as alt

## Define helpers

## Generate Data

Generate data.

In [2]:
N = 1000
df_true, df_data = generate_data(N, seed=42)

In [3]:
rate = (
	alt.Chart(df_data).mark_line(color='#de425b').encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Rate'),
	)
)

count = alt.Chart(df_data).mark_point(size=3).encode(
	x=alt.X('x', title='x'),
	y=alt.Y('y', title='Count'),
)

(rate + count).properties(width=600, height=200)

In [4]:
df_data['range'] = assign_disjoint_bins(df_data['x'].values)

tmp = df_data.groupby('range', observed=True).agg(
	y_agg = ('y', 'sum'),
	N = ('y', 'count')
).reset_index()

df_data = df_data.merge(tmp, on='range', how='left')
df_data['rate_agg'] = df_data['y_agg'] / df_data['N']
df_data.head(5)

Unnamed: 0,x,log_exposure,f,rate,y,range,y_agg,N,rate_agg
0,82,-1.025,-0.836819,0.15539,0,"[80, 85)",4,35,0.114286
1,86,-1.075,-0.056807,0.32245,1,"[85, 90)",26,68,0.382353
2,74,-0.925,-0.315744,0.289169,0,"[70, 75)",34,42,0.809524
3,74,-0.925,-0.315744,0.289169,0,"[70, 75)",34,42,0.809524
4,87,-1.0875,0.165866,0.397869,0,"[85, 90)",26,68,0.382353


In [5]:
# Create a clean dataset without the Interval objects for plotting
df_plot = df_data.drop(columns=['range'])

rate = (
	alt.Chart(df_plot)
	.mark_line(color='red')
	.encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Rate'),
	)
)

count = alt.Chart(df_plot).mark_point(size=2).encode(
	x=alt.X('x', title='x'),
	y=alt.Y('y', title='Count'),
)

rate_agg = (
	alt.Chart(df_plot)
	.mark_line(color='orange', strokeDash=[5, 5])  # Changed color and added dashed line to distinguish
	.encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate_agg', title='Aggregated Rate'),
	)
)

(rate_agg + rate + count).properties(width=600, height=200)

The following model should fail to accurately infer the rate.

In [26]:
model = DisjointAggPP(df_data, df_true, L = 7)

In [27]:
prng_key = jax.random.PRNGKey(42)
init_values = {'baseline': -df_true['log_exposure'].values.mean()}
guide = AutoNormal(model.model, init_loc_fn=init_to_value(values=init_values))
model.run_inference_svi(prng_key, guide)

100%|██████████| 5000/5000 [00:39<00:00, 126.28it/s, init loss: 6953.7036, avg. loss [4751-5000]: 315.2757]


In [28]:
post_pred = model.posterior_predictive_svi(prng_key, guide)

In [29]:
post_rate_sum = np.quantile(post_pred['rate'], q = (0.025, 0.5, 0.975), axis=0)
df_true['M'] = post_rate_sum[1][df_true['x'].values]
df_true['CL'] = post_rate_sum[0][df_true['x'].values]
df_true['CU'] = post_rate_sum[2][df_true['x'].values]

In [30]:
rate = (
	alt.Chart(df_plot)
	.mark_line(color='red')
	.encode(
		x=alt.X('x', title='x'),
		y=alt.Y('rate', title='Intensity', scale=alt.Scale(domain=(0, 12))),
	)
)

rate_est = (
	alt.Chart(df_true)
	.mark_line(color='green')  # Changed
	.encode(
		x=alt.X('x', title='x'),
		y=alt.Y('M', title='Intensity'),
	)
)

bands = alt.Chart(df_true).mark_errorband(color='green').encode(
	x=alt.X('x', title='x'),
	y=alt.Y('CL', title='Intensity'),
	y2=alt.Y2('CU', title='Intensity'),
)

(rate + rate_est + bands).properties(width=600, height=200)