In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import seaborn.objects as so

sns.set(font_scale=1.5, style='whitegrid')

In [None]:
import wandb
from tqdm.auto import tqdm
import pandas as pd

def get_summary_metrics(sweep_id, filter_func=None):
  api = wandb.Api(timeout=60)
  sweep = api.sweep(sweep_id)

  data = []
  for run in tqdm(sweep.runs, desc='Runs', leave=False):
    if callable(filter_func) and not filter_func(run):
      continue
    
    run_summary = dict(run_id=run.id, **run.config, **run.summary)

    data.append(run_summary)

  return sweep, pd.DataFrame(data)

In [None]:
_, fs_metrics = get_summary_metrics('deeplearn/fspace-inference/6vah02u6')

fs_metrics = fs_metrics[['run_id', 'ckpt_path', 'corr_config', 's/test/acc', 's/test/sel_acc', 's/test/ece', 's/test/avg_nll']]
fs_metrics['level'] = fs_metrics['corr_config'].apply(lambda row: int(row.split('_')[-1]))
fs_metrics['corruption'] = fs_metrics['corr_config'].apply(lambda row: '_'.join(row.split('_')[:-1]))

ckpt_to_seed_id = dict(zip(fs_metrics.ckpt_path.unique().tolist(), list(range(fs_metrics.ckpt_path.nunique()))))
fs_metrics['seed_id'] = fs_metrics['ckpt_path'].apply(lambda row: ckpt_to_seed_id[row])
fs_metrics['mode'] = 'lmap'

In [None]:
_, fs_metrics = get_summary_metrics('deeplearn/fspace-inference/6vah02u6')

fs_metrics = fs_metrics[['run_id', 'ckpt_path', 'corr_config', 's/test/acc', 's/test/sel_acc', 's/test/ece', 's/test/avg_nll']]
fs_metrics['level'] = fs_metrics['corr_config'].apply(lambda row: int(row.split('_')[-1]))
fs_metrics['corruption'] = fs_metrics['corr_config'].apply(lambda row: '_'.join(row.split('_')[:-1]))

ckpt_to_seed_id = dict(zip(fs_metrics.ckpt_path.unique().tolist(), list(range(fs_metrics.ckpt_path.nunique()))))
fs_metrics['seed_id'] = fs_metrics['ckpt_path'].apply(lambda row: ckpt_to_seed_id[row])
fs_metrics['mode'] = 'lmap'

In [None]:
_, ps_metrics = get_summary_metrics('deeplearn/fspace-inference/554f3dib')

ps_metrics = ps_metrics[['run_id', 'ckpt_path', 'corr_config', 's/test/acc', 's/test/sel_acc', 's/test/ece', 's/test/avg_nll']]
ps_metrics['level'] = ps_metrics['corr_config'].apply(lambda row: int(row.split('_')[-1]))
ps_metrics['corruption'] = ps_metrics['corr_config'].apply(lambda row: '_'.join(row.split('_')[:-1]))

ckpt_to_seed_id = dict(zip(ps_metrics.ckpt_path.unique().tolist(), list(range(ps_metrics.ckpt_path.nunique()))))
ps_metrics['seed_id'] = ps_metrics['ckpt_path'].apply(lambda row: ckpt_to_seed_id[row])
ps_metrics['mode'] = 'ps'

In [None]:
metrics = pd.concat([fs_metrics, ps_metrics])

# corr_list = ['brightness', 'contrast', 'defocus_blur', 'elastic']
# corr_list = ['fog', 'frost', 'frosted_glass_blur', 'gaussian_blur']
# corr_list = ['impulse_noise', 'jpeg_compression', 'motion_blur', 'pixelate']
# corr_list = ['saturate', 'shot_noise', 'snow', 'spatter']
# corr_list = ['speckle_noise', 'zoom_blur', 'brightness', 'gaussian_noise']
metrics = metrics[metrics.corruption.isin(corr_list)][['run_id', 'corruption', 'level', 'mode', 's/test/acc', 's/test/sel_acc']]
metrics['s/test/sel_acc'] /= 100.


# _m = pd.read_csv('results/c10c_fsvi_ctx_train.csv')
# _m['mode'] = 'fsvi_ctx_train'
# _m['s/test/acc'] /= 100
# _m['s/test/sel_acc'] /= 100
# metrics = pd.concat([metrics, _m], ignore_index=True)

# _m = pd.read_csv('results/c10c_fsvi_ctx_c100.csv')
# _m['mode'] = 'fsvi_ctx_c100'
# _m['s/test/acc'] /= 100
# _m['s/test/sel_acc'] /= 100
# metrics = pd.concat([metrics, _m], ignore_index=True)

metrics

In [None]:
g = sns.relplot(data=metrics, kind='line', x='level', y='s/test/acc',
                hue='mode', col='corruption', col_wrap=2, errorbar='sd',
                marker='o', markersize=10, linewidth=3,
                height=3.5, aspect=1,
                palette=sns.color_palette("Set2", 4))

g.set_titles(template='{col_name}')
g.set(xlabel='Corruption Level', ylabel='Accuracy', xticks=range(1,6))

labelmap = { 'ps': 'PS-MAP', 'fsgc': 'FSGC', 'fs': 'FS-MAP', 'lmap': 'L-MAP' }

handles, labels = g.axes[0].get_legend_handles_labels()
labels = [labelmap[l] for l in labels]
for h in handles:
    h.set(markersize=10, linewidth=3)
g.axes[2].legend(handles=handles, labels=labels, fontsize=12)
g.axes[0].legend().remove()
g.legend.remove()

g.fig.show()
# g.fig.savefig('c10c_acc.pdf', bbox_inches='tight')

In [None]:
g = sns.relplot(data=metrics[metrics.corruption.isin(corr_list)], kind='line', x='level', y='s/test/sel_acc',
                hue='mode', col='corruption', col_wrap=2, errorbar='sd',
                marker='o', markersize=10, linewidth=3,
                height=3.5, aspect=1,
                palette=sns.color_palette("Set2", 4))

g.set_titles(template='{col_name}')
g.set(xlabel='Corruption Level', ylabel='Sel. Accuracy', xticks=range(1,6))

handles, labels = g.axes[0].get_legend_handles_labels()
labels = [r'FSGC', r'PS-MAP', r'FSVI', r'FSVI (CIFAR-100)']
for h in handles:
    h.set(markersize=10, linewidth=3)
g.axes[2].legend(handles=handles, labels=labels, fontsize=12)
g.axes[0].legend().remove()
g.legend.remove()

g.fig.show()
# g.fig.savefig('c10c_sel_acc.pdf', bbox_inches='tight')