In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import wandb

sns.set(font_scale=1.2, style='whitegrid')

def get_metrics(sweep_id, keys=None, config_keys=None):
  api = wandb.Api(timeout=60)
  sweep = api.sweep(sweep_id)

  if isinstance(keys, list):
    keys.extend(['_runtime', '_step', '_timestamp'])
    keys = list(set(keys))

  data = []
  for run in tqdm(sweep.runs, desc='Sweeps', leave=False):
    cfg = {k: run.config[k] for k in config_keys or []}
    for row in tqdm(run.scan_history(keys=keys), desc='History', leave=False):
      data.append(dict(run_id=run.id, **cfg, **row))

  return sweep, pd.DataFrame(data)

def get_summary_metrics(sweep_id, config_keys=None):
  api = wandb.Api(timeout=60)
  sweep = api.sweep(sweep_id)

  data = []
  for run in tqdm(sweep.runs, desc='Sweeps', leave=False):
    cfg = {k: run.config[k] for k in config_keys or []}
    data.append(dict(run_id=run.id, **cfg, **run.summary))

  return sweep, pd.DataFrame(data)

In [None]:
_, metrics = get_summary_metrics('deeplearn/data-aug-likelihood/j991kgwd', config_keys=['aug', 'temperature'])

## Augmentation Likelihood
_, aug_metrics = get_summary_metrics('deeplearn/data-aug-likelihood/3ty8kjvq', config_keys=['aug_scale', 'temperature'])

In [None]:
plt_data = metrics[['run_id', 'aug', 'temperature', 'test/bma_acc']].copy().reset_index()

_tmp = aug_metrics[['run_id', 'temperature', 'test/bma_acc']].copy().reset_index()
_tmp['aug'] = -1

plt_data = pd.concat([plt_data, _tmp]).reset_index().drop(columns=['index'])

plt_data

In [None]:
g = sns.relplot(data=plt_data, x='temperature', y='test/bma_acc', hue='aug', kind='line',
                markers=True, dashes=False,
                height=5, aspect=4/3, linewidth=3, palette=sns.color_palette('tab10', 3))
g.set(xscale='log', xlabel=r'$T$', ylabel='Test Accuracy')
g.legend.set_visible(False)

h, l = g.ax.get_legend_handles_labels()
l = ['Aug. Likelihood', 'No Augmentation', 'Augmentation']
g.fig.legend(handles=h, labels=l, title='Training Mode', bbox_to_anchor=(.5, -0.05, .1, 0.),
             loc='lower center', ncol=3, borderaxespad=-2, frameon=True)

g.fig.tight_layout()
# g.fig.savefig('accT.png', bbox_inches='tight')

## Aug. Likelihood Sweep

### cSGLD Aug. Scale v/s Prior Scale

In [None]:
_, metrics = get_summary_metrics('deeplearn/data-aug-likelihood/e1sdwmzy',
                                 config_keys=['aug_scale', 'prior_scale', 'temperature'])
metrics

In [None]:
metrics.temperature.unique()

In [None]:
T = 1e-3
plt_data = metrics[metrics.temperature == T][['run_id', 'aug_scale', 'prior_scale', 'test/bma_acc']]
plt_data = plt_data.pivot(index='aug_scale', columns='prior_scale', values='test/bma_acc')

fig, ax = plt.subplots(figsize=(7, 7))
g = sns.heatmap(data=plt_data, ax=ax, fmt='.4f', linewidths=1.,
                annot=True, annot_kws=dict(fontsize=14),
                cbar=True, cbar_kws=dict(shrink=.5),
                yticklabels=plt_data.index,
                cmap=sns.color_palette('summer'))
g.set(xlabel=r'Prior Scale ($\sigma)$', ylabel=r'Aug. Scale ($\epsilon$)', title=rf'cSGLD at $T = {T}$ with Random Aug.')
g.set_xticklabels(g.get_xticklabels(), fontdict=dict(size=16))
g.set_yticklabels(g.get_yticklabels(), fontdict=dict(size=12))
fig.tight_layout()

# fig.savefig('csgld_aug_lik_scales_cold.png', bbox_inches='tight')

### SGD Aug. Scale v/s Prior Scale

In [None]:
_, metrics = get_summary_metrics('deeplearn/data-aug-likelihood/mre7c18c',
                                 config_keys=['aug_scale', 'prior_scale'])
metrics

In [None]:
plt_data = metrics[['run_id', 'aug_scale', 'prior_scale', 'test/best_acc']]
plt_data = plt_data.pivot(index='aug_scale', columns='prior_scale', values='test/best_acc')

fig, ax = plt.subplots(figsize=(7, 7))
g = sns.heatmap(data=plt_data, ax=ax, fmt='.3f', linewidths=1.,
                annot=True, annot_kws=dict(fontsize=14),
                cbar=True, cbar_kws=dict(shrink=.5),
                yticklabels=plt_data.index,
                cmap=sns.color_palette('summer'))
