In [None]:
from tqdm.auto import tqdm
import wandb
import pandas as pd


def get_summary_metrics(sweep_id, filter_func=None):
  api = wandb.Api()
  sweep = api.sweep(sweep_id)

  data = []
  for run in tqdm(sweep.runs, desc='Runs', leave=False):
    if callable(filter_func) and not filter_func(run):
      continue
    data.append(dict(run_id=run.id, **run.config, **run.summary))

  return sweep, pd.DataFrame(data)

In [None]:
all_metrics = []
for s_id, _dset in zip([
  'deeplearn/pactl/jpe804ba', 'deeplearn/pactl/6vezy6j2', 'deeplearn/pactl/gebtk6cj',
  'deeplearn/pactl/fxiv2f7l', 'deeplearn/pactl/14okon9o', 'deeplearn/pactl/93dkg8vg',
], ['CIFAR-10', 'CIFAR-10', 'CIFAR-10', 'CIFAR-100', 'CIFAR-100', 'CIFAR-100']):
  _, _data = get_summary_metrics(s_id)
  _data['dataset'] = _dset
  all_metrics.append(_data)

for s_id, _dset in zip([
  'deeplearn/pactl/4k5ujb2j', 'deeplearn/pactl/oo2f42i9', 'deeplearn/pactl/ybugd5bn',
  'deeplearn/pactl/gtwp5qw2', 'deeplearn/pactl/dmtqx539', 'deeplearn/pactl/e8vkg5u7',
], ['CIFAR-10', 'CIFAR-10', 'CIFAR-10', 'CIFAR-100', 'CIFAR-100', 'CIFAR-100']):
  _, _data = get_summary_metrics(s_id)
  _data['dataset'] = _dset
  _data['intrinsic_dim'] = 0
  all_metrics.append(_data)

all_metrics = pd.concat(all_metrics).reset_index().drop(columns=['index'])
all_metrics.train_subset *= -1

In [None]:
best_metrics = all_metrics.iloc[all_metrics.groupby(['dataset', 'intrinsic_dim', 'train_subset'])['raw_err_bound_100'].idxmin()]

## CIFAR-10

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

from palettable.cartocolors.diverging import Temps_4

sns.set(font_scale=2., style='whitegrid')

fig, ax = plt.subplots(figsize=(9,6.5))

sns.lineplot(ax=ax, data=best_metrics[best_metrics.dataset == 'CIFAR-10'],
             x='train_subset', y='raw_err_bound_100', hue='intrinsic_dim',
             lw=8, marker='o', markersize=18,
             palette=Temps_4.mpl_colors)
ax.set(xticks=[.2, .5, .8], xlabel='Prior Train Subset', ylabel=r'Err. Bound ($\%$)')

handles, labels = ax.get_legend_handles_labels()
for h in handles:
    h.set(linewidth=8)
ax.legend(handles=handles, labels=labels, title='Intrinsic. Dim.', loc='best')

fig.tight_layout()
fig.show()

# fig.savefig('c10_id_subset.pdf', bbox_inches='tight')

## CIFAR-100

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

from palettable.cartocolors.diverging import Temps_5

sns.set(font_scale=2., style='whitegrid')

fig, ax = plt.subplots(figsize=(9,6.5))

sns.lineplot(ax=ax, data=best_metrics[best_metrics.dataset == 'CIFAR-100'],
             x='train_subset', y='raw_err_bound_100', hue='intrinsic_dim',
             lw=8, marker='o', markersize=18,
             palette=Temps_5.mpl_colors)
ax.set(xticks=[.2, .5, .8], xlabel='Prior Train Subset', ylabel=r'Err. Bound ($\%$)')

handles, labels = ax.get_legend_handles_labels()
for h in handles:
    h.set(linewidth=8)
ax.legend(handles=handles, labels=labels, title='Intrinsic. Dim.', loc='best')

fig.tight_layout()
fig.show()

# fig.savefig('c100_id_subset.pdf', bbox_inches='tight')