In [1]:
import jax.numpy as jnp
from jax.scipy.special import i0  # Modified Bessel function of the first kind

def transition_density_cir(theta_t, theta_0,  b, t):
    """
    Compute the transition density for the CIR process using the provided formula.

    :param theta_t: The value at time t.
    :param theta_0: The initial value of the process.
    :param b: The speed of reversion parameter.
    :return: The transition density value.
    """
    c=1/(1-jnp.exp(-b*t))
    bessel_term = i0(2 * c * jnp.sqrt(theta_0 * theta_t * jnp.exp(-b * t)))
    density = c * jnp.exp(-c * (theta_0 * jnp.exp(-b *t) + theta_t))  * bessel_term

    return density

In [10]:
import jax.random as jr
def _random_chi2(key, df, shape=(), dtype=jnp.float_):
    return 2.0 * jr.gamma(key, 0.5 * df, shape=shape, dtype=dtype)

def sample_from_ncx2(key, df, nc, sample_shape=()):
    
    shape = sample_shape + jnp.shape(df) + jnp.shape(nc)

    key1, key2, key3 = jr.split(key, 3)

    i = jr.poisson(key1, 0.5 * nc, shape=shape)
    n = jr.normal(key2, shape=shape) + jnp.sqrt(nc)
    cond = jnp.greater(df, 1.0)
    chi2 = _random_chi2(key3, jnp.where(cond, df - 1.0, df + 2.0 * i), shape=shape)
    return jnp.where(cond, chi2 + n * n, chi2)

def sample_CIR(key, theta_0, a, b, t):
    exp_bt = jnp.exp(-b * t)
    d = 2 * a  # degrees of freedom
    mu = 2 * theta_0 * exp_bt / (1 - exp_bt)  # non-centrality parameter

    # Sample from the non-central chi-squared distribution
    theta_t_sample = sample_from_ncx2(key, df=d, nc=mu)

    theta_t = (1 - exp_bt) / 2 * theta_t_sample
    return theta_t

def sample_dirichlet_from_cir(key, thetas, alphas, b, T):
    cir_processes = jnp.zeros((len(alphas)))
    keys = jr.split(key, len(alphas))  # Split the key for each dimension

    for i, (alpha, k) in enumerate(zip(alphas, keys)):
        theta = thetas[i]  # initial value
        theta = sample_CIR(k, theta, alpha, b, T)
        cir_processes = cir_processes.at[i].set(theta)

    cir_final_values = cir_processes
    dirichlet_sample = cir_final_values / cir_final_values.sum()

    return dirichlet_sample

key = jr.PRNGKey(0)  # Initialize a PRNG key

thetas = jnp.array([0, 0, 1])
alphas = jnp.array([2, 3, 4])
b = 1
T = 10

sample = sample_dirichlet_from_cir(key, thetas, alphas, b, T)
sample

  return 2.0 * jr.gamma(key, 0.5 * df, shape=shape, dtype=dtype)


Array([0.34041935, 0.42871287, 0.23086777], dtype=float32)

In [None]:
def calculate_dirichlet_statistics(alphas, samples):
    """
    Calculate empirical statistics (mean and variance) for a set of samples from a Dirichlet distribution.

    :param alphas: Parameters of the Dirichlet distribution.
    :param samples: Samples generated from the Dirichlet distribution.
    :return: Empirical mean and variance for each component of the Dirichlet distribution.
    """
    empirical_mean = np.mean(samples, axis=0)
    empirical_variance = np.var(samples, axis=0)
    return empirical_mean, empirical_variance

def theoretical_dirichlet_statistics(alphas):
    """
    Calculate theoretical mean and variance for a Dirichlet distribution.

    :param alphas: Parameters of the Dirichlet distribution.
    :return: Theoretical mean and variance for each component of the Dirichlet distribution.
    """
    alpha_sum = sum(alphas)
    theoretical_mean = [alpha / alpha_sum for alpha in alphas]
    theoretical_variance = [(alpha * (alpha_sum - alpha)) / (alpha_sum**2 * (alpha_sum + 1)) for alpha in alphas]
    return theoretical_mean, theoretical_variance



num_samples = 10000
num_steps = 50
sample = jnp.array([sample_dirichlet_from_cir(thetas, alphas, b, T) for _ in range(num_samples_reduced)])

# Calculate empirical and theoretical statistics for the  sample set
empirical_mean_reduced, empirical_variance_reduced = calculate_dirichlet_statistics(alphas, samples_reduced)
theoretical_mean, theoretical_variance = theoretical_dirichlet_statistics(alphas)

(empirical_mean_reduced, empirical_variance_reduced), (theoretical_mean, theoretical_variance)