In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import seaborn.objects as so

sns.set(font_scale=1.5, style='whitegrid')

In [None]:
import wandb
from tqdm.auto import tqdm
import pandas as pd

def get_summary_metrics(sweep_id, filter_func=None):
  api = wandb.Api(timeout=60)
  sweep = api.sweep(sweep_id)

  data = []
  for run in tqdm(sweep.runs, desc='Runs', leave=False):
    if callable(filter_func) and not filter_func(run):
      continue
    
    run_summary = dict(run_id=run.id, **run.config, **run.summary)

    data.append(run_summary)

  return sweep, pd.DataFrame(data)

In [None]:
_, fs_metrics = get_summary_metrics('deeplearn/fspace-inference/chefbo1z')

fs_metrics = fs_metrics[['run_id', 'ckpt_path', 'corr_config', 's/test/acc', 's/test/sel_acc', 's/test/ece', 's/test/avg_nll']]
fs_metrics['level'] = fs_metrics['corr_config'].apply(lambda row: int(row.split('_')[-1]))
fs_metrics['corruption'] = fs_metrics['corr_config'].apply(lambda row: '_'.join(row.split('_')[:-1]))

ckpt_to_seed_id = dict(zip(fs_metrics.ckpt_path.unique().tolist(), list(range(fs_metrics.ckpt_path.nunique()))))
fs_metrics['seed_id'] = fs_metrics['ckpt_path'].apply(lambda row: ckpt_to_seed_id[row])
fs_metrics['mode'] = 'fs'

In [None]:
_, ps_metrics = get_summary_metrics('deeplearn/fspace-inference/eoi1wbz4')

ps_metrics = ps_metrics[['run_id', 'ckpt_path', 'corr_config', 's/test/acc', 's/test/sel_acc', 's/test/ece', 's/test/avg_nll']]
ps_metrics['level'] = ps_metrics['corr_config'].apply(lambda row: int(row.split('_')[-1]))
ps_metrics['corruption'] = ps_metrics['corr_config'].apply(lambda row: '_'.join(row.split('_')[:-1]))

ckpt_to_seed_id = dict(zip(ps_metrics.ckpt_path.unique().tolist(), list(range(ps_metrics.ckpt_path.nunique()))))
ps_metrics['seed_id'] = ps_metrics['ckpt_path'].apply(lambda row: ckpt_to_seed_id[row])
ps_metrics['mode'] = 'ps'

In [None]:
metrics = pd.concat([fs_metrics, ps_metrics])

In [None]:
g = sns.relplot(data=metrics, kind='line', x='level', y='s/test/acc',
                hue='mode', col='corruption', col_wrap=3, errorbar='sd',
                marker='o', markersize=20)

g.fig.tight_layout()
# g.fig.savefig('c10c.pdf', bbox_inches='tight')