In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly_express as px
import numpy as np

# Get everything in order

In [None]:
dataset='mnist'

cifar10_classes = {0:'airplane',
               1:'automobile',
               2:'bird',
               3:'cat',
               4:'deer',
               5:'dog',
               6:'frog',
               7:'horse',
               8:'ship',
               9:'truck',}

mnist_classes = {0:'Digit=0',
               1:'Digit=1',
               2:'Digit=2',
               3:'Digit=3',
               4:'Digit=4',
               5:'Digit=5',
               6:'Digit=6',
               7:'Digit=7',
               8:'Digit=8',
               9:'Digit=9',}
    
    
models_mnist = [
    f'xent_mnist',
    f'ntxent_mnist',
    f'tripent_mnist',
    f'trip_mnist',
    f'trip_sup_mnist',
    f'random_init'
]

models_cifar = [
    f'xent_cifar10',
    f'ntxent_cifar10',
    f'tripent_cifar10',
    f'trip_cifar10',
    f'trip_sup_cifar10',
    f'random_init'
]

models_search = [
    'cross-entropy',
    'nt-xent',
    'triplet-entropy',
    'triplet',
    'triplet-supervised',   
    'random', 
    
]

models_nice_name = [
    'Cross-Entropy',
    'NT-XENT',
    'Triplet-Entropy',
    'Triplet-Loss',
    'Triplet-Supervised',
    'Random'
]



cols = ['#3274a1', '#e1812c', '#3a923a', '#c03d3e', '#9372b2', '#000075']
colors = {m:c for m, c in zip(models_nice_name,cols)}
nice_names_mnist = {m:mn for m, mn in zip(models_mnist,models_nice_name)}
nice_names_cifar = {m:mn for m, mn in zip(models_cifar,models_nice_name)}
nice_names_search = {m:mn for m, mn in zip(models_search,models_nice_name)}

def RMQM(data):
    distances = data['average_distance_increase'] + data['average_distance_increase_previous']
    spikes = 1/data['average_spike'] + 1/data['average_spike_previous']
    
    rmqm = distances + spikes
    
    return np.log(1+rmqm)

In [None]:
distances_hl_eps_mnist_per_label = pd.read_csv(f'../results/distances/results_mnist_noise_p-2_high_level_label_data.csv')
distances_ll_eps_mnist_per_label = pd.read_csv(f'../results/distances/results_mnist_noise_p-2_plot_label_data.csv')
distances_hl_pgd_mnist_per_label = pd.read_csv(f'../results/distances/results_mnist_pgd_p-2_high_level_label_data.csv')
distances_ll_pgd_mnist_per_label = pd.read_csv(f'../results/distances/results_mnist_pgd_p-2_plot_label_data.csv')
distances_hl_eps_mnist = pd.read_csv(f'../results/distances/results_mnist_noise_p-2_high_level_data.csv')
distances_ll_eps_mnist = pd.read_csv(f'../results/distances/results_mnist_noise_p-2_plot_data.csv')
distances_hl_pgd_mnist = pd.read_csv(f'../results/distances/results_mnist_pgd_p-2_high_level_data.csv')
distances_ll_pgd_mnist = pd.read_csv(f'../results/distances/results_mnist_pgd_p-2_plot_data.csv')


results_nn_omniglot = pd.read_csv(f'../results/search/knn_results_mnist_to_omniglot.csv')
results_nn_omniglot_finetuned = pd.read_csv(f'../results/search/knn_results_mnist_to_omniglot_finetuned.csv')
results_nn_fashion = pd.read_csv(f'../results/search/knn_results_mnist_to_fashion.csv')
results_nn_fashion_finetuned = pd.read_csv(f'../results/search/knn_results_mnist_to_fashion_finetuned.csv')

distances_hl_eps_mnist['model'] = distances_hl_eps_mnist.model.apply(lambda x: nice_names_mnist[x])
distances_hl_pgd_mnist['model'] = distances_hl_pgd_mnist.model.apply(lambda x: nice_names_mnist[x])
distances_ll_pgd_mnist['model'] = distances_ll_pgd_mnist.model.apply(lambda x: nice_names_mnist[x])
distances_ll_eps_mnist['model'] = distances_ll_eps_mnist.model.apply(lambda x: nice_names_mnist[x])
distances_hl_eps_mnist_per_label['model'] = distances_hl_eps_mnist_per_label.model.apply(lambda x: nice_names_mnist[x])
distances_hl_pgd_mnist_per_label['model'] = distances_hl_pgd_mnist_per_label.model.apply(lambda x: nice_names_mnist[x])
distances_ll_pgd_mnist_per_label['model'] = distances_ll_pgd_mnist_per_label.model.apply(lambda x: nice_names_mnist[x])
distances_ll_eps_mnist_per_label['model'] = distances_ll_eps_mnist_per_label.model.apply(lambda x: nice_names_mnist[x])
results_nn_omniglot['model'] = results_nn_omniglot.model.apply(lambda x: nice_names_search[x])
results_nn_omniglot_finetuned['model'] = results_nn_omniglot_finetuned.model.apply(lambda x: nice_names_search[x])
results_nn_fashion['model'] = results_nn_fashion.model.apply(lambda x: nice_names_search[x])
results_nn_fashion_finetuned['model'] = results_nn_fashion_finetuned.model.apply(lambda x: nice_names_search[x])

