In [1]:
import os
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
from inference.spatial_compartmental.viz import  posterior_comparison, post_prediction_comparison
import jax
jax.config.update("jax_enable_x64", True) #64 bit precision calcs
jax.config.update('jax_platform_name', 'cpu')
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import gc
from abm.spatial_compartmental.sir import get_abm
from abm.spatial_compartmental.utils import Neighbourhood, calc_start_n_initial
import pickle
import arviz as az
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [22]:
def vectorized_abm_state_counts(abm_grids_batch):
    """
    Calculates state counts for a batch of ABM grids using JAX.

    Args:
        abm_grids_batch (jnp.ndarray): An array of shape (batch_size, steps, 3, H, W)
                                       representing all grids at a single time step.

    Returns:
        dict: A dictionary where keys are state names (str) and values are
              JAX arrays (shape (batch_size,)) containing the counts of that
              state for each grid in the batch.
    """
    counts = {}
    # Susceptible: (layer0==0) & (layer1==1) & (layer2==0)
    counts["Susceptible"] = jnp.sum(
        (abm_grids_batch[:, :, 0, :, :] == 0)
        & (abm_grids_batch[:, :, 1, :, :] == 1)
        & (abm_grids_batch[:, :, 2, :, :] == 0),
        axis=(2, 3),
    ).astype(jnp.int32)

    # Infected: (layer0==1) & (layer1==0) & (layer2==0)
    counts["Infected"] = jnp.sum(
        (abm_grids_batch[:, :, 0, :, :] == 1)
        & (abm_grids_batch[:, :, 1, :, :] == 0)
        & (abm_grids_batch[:, :, 2, :, :] == 0),
        axis=(2, 3),
    ).astype(jnp.int32)

    # Recovered: (layer0==0) & (layer1==0) & (layer2==1)
    counts["Recovered"] = jnp.sum(
        (abm_grids_batch[:, :, 0, :, :] == 0)
        & (abm_grids_batch[:, :, 1, :, :] == 0)
        & (abm_grids_batch[:, :, 2, :, :] == 1),
        axis=(2, 3),
    ).astype(jnp.int32)
    return counts

In [47]:
#NOTE: Requires 64-bit precision
def calculate_NLPD(key, abm_conf, posterior_preds, grid_size, num_steps):
    key, subkey = jax.random.split(key)
    kde_est =jax.scipy.stats.gaussian_kde(posterior_preds.T+1e-6*jax.random.normal(subkey, 
                                                                                   (num_steps*2, posterior_preds.shape[0])))
    abm = get_abm(Neighbourhood.VONNEUMANN, vmap=True)
    multi_grid_timeseries = abm(jax.random.split(key,posterior_preds.shape[0]), 
            grid_size,
            num_steps,
            abm_conf[0],
            abm_conf[1],
            abm_conf[2],
            abm_conf[3],
            calc_start_n_initial(abm_conf[4], grid_size),
            calc_start_n_initial(abm_conf[5], grid_size),
            calc_start_n_initial(abm_conf[6], grid_size))
    state_count_timeseries = vectorized_abm_state_counts(multi_grid_timeseries)
    #state_count_timeseries['Susceptible'][:,1:],
    ground_truth_observations = jnp.concatenate(( state_count_timeseries['Infected'][:,1:],
                                                 state_count_timeseries['Recovered'][:,1:]),axis=1)

    log_densities = kde_est.logpdf(ground_truth_observations.astype(jnp.float64).T)
    #print(log_densities)
    averaged_nlpd = -np.mean(log_densities)

    if jnp.isnan(averaged_nlpd).any():
        raise ValueError(f'NLPD for {abm_conf} contains NaNs')
    return averaged_nlpd

# ABM NLPDS

In [None]:
# These are all cat(infected,recovered)


with open('inference_outputs/2025-06-12/07-45-52/post_predictions.npy','rb') as f:
    abm_a = jnp.load(f, allow_pickle=True).item()

with open('inference_outputs/2025-06-12/07-52-12/post_predictions.npy','rb') as f:
    abm_b = jnp.load(f, allow_pickle=True).item()

with open('inference_outputs/2025-06-12/07-54-50/post_predictions.npy','rb') as f:
    abm_c = jnp.load(f, allow_pickle=True).item()

In [49]:
for idx,(name,abm_dict) in enumerate(zip(['A','B','C'],[abm_a, abm_b, abm_c])):

    abm_preds = list(abm_dict.values())[0]['obs']
    #infect_preds, rec_preds = jnp.split(abm_preds,2, axis=1)
    #suscep_preds = 100 - infect_preds - rec_preds
    
    #abm_preds = jnp.concat((suscep_preds, infect_preds, rec_preds),axis=1)

    abm_preds=jax.random.choice(jax.random.fold_in(jax.random.key(72),idx), abm_preds, 
                                shape=(1000,), replace=False)

    try:
        nlpd = calculate_NLPD(jax.random.fold_in(jax.random.key(1337),idx),
                          list(abm_dict.keys())[0],
                          abm_preds,10, 30)
        print(f'ABM {name} NLPD: {nlpd:.3e}')
    except:
        print(f'ABM {name} NLPD: NaN')

ABM A NLPD: 1.217e+02
ABM B NLPD: NaN
ABM C NLPD: 1.576e+02


## MCMLP NLPDS

In [None]:

with open('inference_outputs/2025-06-12/07-30-02/post_predictions.npy','rb') as f:
    mcmlp_a = jnp.load(f, allow_pickle=True).item()

with open('inference_outputs/2025-06-12/07-35-35/post_predictions.npy','rb') as f:
    mcmlp_b = jnp.load(f, allow_pickle=True).item()

with open('inference_outputs/2025-06-12/07-36-23/post_predictions.npy','rb') as f:
    mcmlp_c = jnp.load(f, allow_pickle=True).item()

In [51]:
for idx,(name,mcmlp_dict) in enumerate(zip(['A','B','C'],[mcmlp_a, mcmlp_b, mcmlp_c])):

    mcmlp_preds = list(mcmlp_dict.values())[0]['obs']
    suscep_preds,infect_preds, rec_preds = jnp.split(mcmlp_preds,3, axis=1)
    
    mcmlp_preds = jnp.concat((infect_preds, rec_preds),axis=1)

    mcmlp_preds=jax.random.choice(jax.random.fold_in(jax.random.key(72),idx), mcmlp_preds, 
                                shape=(1000,), replace=False)

    try:
        nlpd = calculate_NLPD(jax.random.fold_in(jax.random.key(1337),idx),
                          list(mcmlp_dict.keys())[0],
                          mcmlp_preds,10, 30)
        print(f'MCMLP {name} NLPD: {nlpd:.3e}')
    except:
        print(f'MCMLP {name} NLPD: NaN')

MCMLP A NLPD: 3.631e+03
MCMLP B NLPD: 3.308e+03
MCMLP C NLPD: 2.193e+03
