In [None]:
!pip install -q datasets

In [None]:
from scipy.stats import dirichlet
from tqdm.notebook import tqdm
from datasets import load_dataset
import numpy as np
import gc


def bootstrap_CI(p, alpha=0.05, k=2000):
  """
    Computes the confidence interval of the mean using bootstrapping.
    Here the confidence interval is the 100*(1-alpha) central CI, from percentile 100*(alpha/2) to 100*(1-alpha/2) rounded to broadest interval when picking the indices.
    Line Clemmensen suggests picking k (number of repeats) to 1000 or 2000 for this tasks, so I do this.
  """
  import numpy as np
  assert isinstance(p, np.ndarray)
  assert p.ndim == 1
  assert ((0 <= p) & (p <= 1)).all()
  N = len(p)
  bootstraps = np.random.choice(p, (k,N), replace=True)
  ci_lower = alpha/2.
  ci_upper = 1.-(alpha/2.)
  idxs = [
    int(np.floor(k*ci_lower)),
    int(np.ceil(k*ci_upper))
  ]
  CI = np.sort(np.mean(bootstraps, axis=-1))[idxs]     # Sorts lowest to highest
  assert CI[0] < CI[1]  # To be on the safe side...
  return CI, N    # Returns CI and support (N)


def posterior_mean(alpha_, df, n_samples):
  means = []
  ps = []
  for ix, row in tqdm(df.iterrows(), total=len(df), desc="Looping over datapoints", leave=False):
    n_choices = row['n_choices']
    n_correct = row['n_correct']     # Invariant: n_correct < n_choices
    alphas = np.ones((n_choices,))*alpha_
    pi = dirichlet.rvs(alpha=alphas, size=n_samples)     # [n_samples, n_choices]
    p_true = np.sum(pi[:, :n_correct], axis=-1)
    p_false = np.sum(pi[:, n_correct:], axis=-1)
    assert np.isclose(np.sum(pi, axis=-1), 1.).all()
    assert not np.isnan(p_true).any()
    assert not np.isnan(p_false).any()
    assert p_true.shape == (n_samples,)
    assert p_false.shape == (n_samples,)
    p = p_true / (p_true + p_false)     #(p_true > p_false)
    assert p.shape == (n_samples,)
    ps.append(p)
    means.append(np.mean(p, axis=0))
    assert isinstance(means[-1], float)

  CI = bootstrap_CI(np.array(ps).reshape(-1), k=1000)
  mean = np.mean(np.array(means))
  del ps
  del means
  gc.collect()
  return mean, CI

df = load_dataset('truthful_qa', 'multiple_choice', split='validation').to_pandas()
df['n_choices'] = df.apply(lambda row: len(row['mc2_targets']['labels']), axis=1)     # Maps each row to a series
df['n_correct'] = df.apply(lambda row: sum(row['mc2_targets']['labels']), axis=1)
alphas = np.logspace(-2,2,50)
posterior_baseline = []
CIs = []
for alpha in (pbar:=tqdm(alphas, desc="Alphas search", leave=False)):
  pbar.set_description(f"Alpha={alpha:.2E}")
  p_mean, _ = posterior_mean(alpha_=alpha, df=df, n_samples=500)
  posterior_baseline.append(p_mean)
best_ix = np.argmax(np.array(posterior_baseline))
alpha_opt, posterior_opt = alphas[best_ix], posterior_baseline[best_ix]
print(f"Optimal alpha = {alpha_opt:.2e} with posterior mean {posterior_opt:.2f}.")

Alphas search:   0%|          | 0/50 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

Optimal alpha = 6.55e-02 with posterior mean 0.45.


In [None]:
_, CI = posterior_mean(alpha_=alpha_opt, df=df, n_samples=500)
print(CI)

Looping over datapoints:   0%|          | 0/817 [00:00<?, ?it/s]

(array([0.44731453, 0.44993357]), 408500)