distances_hl_eps_mnist = distances_hl_eps_mnist[distances_hl_eps_mnist.model!='Random']
distances_ll_pgd_mnist = distances_ll_pgd_mnist[distances_ll_pgd_mnist.model!='Random']
distances_hl_pgd_mnist = distances_hl_pgd_mnist[distances_hl_pgd_mnist.model!='Random']
distances_ll_eps_mnist = distances_ll_eps_mnist[distances_ll_eps_mnist.model!='Random']
distances_hl_eps_mnist_per_label = distances_hl_eps_mnist_per_label[distances_hl_eps_mnist_per_label.model!='Random']
distances_ll_pgd_mnist_per_label = distances_ll_pgd_mnist_per_label[distances_ll_pgd_mnist_per_label.model!='Random']
distances_hl_pgd_mnist_per_label = distances_hl_pgd_mnist_per_label[distances_hl_pgd_mnist_per_label.model!='Random']
distances_ll_eps_mnist_per_label = distances_ll_eps_mnist_per_label[distances_ll_eps_mnist_per_label.model!='Random']


results_nn_omniglot = results_nn_omniglot[results_nn_omniglot.model!='Random']
results_nn_omniglot_finetuned = results_nn_omniglot_finetuned[results_nn_omniglot_finetuned.model!='Random']
results_nn_fashion = results_nn_fashion[results_nn_fashion.model!='Random']
results_nn_fashion_finetuned = results_nn_fashion_finetuned[results_nn_fashion_finetuned.model!='Random']

In [None]:
distances_hl_eps_mnist['RMQM'] = distances_hl_eps_mnist.apply(RMQM, axis=1)
distances_hl_pgd_mnist['RMQM'] = distances_hl_pgd_mnist.apply(RMQM, axis=1)
distances_hl_eps_mnist_per_label['RMQM'] = distances_hl_eps_mnist_per_label.apply(RMQM, axis=1)
distances_hl_pgd_mnist_per_label['RMQM'] = distances_hl_pgd_mnist_per_label.apply(RMQM, axis=1)

In [None]:
distances_hl_eps_cifar_per_label = pd.read_csv(f'../results/distances/results_cifar10_noise_p-2_high_level_label_data.csv')
distances_ll_eps_cifar_per_label = pd.read_csv(f'../results/distances/results_cifar10_noise_p-2_plot_label_data.csv')
distances_hl_eps_cifar = pd.read_csv(f'../results/distances/results_cifar10_noise_p-2_high_level_data.csv')
distances_ll_eps_cifar = pd.read_csv(f'../results/distances/results_cifar10_noise_p-2_plot_data.csv')
distances_hl_pgd_cifar_per_label = pd.read_csv(f'../results/distances/results_cifar10_pgd_p-2_high_level_label_data.csv')
distances_ll_pgd_cifar_per_label = pd.read_csv(f'../results/distances/results_cifar10_pgd_p-2_plot_label_data.csv')
distances_hl_pgd_cifar = pd.read_csv(f'../results/distances/results_cifar10_pgd_p-2_high_level_data.csv')
distances_ll_pgd_cifar = pd.read_csv(f'../results/distances/results_cifar10_pgd_p-2_plot_data.csv')

results_nn_caltech_finetuned = pd.read_csv('../results/search/knn_results_cifar10_to_caltech_finetuned.csv')
results_nn_caltech = pd.read_csv('../results/search/knn_results_cifar10_to_caltech.csv')

