In [7]:
import numpyro
import numpyro.distributions as dist
import jax.numpy as jnp
import numpy as np
import pandas

from data.election88 import data

numpyro.set_platform('cpu')
numpyro.set_host_device_count(10)

state = np.array(data["state"])-1
state_onehot = jnp.array(np.eye(data["n_state"])[state-1])
black = jnp.array(data["black"])
female = jnp.array(data["female"])
y = jnp.array(data["y"])

In [11]:
def election(black, female, state_onehot, y):
    mean_a = numpyro.sample("mean_a", dist.Normal(0,100))
    log_scale_a = numpyro.sample("log_scale_a", dist.Normal(0,10))

    with numpyro.plate('state_param', state_onehot.shape[1]):
        a = numpyro.sample("a", dist.Normal(mean_a, jnp.exp(log_scale_a)))
    
    b_black = numpyro.sample("b_black", dist.Normal(0,100))
    b_female = numpyro.sample("b_female", dist.Normal(0,100))

    logits = jnp.dot(state_onehot,a) + b_black*black + b_female*female

    with numpyro.plate('observation', len(y)):
        return numpyro.sample('result', dist.Bernoulli(logits = logits),obs=y)

In [21]:
from jax import random
from numpyro.infer import MCMC, NUTS 

from numpyro.handlers import reparam
from numpyro.infer.reparam import LocScaleReparam

nuts_kernel = NUTS(election)

mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=5000 ,num_chains=10)

rng_key = random.PRNGKey(0)

mcmc.run(rng_key, black, female, state_onehot, y, extra_fields=('num_steps',))

mcmc.print_summary()



reparam_model = reparam(election, config={"a": LocScaleReparam(0)})

nuts_kernel2 = NUTS(reparam_model)

mcmc2 = MCMC(nuts_kernel2, num_warmup=1000, num_samples=5000 ,num_chains=10)

rng_key2 = random.PRNGKey(0)

mcmc2.run(rng_key, black, female, state_onehot, y, extra_fields=('num_steps',))


mcmc2.print_summary()

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]


                   mean       std    median      5.0%     95.0%     n_eff     r_hat
         a[0]      0.43      0.43      0.43     -0.27      1.15  60889.18      1.00
         a[1]      0.51      0.15      0.51      0.27      0.76  63997.99      1.00
         a[2]      0.45      0.19      0.45      0.14      0.76  63561.88      1.00
         a[3]      0.44      0.06      0.44      0.34      0.54  54561.46      1.00
         a[4]      0.50      0.16      0.50      0.25      0.77  58873.09      1.00
         a[5]      0.27      0.16      0.27     -0.01      0.54  62102.14      1.00
         a[6]      0.02      0.27      0.03     -0.41      0.48  57821.70      1.00
         a[7]     -0.15      0.39     -0.13     -0.77      0.49  51154.72      1.00
         a[8]      0.73      0.09      0.73      0.58      0.86  56249.67      1.00
         a[9]      0.64      0.13      0.64      0.43      0.86  61233.54      1.00
        a[10]      0.43      0.43      0.44     -0.28      1.14  61848.09  

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]


                      mean       std    median      5.0%     95.0%     n_eff     r_hat
 a_decentered[0]     -0.00      1.00     -0.00     -1.62      1.67  63488.70      1.00
 a_decentered[1]      0.18      0.39      0.18     -0.46      0.81  30664.23      1.00
 a_decentered[2]      0.04      0.46      0.04     -0.72      0.80  35853.86      1.00
 a_decentered[3]      0.01      0.21      0.01     -0.33      0.35  10716.40      1.00
 a_decentered[4]      0.17      0.41      0.17     -0.49      0.85  30850.99      1.00
 a_decentered[5]     -0.39      0.42     -0.39     -1.05      0.31  32699.73      1.00
 a_decentered[6]     -0.97      0.65     -0.97     -2.06      0.06  49392.70      1.00
 a_decentered[7]     -1.36      0.87     -1.35     -2.80      0.05  61329.70      1.00
 a_decentered[8]      0.70      0.26      0.69      0.27      1.13  14250.55      1.00
 a_decentered[9]      0.50      0.34      0.50     -0.06      1.05  22559.22      1.00
