# Demonstration of `numpyro-oop` usage



In [1]:
import pandas as pd
import numpyro
import numpyro.distributions as dist

from numpyro_oop import BaseNumpyroModel

  from .autonotebook import tqdm as notebook_tqdm


## Basic regression example

In [2]:
# example from https://num.pyro.ai/en/stable/tutorials/bayesian_regression.html

DATASET_URL = "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/WaffleDivorce.csv"
dset = pd.read_csv(DATASET_URL, sep=";")
dset.head()

Unnamed: 0,Location,Loc,Population,MedianAgeMarriage,Marriage,Marriage SE,Divorce,Divorce SE,WaffleHouses,South,Slaves1860,Population1860,PropSlaves1860
0,Alabama,AL,4.78,25.3,20.2,1.27,12.7,0.79,128,1,435080,964201,0.45
1,Alaska,AK,0.71,25.2,26.0,2.93,12.5,2.05,0,0,0,0,0.0
2,Arizona,AZ,6.33,25.8,20.3,0.98,10.8,0.74,18,0,0,0,0.0
3,Arkansas,AR,2.92,24.3,26.4,1.7,13.5,1.22,41,1,111115,435450,0.26
4,California,CA,37.25,26.8,19.1,0.39,8.0,0.24,0,0,0,379994,0.0


In [3]:
def standardize(x):
    return (x - x.mean()) / x.std()


dset["AgeScaled"] = dset.MedianAgeMarriage.pipe(standardize)
dset["MarriageScaled"] = dset.Marriage.pipe(standardize)
dset["DivorceScaled"] = dset.Divorce.pipe(standardize)

### Defining the model class

The basic idea of `numpyro-oop` is that the user should define a new class that inherits from `BaseNumpyroModel`. 
This new class mainly needs to define the `model` method.
Then all other sampling and prediction steps are handled by `numpyro-oop`, or related libraries (e.g. `arviz`).

Here's a demonstration of a basic `RegressionModel` class plus `model` definition for the data above.
See the [numpyro demo here](https://num.pyro.ai/en/stable/tutorials/bayesian_regression.html) for more.

In [4]:
class RegressionModel(BaseNumpyroModel):
    def model(self, data=None):
        a = numpyro.sample("a", dist.Normal(0.0, 0.2))
        bM = numpyro.sample("bM", dist.Normal(0.0, 0.5))
        M = bM * data.MarriageScaled.values
        sigma = numpyro.sample("sigma", dist.Exponential(1.0))
        mu = numpyro.deterministic("mu", a + M)
        numpyro.sample("obs", dist.Normal(mu, sigma), obs=data.DivorceScaled.values)

We can now instantiate this class, adding the dataset and a random seed.


In [5]:
m1 = RegressionModel(data=dset, seed=42)

We can also render a graphical representation of our model (requires `graphviz` package).

In [6]:
m1_graph = m1.render()

Now sample from the model (using the NUTS sampler, by default):

In [7]:
m1.sample()

  self.mcmc = MCMC(
sample: 100%|██████████| 2000/2000 [00:01<00:00, 1988.47it/s, 3 steps of size 8.93e-01. acc. prob=0.88]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2012.41it/s, 3 steps of size 7.06e-01. acc. prob=0.94]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2024.45it/s, 7 steps of size 6.50e-01. acc. prob=0.93]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2081.19it/s, 3 steps of size 8.94e-01. acc. prob=0.90]


In [8]:
m1.mcmc.print_summary(0.90)


                mean       std    median      5.0%     95.0%     n_eff     r_hat
         a      0.00      0.11      0.01     -0.17      0.19   4029.47      1.00
        bM      0.35      0.13      0.35      0.13      0.55   3909.32      1.00
     sigma      0.95      0.10      0.94      0.77      1.09   3576.33      1.00

Number of divergences: 0


In [9]:
preds = m1.predict()

  predictive = Predictive(


In [11]:
preds["mu"].mean(axis=0)

Array([ 9.8082935e-03,  5.4307252e-01,  1.9002507e-02,  5.7984954e-01,
       -9.1328040e-02,  3.1321731e-01, -2.7521229e-01,  2.7644044e-01,
       -2.2004701e-01, -2.8440651e-01,  1.8449832e-01,  4.4193631e-01,
        5.2468413e-01, -2.0165858e-01, -2.6968556e-02,  1.2933306e-01,
        1.8449832e-01,  1.9369255e-01,  4.6585146e-02, -6.0620391e-01,
       -1.6488174e-01, -3.9473706e-01, -3.3037758e-01, -4.4070810e-01,
       -7.2939612e-02, -1.3729911e-01, -1.4649333e-01, -4.5356981e-02,
       -3.1198916e-01, -4.8667917e-01,  2.8196722e-02, -3.0279493e-01,
        2.8196722e-02,  6.0743207e-01, -2.9360071e-01,  3.4079999e-01,
       -1.0971647e-01, -4.2231968e-01, -4.6829075e-01, -1.8327014e-01,
        6.1408355e-04, -6.3745402e-02,  1.2933306e-01,  8.7406427e-01,
       -3.3957177e-01,  3.7390932e-02,  1.2013883e-01,  1.9369255e-01,
       -2.6601809e-01,  9.7520059e-01], dtype=float32)

## Hierarchical regression example