distances_hl_eps_cifar['model'] = distances_hl_eps_cifar.model.apply(lambda x: nice_names_cifar[x])
distances_ll_eps_cifar['model'] = distances_ll_eps_cifar.model.apply(lambda x: nice_names_cifar[x])
distances_hl_eps_cifar_per_label['model'] = distances_hl_eps_cifar_per_label.model.apply(lambda x: nice_names_cifar[x])
distances_ll_eps_cifar_per_label['model'] = distances_ll_eps_cifar_per_label.model.apply(lambda x: nice_names_cifar[x])
distances_hl_pgd_cifar['model'] = distances_hl_pgd_cifar.model.apply(lambda x: nice_names_cifar[x])
distances_ll_pgd_cifar['model'] = distances_ll_pgd_cifar.model.apply(lambda x: nice_names_cifar[x])
distances_hl_pgd_cifar_per_label['model'] = distances_hl_pgd_cifar_per_label.model.apply(lambda x: nice_names_cifar[x])
distances_ll_pgd_cifar_per_label['model'] = distances_ll_pgd_cifar_per_label.model.apply(lambda x: nice_names_cifar[x])

results_nn_caltech_finetuned['model'] = results_nn_caltech_finetuned.model.apply(lambda x: nice_names_search[x])
results_nn_caltech['model'] = results_nn_caltech.model.apply(lambda x: nice_names_search[x])

distances_hl_eps_cifar['RMQM'] = distances_hl_eps_cifar.apply(RMQM, axis=1)
distances_hl_eps_cifar_per_label['RMQM'] = distances_hl_eps_cifar_per_label.apply(RMQM, axis=1)
distances_hl_eps_cifar = distances_hl_eps_cifar[distances_hl_eps_cifar.model!='Random']
distances_ll_eps_cifar = distances_ll_eps_cifar[distances_ll_eps_cifar.model!='Random']
distances_hl_eps_cifar_per_label = distances_hl_eps_cifar_per_label[distances_hl_eps_cifar_per_label.model!='Random']
distances_ll_eps_cifar_per_label = distances_ll_eps_cifar_per_label[distances_ll_eps_cifar_per_label.model!='Random']
distances_hl_pgd_cifar['RMQM'] = distances_hl_pgd_cifar.apply(RMQM, axis=1)
distances_hl_pgd_cifar_per_label['RMQM'] = distances_hl_pgd_cifar_per_label.apply(RMQM, axis=1)
distances_hl_pgd_cifar = distances_hl_pgd_cifar[distances_hl_pgd_cifar.model!='Random']
distances_ll_pgd_cifar = distances_ll_pgd_cifar[distances_ll_pgd_cifar.model!='Random']
distances_hl_pgd_cifar_per_label = distances_hl_pgd_cifar_per_label[distances_hl_pgd_cifar_per_label.model!='Random']
distances_ll_pgd_cifar_per_label = distances_ll_pgd_cifar_per_label[distances_ll_pgd_cifar_per_label.model!='Random']


results_nn_caltech_finetuned = results_nn_caltech_finetuned[results_nn_caltech_finetuned.model!='Random']
results_nn_caltech = results_nn_caltech[results_nn_caltech.model!='Random']

In [None]:
import matplotlib
matplotlib.rcParams.update({'font.size': 15})

# Plot RMQM

In [None]:
matplotlib.rcParams.update({'font.size': 12})
order = [
    'Triplet-Supervised',
    'Triplet-Entropy',
    'Cross-Entropy',
    'Triplet-Loss',
    'NT-XENT',
]

data = distances_hl_eps_mnist_per_label[distances_hl_eps_mnist_per_label.embedding_dim>3].\
        sort_values(['RMQM'], ascending=True).\
        rename(columns={'optim':'Optimizer', 'embedding_dim':'Representation Dimension', 'model':'Model'}).\
        replace('adam', 'Adam').\
        replace('sgd', 'SGD+Momentum')
data['Dataset'] = 'MNIST'
sns.catplot(
    data=data,
    x='Representation Dimension',
    y='RMQM',
    col='Optimizer',
    hue='Model',
    hue_order=order,
    seed=42,
    kind='bar',
    palette=colors
)
plt.savefig('../reports/paper_figs/rmqm_mnist.png', transparent=True, dpi=100)
plt.show()

data = distances_hl_eps_mnist_per_label[distances_hl_eps_mnist_per_label.embedding_dim>3].\
        sort_values(['RMQM'], ascending=True).\
        rename(columns={'optim':'Optimizer', 'embedding_dim':'Representation Dimension', 'model':'Model'}).\
        replace('adam', 'Adam').\
        replace('sgd', 'SGD+Momentum')
