In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '7'

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import seaborn.objects as so

sns.set(font_scale=1.5, style='whitegrid')

In [None]:
import wandb
from tqdm.auto import tqdm
import pandas as pd

def get_trace_metrics(sweep_id, filter_func=None):
  api = wandb.Api(timeout=60)
  sweep = api.sweep(sweep_id)

  data = []
  for run in tqdm(sweep.runs, desc='Runs', leave=False):
    if callable(filter_func) and not filter_func(run):
      continue

    for i, row in run.history().iterrows():
      data.append(dict(run_id=run.id, **run.config, **row))

  return sweep, pd.DataFrame(data)

In [None]:
_, fs_metrics = get_trace_metrics('deeplearn/fspace-inference/t0u3iuuu')

_, ps_metrics = get_trace_metrics('deeplearn/fspace-inference/btbjb050')

In [None]:
# fs_metrics = fs_metrics[['run_id', 'ckpt_path', 'corr_config', 's/test/sel_pred', 's/test/threshold', 's/test/x_id']]
fs_metrics['level'] = fs_metrics['corr_config'].apply(lambda row: int(row.split('_')[-1]))
fs_metrics['corruption'] = fs_metrics['corr_config'].apply(lambda row: '_'.join(row.split('_')[:-1]))

ckpt_to_seed_id = dict(zip(fs_metrics.ckpt_path.unique().tolist(), list(range(fs_metrics.ckpt_path.nunique()))))
fs_metrics['seed_id'] = fs_metrics['ckpt_path'].apply(lambda row: ckpt_to_seed_id[row])
fs_metrics['mode'] = 'fs'

In [None]:
# ps_metrics = ps_metrics[['run_id', 'ckpt_path', 'corr_config', 's/test/sel_pred', 's/test/threshold', 's/test/x_id']]
ps_metrics['level'] = ps_metrics['corr_config'].apply(lambda row: int(row.split('_')[-1]))
ps_metrics['corruption'] = ps_metrics['corr_config'].apply(lambda row: '_'.join(row.split('_')[:-1]))

ckpt_to_seed_id = dict(zip(ps_metrics.ckpt_path.unique().tolist(), list(range(ps_metrics.ckpt_path.nunique()))))
ps_metrics['seed_id'] = ps_metrics['ckpt_path'].apply(lambda row: ckpt_to_seed_id[row])
ps_metrics['mode'] = 'ps'

In [None]:
metrics = pd.concat([fs_metrics, ps_metrics])
metrics = metrics[['run_id', 'corruption', 'level', 'seed_id', 'mode', 's/test/sel_pred', 's/test/threshold']]

corr_list = ['speckle_noise', 'shot_noise', 'pixelate', 'gaussian_blur']
metrics = metrics[metrics.corruption.isin(corr_list)]
metrics

In [None]:
c_level = 5
g = sns.relplot(data=metrics[metrics.level == c_level], kind='line', x='s/test/threshold', y='s/test/sel_pred',
                hue='mode', col='corruption', col_wrap=2, errorbar='sd',
                height=3.3,
                palette=sns.color_palette("Set2", 2))

g.set_titles(template='{col_name}')
g.set(xlabel=r'Threshold $\tau$', ylabel='Sel. Accuracy', ylim=(.25,1+1e-3))

handles, labels = g.axes[0].get_legend_handles_labels()
labels = [r"FSGC", r"PS-MAP"]
for h in handles:
    h.set(markersize=10, linewidth=3)
g.axes[2].legend(handles=handles, labels=labels, loc='lower right', fontsize=14)
# g.axes[0].legend().remove()
g.legend.remove()

g.fig.show()
# g.fig.savefig(f'c10c_sel_pred_level_{c_level}.pdf', bbox_inches='tight')