In [None]:
import wandb
from tqdm.auto import tqdm
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(font_scale=2, style='whitegrid')

In [None]:
def get_summary_metrics(sweep_id, config_keys=None, filter_func=None):
  api = wandb.Api(timeout=60)
  sweep = api.sweep(sweep_id)

  data = []
  for run in tqdm(sweep.runs, desc='Runs', leave=False):
    cfg = {k: run.config[k] for k in config_keys or []}
    if callable(filter_func) and not filter_func(run, cfg):
      continue
    data.append(dict(run_id=run.id, **cfg, **run.summary))

  return sweep, pd.DataFrame(data)

In [None]:
def f(run, cfg):
    if run.state != 'finished':
        return False
    return run.config['train_subset'] == 1 and run.config['label_noise'] < .4

## CIFAR-10
_, metrics = get_summary_metrics('deeplearn/pactl/xet192s5', config_keys=['base_width', 'label_noise'],
                                 filter_func=f)
metrics['mode'] = 'scratch'

## CIFAR-10 Transfer
_, tr_metrics = get_summary_metrics('deeplearn/pactl/mcex7uak', config_keys=['pre_base_width', 'label_noise'],
                                    filter_func=f)
tr_metrics = tr_metrics.rename(columns={ 'pre_base_width': 'base_width' })
tr_metrics['mode'] = 'transfer'

all_metrics = pd.concat([metrics, tr_metrics]).reset_index().drop(columns=['index'])
# all_metrics

In [None]:
from palettable.cartocolors.qualitative import Vivid_3 as _palette

fig, ax = plt.subplots(figsize=(8,8))
sns.lineplot(ax=ax, data=all_metrics, x='base_width', y='sgd/test/acc', hue='label_noise',
             markersize=11, linewidth=4, style='mode', #marker='o',
             palette=sns.color_palette(_palette.mpl_colors, all_metrics.label_noise.nunique()))
ax.set(xlabel='Base Width', ylabel='Test Accuracy (Last Epoch)')
handles, labels = ax.get_legend_handles_labels()
for h in handles:
    h.set(linewidth=4)
labels[0] = 'Label Noise'
labels[-3] = 'Mode'
labels[-2] = 'Scratch'
labels[-1] = 'Transfer'
ax.legend(handles, labels, loc='lower right', bbox_to_anchor=(.8, 0, .2, 1))

fig.tight_layout()
# fig.savefig('dd_cifar10.pdf', bbox_inches='tight')

In [None]:
from palettable.wesanderson import Moonrise5_6, Moonrise6_5
import numpy as np

fig, ax = plt.subplots(figsize=(14,9))
sns.lineplot(ax=ax, data=all_metrics, x='base_width', y='sgd/test/best_acc', hue='label_noise',
             markersize=12, linewidth=4, style='mode', # marker='o',
             palette=sns.color_palette(Moonrise6_5.mpl_colors))
ax.set(xlabel='Base Width', ylabel='Test Accuracy (Early Stopping)')
handles, labels = ax.get_legend_handles_labels()
for h in handles:
    h.set(linewidth=4)
labels[0] = 'Label Noise'
labels[-3] = 'Mode'
labels[-2] = 'Scratch'
labels[-1] = 'Transfer'
ax.legend(handles, labels, loc='upper right', bbox_to_anchor=(1, .0, .5, 1))

fig.tight_layout()
# fig.savefig('es.pdf', bbox_inches='tight')

## Train

In [None]:
from palettable.wesanderson import Moonrise5_6, Moonrise6_5

fig, ax = plt.subplots(figsize=(13,9))
sns.lineplot(ax=ax, data=all_metrics, x='base_width', y='sgd/train/mini_loss', hue='label_noise',
             marker='o', markersize=11, linewidth=6, style='mode',
             palette=sns.color_palette(Moonrise6_5.mpl_colors))
ax.set(xlabel='Base Width', ylabel='Train Loss (Last Epoch)', yscale='log')
handles, labels = ax.get_legend_handles_labels()
for h in handles:
    h.set(linewidth=4)
labels[0] = 'Label Noise'
labels[-3] = 'Mode'
labels[-2] = 'Scratch'
labels[-1] = 'Transfer'
ax.legend(handles, labels, loc='upper right', bbox_to_anchor=(1, 0, .5, 1))

fig.tight_layout()
# fig.savefig('train_loss.pdf', bbox_inches='tight')