In [None]:
import matplotlib.pyplot as plt
from matplotlib import ticker
import seaborn as sns

import pandas as pd
import dill

sns.set(font_scale=2.5, style='whitegrid')

folder = '../experiments/transfer_learning/shuffled'
# read pandas dataframe
with open(f'{folder}/results.df', 'rb') as f:
    df = dill.load(f)

df.loc[df['expt'].str.contains('cnn'),'model']='CNN'
df.loc[df['expt'].str.contains('mlp'),'model']='MLP'

mlp_shuffled_pixels = df[df['expt']=='mlp'].copy()
mlp_shuffled_pixels['expt'] = 'mlp_shuffled_pixels'
# append mlp_shuffled to dataframe but with 'expt'='mlp_shuffled_pixels' since equivalent
df = pd.concat([mlp_shuffled_pixels,df]).reset_index().drop(columns=['index'])

In [None]:
# palette = [sns.cubehelix_palette(as_cmap=False)[i] for i in range(6)]
# palette = [palette[1], palette[-2]]
from palettable.cartocolors.diverging import Temps_5
palette = Temps_5.mpl_colors
palette = [palette[1], palette[-1]]

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(17, 5), sharey=True)
fig.tight_layout(h_pad=100)

expt_name = ''

for i, (ax, expt_name, title) in enumerate(zip(axes,
                                        ['', '_shuffled_pixels', '_shuffled_labels'],
                                        ['Standard', 'Shuffled Pixels', 'Shuffled Labels'])):
    dfa = df.loc[df['expt'].str[3:]==expt_name]

    sns.lineplot(data=dfa, ax=ax, x='d', y='raw_err_bound_100',
                hue='model', legend=False, alpha=.9, lw=7,
                palette=palette,
                hue_order=['CNN', 'MLP'])
    sns.scatterplot(data=dfa, ax=ax, x='d', y='raw_err_bound_100', hue='model',
                    alpha=.9, marker='o', s=300, legend=(i==2),
                    palette=palette,
                    hue_order=['CNN', 'MLP'])

    if i == 2:
        handles, labels = ax.get_legend_handles_labels()
        for h in handles:
            h.set_sizes([300])
            h.set_alpha(.9)
        ax.legend(handles=handles, labels=labels, title='', loc='lower right')

    formatter = ticker.ScalarFormatter(useMathText=True)
    formatter.set_scientific(True)
    formatter.set_powerlimits((-1,1))
    ax.xaxis.set_major_formatter(formatter)

    ax.set(xlabel='', ylabel='', xticks=[0,1e4,2e4,3e4])
    ax.set_title(title, pad=20)
    if i == 0:
        ax.set(ylabel=r'Err. Bound ($\%$)', yticks=[60,70,80,90,100])

fig.text(0.5, -0.05, r'Subspace Dimension $d$', va='center', ha='center')
fig.show()
fig.savefig('shuffled.pdf', bbox_inches='tight')