data['Dataset'] = 'MNIST'
data1 = distances_hl_eps_cifar_per_label.sort_values(['RMQM'], ascending=True).\
rename(columns={'optim':'Optimizer', 'embedding_dim':'Representation Dimension', 'model':'Model'}).\
replace('adam', 'Adam').\
replace('sgd', 'SGD+Momentum')
data1['Dataset'] = 'Cifar10'
sns.catplot(
    data=pd.concat([data,data1]),
    x='Representation Dimension',
    y='RMQM',
    col='Optimizer',
    row='Dataset',
    hue='Model',
    hue_order=order,
    seed=42,
    kind='bar',
    palette=colors
)
plt.savefig('../reports/paper_figs/rmqm.png', transparent=True, dpi=100)
plt.show()

data = distances_hl_eps_cifar_per_label.sort_values(['RMQM'], ascending=True).\
rename(columns={'optim':'Optimizer', 'embedding_dim':'Representation Dimension', 'model':'Model'}).\
replace('adam', 'Adam').\
replace('sgd', 'SGD+Momentum')
sns.catplot(
    data=data,
    x='Representation Dimension',
    y='RMQM',
    col='Optimizer',
    hue='Model',
    hue_order=order,
    kind='bar',
    palette=colors
)
plt.savefig('../reports/paper_figs/rmqm_cifar10.png', transparent=True, dpi=100)
plt.show()

In [None]:
matplotlib.rcParams.update({'font.size': 12})
order = [
    'Triplet-Supervised',
    'Triplet-Entropy',
    'Cross-Entropy',
    'Triplet-Loss',
    'NT-XENT',
]

data = distances_hl_pgd_mnist_per_label[distances_hl_pgd_mnist_per_label.embedding_dim>3].\
        sort_values(['RMQM'], ascending=True).\
        rename(columns={'optim':'Optimizer', 'embedding_dim':'Representation Dimension', 'model':'Model'}).\
        replace('adam', 'Adam').\
        replace('sgd', 'SGD+Momentum')
data['Dataset'] = 'MNIST'
sns.catplot(
    data=data,
    x='Representation Dimension',
    y='RMQM',
    col='Optimizer',
    hue='Model',
    hue_order=order,
    seed=42,
    kind='bar',
    palette=colors
)
plt.savefig('../reports/paper_figs/rmqm_mnist_pgd.png', transparent=True, dpi=100)
plt.show()

data = distances_hl_pgd_mnist_per_label[distances_hl_pgd_mnist_per_label.embedding_dim>3].\
        sort_values(['RMQM'], ascending=True).\
        rename(columns={'optim':'Optimizer', 'embedding_dim':'Representation Dimension', 'model':'Model'}).\
        replace('adam', 'Adam').\
        replace('sgd', 'SGD+Momentum')
data['Dataset'] = 'MNIST'
data1 = distances_hl_pgd_cifar_per_label.sort_values(['RMQM'], ascending=True).\
rename(columns={'optim':'Optimizer', 'embedding_dim':'Representation Dimension', 'model':'Model'}).\
replace('adam', 'Adam').\
replace('sgd', 'SGD+Momentum')
data1['Dataset'] = 'Cifar10'
sns.catplot(
    data=pd.concat([data,data1]),
    x='Representation Dimension',
    y='RMQM',
    col='Optimizer',
    row='Dataset',
    hue='Model',
    hue_order=order,
    seed=42,
    kind='bar',
    palette=colors
)
plt.savefig('../reports/paper_figs/rmqm_pgd.png', transparent=True, dpi=100)
plt.show()

data = distances_hl_pgd_cifar_per_label.sort_values(['RMQM'], ascending=True).\
rename(columns={'optim':'Optimizer', 'embedding_dim':'Representation Dimension', 'model':'Model'}).\
replace('adam', 'Adam').\
replace('sgd', 'SGD+Momentum')
sns.catplot(
    data=data,
    x='Representation Dimension',
    y='RMQM',
    col='Optimizer',
    hue='Model',
    hue_order=order,
    kind='bar',
    palette=colors
)
plt.savefig('../reports/paper_figs/rmqm_cifar10_pgd.png', transparent=True, dpi=100)
plt.show()

# Performance vs MQM

In [None]:
data1 = distances_hl_eps_mnist.merge(results_nn_fashion).\
        sort_values(['RMQM'], ascending=True).\
        rename(columns={'optim':'Optimizer','model':'Model', 'accuracy':"Task Performance"}).\
        replace('adam', 'Adam').\
        replace('sgd', 'SGD+Momentum')
data1['Task'] = 'Fashion MNIST'
# data1['Task Performance'] = data1['Task Performance']/data1['Task Performance'].max()

