In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import seaborn.objects as so

sns.set(font_scale=1.5, style='whitegrid')

In [None]:
import wandb
from tqdm.auto import tqdm
import pandas as pd


def get_summary_metrics(sweep_id, filter_func=None):
  api = wandb.Api(timeout=60)
  sweep = api.sweep(sweep_id)

  data = []
  for run in tqdm(sweep.runs, desc='Runs', leave=False):
    if callable(filter_func) and not filter_func(run):
      continue
    
    run_summary = dict(run_id=run.id, **run.config, **run.summary)
    scan_keys = [k for k in run_summary.keys() if k.endswith(('epoch', 'acc', 'nll', 'avg_nll'))]
    for row in run.scan_history(keys=scan_keys):
      if run_summary['val/best_epoch'] == row['sgd/val/epoch']:
        run_summary = dict(**run_summary, **row)
        break

    data.append(run_summary)

  return sweep, pd.DataFrame(data)

## FashionMNIST

In [None]:
# ## fmnist-mresnet18-lnoise
# _, _data = get_summary_metrics('deeplearn/fspace-inference/r1rfv4xj')
# _data['mode'] = 'PS-MAP'

# ## fmnist-mresnet18-lnoise-fsmap
# _, _data2 = get_summary_metrics('deeplearn/fspace-inference/zl1v1w2n')
# _data2['mode'] = 'FS-MAP'

# results = pd.concat([_data, _data2], ignore_index=False)
# results['decay'] = results['weight_decay'].fillna(results['func_decay'])

# results.to_csv('results/fmnist_lnoise.csv', index=False)

In [None]:
results = pd.read_csv('results/fmnist_lnoise.csv')

In [None]:
p = so.Plot(results,
                x='decay', y='sgd/test/acc', color='mode')\
            .add(so.Line(marker='o'))\
            .scale(x='log')\
            .label(x='Decay Coeff.', y='Test Accuracy', color='Mode', title='Noise: {}'.format)\
            .facet('label_noise', wrap=2)

p.plot()
# p.save('fmnist_acc_noise.png', bbox_inches='tight')

In [None]:
p = so.Plot(results,
                x='decay', y='sgd/test/avg_nll', color='mode')\
            .add(so.Line(marker='o'))\
            .scale(x='log')\
            .label(x='Decay Coeff.', y='Test Avg. NLL', color='Mode', title='Noise: {}'.format)\
            .facet('label_noise', wrap=2)

p.plot()
# p.save('fmnist_avg_nll_noise.png', bbox_inches='tight')

## CIFAR-10

In [None]:
# ## c10-mresnet18-lnoise
# _, _data = get_summary_metrics('deeplearn/fspace-inference/fu2bvbac')
# _data['mode'] = 'PS-MAP'

# ## c10-mresnet18-lnoise-fsmap
# _, _data2 = get_summary_metrics('deeplearn/fspace-inference/ovjzu2tf')
# _data2['mode'] = 'FS-MAP'

# results = pd.concat([_data, _data2], ignore_index=False)
# results['decay'] = results['weight_decay'].fillna(results['func_decay'])

# results.to_csv('results/c10_lnoise.csv', index=False)

In [None]:
results = pd.read_csv('results/c10_lnoise.csv')

In [None]:
p = so.Plot(results,
                x='decay', y='sgd/test/acc', color='mode')\
            .add(so.Line(marker='o'))\
            .scale(x='log')\
            .label(x='Decay Coeff.', y='Test Accuracy', color='Mode', title='Noise: {}'.format)\
            .facet('label_noise', wrap=2)

p.plot()
# p.save('c10_acc_noise.png', bbox_inches='tight')

In [None]:
p = so.Plot(results,
                x='decay', y='sgd/test/avg_nll', color='mode')\
            .add(so.Line(marker='o'))\
            .scale(x='log')\
            .label(x='Decay Coeff.', y='Test Avg. NLL', color='Mode', title='Noise: {}'.format)\
            .facet('label_noise', wrap=2)

p.plot()
# p.save('c10_avg_nll_noise.png', bbox_inches='tight')