In [None]:
import wandb
from tqdm.auto import tqdm
import pandas as pd

def get_summary_metrics(sweep_id, filter_func=None):
  api = wandb.Api(timeout=60)
  sweep = api.sweep(sweep_id)

  data = []
  for run in tqdm(sweep.runs, desc='Runs', leave=False):
    if callable(filter_func) and not filter_func(run):
      continue
    
    run_summary = dict(run_id=run.id, **run.config, **run.summary)

    data.append(run_summary)

  return sweep, pd.DataFrame(data)

In [None]:
# _, results = get_summary_metrics('deeplearn/fspace-inference/')
# metrics = results[['run_id', 'seed', 'dataset', 'context_size', 's/test/acc', 's/test/sel_acc', 's/test/avg_nll', 's/test/ece']]
# metrics['s/test/acc'] *= 100
# metrics['s/test/ece'] *= 100

In [None]:
metrics = pd.concat([
    pd.read_csv('results/fmnist_lmap_ctx_size.csv'),
    pd.read_csv('results/c10_lmap_ctx_size.csv'),
], ignore_index=True)

# mu = metrics.groupby(['dataset', 'context_size']).mean(numeric_only=True).drop(columns=['seed'])
# mu['s/test/acc'] = mu['s/test/acc'].round(1)
# mu['s/test/sel_acc'] = mu['s/test/sel_acc'].round(1)
# mu['s/test/avg_nll'] = mu['s/test/avg_nll'].round(2)
# mu['s/test/ece'] = mu['s/test/ece'].round(1)

# sigma = metrics.groupby(['dataset', 'context_size']).std(numeric_only=True).drop(columns=['seed'])
# sigma['s/test/acc'] = sigma['s/test/acc'].round(1)
# sigma['s/test/sel_acc'] = sigma['s/test/sel_acc'].round(1)
# sigma['s/test/avg_nll'] = sigma['s/test/avg_nll'].round(2)
# sigma['s/test/ece'] = sigma['s/test/ece'].round(1)

# ('$' + mu.astype(str) + ' \pm ' + sigma.astype(str) + '$')#.reset_index().to_latex('tmp.txt', index=False)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(font_scale=2., style='whitegrid')

fig, ax = plt.subplots(figsize=(5,5))

sns.lineplot(data=metrics, ax=ax, x='context_size', y='s/test/ece',
             hue='dataset', marker='o', markersize=10,
             palette=sns.color_palette('Set1', 2))

labelmap = { 'fmnist': 'FashionMNIST', 'cifar10': 'CIFAR-10' }

handles, labels = ax.get_legend_handles_labels()
labels = [labelmap[l] for l in labels]
for h in handles:
    h.set(linewidth=2)
ax.legend(handles=handles, labels=labels, fontsize=20, title='', loc='center right')

ax.set(xscale='log', xlabel=r'# Samples ($S$)', ylabel='Calibration Error')

fig.show()
# fig.savefig('ece_ctx_size.pdf', bbox_inches='tight')