In [None]:
from matplotlib import pyplot as plt
from matplotlib import rc
from matplotlib.colors import LogNorm
import numpy as np
import pandas as pd
import seaborn as sns

import distutils.spawn
import glob
import os
import pickle

In [None]:
sns.set(font_scale=1.5)

if distutils.spawn.find_executable('latex'):
    rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
    rc('text', usetex=True)

In [None]:
exp_name = 'perf-weights-regularity-2'

In [None]:
results = {'accuracy': [], 'regularity': [], 'lr': [], 'scaling': []}
for directory in glob.glob(os.path.join('results', exp_name, '*')):
    with open(os.path.join(directory, 'config.pkl'), 'rb') as f:
        config = pickle.load(f)
        results['regularity'].append(config['model-config']['regularity']['value'])
        results['lr'].append(config['model-config']['lr'])
        results['scaling'].append(config['model-config']['scaling_beta'])
    with open(os.path.join(directory, 'metrics.pkl'), 'rb') as f:
        metrics = pickle.load(f)
        results['accuracy'].append(metrics['test_accuracy'])

In [None]:
df = pd.DataFrame.from_dict(results)

In [None]:
df.head()

In [None]:
df2 = pd.pivot_table(df, index='scaling', columns='regularity', values='accuracy', aggfunc=np.max)

In [None]:
df2.index = np.round(df2.index, 2)
df2.columns = np.round(df2.columns, 2)

In [None]:
sns.heatmap(df2[::-1], vmin=0, vmax=1, center=0.5, xticklabels=1, yticklabels=1, square=True)
#plt.savefig('figures/perf-mnist-reg-scaling.pdf', bbox_inches='tight')
plt.show()

In [None]:
df3 = pd.pivot_table(df[df['lr']==0.0001], index='scaling', columns='regularity', values='accuracy', aggfunc=np.max)

In [None]:
df3.head()

In [None]:
df3.index = np.round(df2.index, 2)
df3.columns = np.round(df2.columns, 2)

In [None]:
sns.heatmap(df3[::-1], vmin=0.8, vmax=1, center=0.9, xticklabels=1, yticklabels=1, square=True)
#plt.savefig('figures/perf-mnist-reg-scaling-fixed-lr.pdf', bbox_inches='tight')
plt.show()