## Ablation Study

This collects and shows the data from the ablation study on vectorization and compilation through JAX and Custom TFP.

In [None]:
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
pd.set_option('display.max_rows', 200)

## JAX

In [None]:
expts = ['logreg', 'ffnn', 'mnist', 'embed', 'cifar10', 'lstm']

In [None]:
expt_dict = {}
for expt in expts:
    expt_dict[expt] = {}
    pickle_name = f'./raw/jaxdp_{expt}_bs_128_priv_True'
    try:
        with open(pickle_name+'.pkl', 'rb') as f:
            expt_dict[expt]['base'] = np.median(pickle.load(f)['timings'])
    except: print(f'Failed {expt} xla')
    try:
        with open(pickle_name+'_novmap.pkl', 'rb') as f:
            expt_dict[expt]['nv'] = np.median(pickle.load(f)['timings'])
    except: print(f'Failed {expt} xla')
    try:
        with open(pickle_name+'_nojit.pkl', 'rb') as f:
            expt_dict[expt]['nj'] = np.median(pickle.load(f)['timings'])
    except: print(f'Failed {expt} xla')
    try:
        with open(pickle_name+'_nojit_novmap.pkl', 'rb') as f:
            expt_dict[expt]['nvj'] = np.median(pickle.load(f)['timings'])
    except: print(f'Failed {expt} xla')

In [None]:
expt_dict

## Custom TFP

In [None]:
expts = ['logreg', 'ffnn', 'mnist', 'embed', 'lstm', 'cifar10']

In [None]:
tf_expt_dict = {}
for expt in expts:
    tf_expt_dict[expt] = {}
    pickle_name = f'./raw/tf2dp_{expt}_bs_128_priv_True_xla_'
    try:
        with open(pickle_name+'False.pkl', 'rb') as f:
            tf_expt_dict[expt]['base'] = np.median(pickle.load(f)['timings'])
    except: print(f'Failed {expt} base')
    try:
        with open(pickle_name+'True.pkl', 'rb') as f:
            tf_expt_dict[expt]['xla'] = np.median(pickle.load(f)['timings'])
    except: print(f'Failed {expt} xla')
    try:
        with open(pickle_name+'False_novmap.pkl', 'rb') as f:
            tf_expt_dict[expt]['nv'] = np.median(pickle.load(f)['timings'])
    except: print(f'Failed {expt} nv')
    try:
        with open(pickle_name+'False_nojit_novmap.pkl', 'rb') as f:
            tf_expt_dict[expt]['nvj'] = np.median(pickle.load(f)['timings'])
    except: print(f'Failed {expt} nvj')
    try:
        with open(pickle_name+'True_novmap.pkl', 'rb') as f:
            tf_expt_dict[expt]['xla_nv'] = np.median(pickle.load(f)['timings'])
    except: print(f'Failed {expt} xla_nv')

In [None]:
tf_expt_dict