data2 = distances_hl_eps_mnist.merge(results_nn_omniglot).\
        sort_values(['RMQM'], ascending=True).\
        rename(columns={'optim':'Optimizer','model':'Model', 'accuracy':"Task Performance"}).\
        replace('adam', 'Adam').\
        replace('sgd', 'SGD+Momentum')
data2['Task'] = 'Omniglot'
# data2['Task Performance'] = data2['Task Performance']/data2['Task Performance'].max()

data3 = distances_hl_eps_cifar.merge(results_nn_caltech).\
        sort_values(['RMQM'], ascending=True).\
        rename(columns={'optim':'Optimizer','model':'Model', 'accuracy':"Task Performance"}).\
        replace('adam', 'Adam').\
        replace('sgd', 'SGD+Momentum')
data3['Task'] = 'Caltech'
# data3['Task Performance'] = data3['Task Performance']/data3['Task Performance'].max()

data = pd.concat([data2, data3])
plt.figure(figsize=(10,10))
for task, color in zip(data.Task.unique(), cols[-4:]):
    sns.regplot(data[(data.Task==task)&(data.embedding_dim>3)]['Task Performance'],
                data[(data.Task==task)&(data.embedding_dim>3)].RMQM,
                label=task,
                color=color,
                marker='o'
               )
plt.legend(title='Downstream Task')
plt.xlabel('Task Accuracy')
plt.savefig('../reports/paper_figs/rmqm_vs_accuracy_on_tasks.png', transparent=True, dpi=100)
plt.show()

data['Task Performance'] = data.groupby('Task')['Task Performance'].transform(lambda x: x / x.max())
plt.figure(figsize=(10,10))
for task, color in zip(data.Task.unique(), cols[-4:]):
    sns.regplot(data[(data.Task==task)&(data.embedding_dim>3)]['Task Performance'],
                data[(data.Task==task)&(data.embedding_dim>3)].RMQM,
                label=task,
                color=color,
                marker='o'
               )
plt.legend(title='Downstream Task')
plt.savefig('../reports/paper_figs/rmqm_vs_accuracy_on_tasks_scaled.png', transparent=True, dpi=100)
plt.show()

In [None]:
correlations = data[data.embedding_dim>3].\
rename(columns={'embedding_dim':'Embedding Dimension'})[[
    'Task Performance', 'RMQM','Embedding Dimension']].reset_index(drop=True).corr()
sns.heatmap(correlations, annot=True)
plt.savefig('../reports/paper_figs/correlation_plot_caltech_omniglot.png', transparent=True, dpi=100)

In [None]:
data1 = distances_hl_pgd_mnist.merge(results_nn_fashion).\
        sort_values(['RMQM'], ascending=True).\
        rename(columns={'optim':'Optimizer','model':'Model', 'accuracy':"Task Performance"}).\
        replace('adam', 'Adam').\
        replace('sgd', 'SGD+Momentum')
data1['Task'] = 'Fashion MNIST'
# data1['Task Performance'] = data1['Task Performance']/data1['Task Performance'].max()

data2 = distances_hl_pgd_mnist.merge(results_nn_omniglot).\
        sort_values(['RMQM'], ascending=True).\
        rename(columns={'optim':'Optimizer','model':'Model', 'accuracy':"Task Performance"}).\
        replace('adam', 'Adam').\
        replace('sgd', 'SGD+Momentum')
data2['Task'] = 'Omniglot'
# data2['Task Performance'] = data2['Task Performance']/data2['Task Performance'].max()

data3 = distances_hl_pgd_cifar.merge(results_nn_caltech).\
        sort_values(['RMQM'], ascending=True).\
        rename(columns={'optim':'Optimizer','model':'Model', 'accuracy':"Task Performance"}).\
        replace('adam', 'Adam').\
        replace('sgd', 'SGD+Momentum')
data3['Task'] = 'Caltech'
# data3['Task Performance'] = data3['Task Performance']/data3['Task Performance'].max()

data = pd.concat([data2, data3])
plt.figure(figsize=(10,10))
for task, color in zip(data.Task.unique(), cols[-4:]):
    sns.regplot(data[(data.Task==task)&(data.embedding_dim>3)]['Task Performance'],
                data[(data.Task==task)&(data.embedding_dim>3)].RMQM,
                label=task,
                color=color,
                marker='o'
               )
plt.legend(title='Downstream Task')
plt.xlabel('Task Accuracy')
plt.savefig('../reports/paper_figs/rmqm_vs_accuracy_on_tasks_pgd.png', transparent=True, dpi=100)
plt.show()

