# "MCMC From Scratch III: Convergence"
> "Exploring a metric for convergence of chains in MCMC"

In the first notebook, we introduced the Eight schools problem and approached it using Bayesian statistics. The fundamental question was "how do we adjust the effect estimate for school A (which was very high) in light of what we know about the other schools?" This amounted to building a model with the parameter `t`, the "true" effect size per school. Our model also included hyperparameters `mu` and `sigma`.

In the second notebook, we introduced Markov Chain Monte Carlo (MCMC) methods, and gave the original formulation in the form of the Metropolis-Hastings algorithm. We ended that notebook with the question of convergence: how can we tell the accuracy of our estimates? This notebook explores several issues with convergence, and provides a method for monitoring it, straight out of [Bayesian Data Analysis](http://www.stat.columbia.edu/~gelman/book/).

In [1]:
import numpy as np
import altair as alt
import pandas as pd
from scipy.stats import norm, uniform
from tqdm import tqdm

from mcmc import json_dir, simulate, simulate_multiple, generate_sample, make_2d_histogram, transition_MH, visualize_simulation

alt.data_transformers.register('json_dir', json_dir)
alt.data_transformers.enable('json_dir', data_dir='/altairdata')

DataTransformerRegistry.enable('json_dir')

In [2]:
effect_estimates = np.array([28, 8, -3, 7, -1, 1, 18, 12])
std_estimates = np.array([15, 10, 16, 11, 9, 11, 10, 18])
school_names = ["A", "B", "C", "D", "E", "F", "G", "H"]

df = pd.DataFrame({"effect_estimate": effect_estimates,
                   "std_estimate": std_estimates,
                  "school": school_names})


def p_prop(q):
    *t, mu, sigma = q
    p_mu = norm.pdf(mu, loc=8.75, scale=20)
    p_sigma = uniform.pdf(sigma, loc=0, scale=100)
    p_t_mu_sigma = norm.pdf(t, loc=mu, scale=sigma)
    p_ee_t = norm.pdf(effect_estimates, loc=t, scale=std_estimates).prod()
    return p_ee_t * p_t_mu_sigma * p_mu * p_sigma

# Problem illustration: far away starting point
Up until now we've always been starting our Markov chains at (0, 0). But what happens if the distribution lies far away from that point? We will find a Markov chain with a very long starting chain, not really getting anywhere for the first part of the chain.

In [3]:
def p_prop_start(q):
    x, y = q
    return norm.pdf(x, loc=25)*norm.pdf(y, loc=25)

bins = np.linspace(22, 28, 100)
sample = generate_sample(p_prop_start, bins, bins)
true_density = alt.Chart(sample).mark_point().encode(x="x0", y="x1", color="p")

transition = lambda current: transition_MH(current, p_prop_start)
df = simulate((0, 0), transition, n_iter=2000)
simulation = alt.Chart(df).mark_point().encode(x="x0", y="x1")

true_density + simulation

100%|██████████| 10000/10000 [00:02<00:00, 4663.33it/s]
0it [00:00, ?it/s]
100%|██████████| 1999/1999 [00:01<00:00, 1876.03it/s]


In [4]:
df = simulate((0, 0), transition, n_iter=2000)
alt.Chart(df).mark_point().encode(x="x0", y="x1")

0it [00:00, ?it/s]
100%|██████████| 1999/1999 [00:01<00:00, 1643.58it/s]


In [5]:
alt.Chart(df).mark_bar().encode(x=alt.X("x0", bin=alt.Bin(maxbins=50)), y="count()")

That doesn't seem very good: the first part of the chain is mostly just walking towards the density, and thus aren't really draws from the probability distribution. We could of course run our Markov chain for longer to get rid of the effect of the initial "walk-in" period. However, even with our Eight schools model this already starts to take quite some time. A better idea is to do some warmup. Instead of taking many more samples, we simply discard some in the beginning.

In [6]:
def simulate(initial, transition, n_iter=100, n_warmup=50):
    current = initial
    for _ in range(n_warmup):
        current = transition(current)
    
    result = []
    for _ in range(n_iter):
        current = transition(current)
        result.append(current)
    return result
    
simulate((10, 10), transition, n_iter=10, n_warmup=5)

[array([10.14809936, 10.1056201 ]),
 array([10.3425135 ,  9.91510501]),
 array([10.34463992, 10.06788861]),
 array([10.47276565,  9.89004561]),
 array([10.40585508,  9.93804762]),
 array([10.36900128,  9.96813257]),
 array([10.36900128,  9.96813257]),
 array([10.43219288, 10.10059745]),
 array([10.51833388, 10.21735884]),
 array([10.51833388, 10.21735884])]

In [7]:
def simulate(initial, transition, n_iter=100, n_warmup=0):
    # Warmup
    current = initial
    for _ in tqdm(range(n_warmup)):
        current = transition(current)
        
    # Simulation
    result = [current]
    for _ in tqdm(range(n_iter-1)):
        current = transition(current)
        result.append(current)
        
    # Bookkeeping
    n_dim = len(initial)
    result = pd.DataFrame(result,
                        columns=[f"x{i}" for i in range(n_dim)],
                         index=range(n_iter))
    result.index.name = "i"
    return result

df = simulate((0, 0), transition, n_iter=1000, n_warmup=1000)
simulation = alt.Chart(df).mark_point().encode(x="x0", y="x1")

true_density + simulation

100%|██████████| 1000/1000 [00:00<00:00, 1601.42it/s]
100%|██████████| 999/999 [00:00<00:00, 1504.87it/s]


That already looks a lot better. It's common to warm up with the same number of samples as the actual iteration. This is a conservative metric, but works well in practice.

# Problem illustration: multiple modes
Having a warmup period doesn't solve all of our problems. If a distribution has multiple modes (that is, multiple "peaks"), then even with a correctly initialized Markov chain and warmup, we might not explore the true distribution very well. An example with two modes:

In [8]:
def p_prop_bimodal(q):
    x, y = q
    return norm.pdf(x)*(norm.pdf(y, loc=-3) + norm.pdf(y, loc=3))

bins = np.linspace(-6, 6, 150)
sample = generate_sample(p_prop_bimodal, bins, bins)
true_density = alt.Chart(sample).mark_point().encode(x="x0", y="x1", color="p")

true_density

100%|██████████| 22500/22500 [00:07<00:00, 2934.33it/s]


We have two peaks in this example. One centered on (0, 3), and one centered on (0, -3) (reasonable to expect given the form of our probability proportion. Now let's see what a Markov chain started in (0, -3) does after 10000 (!) iterations.

In [9]:
transition = lambda current: transition_MH(current, p_prop_bimodal, scale=0.05)
df = simulate((0, -3), transition, n_iter=10_000)
make_2d_histogram(df, bins, bins)

0it [00:00, ?it/s]
100%|██████████| 9999/9999 [00:07<00:00, 1340.50it/s]


Uh-oh, it only saw one peak! This is a problem because we cannot detect it from the Markov chain alone.

An obvious mitigating measure to this problem is to use multiple chains, starting from multiple points in the space. Using four chains, starting from (-4, -4), (-4, 4), (4, 4), and (4, -4) for 2500 steps each, we recover the correct distribution.

In [10]:
df = simulate_multiple([(-4, -4), (-4, 4), (4, 4), (4, -4)], transition, n_iter=1000, n_warmup=1000)
hist = make_2d_histogram(df, bins, bins)

hist

100%|██████████| 1000/1000 [00:00<00:00, 1175.03it/s]
100%|██████████| 999/999 [00:00<00:00, 1356.99it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1376.22it/s]
100%|██████████| 999/999 [00:00<00:00, 1423.46it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1376.43it/s]
100%|██████████| 999/999 [00:00<00:00, 1385.51it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1372.46it/s]
100%|██████████| 999/999 [00:00<00:00, 1331.06it/s]


So how do we choose starting points? From our prior information, we typically already have a good idea of the *range* of parameters. We pick our starting values spread around, near the edges of the ranges. This ensures we cover most of the space.

# Problem illustration: bad configuration
But our problems don't end there. Even if we have multiple dispersed starting points, if we take too big steps (we have too large a `scale` parameter), *even if we start smack dab in the middle of the distribution*, we will reject very many proposals because they lie too far outside the distribution.

In [11]:
def p_prop_peaked(q):
    x, y = q
    return norm.pdf(x, loc=0, scale=1) * norm.pdf(y, loc=0, scale=1)


bins = np.linspace(-3, 3, 150)

sample = generate_sample(p_prop_peaked, bins, bins)
sample.p = sample.p
true_density = alt.Chart(sample).mark_point().encode(x="x0", y="x1", color="p")

true_density

100%|██████████| 22500/22500 [00:04<00:00, 4846.09it/s]


In [12]:
def transition(current, p_prop, scale=1):
    proposal = norm.rvs(loc=current, scale=scale)
    u = uniform.rvs()
    return proposal if p_prop(proposal) / p_prop(current) >= u else current

In [13]:
def make_2d_histogram(df, x_bins, y_bins):
    """
    Create a 2d histogram from a simulation dataframe
    
    :param df: DataFrame with columns x0, x1, encoding points
    :param x_bins: iterable of bin edges for the x-axis
    :param y_bins: iterable of bin edges for the y-axis
    :return: an Altair chart with a histogram of the points
    """
    n_x, n_y = len(x_bins)-1, len(y_bins)-1
    H, _, _ = np.histogram2d(df.x0, df.x1, bins=[x_bins, y_bins])
    binned = pd.DataFrame([(x_bins[i], y_bins[j], H[i, j]) for i in range(n_x) for j in range(n_y)],
                  columns=("x0", "x1", "p"))
    binned.p = binned.p / binned.p.sum()
    return alt.Chart(binned).mark_point().encode(x="x0", y="x1", color="p", tooltip=["x0", "x1"])

In [14]:
transition = lambda current: transition_MH(current, p_prop_peaked, scale=10)
df = simulate((-0.5, 0), transition, n_iter=1000)
hist = make_2d_histogram(df, bins, bins)
hist

0it [00:00, ?it/s]
100%|██████████| 999/999 [00:00<00:00, 1994.99it/s]


We see that our histogram is very sparse. Indeed, we don't have that many distinct points.

In [15]:
len(df[["x0", "x1"]].drop_duplicates())

31

The solution is obvious: make your `scale` parameter smaller. However, how do you diagnose it? One obvious diagnostic is the trace plot: it traces out how a parameter evolves over time. For instance, for `x0` in our current example. We see that we have very large flat lines, indicating our chain isn't exploring the space very well. This is a good reason to set the scale parameter a bit smaller.

In [16]:
alt.Chart(df.reset_index()).mark_line().encode(x="i", y="x0")

In [17]:
alt.Chart(df.reset_index()).mark_line().encode(x="i", y="x0") | alt.Chart(df.reset_index()).mark_line().encode(x="i", y="x1")

The histogram is looking a lot better with the smaller scale parameters.

In [18]:
transition = lambda current: transition_MH(current, p_prop_peaked, scale=1)
df = simulate((0, 0), transition, n_iter=1000)
hist = make_2d_histogram(df, bins, bins)
hist

0it [00:00, ?it/s]
100%|██████████| 999/999 [00:00<00:00, 1778.11it/s]


And the individual parameter plots look like "hairy caterpillars". They seem to be exploring their space well, having developed some sort of stable distribution. We call this property "stationarity".

In [19]:
alt.Chart(df.reset_index()).mark_line().encode(x="i", y="x0") | alt.Chart(df.reset_index()).mark_line().encode(x="i", y="x1")

However, the opposite problem may happen as well. We might be taking too small steps (setting `scale` too low). In this case we don't even reach the real distribution.

In [20]:
transition = lambda current: transition_MH(current, p_prop_peaked, scale=0.01)
df = simulate_multiple([(-3, -3), (-3, 3), (3, 3), (3, -3)], transition, n_iter=1000, n_warmup=1000)
hist = make_2d_histogram(df, bins, bins)
hist

100%|██████████| 1000/1000 [00:00<00:00, 1614.58it/s]
100%|██████████| 999/999 [00:00<00:00, 1854.81it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1790.78it/s]
100%|██████████| 999/999 [00:00<00:00, 1866.50it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1820.93it/s]
100%|██████████| 999/999 [00:00<00:00, 1999.77it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1759.00it/s]
100%|██████████| 999/999 [00:00<00:00, 1878.40it/s]


Note that the individual trace plots look okay. A bit low in variation, maybe, but they *are* walking around. However, if we compare them together, we see that they really haven't *mixed*. Each chain appears to be stuck in its own little world. This could be a reason to set the scale parameter a bit higher.

In [21]:
alt.Chart(df[df.simulation==0]).mark_line().encode(x="i", y="x1").interactive()

In [22]:
alt.Chart(df).mark_line().encode(x="i", y="x0", color="simulation:N") | alt.Chart(df).mark_line().encode(x="i", y="x1", color="simulation:N")

The histogram, again, looks a lot better.

In [23]:
transition = lambda current: transition_MH(current, p_prop_peaked, scale=1)
df = simulate_multiple([(-3, -3), (-3, 3), (3, 3), (3, -3)], transition, n_iter=1000, n_warmup=1000)
hist = make_2d_histogram(df, bins, bins)
hist

100%|██████████| 1000/1000 [00:00<00:00, 1711.35it/s]
100%|██████████| 999/999 [00:00<00:00, 1939.46it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1853.70it/s]
100%|██████████| 999/999 [00:00<00:00, 1934.27it/s]
100%|██████████| 1000/1000 [00:00<00:00, 2034.15it/s]
100%|██████████| 999/999 [00:00<00:00, 1830.62it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1960.98it/s]
100%|██████████| 999/999 [00:00<00:00, 1650.19it/s]


And look at the trace plots! They are pretty much interchangeable, in spite of having wildly different starting points.

In [24]:
alt.Chart(df).mark_line().encode(x="i", y="x0").facet(facet="simulation:N", columns=2)

# A systematic monitoring of convergence

Looking at trace plots is all well and good, but how hairy should a caterpillar be? And how similar should mixed trace plots be? A systematic way of looking at this is through (co-)variance. If we split a sequence in two and calculate the covariance between the two halves, we can see whether it's distribution is stable (they would be highly correlated). Likewise, we can compare covariance between the chains with different starting points to see if they have mixed (their correlation should also be high).

To implement this, let's take our bimodal example and run 4 chains. Each chain we will split in two, so we get 8 "pseudo-chains".

In [25]:
bins = np.linspace(-6, 6, 150)

transition = lambda current: transition_MH(current, p_prop_bimodal, scale=0.5)
df = simulate_multiple([(-4, -4), (-4, 4), (4, 4), (4, -4)], transition, n_iter=1000, n_warmup=1000)
hist = make_2d_histogram(df, bins, bins)

hist

100%|██████████| 1000/1000 [00:00<00:00, 1327.58it/s]
100%|██████████| 999/999 [00:00<00:00, 1180.08it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1220.09it/s]
100%|██████████| 999/999 [00:00<00:00, 1260.69it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1346.56it/s]
100%|██████████| 999/999 [00:00<00:00, 1357.99it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1450.81it/s]
100%|██████████| 999/999 [00:00<00:00, 1381.01it/s]


In [26]:
def split_chain(chain):
    n_iter = chain.i.max()+1
    first_half, second_half = chain[chain.i < n_iter / 2], chain[chain.i >= n_iter / 2]
    
    # Update chain number
    second_half.simulation = second_half.simulation + 0.5
    second_half.i = second_half.i - int(n_iter / 2)
    
    return pd.concat([first_half, second_half])
    
df = df.groupby("simulation").apply(split_chain).reset_index(drop=True)
df.sample(10)

Unnamed: 0,simulation,i,x0,x1
1736,1.5,236,0.933915,-3.274473
806,0.5,306,0.677692,-2.330731
3161,3.0,161,-0.386769,-3.353522
1761,1.5,261,-1.568762,-2.479262
1945,1.5,445,-0.951129,-3.600105
3248,3.0,248,0.703267,-3.201291
2084,2.0,84,-0.536146,-1.436394
2368,2.0,368,-0.387418,-1.797777
3372,3.0,372,-0.049361,-3.587052
2855,2.5,355,0.843897,-3.111331


We calculate the within-chain variance `W` by simply averaging the variances of each of the chains.

$$\mathrm{var}(x) = \frac{1}{n-1} \sum_{i=1}^{n} (x_i - \bar{x})^2$$

In [27]:
parameter = "x1"

chain = df[df.simulation==3.5]

def variance(chain):
    n = len(chain)
    return 1 / (n - 1) * ((chain - chain.mean())**2).sum()
    
def W(df, parameter):
    return df.groupby("simulation")[parameter].apply(variance).mean()

W(df, parameter)

0.741995211362141

We can think of the between-chain variance `B` as the variance we would obtain if every chain was stuck at one value, its own mean. `B` is then the variance of the means, multiplied by the number of iterations per chain.

In [28]:
def B(df, parameter):
    n_iter = df.i.max()
    mean_chains = df.groupby("simulation")[parameter].mean()
    return variance(mean_chains) * n_iter

B(df, parameter)

24.645279080700863

To estimate the variance of the total distribution `varhat`, we take the weighted average of `W` and `B`. The weights are chosen such that, while `W` is an underestimate of the total variance, the between-variance estimate corrects for it, while still retaining all the required convergence properties.

In [29]:
def varhat(df, parameter):
    n_iter = df.i.max()
    return ((n_iter - 1) * W(df, parameter) + B(df, parameter)) / n_iter

varhat(df, parameter), W(df, parameter)

(0.7898975838457857, 0.741995211362141)

But estimating the total variance of the distribution we're sampling from doesn't tell us enough: it might be anything, really. Instead, what we look at is the ratio between the total variance and the within-chain variance. This "standardizes" the variance and gives us a dimensionless quantity to monitor. We call it `Rhat` (or $\widehat{R}$).

In [30]:
def Rhat(df, parameter):
    return np.sqrt(varhat(df, parameter) / W(df, parameter))

Rhat(df, parameter), 

(1.0317746217682062,)

This isn't a very efficient implementation. Let's improve it a tad. 

In [31]:
def Rhat(df, parameter):
    n_iter = df.i.max()
    n_chains = df.simulation.nunique()
    
    chains = df.groupby("simulation")[parameter]
    
    W = chains.apply(lambda c: ((c - c.mean())**2).sum() / (n_iter - 1)).mean()
    
    chain_means = chains.mean()
    B = chain_means.apply(lambda c_mean: (c_mean - chain_means.mean())**2).sum() / (n_chains - 1)
    
    return np.sqrt((n_iter - 1)/n_iter + B / W)
    
Rhat(df, "x1")

1.0317099774566303

Of course we can plot this metric over time. We see that, as the chains come closer together, the `Rhat` converges to 1. This means each chain is exploring the distribution adequately.

In [32]:
def convergence(df, params):
    a_Rhat = [[i] + [Rhat(df[df.i <= i], p) for p in ["x0", "x1"]] for i in tqdm(range(1, df.i.max()+1))]
    return pd.DataFrame(a_Rhat, columns=["i"] + [f"Rhat_{p}" for p in params])

df_Rhat = convergence(df, ["x0", "x1"])
alt.Chart(df_Rhat).mark_line().encode(x="i", y="Rhat_x0") | alt.Chart(df_Rhat).mark_line().encode(x="i", y="Rhat_x1")

  import sys
  import sys
100%|██████████| 499/499 [00:06<00:00, 76.17it/s]


Plotting `Rhat_x1` already gives us some more information! It appears `x0` is explored well, but `x1` not very. Let's draw some more samples.

In [33]:
bins = np.linspace(-6, 6, 150)

transition = lambda current: transition_MH(current, p_prop_bimodal, scale=0.5)
df = simulate_multiple([(-4, -4), (-4, 4), (4, 4), (4, -4)], transition, n_iter=5000, n_warmup=1000)
a_Rhat = [[i] + [Rhat(df[df.i <= i], p) for p in ["x0", "x1"]] for i in tqdm(range(1, df.i.max()+1))]
df_Rhat = pd.DataFrame(a_Rhat, columns=["i", "Rhat_x0", "Rhat_x1"])

alt.Chart(df_Rhat).mark_line().encode(x="i", y="Rhat_x0") | alt.Chart(df_Rhat).mark_line().encode(x="i", y="Rhat_x1")

100%|██████████| 1000/1000 [00:00<00:00, 1325.83it/s]
100%|██████████| 4999/4999 [00:03<00:00, 1370.48it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1380.16it/s]
100%|██████████| 4999/4999 [00:03<00:00, 1379.65it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1358.43it/s]
100%|██████████| 4999/4999 [00:03<00:00, 1335.44it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1442.58it/s]
100%|██████████| 4999/4999 [00:03<00:00, 1379.26it/s]
  import sys
100%|██████████| 4999/4999 [00:49<00:00, 101.22it/s]


In [34]:
hist = make_2d_histogram(df, bins, bins)
hist

# Summary
We have looked at several potential convergence problems with the Metropolis-Hastings.

- Starting too far away
- Multiple modes in the distribution
- Bad configuration: too big steps or too small steps

We also developed some mitigating measures for these problems:

- Warmup: discarding first half of the samples
- Multiple chains: starting at various points in the sample space
- Rhat: monitoring convergence in a systematic way

In the final notebook, we will apply these things to our Eight Schools problem.