In [3]:
import jax.numpy as jnp
from jax import random

def jax_credibility_interval(data=None, dist_func=None, credibility_level=0.90, key=None):
    """
    Calculate a central credibility interval based on percentiles using JAX.
    
    Parameters:
    -----------
    data : jnp.array, optional
        Samples from the posterior distribution
    dist_func : callable, optional
        Function that generates samples when called with a key and size
    credibility_level : float, default=0.90
        The desired credibility level (between 0 and 1)
    key : jax.random.PRNGKey, optional
        Random key for JAX's random number generation
    
    Returns:
    --------
    tuple
        (lower_bound, upper_bound) representing the credibility interval
    """
    if data is None and dist_func is None:
        raise ValueError("Either data or dist_func must be provided")
    
    # If no data is provided but we have a distribution function, generate samples
    if data is None:
        if key is None:
            key = random.PRNGKey(0)
        data = dist_func(key, 10000)
    
    # Calculate the percentiles based on the credibility level
    alpha = (1 - credibility_level) / 2
    
    # Sort the data for percentile calculation
    sorted_data = jnp.sort(data)
    n = len(sorted_data)
    
    # Calculate indices for the percentiles
    lower_idx = jnp.floor(n * alpha).astype(int)
    upper_idx = jnp.floor(n * (1 - alpha)).astype(int)
    
    # Get the values at those indices
    lower_bound = sorted_data[lower_idx]
    upper_bound = sorted_data[upper_idx]
    
    return lower_bound, upper_bound

# Example usage:
# (This would be replaced with your actual code)
def example_usage():
    # Initialize a random key
    key = random.PRNGKey(42)
    
    # Example 1: Using pre-computed data
    from jax.random import beta
    samples = beta(key, a=10.0, b=3.0, shape=(10000,))
    lower, upper = jax_credibility_interval(data=samples, credibility_level=0.90)
    print(f"90% Credibility Interval: [{lower:.4f}, {upper:.4f}]")
    
    # Example 2: Using a distribution function
    def beta_sampler(key, size):
        return beta(key, a=10.0, b=3.0, shape=(size,))
    
    new_key = random.split(key)[0]
    lower, upper = jax_credibility_interval(
        dist_func=beta_sampler, 
        credibility_level=0.95,
        key=new_key
    )
    print(f"95% Credibility Interval: [{lower:.4f}, {upper:.4f}]")

# This shows the function signature and example usage


example_usage(  )

90% Credibility Interval: [0.5628, 0.9294]
95% Credibility Interval: [0.5211, 0.9459]