data['Task Performance'] = data.groupby('Task')['Task Performance'].transform(lambda x: x / x.max())
plt.figure(figsize=(10,10))
for task, color in zip(data.Task.unique(), cols[-4:]):
    sns.regplot(data[(data.Task==task)&(data.embedding_dim>3)]['Task Performance'],
                data[(data.Task==task)&(data.embedding_dim>3)].RMQM,
                label=task,
                color=color,
                marker='o'
               )
plt.legend(title='Downstream Task')
plt.savefig('../reports/paper_figs/rmqm_vs_accuracy_on_tasks_scaled_pgd.png', transparent=True, dpi=100)
plt.show()

In [None]:
correlations = data[data.embedding_dim>3].\
rename(columns={'embedding_dim':'Embedding Dimension'})[[
    'Task Performance', 'RMQM','Embedding Dimension']].reset_index(drop=True).corr()
sns.heatmap(correlations, annot=True)
plt.savefig('../reports/paper_figs/correlation_plot_caltech_omniglot_pgd.png', transparent=True, dpi=100)

# Scores over Time

In [None]:
distances_ll_eps_cifar_per_label.label = distances_ll_eps_cifar_per_label.label.apply(lambda x: cifar10_classes[x])
distances_ll_eps_mnist_per_label.label = distances_ll_eps_mnist_per_label.label.apply(lambda x: cifar10_classes[x])

distances_ll_pgd_cifar_per_label.label = distances_ll_pgd_cifar_per_label.label.apply(lambda x: cifar10_classes[x])
distances_ll_pgd_mnist_per_label.label = distances_ll_pgd_mnist_per_label.label.apply(lambda x: cifar10_classes[x])

## Percantage change from original point

### CIFAR10

In [None]:
embedding_dim=None
optim='adam'

for optim in ['adam']:
    plt.figure(figsize=(20,20))
    for i,embedding_dim in enumerate([16,32,64,128,256,512]):
        plt.rcParams.update({'font.size': 15})
        plt.subplot(2,3,i+1)
        for model in distances_ll_eps_cifar.model.unique():
            data = distances_ll_eps_cifar[
                (distances_ll_eps_cifar.model==model)&
                (distances_ll_eps_cifar.optim==optim)&
                (distances_ll_eps_cifar.embedding_dim==embedding_dim)
            ]
            plt.plot(
                np.linspace(0,1,100)[2:],
                data.original_pct_change_mean,
                '.-',
                label=model,
                c=colors[model])
            plt.fill_between(
                np.linspace(0,1,100)[2:],
                (data.original_pct_change_mean-data.original_pct_change_mean/np.sqrt(len(data))),
                (data.original_pct_change_mean+data.original_pct_change_mean/np.sqrt(len(data))),
                color=colors[model],
                alpha=.3
            )

        plt.title(f'Embedding Dimension={embedding_dim}')
        if i==2:
            plt.legend(prop={'size': 15})

        if i==0 or i==3:
            plt.ylabel('Average Percantage Change In Distance To Original Point')
        if i>=3:
            plt.xlabel('Epsilon')
            
        plt.ylim(0,4)

        plt.savefig(
            f'../reports/paper_figs/cifar_average_percantage_change_dim_{optim}.png',
            transparent=True, 
            dpi=100
        )
    plt.show()

In [None]:
embedding_dim=None
optim='adam'

for optim in ['adam']:
    plt.figure(figsize=(20,20))
    for i,embedding_dim in enumerate([16,32,64,128,256,512]):
        plt.rcParams.update({'font.size': 15})
        plt.subplot(2,3,i+1)
        for model in distances_ll_pgd_cifar.model.unique():
            data = distances_ll_pgd_cifar[
                (distances_ll_pgd_cifar.model==model)&
                (distances_ll_pgd_cifar.optim==optim)&
                (distances_ll_pgd_cifar.embedding_dim==embedding_dim)
            ]
            plt.plot(
                range(2,30),
                data.original_pct_change_mean,
                '.-',
                label=model,
                c=colors[model])
            plt.fill_between(
                range(2,30),
                (data.original_pct_change_mean-data.original_pct_change_std_error/np.sqrt(len(data))),
                (data.original_pct_change_mean+data.original_pct_change_std_error/np.sqrt(len(data))),
                color=colors[model],
                alpha=.3
            )

        plt.title(f'Embedding Dimension={embedding_dim}')
        if i==2:
            plt.legend(prop={'size': 15})

        if i==0 or i==3:
            plt.ylabel('Average Percantage Change In Distance To Original Point')
        if i>=3:
            plt.xlabel('Number of PGD Iterations')
        
        plt.ylim(-2,8)
        
    plt.savefig(
        f'../reports/paper_figs/cifar_average_percantage_change_dim_{optim}_pgd.png',
        transparent=True, 
        dpi=100
    )
    plt.show()