a_decentered[10]     -0.00      1.00     -

In [17]:
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from numpyro.optim import Adam

optimizer = Adam(step_size = 0.05)

learnable_model = reparam(election, config={"mean_a": LocScaleReparam(),"log_scale_a": LocScaleReparam(),"a": LocScaleReparam(), "b_black": LocScaleReparam(), "b_female": LocScaleReparam()})

# setup the inference algorithm
svi = SVI(learnable_model, AutoNormal(learnable_model), optimizer, loss=Trace_ELBO())

# do gradient steps
svi_result = svi.run(random.PRNGKey(0), 20000, black, female, state_onehot, y)
params = svi_result.params

100%|██████████| 20000/20000 [00:32<00:00, 618.88it/s, init loss: 83676.1641, avg. loss [19001-20000]: 7605.6636]


In [18]:
params

{'a_centered': DeviceArray(0.01152046, dtype=float32),
 'a_decentered_auto_loc': DeviceArray([ 0.16957209,  0.2497343 ,  0.00467953,  0.07183243,
               0.36429596, -0.3674008 , -1.0734462 , -1.4650859 ,
               0.7028089 ,  0.6113871 ,  0.01626012, -0.18858299,
              -0.21062893,  1.0230243 , -1.6301727 ,  0.91783285,
               0.20873275,  1.7495271 , -0.68199146, -0.30577976,
              -1.1665448 ,  0.11208902, -0.7536845 ,  2.2065012 ,
              -0.82361096, -1.0587468 ,  0.18874536,  0.35023004,
               0.34948853, -0.08739155, -0.35708398, -1.062291  ,
               0.51508826, -0.02952179,  0.4292868 ,  0.04279809,
              -0.4950164 , -0.44563857, -2.0882251 ,  1.570776  ,
              -0.04335152,  1.3176328 ,  0.03253926,  1.4917134 ,
               0.06242191,  1.5670143 , -0.9344607 , -0.8193793 ,
              -0.59759915,  0.02053729,  1.6927997 ], dtype=float32),
 'a_decentered_auto_scale': DeviceArray([1.0786226 , 0.315

In [19]:
nuts_kernel3 = NUTS(learnable_model)

mcmc3 = MCMC(nuts_kernel3, num_warmup=1000, num_samples=5000 ,num_chains=10)

rng_key3 = random.PRNGKey(0)

mcmc3.run(rng_key3, black, female, state_onehot, y, extra_fields=('num_steps',))

mcmc3.print_summary()

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]

  0%|          | 0/6000 [00:00<?, ?it/s]


                              mean       std    median      5.0%     95.0%     n_eff     r_hat
         a_decentered[0]      0.21      0.65      0.21     -0.86      1.28 110148.71      1.00
         a_decentered[1]      0.33      0.24      0.33     -0.06      0.72  80524.11      1.00
         a_decentered[2]      0.25      0.29      0.24     -0.23      0.71  86338.23      1.00
         a_decentered[3]      0.22      0.11      0.22      0.04      0.41  32927.29      1.00
         a_decentered[4]      0.33      0.25      0.33     -0.08      0.75  83825.05      1.00
         a_decentered[5]     -0.04      0.26     -0.04     -0.46      0.38  88796.06      1.00
         a_decentered[6]     -0.41      0.41     -0.41     -1.08      0.28 100273.00      1.00
         a_decentered[7]     -0.67      0.57     -0.66     -1.62      0.25  98837.29      1.00
         a_decentered[8]      0.67      0.15      0.67      0.43      0.91  45417.31      1.00
         a_decentered[9]      0.54      0.21     