# Bayesian statistics with Jax/Numpyro

In [None]:
import arviz as az
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
import pandas as pd
import scipy.stats as st

from numpyro.infer import MCMC, NUTS, Predictive

In [None]:
!pip install numpyro

## Bayesian refresher

### Bayes law

Bayes law is given by:

$$P(A | B) = \frac{P(B | A) P(A)}{P(B)}$$

**Obligatory medical test example**

Prior to discussing Bayes law in the context of statistical inference, we believe that it is productive to begin by examining a commonly used example for introducing Bayes law.

Consider a disease that affects 0.1\% of the population (i.e. 1 in 1,000 individuals actually have the disease). Suppose a country is screening individuals at an airport for whether they have this disease or not. They are administering a test with a sensitivity of 99\% and a specificity of 90\%. Now, suppose that an individual entering the country has tested
positive for the disease, what is the probability that they have the disease?

One of the surprising aspects of this problem is that many people "intuitively" believe that the probability the individual has the disease is relatively high. However, the actual probability that the individual has the disease is <1\%.

We can get this answer by applying Bayes law. The probability of interest ($P(A | B)$) is "the probability that the individual has the disease given they tested positive for the disease". So let $A$ be defined as "individual has the disease" and $B$ be defined as "tested positive for the disease".

With these defined, we can start calculating pieces of Bayes law:

* $P(B | A)$: This is the probability the test is positive given the individual does in fact have the disease. This is given by the sensitivity (aka 0.99).
* $P(A)$: This is the probability of having the disease with no other information which is
  0.001 since 0.1\% of the population has the disease.
* $P(B)$: This is the probability of getting a positive test which can be broken into two separate pieces -- You either test positive and have the disease or you test positive and do not have the disease (i.e. $P(B) = P(B | A) P(A) + P(B | ~A) P(~A)$). We already have
  identified $P(B | A)$ and $P(A)$ above and we know that $P(~A) = 1 - P(A)$.
  The specificity of the test gives us $P(~B | ~A)$ and $P(B | ~A) = 1 - P(~B | ~A)$.

In [None]:
p_bga = 0.99
p_a = 0.001
p_b = 0.99*p_a + 0.90*(1 - p_a)

In [None]:
p_bga*p_a / p_b

## Bayesian statistics

Like other types of statistics, Bayesian statistics is focused on the inverse problem:

> Given some data, $Y$, and a parameterized class of models, $f(\Theta)$, find the parameter vector $\theta \in \Theta$ that could have produced the data.

In Bayesian statistics we recover a "joint distribution of parameters" that could have generated the data. We use Bayes law as below to do this:

$$\underbrace{P(\theta | Y)}_{\text{posterior}} = \frac{\overbrace{P(Y | \theta)}^{\text{likelihood}} \overbrace{P(\theta)}^{\text{prior}}}{\underbrace{P(Y)}_{\text{normalizing component}}}$$

We have labeled each of these components according to names that are commonly used to describe
them:


**Normalizing component**: $P(Y)$

The normalizing component is often ignored because it's simply the value needed to ensure that the posterior is a probability distribution -- In fact, you will often find Bayes law written as $P(\theta | Y) \propto P(Y | \theta) P(\theta)$


**Likelihood**: $P(Y | \theta)$

We refer to this term as the likelihood and it establishes how likely it was to observe certain realizations o the data for a given parameter $\theta$.


**Prior**: $P(\theta)$

The prior is what most people would identify as the "defining feature" of Bayesian statistics. The
prior specifies the "prior belief" that the statistician assigns to different parameter values
_before_ having seen any data.

_Why priors?_

One often begins their inference process with some idea about what parameters make sense and which ones don't. A prior reflects the subjective beliefs of the statistician who is running the analysis (or the beliefs of the audience that they are trying to convince).

For example, suppose you were modeling a demand curve. We have come to accept that typically demand curves slope downwards which is a belief that we could express with a prior.

A natural question that follows the introduction of the idea of a prior is, "doesn't a prior make your analysis subjective rather than strictly objective?" Yes, it does. If you would like to understand why we don't think that this is a problem, we would reference

Some great discusson on priors