g.set(xlabel=r'Prior Scale ($\sigma)$', ylabel=r'Aug. Scale ($\epsilon$)', title='SGD with Fixed Aug.')
g.set_xticklabels(g.get_xticklabels(), fontdict=dict(size=16))
g.set_yticklabels(g.get_yticklabels(), fontdict=dict(size=12))
fig.tight_layout()

# fig.savefig('sgd_aug_lik_scales.png', bbox_inches='tight')

### SGD Prior Scale with or without Random Data Augmentation

This is the usual way of doing things, just with an explicit prior.

In [None]:
_, metrics = get_summary_metrics('deeplearn/data-aug-likelihood/zay9l79c',
                                 config_keys=['aug', 'prior_scale'])
metrics

In [None]:
plt_data = metrics[['run_id', 'aug', 'prior_scale', 'test/best_acc']]
plt_data = plt_data.pivot(index='aug', columns='prior_scale', values='test/best_acc')

fig, ax = plt.subplots(figsize=(7, 3))
g = sns.heatmap(data=plt_data, ax=ax, fmt='.3f', linewidths=1.,
                annot=True, annot_kws=dict(fontsize=14),
                cbar=True, cbar_kws=dict(shrink=.5),
                yticklabels=plt_data.index,
                cmap=sns.color_palette('summer'))
g.set(xlabel=r'Prior Scale ($\sigma)$', ylabel=r'Aug. (?)', title='SGD with Random Data Aug.')
g.set_xticklabels(g.get_xticklabels(), fontdict=dict(size=16))
g.set_yticklabels(g.get_yticklabels(), fontdict=dict(size=12))
fig.tight_layout()

# fig.savefig('sgd_lik_scales.png', bbox_inches='tight')

### SGD Prior Scale with Fixed Set of Data Augmentations

In [None]:
_, metrics = get_summary_metrics('deeplearn/data-aug-likelihood/sjcc8rjk',
                                 config_keys=['prior_scale'])
metrics

In [None]:
plt_data = metrics[['run_id', 'train/epoch', 'prior_scale', 'test/best_acc']]
plt_data = plt_data.pivot(index='train/epoch', columns='prior_scale', values='test/best_acc')

fig, ax = plt.subplots(figsize=(7, 2))
g = sns.heatmap(data=plt_data, ax=ax, fmt='.3f', linewidths=1.,
                annot=True, annot_kws=dict(fontsize=14),
                cbar=True, cbar_kws=dict(shrink=.5),
                yticklabels=plt_data.index,
                cmap=sns.color_palette('summer'))
g.set(xlabel=r'Prior Scale ($\sigma)$', ylabel=r'', title='SGD with Fixed Data Aug.')
g.set_xticklabels(g.get_xticklabels(), fontdict=dict(size=16))
g.set_yticklabels([], fontdict=dict(size=12))
fig.tight_layout()

fig.savefig('sgd_lik_scales_fixed_aug.png', bbox_inches='tight')

## cSGLD with Fixed Set of Data Augmentations

In [None]:
_, metrics = get_summary_metrics('deeplearn/data-aug-likelihood/o7gy31ps',
                                 config_keys=['aug', 'prior_scale', 'temperature'])
metrics

In [None]:
T = .5
plt_data = metrics[metrics.temperature == T][['run_id', 'aug', 'prior_scale', 'test/bma_acc']]
plt_data = plt_data.pivot(index='aug', columns='prior_scale', values='test/bma_acc')

fig, ax = plt.subplots(figsize=(7, 3))
g = sns.heatmap(data=plt_data, ax=ax, fmt='.3f', linewidths=1.,
                annot=True, annot_kws=dict(fontsize=14),
                cbar=True, cbar_kws=dict(shrink=.5),
                yticklabels=plt_data.index,
                cmap=sns.color_palette('summer'))
g.set(xlabel=r'Prior Scale ($\sigma)$', ylabel=r'Aug. Scale ($\epsilon$)', title=rf'cSGLD at $T = {T}$ with Fixed Aug.')
g.set_xticklabels(g.get_xticklabels(), fontdict=dict(size=16))
g.set_yticklabels(g.get_yticklabels(), fontdict=dict(size=12))
fig.tight_layout()

fig.savefig('csgld_fixed_aug_cold.png', bbox_inches='tight')

### cSGLD + Fixed Aug + FRN Layer

In [None]:
_, metrics = get_summary_metrics('deeplearn/data-aug-likelihood/9vf05eep',
                                 config_keys=['aug', 'prior_scale', 'temperature'])
metrics

In [None]:
T = 1
plt_data = metrics[['run_id', 'aug', 'prior_scale', 'test/bma_acc']]
plt_data = plt_data.pivot(index='aug', columns='prior_scale', values='test/bma_acc')

