# "MCMC From Scratch IV: Eight Schools"
> "In which we try to solve the actual problem"

A long, long time ago we introduced the Eight Schools problem (in our first notebook) and the application of Bayesian statistics to it. Then we introduced Markov Chain Monte Carlo, and learned how to work with it, primarily by monitoring convergence and making adjustments as needed.

In this notebook, we'll apply the techniques from the previous notebooks to arrive at some insight for the Eight Schools problem.

In [1]:
import numpy as np
import altair as alt
import pandas as pd
from scipy.stats import norm, halfcauchy, uniform
from tqdm import tqdm

from mcmc import json_dir, simulate_multiple, generate_sample, transition_MH, convergence, Rhat

alt.data_transformers.register('json_dir', json_dir)
alt.data_transformers.enable('json_dir', data_dir='/altairdata')

DataTransformerRegistry.enable('json_dir')

In [2]:
effect_estimates = np.array([28, 8, -3, 7, -1, 1, 18, 12])
std_estimates = np.array([15, 10, 16, 11, 9, 11, 10, 18])
school_names = ["A", "B", "C", "D", "E", "F", "G", "H"]

data = pd.DataFrame({"effect_estimate": effect_estimates,
                   "std_estimate": std_estimates,
                  "school": school_names})

original = alt.Chart(data).mark_bar().encode(x="school", y="effect_estimate").properties(width=500)
original.configure_axisX(labelAngle=0)

# Implementation
We created the probability proportion in our first notebook. We have priors on `mu` and `sigma` that lead to the true effects `t`, that lead to the observed effect estimates `effect_estimates`.

Since the numbers coming out of this function, we use the logarithm of probability proportion instead, and modify our MH-sampler accordingly.

Additionally, it turns out we get a much better sampler if we reparametrize `t` to be `t = mu + eta*sigma` where `eta` is standard normal.

We initialize our transition function with a scale of 0.5.

In [3]:
def log_p_prop(q):
    assert len(q) == 10
    *eta, mu, sigma = q
    t = [mu + e*sigma for e in eta]
    
    if sigma <= 0: return -np.inf
    
    p_mu = np.log(norm.pdf(mu, loc=8.75, scale=20))
    p_sigma = np.log(halfcauchy.pdf(sigma, loc=0, scale=5))
    p_eta = np.log(norm.pdf(eta)).sum()
    p_ee_t = np.log(norm.pdf(effect_estimates, loc=t, scale=std_estimates)).sum()
    return p_ee_t + p_eta + p_mu + p_sigma

def transition_MH_log(current, log_p_prop, scale=1):
    proposal = tuple(norm.rvs(current, scale=scale))
    u = uniform.rvs()
    return proposal if log_p_prop(proposal) - log_p_prop(current) > np.log(u) else current

transition = lambda q: transition_MH_log(q, log_p_prop, scale=0.5)

We use 4 chains, picking starting points at extreme values.

In [4]:
draw_eta = lambda: norm.rvs(0, scale=5, size=8)
draw_mu = lambda: norm.rvs(loc=8.75, scale=20)
draw_sigma = lambda: halfcauchy.rvs(loc=0, scale=6)

initial_points = [np.hstack([draw_eta(), draw_mu(), draw_sigma()]) for _ in range(4)]
initial_points

[array([-6.04750549, -1.25484591, -3.22637901,  0.46608746,  0.35213615,
         0.44602285,  0.14927671,  2.3390834 , 27.71938233,  2.03730451]),
 array([-2.63713616, -3.73081925,  2.85571866,  5.39713329, -8.24680864,
         1.91387522, -5.31124558, -3.48525332, -7.18339837,  1.37825265]),
 array([ 5.75623003,  2.71323552,  7.45520769,  0.66492298,  3.23292559,
        -5.15362589,  2.20975691, -1.02537958, 25.62653566,  8.3254704 ]),
 array([ 3.43142137, -2.34022024,  2.25220227, -2.0212946 , -1.474567  ,
         1.0436941 , -3.95965747,  4.85534305,  9.7928197 ,  2.6873203 ])]

