## Runtimes Experiment - Confidence Intervals

This notebook is nearly identical to `runtimes.ipynb`, except the data it consumes was run for 100 epochs and this notebook will plot confidence intervals along with the bars.

In [None]:
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
pd.set_option('display.max_rows', 200)

In [None]:
names    = ['JAX',   'TensorFlow 2',   'TensorFlow 1', 'PyTorch', 'JAX', 'Custom TFP', 'TFP', 'Opacus', 'BackPACK', 'PyVacy', 'CRB', 'TensorFlow 2 (XLA)', 'Custom TFP (XLA)', 'TensorFlow 1 (XLA)', 'TFP (XLA)',]
private  = [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1]
filenames= ['jaxdp', 'tf2dp', 'tf1dp', 'pytorch', 'jaxdp', 'tf2dp', 'tf1dp', 'opacusdp', 'backpackdp', 'pyvacydp', 'owkindp', 'tf2dp', 'tf2dp', 'tf1dp', 'tf1dp',]
expts = ['logreg', 'ffnn', 'mnist', 'embed', 'lstm']
batch_sizes = [16, 32, 64, 128, 256]

In [None]:
len(names), len(private), len(filenames)

In [None]:
def expt_iterator():
    for expt in expts:
        for bs in batch_sizes:
            for dpsgd, name, filename in zip(private, names, filenames):
                yield expt, bs, name, filename, bool(dpsgd)

In [None]:
files = []
success, errors = 0, 0
for expt, bs, name, filename, dpsgd in expt_iterator():
    pickle_name = f'./raw/{filename}_{expt}_bs_{bs}_priv_{dpsgd}'
    
    use_xla = 'xla' in name.lower() or name.lower().startswith('jax')
    if filename.startswith('tf'):
        pickle_name += f'_xla_{use_xla}'
    
    try:
        with open(pickle_name+'.pkl', 'rb') as f:
            d = pickle.load(f)
            success += 1
    except:
        print(f'Failed to load {pickle_name}.pkl')
        d = None
        errors += 1
    files.append((filename, name, expt, bs, dpsgd, use_xla, d))

In [None]:
success, errors

In [None]:
# df_list = []
# for *row, d in files:
#     # d = [np.median(d['timings'])] if d else [0.]
#     d = [np.mean(d['timings'][1:])] if d else [0.]
#     df_list.append(pd.Series(row + d))

# df = pd.concat(df_list, axis=1).transpose()
# df.columns = ['Filename', 'Library', 'Experiment', 'Batch Size', 'Private?', 'XLA', 'Runtime']
# df['Runtime'] = df['Runtime'].astype(float)
# old_df = df.copy()

In [None]:
df_list = []
for *row, d in files:
    if d:
        assert len(d['timings']) == 102
        for timing in d['timings'][1:]:
            df_list.append(pd.Series(row + [timing]))
    else:
        df_list.append(pd.Series(row + [0.]))
    

df = pd.concat(df_list, axis=1).transpose()
df.columns = ['Filename', 'Library', 'Experiment', 'Batch Size', 'Private?', 'XLA', 'Runtime']
df['Runtime'] = df['Runtime'].astype(float)

In [None]:
df['Order'] = -1
for i,name in enumerate(['JAX', 'Custom TFP (XLA)', 'Custom TFP', 'TFP (XLA)', 'TFP', 
                         'Opacus', 'BackPACK', 'CRB', 'PyVacy', 
                         'TensorFlow 2', 'TensorFlow 2 (XLA)', 'TensorFlow 1', 'TensorFlow 1 (XLA)', 'PyTorch']):
    df.loc[df['Library'] == name, 'Order'] = i
assert not (df['Order'] == -1).sum()
df = df.sort_values(by=['Batch Size', 'Order'])

In [None]:
df.head()