### MNIST

In [None]:
embedding_dim=None
optim='adam'

for optim in ['adam', 'sgd']:
    plt.figure(figsize=(20,20))
    for i,embedding_dim in enumerate([16,32,64,128,256,512]):
        plt.rcParams.update({'font.size': 15})
        plt.subplot(2,3,i+1)
        for model in distances_ll_eps_mnist.model.unique():
            data = distances_ll_eps_mnist[
                (distances_ll_eps_mnist.model==model)&
                (distances_ll_eps_mnist.optim==optim)&
                (distances_ll_eps_mnist.embedding_dim==embedding_dim)
            ]
            plt.plot(
                np.linspace(0,1,100)[2:],
                data.original_pct_change_mean,
                '.-',
                label=model,
                c=colors[model])
            plt.fill_between(
                np.linspace(0,1,100)[2:],
                (data.original_pct_change_mean-data.original_pct_change_std_error/np.sqrt(len(data))),
                (data.original_pct_change_mean+data.original_pct_change_std_error/np.sqrt(len(data))),
                color=colors[model],
                alpha=.3
            )

        plt.title(f'Embedding Dimension={embedding_dim}')
        if i==2:
            plt.legend(prop={'size': 15})

        if i==0 or i==3:
            plt.ylabel('Average Percantage Change In Distance To Original Point')
        if i>=3:
            plt.xlabel('Epsilon')
            
        if optim=='adam':
            plt.ylim(0,1.25)
        else:
            plt.ylim(0,0.6)

        plt.savefig(
            f'../reports/paper_figs/mnist_average_percantage_change_dim_{optim}.png',
            transparent=True, 
            dpi=100
        )
    plt.show()

In [None]:
embedding_dim=None
optim='adam'

for optim in ['adam', 'sgd']:
    plt.figure(figsize=(20,20))
    for i,embedding_dim in enumerate([16,32,64,128,256,512]):
        plt.rcParams.update({'font.size': 15})
        plt.subplot(2,3,i+1)
        for model in distances_ll_pgd_mnist.model.unique():
            data = distances_ll_pgd_mnist[
                (distances_ll_pgd_mnist.model==model)&
                (distances_ll_pgd_mnist.optim==optim)&
                (distances_ll_pgd_mnist.embedding_dim==embedding_dim)
            ]
            plt.plot(
                range(2,30),
                data.original_pct_change_mean,
                '.-',
                label=model,
                c=colors[model])
            plt.fill_between(
                range(2,30),
                (data.original_pct_change_mean-data.original_pct_change_mean/np.sqrt(len(data))),
                (data.original_pct_change_mean+data.original_pct_change_mean/np.sqrt(len(data))),
                color=colors[model],
                alpha=.3
            )

        plt.title(f'Embedding Dimension={embedding_dim}')
        if i==2:
            plt.legend(prop={'size': 15})

        if i==0 or i==3:
            plt.ylabel('Average Percantage Change In Distance To Original Point')
        if i>=3:
            plt.xlabel('Number of PGD Iterations')
            
        plt.ylim(-0.05,2.5)

        plt.savefig(
            f'../reports/paper_figs/mnist_average_percantage_change_dim_{optim}_pgd.png',
            transparent=True, 
            dpi=100
        )
    plt.show()

## Total walking distance

### MNIST

In [None]:
embedding_dim=None
optim='adam'

for optim in ['adam', 'sgd']:
    plt.figure(figsize=(20,20))
    for i,embedding_dim in enumerate([16,32,64,128,256,512]):
        plt.rcParams.update({'font.size': 15})
        plt.subplot(2,3,i+1)
        for model in distances_ll_eps_mnist.model.unique():
            data = distances_ll_eps_mnist[
                (distances_ll_eps_mnist.model==model)&
                (distances_ll_eps_mnist.optim==optim)&
                (distances_ll_eps_mnist.embedding_dim==embedding_dim)
            ]
            plt.plot(
                np.linspace(0,1,100)[2:],
                data.original_mean,
                '.-',
                label=model,
                c=colors[model])
            plt.fill_between(
                np.linspace(0,1,100)[2:],
                (data.original_mean-data.original_std_error/np.sqrt(len(data))),
                (data.original_mean+data.original_std_error/np.sqrt(len(data))),
                color=colors[model],
                alpha=.3
            )

        plt.title(f'Embedding Dimension={embedding_dim}')
        if i==2:
            plt.legend(prop={'size': 15})

        if i==0 or i==3:
            plt.ylabel('Average Distance to Original Point')
        if i>=3:
            plt.xlabel('Epsilon')

        plt.savefig(
            f'../reports/paper_figs/mnist_original_distance_dim_{optim}.png',
            transparent=True, 
            dpi=100
        )
    plt.show()