Since our Metropolis-Hastings sampler is fairly inefficient (or badly configured), We run these 4 chains for 5000 iterations each, warming up with 1000 iterations.

In [5]:
df = simulate_multiple(initial_points, transition, n_iter=5000, n_warmup=1000)
df = df.rename(columns={"x8": "mu", "x9": "sigma"})
df.head()

100%|██████████| 1000/1000 [00:01<00:00, 950.64it/s]
100%|██████████| 4999/4999 [00:04<00:00, 1111.93it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1130.26it/s]
100%|██████████| 4999/4999 [00:04<00:00, 1132.12it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1116.97it/s]
100%|██████████| 4999/4999 [00:04<00:00, 1143.02it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1153.04it/s]
100%|██████████| 4999/4999 [00:04<00:00, 1137.17it/s]


Unnamed: 0,simulation,i,x0,x1,x2,x3,x4,x5,x6,x7,mu,sigma
0,0,0,-0.246168,0.256682,-0.725178,-0.711941,-0.3489,-0.249246,-0.763866,1.404564,3.858413,0.705033
1,0,1,-0.246168,0.256682,-0.725178,-0.711941,-0.3489,-0.249246,-0.763866,1.404564,3.858413,0.705033
2,0,2,-0.246168,0.256682,-0.725178,-0.711941,-0.3489,-0.249246,-0.763866,1.404564,3.858413,0.705033
3,0,3,-0.357963,0.723806,-0.496617,-0.368671,-0.044607,0.084971,0.324904,1.861539,3.623143,1.076791
4,0,4,-0.920588,-0.322486,-0.277006,-0.139981,-0.366585,-0.001361,-0.283373,1.773508,3.742467,0.38413


In [6]:
params = [f"x{i}" for i in range(8)] + ["mu", "sigma"]
df_Rhat = convergence(df, params)

100%|██████████| 101/101 [00:31<00:00,  3.24it/s]


In [7]:
df_plot = df_Rhat.melt(id_vars="i")
alt.Chart(df_plot).mark_line().encode(x="i", y="value", tooltip="value").facet(facet="variable", columns=2)

Looks like we have some convergence going on. Let's look at the histograms!

In [8]:
cols = [f"x{i}" for i in range(8)]
df_result = df
for col in cols:
    df[col] = df.mu + df[col] * df.sigma

df_result.head()

Unnamed: 0,simulation,i,x0,x1,x2,x3,x4,x5,x6,x7,mu,sigma
0,0,0,3.684856,4.039382,3.347138,3.356471,3.612427,3.682686,3.319862,4.848677,3.858413,0.705033
1,0,1,3.684856,4.039382,3.347138,3.356471,3.612427,3.682686,3.319862,4.848677,3.858413,0.705033
2,0,2,3.684856,4.039382,3.347138,3.356471,3.612427,3.682686,3.319862,4.848677,3.858413,0.705033
3,0,3,3.237692,4.402531,3.08839,3.226162,3.575111,3.714638,3.972997,5.627631,3.623143,1.076791
4,0,4,3.388841,3.618591,3.636061,3.688696,3.601651,3.741944,3.633615,4.423725,3.742467,0.38413


In [9]:
alt.Chart(df.melt(id_vars=["i", "simulation"])).mark_bar().encode(
    x=alt.X("value", bin=alt.Bin(maxbins=50)),
    y="count()",
    tooltip=alt.Tooltip("value", bin=alt.Bin(maxbins=50))
).facet(
    facet="variable",
    columns=2,
)

Let's see how the effect estimates have shrunk.

In [10]:
df_estimates = pd.DataFrame(zip(school_names, df_result[cols].mean()), columns=["school", "effect_estimate"])

In [11]:
new = alt.Chart(df_estimates).mark_bar().encode(
    x="school",
    y=alt.Y("effect_estimate", scale=alt.Scale(domain=[-5, 30])),
    tooltip="effect_estimate",
).properties(
    width=500
).properties(title="After modelling")

(new | original.properties(title="Data")).configure_axisX(labelAngle=0)