fig, ax = plt.subplots(figsize=(7, 3))
g = sns.heatmap(data=plt_data, ax=ax, fmt='.3f', linewidths=1.,
                annot=True, annot_kws=dict(fontsize=14),
                cbar=True, cbar_kws=dict(shrink=.5),
                yticklabels=plt_data.index,
                cmap=sns.color_palette('summer'))
g.set(xlabel=r'Prior Scale ($\sigma)$', ylabel=r'Aug. Scale ($\epsilon$)', title=rf'cSGLD at $T = {T}$ with Fixed Aug + FRN Layer')
g.set_xticklabels(g.get_xticklabels(), fontdict=dict(size=16))
g.set_yticklabels(g.get_yticklabels(), fontdict=dict(size=12))
fig.tight_layout()

# fig.savefig('csgld_fixed_aug_frn.png', bbox_inches='tight')

In [None]:
_, metrics = get_summary_metrics('deeplearn/data-aug-likelihood/5lady5ar',
                                 config_keys=['aug_scale', 'prior_scale'])
metrics

In [None]:
T = 1
plt_data = metrics[['run_id', 'aug_scale', 'prior_scale', 'test/bma_acc']]
plt_data = plt_data.pivot(index='aug_scale', columns='prior_scale', values='test/bma_acc')

fig, ax = plt.subplots(figsize=(7, 7))
g = sns.heatmap(data=plt_data, ax=ax, fmt='.3f', linewidths=1.,
                annot=True, annot_kws=dict(fontsize=14),
                cbar=True, cbar_kws=dict(shrink=.5),
                yticklabels=plt_data.index,
                cmap=sns.color_palette('summer'))
g.set(xlabel=r'Prior Scale ($\sigma)$', ylabel=r'Aug. Scale ($\epsilon$)', title=rf'cSGLD at $T = {T}$ with Avg. Aug + FRN Layer')
g.set_xticklabels(g.get_xticklabels(), fontdict=dict(size=16))
g.set_yticklabels(g.get_yticklabels(), fontdict=dict(size=12))
fig.tight_layout()

# fig.savefig('csgld_avg_aug_frn.png', bbox_inches='tight')

## Noisy Dirichlet Likelihoods

In [None]:
## Noisy Dirichlet
_, metrics = get_summary_metrics('deeplearn/data-aug-likelihood/wou01xes',
                                 config_keys=['augment', 'temperature'])

## Noisy Dirichlet + KL Consistency
_, metrics_consistency = get_summary_metrics('deeplearn/data-aug-likelihood/2irgmpw7',
                                 config_keys=['temperature'])
metrics_consistency['augment'] = -1
metrics = pd.concat([metrics, metrics_consistency]).reset_index()
metrics

In [None]:
g = sns.relplot(data=metrics, x='temperature', y='csgld/test/bma_acc', hue='augment', kind='line',
                markers=True, dashes=False,
                height=5, aspect=4/3, linewidth=3, palette=sns.color_palette('tab10', 3))
g.set(xscale='log', xlabel=r'$T$', ylabel='BMA Accuracy')
g.fig.savefig('save.png', bbox_inches='tight')

## Noisy Labels

In [None]:
## Softmax
_, softmax_metrics = get_summary_metrics('deeplearn/data-aug-likelihood/4ypwsj0k',
                                         config_keys=['label_noise', 'temperature'])
softmax_metrics['kind'] = 'softmax'
softmax_metrics['noise'] = float('NaN')

## Noisy Dirichlet
_, dirichlet_metrics = get_summary_metrics('deeplearn/data-aug-likelihood/zboj4e6h',
                                           config_keys=['label_noise', 'noise', 'temperature'])
dirichlet_metrics['kind'] = 'dirichlet'

metrics = pd.concat([softmax_metrics, dirichlet_metrics[dirichlet_metrics.temperature == 1]]).reset_index()

In [None]:
# metrics = pd.read_csv('label_noise.csv')
best_metrics = metrics.iloc[metrics.groupby(by=['kind', 'label_noise'])['csgld/test/bma_acc'].idxmax().values].reset_index()

In [None]:
g = sns.relplot(data=best_metrics, x='label_noise', y='csgld/test/bma_acc', hue='kind',
                kind='line', marker='o', markersize=10,
                height=5, aspect=4/3, linewidth=3, palette=sns.color_palette('tab10', 2))

sns.lineplot(ax=g.ax, data=metrics[metrics.kind == 'dirichlet'], x='label_noise', y='csgld/test/bma_acc',
             style='noise', legend=False,
             linewidth=1, alpha=.5, color=sns.color_palette('tab10', 2)[0])

sns.lineplot(ax=g.ax, data=metrics[metrics.kind == 'softmax'], x='label_noise', y='csgld/test/bma_acc',
             style='temperature', legend=False,
             linewidth=1, alpha=.5, color=sns.color_palette('tab10', 2)[1])

g.set(xlabel='Label Noise', ylabel='Test BMA Accuracy')

g.fig.savefig('label_noise.pdf', bbox_inches='tight')