In [None]:
from math import comb

import numpy as np
import torch

import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt

def average_likelihood(p,q,n=10):
    l_ps = []
    l_qs = []
    for n_head in range(n+1):
        l_p = p**n_head * (1-p)**(n-n_head) * comb(n,n_head)
        l_q = q**n_head * (1-q)**(n-n_head) * comb(n,n_head)
        l_ps += [l_p]
        l_qs += [l_q]
    l_ps = torch.tensor(l_ps)
    l_qs = torch.tensor(l_qs)

    likelihood = l_ps @ l_qs
    return likelihood

In [None]:
p = .5
qs = [.3,.6]
ns = [1,2,4,8,16,32,64,128,256,512,1024]
likelihoods = []

for q in qs:
    for n in ns:
        avg_l = average_likelihood(p,q,n=n).item()
        print(q,n,avg_l)
        likelihoods+=[{'q': q, 'n': n, 'avg_l': avg_l}]
likelihoods = pd.DataFrame(likelihoods)

likelihoods_per_n = likelihoods.set_index(['n','q']).unstack(1).to_numpy()
ppd = (likelihoods_per_n / likelihoods_per_n.sum(1)[:,None]) @ np.array(qs)

In [None]:
palette = sns.color_palette("muted")
sns.set_palette(palette)


fig, axs = plt.subplots(1,2, figsize=(10,5))

axs[0].plot(ns, ppd, color=palette[1], marker='o', label='Bayesian Prediction')
axs[0].axhline(y=p, color=palette[0], linestyle='--', label='True Probability')

axs[0].axhline(y=.3, color='gray', linestyle='--')
axs[0].axhline(y=.6, color='gray', linestyle='--')

axs[0].set_xlabel("Number of In-Context Examples (Coin Flips)")
axs[0].set_ylabel("Predicted Probability of Heads")



#axs[0].set_xscale('log')
axs[0].legend()
#plt.xlim(ns[0],ns[-1])

axs[1].plot(ns, likelihoods_per_n[:,0], color=palette[2], marker='o', label=f'Coin with p={qs[0]}')
axs[1].plot(ns, likelihoods_per_n[:,1], color=palette[3], marker='o', label=f'Coin with p={qs[1]}')

axs[1].set_yscale('log')
#axs[1].set_xscale('log')
axs[1].set_ylabel(f'Likelihood of Observed Data')
axs[1].set_xlabel("Number of In-Context Examples (Coin Flips)")

axs[1].legend()

# utils.tikzplotlib_save(f'figures/bad_prior_coin_flip_failure.tex', axis_width=r".4\textwidth", axis_height=r".4\textwidth")