In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import seaborn.objects as so

In [None]:
import wandb
from tqdm.auto import tqdm
import pandas as pd

def get_summary_metrics(sweep_id, filter_func=None):
  api = wandb.Api(timeout=60)
  sweep = api.sweep(sweep_id)

  data = []
  for run in tqdm(sweep.runs, desc='Runs', leave=False):
    if callable(filter_func) and not filter_func(run):
      continue
    
    run_summary = dict(run_id=run.id, **run.config, **run.summary)

    data.append(run_summary)

  return sweep, pd.DataFrame(data)

In [None]:
# _, lmap_metrics = get_summary_metrics('deeplearn/fspace-inference/6vah02u6')

# lmap_metrics = lmap_metrics[['run_id', 'ckpt_path', 'corr_config', 's/test/acc', 's/test/sel_acc', 's/test/ece', 's/test/avg_nll']]
# lmap_metrics['level'] = lmap_metrics['corr_config'].apply(lambda row: int(row.split('_')[-1]))
# lmap_metrics['corruption'] = lmap_metrics['corr_config'].apply(lambda row: '_'.join(row.split('_')[:-1]))

# ckpt_to_seed_id = dict(zip(lmap_metrics.ckpt_path.unique().tolist(), list(range(lmap_metrics.ckpt_path.nunique()))))
# lmap_metrics['seed_id'] = lmap_metrics['ckpt_path'].apply(lambda row: ckpt_to_seed_id[row])
# lmap_metrics['mode'] = 'lmap'

In [None]:
_, fsgc_metrics = get_summary_metrics('deeplearn/fspace-inference/chefbo1z')

fsgc_metrics['level'] = fsgc_metrics['corr_config'].apply(lambda row: int(row.split('_')[-1]))
fsgc_metrics['corruption'] = fsgc_metrics['corr_config'].apply(lambda row: '_'.join(row.split('_')[:-1]))

ckpt_to_seed_id = dict(zip(fsgc_metrics.ckpt_path.unique().tolist(), list(range(fsgc_metrics.ckpt_path.nunique()))))
fsgc_metrics['seed_id'] = fsgc_metrics['ckpt_path'].apply(lambda row: ckpt_to_seed_id[row])
fsgc_metrics['mode'] = 'fsgc'

fsgc_metrics = fsgc_metrics[['mode', 'corruption', 'level', 'seed_id', 's/test/acc', 's/test/sel_acc']]
fsgc_metrics['s/test/acc'] *= 100

In [None]:
_, ps_metrics = get_summary_metrics('deeplearn/fspace-inference/eoi1wbz4')

ps_metrics['level'] = ps_metrics['corr_config'].apply(lambda row: int(row.split('_')[-1]))
ps_metrics['corruption'] = ps_metrics['corr_config'].apply(lambda row: '_'.join(row.split('_')[:-1]))

ckpt_to_seed_id = dict(zip(ps_metrics.ckpt_path.unique().tolist(), list(range(ps_metrics.ckpt_path.nunique()))))
ps_metrics['seed_id'] = ps_metrics['ckpt_path'].apply(lambda row: ckpt_to_seed_id[row])
ps_metrics['mode'] = 'ps'

ps_metrics = ps_metrics[['mode', 'corruption', 'level', 'seed_id', 's/test/acc', 's/test/sel_acc']]
ps_metrics['s/test/acc'] *= 100

In [None]:
# fsvi_metrics = pd.read_csv('results/c10c_fsvi_ctx_train.csv')
# fsvi_metrics['mode'] = 'fsvi_ctx_train'

fsvi_metrics = pd.read_csv('results/c10c_fsvi_ctx_c100.csv')
fsvi_metrics['mode'] = 'fsvi'

fsvi_metrics = fsvi_metrics.rename(columns={ 'run_id': 'seed_id' })
fsvi_metrics = fsvi_metrics[['mode', 'corruption', 'level', 'seed_id', 's/test/acc', 's/test/sel_acc']]

In [None]:
metrics = pd.concat([fsgc_metrics, ps_metrics, fsvi_metrics], ignore_index=True)

In [None]:
sns.set(font_scale=1.8, style='whitegrid')

corr_list = ['speckle_noise', 'shot_noise', 'pixelate', 'gaussian_blur']
plt_metrics = metrics[metrics.corruption.isin(corr_list)]
plt_metrics = plt_metrics[plt_metrics['mode'] != 'fsvi']
_m = 'sel_acc' # 'sel_acc'

g = sns.relplot(data=plt_metrics, kind='line', x='level', y=f's/test/{_m}',
                hue='mode', col='corruption', col_wrap=2, errorbar='sd',
                marker='o', markersize=10, linewidth=3,
                height=3.5, aspect=1.,
                palette=sns.color_palette("Set2", 3))

g.set_titles(template='{col_name}')
g.set(xlabel='Corr. Level', ylabel='Sel. Acc.', xticks=range(1,6))

labelmap = { 'ps': 'PS-MAP', 'fsgc': 'FSGC', 'fs': 'FS-MAP', 'lmap': 'L-MAP', 'fsvi': 'FSVI' }

handles, labels = g.axes[0].get_legend_handles_labels()
labels = [labelmap[l] for l in labels]
for h in handles:
    h.set(markersize=10, linewidth=3, marker='o')
g.axes[-2].legend(handles=handles, labels=labels, fontsize=22, loc='lower left')
g.axes[0].legend().remove()
g.legend.remove()

g.fig.tight_layout()

g.fig.show()
# g.fig.savefig(f'c10c_{_m}.pdf', bbox_inches='tight')