In [1]:
from BI import bi, jnp
m = bi(platform='cpu')

jax.local_device_count 32


## Simulate large network data

In [2]:
# Covariates
N = 1000
nodal_predictor = m.dist.normal(0, 1, shape = (N,1),  name='Individual_predictor', sample = True)
dyadic_predictor = m.dist.binomial(probs = 0.5, shape = (N,N),  name='Dyadic_predictor', sample = True)
dyadic_predictor = dyadic_predictor.at[jnp.diag_indices_from(dyadic_predictor)].set(0)
dyadic_predictor =  m.net.mat_to_edgl(dyadic_predictor)
block_predictor = m.dist.categorical( jnp.ones(20),  name='Block_predictor', sample = True, shape=(N,))
N_groups, N_by_group = jnp.unique(block_predictor, return_counts=True)

In [3]:
# Sim network

B_intercept = m.net.block_model(jnp.zeros(N), N,N, sample = True)
B_predictor = m.net.block_model(block_predictor, 20, N_by_group, sample = True)
SR = m.net.sender_receiver(nodal_predictor, nodal_predictor, sample = True)
D = m.net.dyadic_effect(dyadic_predictor, sample = True)
network = m.dist.bernoulli(logits=B_intercept + B_predictor + SR + D, sample = True)

In [4]:
# Analytical model
def model(nodal_predictor, dyadic_predictor, block_predictor, network):
    B_intercept = m.net.block_model(jnp.zeros(N), N,N, sample = False, name = 'B_intercept')
    B_predictor = m.net.block_model(block_predictor, 20, N_by_group, sample = False, name='B_predictor')
    SR = m.net.sender_receiver(nodal_predictor, nodal_predictor, sample = False)
    D = m.net.dyadic_effect(dyadic_predictor, sample = False)
    m.dist.bernoulli(logits=B_intercept + B_predictor + SR + D,  obs = network, name = 'network')

In [5]:
m.data_on_model = dict(
    nodal_predictor = nodal_predictor, 
    dyadic_predictor = dyadic_predictor, 
    block_predictor = block_predictor, 
    network = network)
m.fit(model)

sample: 100%|██████████| 1000/1000 [3:54:33<00:00, 14.07s/it, 255 steps of size 2.28e-02. acc. prob=0.91]  


## Inspector object

In [6]:
def model(weight, height):    
    alpha = m.dist.normal( 178, 20, name = 'a')
    beta = m.dist.log_normal( 0, 1, name = 'b')   
    sigma = m.dist.uniform( 0, 50, name = 's')
    m.dist.normal(alpha + beta * weight , sigma, obs=height)


In [7]:
import numpyro
import numpyro.distributions as dist
from numpyro import handlers
from jax import random
import jax.numpy as jnp
import numpy as np

class ModelInspector:
    def __init__(self, model):
        """
        Initialize with the numpyro model function.
        """
        self.model = model

    def inspect(self, *model_args, rng_key=None, **model_kwargs):
        """
        Traces the model and returns a dict of sites with their distribution 
        details and internal parameters.
        
        Args:
            *model_args: Arguments required by the model function.
            rng_key: Optional JAX PRNGKey. Defaults to random.PRNGKey(0).
            **model_kwargs: Keyword arguments required by the model.
        """
        if rng_key is None:
            rng_key = random.PRNGKey(0)

        # 1. Trace the model execution
        seeded_model = handlers.seed(self.model, rng_key)
        trace = handlers.trace(seeded_model).get_trace(*model_args, **model_kwargs)

        # 2. Parse the trace
        site_details = {}

        for name, site in trace.items():
            if site['type'] == 'sample':
                # It's a distribution (e.g., Normal, LogNormal)
                dist_obj = site['fn']
                
                # Check if it is an observed site
                is_observed = site['is_observed']

                # Extract internal parameters dynamically
                # arg_constraints contains keys like 'loc', 'scale', 'probs', etc.
                params = {}
                if hasattr(dist_obj, 'arg_constraints'):
                    for param_name in dist_obj.arg_constraints.keys():
                        # Retrieve the value from the object
                        val = getattr(dist_obj, param_name)
                        # Convert JAX array to standard Python/Numpy for readability
                        params[param_name] = self._clean_value(val)

                site_details[name] = {
                    'type': 'sample',
                    'distribution': type(dist_obj).__name__,
                    'parameters': params,
                    'observed': is_observed
                }

            elif site['type'] == 'param':
                # It's a learnable parameter (numpyro.param)
                val = site['value']
                site_details[name] = {
                    'type': 'param',
                    'value': self._clean_value(val)
                }

        return site_details

    def _clean_value(self, val):
        """Helper to convert JAX arrays to native Python types or list."""
        if hasattr(val, 'tolist'):
            return val.tolist()
        return val

In [8]:
inspector = ModelInspector(model)
inspector.inspect()

TypeError: model() missing 2 required positional arguments: 'weight' and 'height'