In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from palettable.cmocean.sequential import Thermal_8

sns.set(font_scale=2, style='whitegrid')
palette = sns.color_palette(Thermal_8.mpl_colors)
palette

In [None]:
results = pd.concat([
    pd.read_csv('results/cifar10_softmax.csv'),
    pd.read_csv('results/cifar10_noisy_dirichlet.csv'),
    pd.read_csv('results/cifar10_softmax_noaug.csv'),
    pd.read_csv('results/cifar10_noisy_dirichlet_noaug.csv'),
]).reset_index().drop(columns=['index'])
results = results[results.label_noise == 0]
# results

## Train NLL

### Softmax

In [None]:
fig, ax = plt.subplots(figsize=(7,5.5))

sns.lineplot(ax=ax, data=results[results.likelihood == 'softmax'],
             x='temperature', y='train/ce_nll', hue='augment', ci='sd',
             palette=sns.color_palette([palette[2], palette[-3]]),
             marker='o', linewidth=6, markersize=14)

ax.set(xscale='log', yscale='log', xlabel=r'$T$', ylabel='BMA Train NLL')
handles, labels = ax.get_legend_handles_labels()
for h in handles:
    h.set(linewidth=6)
labels = ['No Aug.', 'Aug.']
ax.legend(handles=handles, labels=labels, loc='lower right')
# ax.legend().remove()

fig.tight_layout()
# fig.savefig('aug_noaug_train_lik_softmax.pdf', bbox_inches='tight')

### Dirichlet

In [None]:
fig, ax = plt.subplots(figsize=(7,5))

sns.lineplot(ax=ax, data=results[results.likelihood == 'dirichlet'],
             x='noise', y='train/ce_nll', hue='augment', ci='sd',
             palette=sns.color_palette([palette[2], palette[-3]]),
             marker='o', linewidth=6, markersize=14)

ax.set(xscale='log', yscale='log', xlabel=r'$T$', ylabel='BMA Train NLL')
handles, labels = ax.get_legend_handles_labels()
for h in handles:
    h.set(linewidth=6)
labels = ['No Aug.', 'Aug.']
ax.legend(handles=handles, labels=labels)
# ax.legend().remove()

fig.tight_layout()
# fig.savefig('aug_noaug_train_lik_dirichlet.pdf', bbox_inches='tight')

## Combined

In [None]:
fig, axes = plt.subplots(ncols=2, sharex=True, figsize=(13, 5))

sns.lineplot(ax=axes[0], data=results[results.likelihood == 'softmax'],
             x='temperature', y='train/ce_nll', hue='augment', ci='sd', legend=False,
             palette=sns.color_palette([palette[2], palette[-3]]),
             marker='o', linewidth=6, markersize=14)

sns.lineplot(ax=axes[1], data=results[results.likelihood == 'dirichlet'],
             x='noise', y='train/ce_nll', hue='augment', ci='sd',
             palette=sns.color_palette([palette[2], palette[-3]]),
             marker='o', linewidth=6, markersize=14)

axes[0].set(xscale='log', yscale='log', xlabel=r'$T$', ylabel='BMA Train NLL')

axes[1].set(xscale='log', yscale='log', xlabel=r'$\alpha_\epsilon$', ylabel='')
handles, labels = axes[1].get_legend_handles_labels()
for h in handles:
    h.set(linewidth=6)
labels = ['No Aug.', 'Aug.']
axes[1].legend(handles=handles, labels=labels, loc='lower right')

# sns.move_legend(axes[1], bbox_to_anchor=(-1, -.25, 1, 0),
#                 loc='lower center', ncol=2, borderaxespad=0., frameon=True, title='')

fig.tight_layout()
# fig.savefig('aug_noaug_train_lik.pdf', bbox_inches='tight')