In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

from palettable.cmocean.sequential import Thermal_8
from matplotlib import ticker

sns.set(font_scale=2.5, style='whitegrid')
palette = sns.color_palette(Thermal_8.mpl_colors)
palette

In [None]:
import wandb

def get_summary_metrics(sweep_id, config_keys=None, filter_func=None):
  api = wandb.Api(timeout=60)
  sweep = api.sweep(sweep_id)

  data = []
  for run in tqdm(sweep.runs, desc='Runs', leave=False):
    cfg = {k: run.config[k] for k in config_keys or []}
    if run.state != 'finished':
      continue
    if callable(filter_func) and not filter_func(cfg):
      continue
    data.append(dict(run_id=run.id, **cfg, **run.summary))

  return sweep, pd.DataFrame(data)

### CIFAR-10

In [None]:
# _, softmax = get_summary_metrics('deeplearn/data-aug-likelihood/ekte3la8',
#                                     config_keys=['likelihood', 'seed', 'temperature', 'label_noise'])
# _, dirichlet1 = get_summary_metrics('deeplearn/data-aug-likelihood/hzn8l7c4',
#                                   config_keys=['likelihood', 'seed', 'noise', 'label_noise'])
# dirichlet1['seed'] = -1
# _, dirichlet23 = get_summary_metrics('deeplearn/data-aug-likelihood/nkoxl3ua',
#                                      config_keys=['likelihood', 'seed', 'noise', 'label_noise'])
# results = pd.concat([softmax, dirichlet1, dirichlet23]).reset_index().drop(columns=['index'])
# results.to_csv('results/c10_label_noise.csv', index=False)

### Tiny Imagenet

In [None]:
# _, softmax = get_summary_metrics('deeplearn/data-aug-likelihood/8pcuvf4s',
#                                     config_keys=['likelihood', 'seed', 'temperature', 'label_noise'])
# _, dirichlet = get_summary_metrics('deeplearn/data-aug-likelihood/z3z7nwgl',
#                                   config_keys=['likelihood', 'seed', 'noise', 'label_noise'])
# results = pd.concat([softmax, dirichlet]).reset_index().drop(columns=['index'])
# results.to_csv('results/ti_label_noise.csv', index=False)

In [None]:
dataset = 'cifar10'
# dataset = 'tiny-imagenet'
if dataset == 'cifar10':
    results = pd.read_csv('results/c10_label_noise.csv')
elif dataset == 'tiny-imagenet':
    results = pd.read_csv('results/ti_label_noise.csv')
else:
    raise NotImplementedError

In [None]:
## Get best noisy Dirichlet run by mean over seeds.
_dirchlet_mean_acc = results[results.likelihood == 'dirichlet'].groupby(['label_noise', 'noise'])[['csgld/test/bma_acc', 'csgld/test/bma_ce_nll']].mean().reset_index()

_best_dirichlet_mean = _dirchlet_mean_acc.iloc[_dirchlet_mean_acc.groupby(['label_noise'])['csgld/test/bma_acc'].idxmax()]
_filter = results[results.likelihood == 'dirichlet'][['label_noise', 'noise']]\
        .apply(tuple, axis=1).isin([tuple(v) for v in _best_dirichlet_mean[['label_noise', 'noise']].values])

best_dirichlet = results[results.likelihood == 'dirichlet'][_filter]

In [None]:
## Get best tempered softmax runs by mean over seeds
_softmax_mean_acc = results[results.likelihood == 'softmax'].groupby(['label_noise', 'temperature'])[['csgld/test/bma_acc', 'csgld/test/bma_ce_nll']].mean().reset_index()

_best_softmax_mean = _softmax_mean_acc.iloc[_softmax_mean_acc.groupby(['label_noise'])['csgld/test/bma_acc'].idxmax()]
_filter = results[results.likelihood == 'softmax'][['label_noise', 'temperature']]\
        .apply(tuple, axis=1).isin([tuple(v) for v in _best_softmax_mean[['label_noise', 'temperature']].values])

best_softmax = results[results.likelihood == 'softmax'][_filter]

## Test Accuracy

In [None]:
# fig, ax = plt.subplots(figsize=(6.7,6.5))

# sns.lineplot(ax=ax, data=results[(results.likelihood == 'softmax') & (results.temperature == 1)],
#              x='label_noise', y='csgld/test/bma_acc',
#              marker='o', markersize=18, linewidth=7, ci='sd', label=r'Softmax ($T = 1$)',
#              color=palette[-4])