In [None]:
for optim in ['adam', 'sgd']:
    plt.figure(figsize=(20,20))
    for i,embedding_dim in enumerate([16,32,64,128,256,512]):
        plt.rcParams.update({'font.size': 15})
        plt.subplot(2,3,i+1)
        for model in distances_ll_pgd_mnist.model.unique():
            data = distances_ll_pgd_mnist[
                (distances_ll_pgd_mnist.model==model)&
                (distances_ll_pgd_mnist.optim==optim)&
                (distances_ll_pgd_mnist.embedding_dim==embedding_dim)
            ]
            plt.plot(
                range(2,30),
                data.original_mean,
                '.-',
                label=model,
                c=colors[model])
            plt.fill_between(
                range(2,30),
                (data.original_mean-data.original_std_error/np.sqrt(len(data))),
                (data.original_mean+data.original_std_error/np.sqrt(len(data))),
                color=colors[model],
                alpha=.3
            )

        plt.title(f'Embedding Dimension={embedding_dim}')
        if i==2:
            plt.legend(prop={'size': 15})

        if i==0 or i==3:
            plt.ylabel('Average Distance to Original Point')
        if i>=3:
            plt.xlabel('Number of PGD Iterations')

        plt.savefig(
            f'../reports/paper_figs/mnist_original_distance_dim_{optim}_pgd.png',
            transparent=True, 
            dpi=100
        )
    plt.show()

### CIFAR

In [None]:
embedding_dim=None
optim='adam'

for optim in ['adam']:
    plt.figure(figsize=(20,20))
    for i,embedding_dim in enumerate([16,32,64,128,256,512]):
        plt.rcParams.update({'font.size': 15})
        plt.subplot(2,3,i+1)
        for model in distances_ll_eps_cifar.model.unique():
            data = distances_ll_eps_cifar[
                (distances_ll_eps_cifar.model==model)&
                (distances_ll_eps_cifar.optim==optim)&
                (distances_ll_eps_cifar.embedding_dim==embedding_dim)
            ]
            plt.plot(
                np.linspace(0,1,100)[2:],
                data.original_mean,
                '.-',
                label=model,
                c=colors[model])
            plt.fill_between(
                np.linspace(0,1,100)[2:],
                (data.original_mean-data.original_std_error/np.sqrt(len(data))),
                (data.original_mean+data.original_std_error/np.sqrt(len(data))),
                color=colors[model],
                alpha=.3
            )

        plt.title(f'Embedding Dimension={embedding_dim}')
        if i==2:
            plt.legend(prop={'size': 15})

        if i==0 or i==3:
            plt.ylabel('Average Distance to Original Point')
        if i>=3:
            plt.xlabel('Epsilon')

        plt.savefig(
            f'../reports/paper_figs/cifar_original_distance_dim_{optim}.png',
            transparent=True, 
            dpi=100
        )
    plt.show()



In [None]:
for optim in ['adam']:
    plt.figure(figsize=(20,20))
    for i,embedding_dim in enumerate([16,32,64,128,256,512]):
        plt.rcParams.update({'font.size': 15})
        plt.subplot(2,3,i+1)
        for model in distances_ll_pgd_cifar.model.unique():
            data = distances_ll_pgd_cifar[
                (distances_ll_pgd_cifar.model==model)&
                (distances_ll_pgd_cifar.optim==optim)&
                (distances_ll_pgd_cifar.embedding_dim==embedding_dim)
            ]
            plt.plot(
                range(2,30),
                data.original_mean,
                '.-',
                label=model,
                c=colors[model])
            plt.fill_between(
                range(2,30),
                (data.original_mean-data.original_std_error/np.sqrt(len(data))),
                (data.original_mean+data.original_std_error/np.sqrt(len(data))),
                color=colors[model],
                alpha=.3
            )

        plt.title(f'Embedding Dimension={embedding_dim}')
        if i==2:
            plt.legend(prop={'size': 15})

        if i==0 or i==3:
            plt.ylabel('Average Distance to Original Point')
        if i>=3:
            plt.xlabel('Number of PGD Iterations')

        plt.savefig(
            f'../reports/paper_figs/cifar_original_distance_dim_{optim}_pgd.png',
            transparent=True, 
            dpi=100
        )
    plt.show()