* https://stat.columbia.edu/~gelman/research/published/philosophy_chapter.pdf
* https://statmodeling.stat.columbia.edu/2016/12/13/bayesian-statistics-whats/
* https://projecteuclid.org/journals/bayesian-analysis/volume-3/issue-3#toc

**Posterior**: $P(\theta | Y)$

The posterior is the conditional distribution that describes the statistician's beliefs about parameter values given the data that they have observed. All of Bayesian statistics will depend on being able to find (and sample from) the posterior.

### "Old school" Bayesian stats

Bayesian statistics used to mostly use ideas of a "conjugate prior" to derive the posterior.

If the prior, $P(\theta)$, and the posterior, $P(\theta | Y)$, belong to the same probability distribution family for a specified likelihood, $P(Y | \theta)$, then we say the prior is a conjugate prior for the likelihood.

Conjugate priors are convenient because they allow us to have a closed-form expression for the posterior. We will examine a single example for today but you can reference the [Wikipedia conjugate priors table](wiki:Conjugate_prior#Table_of_conjugate_distributions) to find more!

**Beta-Bernoulli**

There is a known parameter $n$ which is the number of Bernoulli draws and an unknown parameter $p$ which specifies the probability of success for any Bernoulli trial.

* The prior is specified by the beta distribution with parameters $\alpha$ and $\beta$
* The likelihood is specified by the Binomial distribution with parameters $p$ and $n$

The distribution functions associated with the prior and likelihood are given by

* Likelihood, $P(k | p, n) = {n \choose{k}} p^k (1 - p)^{n-k}$
* Prior, $P(p) = \frac{p^{\alpha - 1} (1 - p)^{\beta - 1}}{B(\alpha, \beta)}$

Thus

\begin{align*}
  P(p | k, \alpha, \beta, n) &\propto {n \choose{k}} p^k (1 - p)^{n-k} \frac{p^{\alpha - 1} (1 - p)^{\beta - 1}}{B(\alpha, \beta)} \\
  &\propto p^{k + \alpha - 1} (1 - p)^{n - k + \beta - 1} \frac{{n \choose{k}}}{B(\alpha, \beta)} \\
  &\dots \\
\end{align*}

If one were to finish writing out the algebra, you would fine that the posterior is given by a $\text{Beta}(\alpha + k, \beta + (n - k))$

**Example**:

Imagine that we'd like to know how likely a student is to pass their PhD qualifying exams.

* The PhD program only admits students that they think are likely to pass, so we begin with a
  prior $P(p) \sim \text{Beta}(8, 2)$
* The student takes 8 classes prior to taking the qualifying exam and passes each class with the same probability that they pass the qualifying exams
* The student has successfully passed 7 of their 8 classes

Since we are using a conjugate-prior, the posterior can be written as

$$P(\theta | k, \alpha, \beta, n) = \text{Beta}(\alpha + k, \beta + 1)$$

We can generate this posterior in Python using the code below:

In [None]:
# hyper parameters
alpha, beta = 8, 2
n = 8

# Data
k = 7

bb_prior = st.beta(alpha, beta)
bb_posterior = st.beta(alpha + k, beta + (n-k))

In [None]:
fig, ax = plt.subplots()

p = np.linspace(0, 1, 100)

f_prior = bb_prior.pdf(p)
f_posterior = bb_posterior.pdf(p)

ax.plot(p, f_prior, "k")
ax.plot(p, f_posterior, "k--")

In [None]:
bb_posterior.rvs(1000).mean()

### "New school" Bayesian stats

Old school Bayesian statistics has one big limitation -- You have to choose distributions that fit into this relatively small set of conjugate-priors, however, there are lots of interesting problems that don't fit into this set...

So how could we use Bayesian statistics to solve some of these problems?

Would it be enough if we could sample from the posterior? If so, how could I sample from the posterior?

Markov chain Monte Carlo (MCMC) methods allow us to do exactly this.

Two papers initially proposed a version of this.

1. _Equation of State Calculations by Fast Computing Machines_ by Nicholas Metropolis, Arianna W. Rosenbluth, Marshall Rosenbluth, Augusta H. Teller, Edward Teller
2. _Monte Carlo Sampling Methods Using Markov Chains and Their Applications_ by W. K. Hastings

The key idea of MCMC methods is to construct a Markov chain such that the stationary distribution of the Markov chain corresponds to the posterior distribution so if we sample from the Markov chain's stationary distribution then we're drawing samples from our posterior!

[Online sampling tool](https://chi-feng.github.io/mcmc-demo/app.html)

The question remains, how can we construct this "magic" Markov chain that has a stationary distribution that happens to be the same as our posterior...

We're going to let "probabilistic programming languages", namely, [numpyro](https://num.pyro.ai/en/stable/) do this for us.

> Note: [PyMC](https://www.pymc.io/welcome.html) is another great probabilistic programming language that we've used quite a bit -- I'm covering numpyro today largely because of a [blog post by Bob Carpenter](https://statmodeling.stat.columbia.edu/2025/10/03/its-a-jax-jax-jax-jax-world/) about numpyro's performance dominance.

#### Simple example

We're going to start exploring how to do this in a very simple model and will then do a more interesting model.

Consider the following model of GDP growth:

$$y_t = \theta_y + \sigma \varepsilon_t$$


where $\varepsilon_t \sim N(0, 1)$ and suppose we know that $\sigma = 1.5$.

Our goal is to find the posterior of the parameter $\theta_y$.

We choose a prior over $\theta_y$ of $f(\theta_y) = N(1, 4)$ and the likelihood that we derive from our model above is $f(y_t | \theta) \sim N(\theta_y, 1.5^2)$

In [None]:
# GDP growth
y = np.array([
    2.18, 1.49, 0.92, -0.17, 0.51, 1.52, 1.29, 0.94, 1.57, 1.69, 1.48, 1.74,
    0.73, 1.18, 1.07, 1.91, 1.45, 1.84, 1.16, 1.69, 0.9, 0.78, 1.35, 1.16,
    1.23, 2.09, 1.23, 1.58, 1.25, 1.87, 1.69, 1.19, 1.15, 1.16, 1.69, 1.9,
    1.33, 1.14, 1.66, 2.25, 1.05, 2.45, 0.7, 1.16, 0.32, 1.19, -0.01, 0.6,
    1.21, 0.97, 0.91, 0.72, 1.01, 1.16, 2.25, 1.75, 1.28, 1.58, 1.61, 1.78,
    1.91, 1.17, 1.8, 1.44, 2.04, 1.07, 0.86, 1.22, 1.22, 1.22, 1.06, 1.01,
    -0.21, 1.06, 0.2, -1.86, -1.13, -0.29, 0.47, 1.44, 0.64, 1.39, 1.03, 1.07,
    0.3, 1.38, 0.62, 1.31, 1.41, 0.83, 0.65, 0.63, 1.29, 0.41, 1.27, 1.39,
    0.13, 1.92, 1.66, 0.72, 0.86, 1.22, 0.68, 0.17
])



In [None]:
def gdp_model(data=None):
    # Prior for the mean GDP growth
    theta_y = numpyro.sample("theta_y", dist.Normal(1.0, 2.0))
    
    # Likelihood with fixed sigma=1.5
    numpyro.sample("obs_ys", dist.Normal(theta_y, 1.5), obs=data)

# Random keys for JAX
rng_key = jax.random.PRNGKey(20251203)
rng_key, rng_key_infer, rng_key_prior, rng_key_post = jax.random.split(rng_key, 4)

gdp_kernel = NUTS(gdp_model)
gdp_mcmc = MCMC(gdp_kernel, num_warmup=1000, num_samples=2000)
gdp_mcmc.run(rng_key_infer, data=y)
gdp_posterior_samples = gdp_mcmc.get_samples()

In [None]:
from numpyro.infer import 

In [None]:
prior_predictive = Predictive(gdp_model, num_samples=2000)
prior_samples = prior_predictive(rng_key_prior, data=None)

posterior_predictive = Predictive(gdp_model, posterior_samples=gdp_posterior_samples)
post_pred_samples = posterior_predictive(rng_key_post, data=None)

In [None]:
fig = plt.figure(figsize=(12, 10))
gs = fig.add_gridspec(2, 2)

# 1. Top Plot: Posterior Predictive (Spans both columns)
ax_top = fig.add_subplot(gs[0, :])
az.plot_dist(
    prior_samples['obs_ys'].flatten(),
    ax=ax_top, color="C1", label="Prior Predictive"
)
az.plot_dist(
    post_pred_samples['obs_ys'].flatten(),
    ax=ax_top, color="C2", label="Posterior Predictive"
)
ax_top.hist(y, label="Observed Data", density=True)
ax_top.set_title("Posterior Predictive Distribution (GDP Growth)")
ax_top.set_xlabel("GDP Growth (%)")
ax_top.legend()

# 2. Bottom Left: Prior vs Posterior (Parameters)
ax_bl = fig.add_subplot(gs[1, 0])
# Plot Prior density for theta_y
az.plot_dist(prior_samples['theta_y'], ax=ax_bl, color="C0", label="Prior")
# Plot Posterior density for theta_y
az.plot_dist(gdp_posterior_samples['theta_y'], ax=ax_bl, color="C2", label="Posterior")
ax_bl.set_title("Prior vs Posterior Distribution of $\\theta_y$")
ax_bl.set_xlabel("$\\theta_y$ Value")
ax_bl.legend()

# 3. Bottom Right: Time-series of traces
ax_br = fig.add_subplot(gs[1, 1])
ax_br.plot(gdp_posterior_samples['theta_y'], alpha=0.7, color="C2")
ax_br.set_title("Trace of $\\theta_y$")
ax_br.set_xlabel("Iteration")
ax_br.set_ylabel("Parameter Value")

plt.tight_layout()
plt.show()

#### CAPM example

Now we do a slightly more interesting model. We will build a Bayesian version of the CAPM regression.

Our model will be described by a few versions of the following equation:

\begin{align*}
   r_{i, t} - r_{f, t} &= \beta_i (r_{m, t} - r_{f, t}) + \sigma_i \varepsilon_{i, t}\\
   \sigma_i &\sim \text{HalfNormal}(4) \\
   \beta_i &\sim N(0, 5)
\end{align*}

In [None]:
# Data for regression
data = pd.read_parquet(
    "https://rice.box.com/shared/static/wbwwg1336g343bauhal74ryz365v0gge.parquet"
)

In [None]:
data.head()

##### Estimating CAPM a single stock at a time

We could estimate the CAPM for a single stock at a time -- Here we estimate the $\beta$ for Zoom.

In [None]:
zm_returns = data.query("ticker == 'ZM'")
ri_m_rf = zm_returns.eval("returns - riskfree").to_numpy()
rm_m_rf = zm_returns.eval("market - riskfree").to_numpy()

In [None]:
def capm_zm(ri_m_rf, rm_m_rf):
    # CAPM parameters
    beta_i = numpyro.sample(
        "beta_i",
        dist.Normal(0.0, 3.0)
    )
    sigma_i = numpyro.sample(
        "sigma_i",
        dist.HalfNormal(4.0)
    )

    # Likelihood
    ll = numpyro.sample("ll", dist.Normal(beta_i*rm_m_rf, sigma_i), obs=ri_m_rf)


# Random keys for JAX
rng_key = jax.random.PRNGKey(20251203)

capm_zm_kernel = NUTS(capm_zm)
capm_zm_mcmc = MCMC(capm_zm_kernel, num_warmup=1000, num_samples=2000)
capm_zm_mcmc.run(rng_key, ri_m_rf=ri_m_rf, rm_m_rf=rm_m_rf)

capm_zm_posterior_samples = capm_zm_mcmc.get_samples()

In [None]:
capm_zm_posterior_samples["beta_i"].mean()

In [None]:
capm_zm_posterior_samples["beta_i"].std()

In [None]:
capm_zm_posterior_samples["sigma_i"].mean()

In [None]:
capm_zm_posterior_samples["sigma_i"].std()

In [None]:
az.plot_trace(az.from_numpyro(posterior=capm_zm_mcmc))

Zoom is estimated as having a slightly negative $\beta$! What does this mean for a stock?

##### Estimating all "amnesia" models at once

We could estimate $\beta_i$ for all stocks in the SP500 by doing a "vectorized" like calculation.

In [None]:
# Basic data
tickers = data["ticker"].unique()
ntickers = tickers.shape[0]
ticker_2_int = dict(zip(tickers, range(ntickers)))
int_2_ticker = {v: k for k, v in ticker_2_int.items()}

ri_m_rf = data.eval("returns - riskfree").to_numpy()
rm_m_rf = data.eval("market - riskfree").to_numpy()
ticker_idx = data["ticker"].map(lambda x: ticker_2_int[x]).to_numpy()

In [None]:
ticker_2_int["AAPL"]

In [None]:
int_2_ticker[3]

In [None]:
ticker_idx

In [None]:
def capm_amnesia(rm_m_rf, ri_m_rf, ticker_idx, ntickers):
    with numpyro.plate("ticker_plate", ntickers):    
        beta_i = numpyro.sample(
            "beta_i", 
            dist.Normal(loc=0.0, scale=3.0)
        )
    
        sigma_i = numpyro.sample(
            "sigma_i",
            dist.HalfNormal(4.0)
        )

    # The mean (loc) for the Normal distribution.
    indexed_beta = beta_i[ticker_idx]
    mu = indexed_beta * rm_m_rf 

    # Select the correct sigma (scale) for each observation using the index array.
    # indexed_sigma has shape (N,)
    indexed_sigma = sigma_i[ticker_idx]

    # 4. Likelihood
    # This defines the likelihood for all N observations simultaneously.
    numpyro.sample(
        "ll", 
        dist.Normal(loc=mu, scale=indexed_sigma), 
        obs=ri_m_rf
    )


# Random keys for JAX
rng_key = jax.random.PRNGKey(20251203)

capm_amnesia_kernel = NUTS(capm_amnesia)
capm_amnesia_mcmc = MCMC(capm_amnesia_kernel, num_warmup=1000, num_samples=2000)
capm_amnesia_mcmc.run(
    rng_key,
    ri_m_rf=ri_m_rf, rm_m_rf=rm_m_rf,
    ticker_idx=ticker_idx, ntickers=ntickers
)

capm_amnesia_posterior_samples = capm_amnesia_mcmc.get_samples()

In [None]:
# Very few stocks have a beta <0... Do we actually believe our ZM estimate?
(capm_amnesia_posterior_samples["beta_i"].mean(axis=0) < 0).mean()

In [None]:
fig, ax = plt.subplots()

ax.hist(
    capm_amnesia_posterior_samples["beta_i"].mean(axis=0),
    bins=[-0.5, -0.25, 0.0, 0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0],
    density=True
)

##### CAPM hierarchical model

The purpose of the hierarchical model is to allow groups of observations to learn from one another -- We have, at most, 60 observations for each of our stocks because we are computing the 5 year beta (similar to what Yahoo Finance reports).

To help with this, we are going to introduce two new notions:

* A _hyperparameter_ is a parameter that's an input to a prior. For example, in our previous example we specified $\beta_i \sim N(0, 5)$ so 0 and 5 were hyperparameters.
* A _hyperprior_ is a prior on a hyperparameter

Hyperpriors will be a central feature of hierarchical models and they will be used to group observations. In our example, we originally wrote the following model

\begin{align*}
   r_{i, t} - r_{f, t} &= \beta_i (r_{m, t} - r_{f, t}) + \sigma_i \varepsilon_{i, t}\\
   \beta_i &\sim N(0, 5) \\
   \sigma_i &\sim \text{HalfNormal}(4)
\end{align*}

A hierarchical version of the model might be specified as

\begin{align*}
   r_{i, t} - r_{f, t} &= \beta_i (r_{m, t} - r_{f, t}) + \sigma_i \varepsilon_{i, t}\\
   \sigma_i &\sim \text{HalfNormal}(4) \\
   \beta_i &\sim N(\hat{\mu}_j, \hat{\sigma}_j) \\
   \hat{\mu}_j &\sim N(0, 5) \\
   \hat{\sigma}_j &\sim \text{HalfNormal}(4)
\end{align*}

where $j$ could indicate the GICS sector (`gics`) that $i$ is identified by.

In [None]:
data.head()

In [None]:
# Ticker/subindustry
_tick_sect = data.loc[:, ["ticker", "gics"]]

# Note -- Duplicated tells us True anywhere that is a duplicate, except
# for the first entry -- This gives us only the unique ticker/gics
tick_sect = _tick_sect.loc[~_tick_sect.duplicated(keep="first"), :]

tickers = tick_sect["ticker"].to_numpy()
ntickers = tickers.shape[0]
sects = tick_sect["gics"].unique()
nsect = sects.shape[0]

# Mappings
ticker_2_int = dict(zip(tickers, range(ntickers)))
int_2_ticker = {v: k for k, v in ticker_2_int.items()}  # Only reverse when unique
sect_2_int = dict(zip(sects, range(nsect)))
int_2_sect = {v: k for k, v in sect_2_int.items()}  # Only reverse when unique
ticker_2_sect = dict(
    zip(
        tick_sect["ticker"].map(ticker_2_int).to_numpy(),
        tick_sect["gics"].map(sect_2_int).to_numpy()
    )
)

# Data
ri_m_rf = data.eval("returns - riskfree").to_numpy()
rm_m_rf = data.eval("market - riskfree").to_numpy()
ticker_idx = data["ticker"].map(lambda x: ticker_2_int[x]).to_numpy()
sect_idx = np.array([ticker_2_sect[x] for x in range(ntickers)])

In [None]:
def capm_hierarchical(
    rm_m_rf,
    ri_m_rf,
    ticker_idx,
    sect_idx,
    ntickers,
    nsect
):
    with numpyro.plate("group_prior_plate", nsect):
        # Hyperprior for the mean of Beta in each sector (mu_hat)
        mu_hat = numpyro.sample(
            "mu_hat", 
            dist.Normal(0.0, 3.0)
        )
    
        # Hyperprior for the standard deviation of Beta in each sector (sigma_hat)
        sigma_hat = numpyro.sample(
            "sigma_hat", 
            dist.HalfNormal(
                4.0
            )
        )

    # We use a plate for the ntickers dimension.
    with numpyro.plate("ticker_plate", ntickers):
        # Beta Prior: pm.Normal("beta_i", muhat[_sect_idx], sigmahat[_sect_idx], shape=ntickers)
        beta_i = numpyro.sample(
            "beta_i", 
            dist.Normal(
                loc=mu_hat[sect_idx], 
                scale=sigma_hat[sect_idx]
            )
        )
        
        # Sigma Prior: pm.HalfCauchy("sigma_i", 5, shape=ntickers)
        sigma_i = numpyro.sample(
            "sigma_i",
            dist.HalfNormal(4)
        )

    # Calculate the mean (mu) and scale (sigma) for every observation using the ticker index (ticker_idx)
    mu = beta_i[ticker_idx] * rm_m_rf
    scale = sigma_i[ticker_idx]
    
    # Define the likelihood for all observations simultaneously.
    numpyro.sample(
        "ll", 
        dist.Normal(loc=mu, scale=scale), 
        obs=ri_m_rf
    )

# Random keys for JAX
rng_key = jax.random.PRNGKey(20251203)

kernel = NUTS(capm_hierarchical)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000)
mcmc.run(
    rng_key_infer,
    ri_m_rf=ri_m_rf, rm_m_rf=rm_m_rf,
    ticker_idx=ticker_idx, sect_idx=sect_idx,
    ntickers=ntickers, nsect=nsect
)

posterior_samples = mcmc.get_samples()

In [None]:
ticker_2_int["ZM"]

In [None]:
posterior_samples["beta_i"][:, 503].mean()

In [None]:
amnesia_means = capm_amnesia_posterior_samples["beta_i"].mean(axis=0)
hierarchial_means = posterior_samples["beta_i"].mean(axis=0)

In [None]:
amnesia_means.shape

In [None]:
want = (sect_idx == 10)

In [None]:
data.query("ticker == 'ZM'").shape

In [None]:
int_2_sect[10]

In [None]:
fig, ax = plt.subplots()

ax.scatter(np.arange(505)[want], amnesia_means[want], c="b")
ax.scatter(np.arange(505)[want], hierarchial_means[want], c="r")