In [1]:
import functools as ft

import jax
import jax.numpy as jnp
from jax import jit, lax, vmap
from numpyro.distributions import StudentT

@ft.partial(jit, static_argnames=('alternative', 'mu', 'conf_level'))
def _t_test_single(x, alternative="two.sided", mu=0, conf_level=0.95):
    nx = x.shape[0]
    mx = jnp.mean(x, axis=-1)
    vx = jnp.var(x, axis=-1)

    df = nx - 1
    stderr = jnp.sqrt(vx/nx)

    t_stat = (mx - mu) / stderr

    t_dist = StudentT(df=df, loc=0, scale=1)

    pval = 2 * t_dist.cdf(-jnp.abs(t_stat))

    alpha = 1. - conf_level

    conf_int  = t_dist.icdf(1 - alpha / 2)

    conf_int = jnp.array([-conf_int, conf_int]) + t_stat

    conf_int = conf_int * stderr + mu


    return t_stat, pval, conf_int

ImportError: cannot import name 'betainc' from 'numpyro.distributions.util' (/opt/miniconda3/envs/jax/lib/python3.10/site-packages/numpyro/distributions/util.py)

In [14]:
X = jnp.array([0.3, 0.4, -0.2, 0.04, 0.07, -0.8, 0.5])

In [15]:
_t_test_single(X)

NotImplementedError: 

In [2]:
import functools as ft

import jax.numpy as jnp
from jax import jit, vmap
from jax.scipy.special import betainc
from tensorflow_probability.substrates.jax.math import special as tfp_special

def pt(x, df, loc=0., scale=1.):
    p = vmap(_pt, in_axes=(0, None, None, None))(x, df, loc, scale)
    return p

@ft.partial(jit, static_argnames=("loc", "scale",))
def _pt(x, df, loc=0., scale=1.):

    scaled = (x - loc) / scale
    scaled_squared = scaled * scaled
    beta_value = df / (df + scaled_squared)

    # when scaled < 0, returns 0.5 * Beta(df/2, 0.5).cdf(beta_value)
    # when scaled > 0, returns 1 - 0.5 * Beta(df/2, 0.5).cdf(beta_value)
    return 0.5 * (
        1
        + jnp.sign(scaled)
        - jnp.sign(scaled) * betainc(0.5 * df, 0.5, beta_value)
    )

In [11]:
x = jnp.array([0., 0., 0.])

In [15]:
%timeit -n10 -r3 pt(x, 2).block_until_ready()

344 µs ± 38.3 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)


In [3]:
def qt(q, df, loc=0., scale=1.):
    q = vmap(_qt, in_axes=(0, None, None, None))(q, df, loc, scale)
    return q

@ft.partial(jit, static_argnames=("loc", "scale",))
def _qt(q, df, loc=0., scale=1.):
    beta_value = tfp_special.betaincinv(0.5 * df, 0.5, 1 - jnp.abs(1 - 2 * q))
    scaled_squared = df * (1 / beta_value - 1)
    scaled = jnp.sign(q - 0.5) * jnp.sqrt(scaled_squared)
    return scaled * scale + loc

In [19]:
x = jnp.array([0.5, 0.5, 0.5])

In [21]:
%timeit -n10 -r3 qt(x, 2).block_until_ready()

358 µs ± 116 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)


In [15]:
@ft.partial(jit, static_argnames=('alternative', 'mu', 'conf_level',))
def _t_test_single(x, alternative="two.sided", mu=0, conf_level=0.95):
    nx = x.shape[0]
    mx = jnp.mean(x, axis=-1, keepdims=True)
    vx = jnp.var(x, axis=-1, keepdims=True)

    df = nx - 1
    stderr = jnp.sqrt(vx/nx)
    t_stat = (mx - mu) / stderr
    pval = 2 * pt(-jnp.abs(t_stat), df)
    alpha = jnp.array([1. - conf_level], dtype=x.dtype)
    conf_int  = qt(1 - alpha / 2, df)
    conf_int = jnp.array([-conf_int, conf_int]) + t_stat
    conf_int = conf_int * stderr + mu

    return t_stat, pval, conf_int

In [20]:
x = jnp.zeros((100, ))

In [11]:
jnp.mean(x, keepdims=True)

DeviceArray([0.01857143], dtype=float32)

In [28]:
%timeit -n10 -r3 _t_test_single(x)

The slowest run took 7.01 times longer than the fastest. This could mean that an intermediate result is being cached.
34 µs ± 32 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)


In [1]:
import jax.random as jrand
import jax.numpy as jnp

In [2]:
key = jrand.PRNGKey(0)

In [3]:
shape = (1000, )
rg = jrand.gamma(key, 1.5, shape)

In [4]:
jnp.mean(rg)

DeviceArray(1.5262381, dtype=float32)