# sns.lineplot(ax=ax, data=best_softmax, x='label_noise', y='csgld/test/bma_acc',
#              marker='o', markersize=18, linewidth=7, ci='sd', label='Tempered Softmax',
#              color=palette[2])
# # sns.lineplot(ax=ax, data=_softmax_mean_acc, x='label_noise', y='csgld/test/bma_acc', style='temperature',
# #              legend=False, alpha=.3, color=palette[2])
# # for _, row in _best_softmax_mean.iterrows():
# #     ax.text(row['label_noise'] - .05, row['csgld/test/bma_acc'] - .05, rf"T={row['temperature']:0.0e}",
# #             fontsize=14, bbox=dict(facecolor=palette[2], alpha=.4))

# sns.lineplot(ax=ax, data=best_dirichlet, x='label_noise', y='csgld/test/bma_acc',
#              marker='o', markersize=18, linewidth=7, ci='sd', label='Noisy Dirichlet',
#              color=palette[-3])
# # sns.lineplot(ax=ax, data=_dirchlet_mean_acc, x='label_noise', y='csgld/test/bma_acc', style='noise',
# #              legend=False, alpha=.3, color=palette[-3])
# # for _, row in _best_dirichlet_mean.iterrows():
# #     ax.text(row['label_noise'] - .05, row['csgld/test/bma_acc'] + .03,
# #             rf"$\alpha_\epsilon$={row['noise']:0.0e}", fontsize=14, bbox=dict(facecolor=palette[-3], alpha=.4))

# ax.legend(fontsize=22.5)
# ax.legend().remove()

# ax.set(xlabel='Label Noise', ylabel='BMA Test Accuracy')
# fig.tight_layout()
# fig.savefig('label_noise_acc.pdf', bbox_inches='tight')

## Test NLL

In [None]:
# fig, ax = plt.subplots(figsize=(6.7,6.5))

# sns.lineplot(ax=ax, data=results[(results.likelihood == 'softmax') & (results.temperature == 1)],
#              x='label_noise', y='csgld/test/bma_ce_nll',
#              marker='o', markersize=18, linewidth=7, ci='sd', label=r'Softmax ($T = 1$)',
#              color=palette[-4])

# sns.lineplot(ax=ax, data=best_softmax, x='label_noise', y='csgld/test/bma_ce_nll',
#              marker='o', markersize=18, linewidth=7, ci='sd', label='Tempered Softmax',
#              color=palette[2])
# # sns.lineplot(ax=ax, data=_softmax_mean_acc, x='label_noise', y='csgld/test/bma_ce_nll', style='temperature',
# #              legend=False, alpha=.3, color=palette[1])

# sns.lineplot(ax=ax, data=best_dirichlet, x='label_noise', y='csgld/test/bma_ce_nll',
#              marker='o', markersize=18, linewidth=7, ci='sd', label='Noisy Dirichlet',
#              color=palette[-3])
# # sns.lineplot(ax=ax, data=_dirchlet_mean_acc, x='label_noise', y='csgld/test/bma_ce_nll', style='noise',
# #              legend=False, alpha=.3, color=palette[-3])

# # ax.legend().remove()
# ax.legend(fontsize=22)

# ax.set(xlabel='Label Noise', ylabel='BMA Test NLL', xticks=[0., .25, .5, .75])
# formatter = ticker.ScalarFormatter(useMathText=True)
# formatter.set_scientific(True)
# formatter.set_powerlimits((-1,1))
# ax.yaxis.set_major_formatter(formatter)

# fig.tight_layout()
# # fig.savefig('label_noise_nll.pdf', bbox_inches='tight')

## Combined

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(15,7.5), sharex=True)

sns.lineplot(ax=axes[0], data=results[(results.likelihood == 'softmax') & (results.temperature == 1)],
             x='label_noise', y='csgld/test/bma_acc',
             marker='o', markersize=18, linewidth=7, ci='sd', label=r'Softmax ($T = 1$)',
             color=palette[-4])

sns.lineplot(ax=axes[0], data=best_softmax, x='label_noise', y='csgld/test/bma_acc',
             marker='o', markersize=18, linewidth=7, ci='sd', label='Tempered Softmax',
             color=palette[2])