In [None]:
means = df.groupby(['Filename', 'Library', 'Experiment', 'Batch Size', 'Private?', 'XLA', 'Order']).agg('mean').reset_index()
means.columns = ['Filename', 'Library', 'Experiment', 'Batch Size', 'Private?', 'XLA', 'Order', 'Runtime']
means = means.sort_values(by=['Batch Size', 'Order'])

In [None]:
expt_to_title = {
    'mnist': 'Convolutional Neural Network (CNN)',
    'lstm': 'LSTM Network',
    'embed': 'Embedding Network',
    'ffnn': 'Fully Connected Neural Network (FCNN)',
    'logreg': 'Logistic Regression',
}

def get_runtime_plot(expt, ylim=None, figsize=(13, 6)):
    f, ax = plt.subplots(2, 1, figsize=figsize, sharey=True)
    plot_df = df[df['Experiment'] == expt].copy()
    if ylim:
        plot_df['Runtime'] = np.minimum(plot_df['Runtime'], ylim-2)

    sns.barplot(x='Library', y='Runtime', hue='Batch Size', ci='sd',
                data=plot_df[plot_df['Private?']], ax=ax[0], palette='muted')
    sns.barplot(x='Library', y='Runtime', hue='Batch Size', ci='sd',
                data=plot_df[plot_df['Private?'] != True], ax=ax[1], palette='muted')

    for ax_ind, private in enumerate([True, False]):
        tmp = means.loc[(means['Experiment'] == expt) & (means['Private?'] == private), 'Runtime']
        for i, (rect, tim) in enumerate(zip(ax[ax_ind].patches, tmp)):
            height = rect.get_height()
            if tim > 100.:
                annotation = f'{int(tim)}'
            elif tim > 0.:
                annotation = f'{tim:.2g}'
            else:
                annotation = ''
            ax[ax_ind].annotate(annotation,
                                xy=(rect.get_x() + rect.get_width() / 2 - 0.3*rect.get_width(), height),
                                xytext=(0, 3),  # 3 points vertical offset
                                textcoords="offset points",
                                va='bottom', ha='left', 
                                fontsize=9, rotation=45)



    plt.title('')
    if expt == 'lstm':
        y = 1.18
    # elif expt == 'embed':
    #     y = 1.1
    else:
        y = 1
    ax[0].set_title('Mean Runtime for One Private Epoch - '+ expt_to_title[expt], 
                    y=y)
    ax[1].set_title('Mean Runtime for One Non-Private Epoch - '+ expt_to_title[expt])
    ax[0].set_xlabel('Library')
    ax[1].set_xlabel('Library')
    ax[0].set_ylabel('Runtime (sec)')
    ax[1].set_ylabel('Runtime (sec)')
    if ylim:
        ax[0].set_ylim(0, ylim)
        ax[1].set_ylim(0, ylim)
    # ax[1].set_ylabel('')
    ax[0].get_legend().remove()
    ax[1].get_legend().remove()
    sns.despine()
    plt.legend()
    f.patch.set_facecolor('white')
    f.tight_layout()
    return f, ax

In [None]:
# table with one batch size

# Ccheck x-axis
f, ax = get_runtime_plot('logreg', ylim=20, figsize=(11, 5))
None

In [None]:
# f.savefig('../../mlsys/assets/logistic_runtimes.pdf')

In [None]:
f, ax = get_runtime_plot('ffnn', 20, figsize=(11, 5))
None

In [None]:
# f.savefig('../../mlsys/assets/ffnn_runtimes.pdf')

In [None]:
f, ax = get_runtime_plot('mnist', 50, figsize=(11, 5))

In [None]:
# f.savefig('../../mlsys/assets/cnn_runtimes.pdf')

In [None]:
f, ax = get_runtime_plot('embed', 20, figsize=(11, 5))
None

In [None]:
# f.savefig('../../mlsys/assets/embed_runtimes.pdf')

In [None]:
f, ax = get_runtime_plot('lstm', 250, figsize=(11, 5))
None

In [None]:
# f.savefig('../../mlsys/assets/lstm_runtimes.pdf')

In [None]:
df.to_csv('arxiv_paper_data.csv')