In [7]:
import os

from IPython.display import set_matplotlib_formats
import matplotlib.pyplot as plt
import pandas as pd

from jax import random, vmap
import jax.numpy as jnp
from jax.scipy.special import logsumexp

import numpyro
from numpyro import handlers
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

plt.style.use("bmh")
if "NUMPYRO_SPHINXBUILD" in os.environ:
    set_matplotlib_formats("svg")

assert numpyro.__version__.startswith("0.15.0")

import numpyro


numpyro

<module 'numpyro' from '/home/vscode/.cache/pypoetry/virtualenvs/combustible-bayes-u3bLjQHq-py3.12/lib/python3.12/site-packages/numpyro/__init__.py'>

In [8]:
DATASET_URL = "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/WaffleDivorce.csv"
dset = pd.read_csv(DATASET_URL, sep=";")
dset

Unnamed: 0,Location,Loc,Population,MedianAgeMarriage,Marriage,Marriage SE,Divorce,Divorce SE,WaffleHouses,South,Slaves1860,Population1860,PropSlaves1860
0,Alabama,AL,4.78,25.3,20.2,1.27,12.7,0.79,128,1,435080,964201,0.45
1,Alaska,AK,0.71,25.2,26.0,2.93,12.5,2.05,0,0,0,0,0.0
2,Arizona,AZ,6.33,25.8,20.3,0.98,10.8,0.74,18,0,0,0,0.0
3,Arkansas,AR,2.92,24.3,26.4,1.7,13.5,1.22,41,1,111115,435450,0.26
4,California,CA,37.25,26.8,19.1,0.39,8.0,0.24,0,0,0,379994,0.0
5,Colorado,CO,5.03,25.7,23.5,1.24,11.6,0.94,11,0,0,34277,0.0
6,Connecticut,CT,3.57,27.6,17.1,1.06,6.7,0.77,0,0,0,460147,0.0
7,Delaware,DE,0.9,26.6,23.1,2.89,8.9,1.39,3,0,1798,112216,0.016
8,District of Columbia,DC,0.6,29.7,17.7,2.53,6.3,1.89,0,0,0,75080,0.0
9,Florida,FL,18.8,26.4,17.0,0.58,8.5,0.32,133,1,61745,140424,0.44


In [11]:
import altair as alt

vars = [
    "Population",
    "MedianAgeMarriage",
    "Marriage",
    "WaffleHouses",
    "South",
    "Divorce",
]

alt.Chart(dset).mark_circle().encode(
    alt.X(alt.repeat("column"), type='quantitative'),
    alt.Y(alt.repeat("row"), type='quantitative'),
    color='Origin:N'
).properties(
    width=150,
    height=150
).repeat(
    row=vars, column=vars
)

In [33]:
base = alt.Chart(dset, width=500).mark_circle(color="black").encode(
        alt.X("Marriage"), alt.Y("Divorce")
)

polynomial_fit = base.transform_regression(
        "Marriage", "Divorce", method="poly", order=1, as_=["Marriage", "Best fit"]
    ).mark_line().transform_fold(
        ["Best fit"], as_=["degree", "Divorce"]
    )

base + polynomial_fit

In [27]:
def model(marriage=None, age=None, divorce=None):
    a = numpyro.sample("a", dist.Normal(0.0, 0.2))
    M, A = 0.0, 0.0
    if marriage is not None:
        bM = numpyro.sample("bM", dist.Normal(0.0, 0.5))
        M = bM * marriage
    if age is not None:
        bA = numpyro.sample("bA", dist.Normal(0.0, 0.5))
        A = bA * age
    sigma = numpyro.sample("sigma", dist.Exponential(1.0))
    mu = a + M + A
    numpyro.sample("obs", dist.Normal(mu, sigma), obs=divorce)

In [34]:
# Start from this source of randomness. We will split keys for subsequent operations.
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

# Run NUTS.
kernel = NUTS(model)
num_samples = 4000
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
mcmc.run(
    rng_key_, marriage=dset.Marriage.values, divorce=dset.Divorce.values
)
mcmc.print_summary()
samples_1 = mcmc.get_samples()



sample: 100%|██████████| 5000/5000 [00:03<00:00, 1524.62it/s, 3 steps of size 6.85e-01. acc. prob=0.91]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
         a      0.10      0.20      0.10     -0.25      0.41   2096.30      1.00
        bM      0.47      0.02      0.47      0.44      0.49   2319.45      1.00
     sigma      2.03      0.21      2.01      1.68      2.34   2422.75      1.00

Number of divergences: 0


In [57]:
def plot_regression(x, y, y_mean, y_hpdi):
    results = pd.DataFrame({
        "Marriage rate": x,
        "Divorce rate": y,
        "Divorce rate (predicted)": y_mean,
        r"Divorce rate (5% confidence)": y_hpdi[0, :],
        r"Divorce rate (95% confidence)": y_hpdi[1, :],
    })
    base = alt.Chart(results)
    
    return alt.layer(
        base.mark_circle().encode(
            alt.X("Marriage rate", scale=alt.Scale(zero=False)),
            alt.Y("Divorce rate", scale=alt.Scale(zero=False), title="Divorce rate"),
        ),
        base.mark_line().encode(
            x="Marriage rate",
            y=alt.Y("Divorce rate (predicted)", title="Divorce rate"),
        ),
        alt.Chart(results).mark_area(opacity=0.3).encode(
            x="Marriage rate",
            y=alt.Y(r"Divorce rate (5% confidence)", title="Divorce rate"),
            y2=alt.Y2(r"Divorce rate (95% confidence)", title="Divorce rate"),
        )
    )



# Compute empirical posterior distribution over mu
posterior_mu = (
    jnp.expand_dims(samples_1["a"], -1)
    + jnp.expand_dims(samples_1["bM"], -1) * dset.Marriage.values
)

mean_mu = jnp.mean(posterior_mu, axis=0)
hpdi_mu = hpdi(posterior_mu, 0.9)
plot_regression(dset.Marriage.values, dset.Divorce.values, mean_mu, hpdi_mu).properties(
    title="Regression line with 90% CI",
    width=600, height=600
)