## Annotation Locations
if dataset == 'cifar10':
    x_all = [.15, .5, .5, .5, .6]
    y_all = [.75, .5, .6, .5, .5]
elif dataset == 'tiny-imagenet':
    x_all = [.15, .5, .25, .5, .6]
    y_all = [.45, .5, .3, .5, .1]

for _i, ((_, row), _x, _y) in enumerate(zip(_best_softmax_mean.iterrows(), x_all, y_all)):
    if _i % 2:
        continue
    axes[0].annotate(r"$T^{*}$=" + f"{row['temperature']:.0e}", xy=(row['label_noise'], row['csgld/test/bma_acc']),  xycoords='data',
                     xytext=(_x, _y), textcoords='data', fontsize=23.5,
                     arrowprops=dict(facecolor=palette[2], shrink=0.1, alpha=.6),
                     horizontalalignment='right', verticalalignment='center')

# for (_, row), dx, dy in zip(_best_softmax_mean.iterrows(), dx_all, dy_all):
#     axes[0].text(row['label_noise'] + dx, row['csgld/test/bma_acc'] + dy, rf"T={row['temperature']:0.0e}",
#             fontsize=18, color='white', bbox=dict(facecolor=palette[2], alpha=.7))

sns.lineplot(ax=axes[0], data=best_dirichlet, x='label_noise', y='csgld/test/bma_acc',
             marker='o', markersize=18, linewidth=7, ci='sd', label='Noisy Dirichlet',
             color=palette[-3])

## Annotation Locations
if dataset == 'cifar10':
    x_all = [.2, .5, .6, .5, .7]
    y_all = [.95, .5, .9, .5, .8]
elif dataset == 'tiny-imagenet':
    x_all = [.2, .3, .4, .5, .6]
    y_all = [.65, .6, .55, .5, .4]

for _i, ((_, row), _x, _y) in enumerate(zip(_best_dirichlet_mean.iterrows(), x_all, y_all)):
    if _i % 2:
        continue
    axes[0].annotate(r"$\alpha_\epsilon^{*}$=" + f"{row['noise']:.0e}", xy=(row['label_noise'], row['csgld/test/bma_acc']),  xycoords='data',
                     xytext=(_x, _y), textcoords='data', fontsize=23.5,
                     arrowprops=dict(facecolor=palette[-3], shrink=0.1, alpha=.6),
                     horizontalalignment='left', verticalalignment='center')

axes[0].legend(fontsize=26, loc='lower left')
axes[0].legend().remove()

axes[0].set(xlabel='Label Noise', ylabel='BMA Test Accuracy', xticks=[0.,.2,.4,.6,.8])

###############################

sns.lineplot(ax=axes[1], data=results[(results.likelihood == 'softmax') & (results.temperature == 1)],
             x='label_noise', y='csgld/test/bma_ce_nll',
             marker='o', markersize=18, linewidth=7, ci='sd', label=r'Softmax ($T = 1$)',
             color=palette[-4])

sns.lineplot(ax=axes[1], data=best_softmax, x='label_noise', y='csgld/test/bma_ce_nll',
             marker='o', markersize=18, linewidth=7, ci='sd', label='Tempered Softmax',
             color=palette[2])
# sns.lineplot(ax=ax, data=_softmax_mean_acc, x='label_noise', y='csgld/test/bma_ce_nll', style='temperature',
#              legend=False, alpha=.3, color=palette[1])

sns.lineplot(ax=axes[1], data=best_dirichlet, x='label_noise', y='csgld/test/bma_ce_nll',
             marker='o', markersize=18, linewidth=7, ci='sd', label='Noisy Dirichlet',
             color=palette[-3])
# sns.lineplot(ax=ax, data=_dirchlet_mean_acc, x='label_noise', y='csgld/test/bma_ce_nll', style='noise',
#              legend=False, alpha=.3, color=palette[-3])

axes[1].legend(fontsize=26, loc='upper left')
# axes[1].legend().remove()

axes[1].set(xlabel='Label Noise', ylabel='BMA Test NLL')
formatter = ticker.ScalarFormatter(useMathText=True)
formatter.set_scientific(True)
formatter.set_powerlimits((-1,1))
axes[1].yaxis.set_major_formatter(formatter)

fig.tight_layout()
fig.savefig('label_noise_acc_nll.pdf', bbox_inches='tight')