# Results Analysis for FAccT 2022 paper

In [None]:
%matplotlib inline
import pandas as pd
import numpy as np

import seaborn as sns
sns.set_context("paper")
sns.set_theme(style="whitegrid")
import matplotlib.pyplot as plt
plt.show()
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual
from IPython.display import display
    
from fair_embedded_ml.metrics import domain_fairness, model_unfairness
from fair_embedded_ml import results_analysis 
from fair_embedded_ml.results_analysis import exp_names
from fair_embedded_ml import results_plot

x_dir = '~/Projects/fair_embedded_ml_results/'
    
class fixed_copy(fixed):
    def get_interact_value(self):
        return self.value.copy()

In [None]:
results = results_analysis.get_results(x_dir)
compression_results = results_analysis.get_compression_results(x_dir)
results_mcc = results_analysis.get_results_for_domains(results)
results_fairness = results_analysis.get_fairness_results(results_mcc)

#### Tables

In [None]:
cnn_16_selection = results_analysis.model_selection(results_fairness, 'sc16_cnn', 0.99)[
    ['input_features','frame_length','frame_step','mel_bins','mfccs','window_fn','all_mcc','model_fairness']]
cnn_16_selection

In [None]:
results_analysis.model_selection(results_fairness, 'sc16_llcnn', 0.99)[['input_features','frame_length','frame_step','mel_bins','mfccs','window_fn','all_mcc','model_fairness']]

In [None]:
cnn_8_selection = results_analysis.model_selection(results_fairness, 'sc8_cnn', 0.99)[
    ['input_features','frame_length','frame_step','mel_bins','mfccs','window_fn','all_mcc','model_fairness']]
cnn_8_selection

In [None]:
results_analysis.model_selection(results_fairness, 'sc8_llcnn', 0.99)[['input_features','frame_length','frame_step','mel_bins','mfccs','window_fn','all_mcc','model_fairness']]

### Section 5.1: Impact of Sample Rate

In [None]:
fig_data = pd.concat([results_analysis.select_models_in_mcc_range(results_fairness, 'sc16_cnn', best=100, min_percentage_of_mcc=0.95),
                      results_analysis.select_models_in_mcc_range(results_fairness, 'sc8_cnn', best=100, min_percentage_of_mcc=0.95),
                      results_analysis.select_models_in_mcc_range(results_fairness, 'sc16_llcnn', best=100, min_percentage_of_mcc=0.95),
                      results_analysis.select_models_in_mcc_range(results_fairness, 'sc8_llcnn', best=100, min_percentage_of_mcc=0.95),
                     ])
fig_data['exp_name'] = fig_data['exp_name'].map({'sc16_cnn':'16k CNN', 'sc8_cnn':'8k CNN', 'sc16_llcnn':'16k llCNN', 'sc8_llcnn':'8k llCNN' })

In [None]:
acc = fig_data.groupby(['exp_name'])['all_mcc'].agg(['mean','std'])
acc['rel_mean'] = acc['mean']/acc.loc['16k CNN','mean'] 
acc['rel_std'] = acc['std']/acc.loc['16k CNN','std'] 
acc

In [None]:
mf = fig_data.groupby(['exp_name'])['model_fairness'].agg(['mean','std'])
mf['rel_mean cnn'] = mf['mean']/mf.loc['16k CNN','mean'] 
mf['rel_std cnn'] = mf['std']/mf.loc['16k CNN','std'] 
mf['rel_mean llcnn'] = mf['mean']/mf.loc['16k llCNN','mean'] 
mf['rel_std llcnn'] = mf['std']/mf.loc['16k llCNN','std'] 
mf

In [None]:
sns.set_context('poster')
f, ax = plt.subplots(figsize=(12,8))
g = sns.boxplot(data = fig_data, x='exp_name', y='all_mcc', palette='tab20', linewidth=1)#.set_title('Accuracy (MCC) results across experiments');
g.set_title('Accuracy (MCC) across experiments', fontsize='large', pad=20)
g.set_xlabel('Experiments', fontsize='large', labelpad=10)
g.set_ylabel('Accuracy (MCC score)', fontsize='large', labelpad=10)
g.tick_params(axis='both', which='major', labelsize=20)
g.axvline(x=1.5 ,c="black", ls='--', linewidth=1)
sns.despine(bottom=True, left=True)
plt.savefig('figures/accuracy_experiments.png')

In [None]:
sns.set_context('poster')
f, ax = plt.subplots(figsize=(12,8))
g = sns.boxplot(data = fig_data, x='exp_name', y='model_fairness', palette='tab20', linewidth=1)
g.set_title('Reliability bias across experiments', fontsize='large', pad=20)
g.set_xlabel('Experiments', fontsize='large', labelpad=10)
g.set_ylabel('Reliability bias (lower is better)', fontsize='large', labelpad=10)
g.tick_params(axis='both', which='major', labelsize=20)
g.axvline(x=1.5 ,c="black", ls='--', linewidth=1)
sns.despine(bottom=True, left=True)
plt.savefig('figures/fairness_experiments.png')

In [None]:
results_fairness[results_fairness['exp_name'].isin(['sc8_cnn','sc16_cnn','sc8_llcnn','sc16_llcnn'])
                ].groupby('exp_name').agg({'female_fairness': [lambda c: c.abs().mean(),
                                                               lambda c: c.abs().std()],
                                           'male_fairness': [lambda c: c.abs().mean(),
                                                             lambda c: c.abs().std()]})

In [None]:
sns.set_theme(style="whitegrid", font_scale=2.2)

g = sns.relplot(data=results_mcc[results_mcc['exp_name'].isin(['sc16_cnn','sc8_cnn','sc16_llcnn','sc8_llcnn'])
                                ].sort_values(by=['model_arch','exp_name']).replace({'domain': {'male_mcc': 'male  ', 'female_mcc': 'female  '}}), 
                x='all_mcc', y='domain_mcc', hue='domain', col="exp_name", col_wrap=4, palette=['green', 'fuchsia'], height=6, aspect=1.1)
for ax in g.axes.flat:
    ax_title = ax.get_title().split(' = ')[-1]
    arch = ax_title.split('_')[-1]
    train_sr = ax_title.split('-')[0].split('_')[0].strip('sc')
    input_sr = ax_title.split('-')[-1].strip('sc')
    ax.axline((0.89,0.89), (0.9,0.9), c="black", ls='--', linewidth=2)
    ax.set_title('{}k {}{}'.format(train_sr, arch[:-3], arch[-3:].upper()), fontsize='large', pad=10)
    ax.set_xlabel('Overall accuracy \n $\mathregular{(MCC_{all})}$', fontsize='large', labelpad=15)
    ax.tick_params(axis='both', which='major', labelsize='medium')
    ax.set_xlim([0.45,0.89])
    ax.set_ylim([0.45,0.89])
    ax.set_ylabel('Subgroup accuracy \n $\mathregular{(MCC_{i})}$', fontsize='large', labelpad=15)
    legend = g._legend
    legend.set_frame_on(True)
    legend.set_bbox_to_anchor((0.96, 0.55))
    legend.set_title('Subgroup')
    legend.get_texts()
    sns.despine(bottom=True, left=True)

plt.suptitle('Accuracy scores for males and females', fontsize='x-large', ha='center', x=0.45, va='top', y=1.1)
plt.savefig('figures/subgroup_performance.png',bbox_inches='tight')

In [None]:
sns.set_theme(style="whitegrid", font_scale=2)

g = sns.FacetGrid(results_mcc[results_mcc['exp_name'].isin(['sc16_cnn','sc8_cnn','sc16_llcnn','sc8_llcnn'])].sort_values(by=['model_arch','exp_name']), #'sc16_cnn','sc8_cnn','sc16_llcnn','sc8_llcnn'
                  col="exp_name", col_wrap=4, height=5, aspect=1.2, hue="domain", palette=['green','fuchsia'])
g.map(sns.kdeplot, "domain_mcc",)
g.add_legend(title='Subgroup', bbox_to_anchor=(0.96, 0.58), frameon=True, borderpad=1.6)
legend = g._legend
legend.texts[0].set_text('male')
legend.texts[1].set_text('female')
g.map(sns.kdeplot, "all_mcc", color='black', ls='--')
g.add_legend(title='', handles= [g._legend_data['male_mcc']], labels=['overall'], bbox_to_anchor=(0.94, 0.41))
for ax in g.axes.flat:
    ax_title = ax.get_title().split(' = ')[-1]
    arch = ax_title.split('_')[-1]
    train_sr = ax_title.split('-')[0].split('_')[0].strip('sc')
    input_sr = ax_title.split('-')[-1].strip('sc')
    ax.set_xlim([0.6,0.93])
    ax.set_title('{}k {}{}'.format(train_sr, arch[:-3], arch[-3:].upper()), fontsize='large', pad=10)
    ax.set_xlabel('Accuracy scores (MCC)', fontsize='large', labelpad=15)
    ax.set_ylabel('Score density', fontsize='large', labelpad=15)
    
    sns.despine(left=True)
    
plt.suptitle('Accuracy score distributions for males and females', fontsize='x-large', va='top', y=1.1)
plt.savefig('figures/subgroup_performance_density.png',bbox_inches='tight')

### Section 5.2: Impact of Pre-processing Parameters

In [None]:
pre_processing_params = ['mel_bins','frame_step','frame_length','mfccs','input_features','window_fn']
dof_preprocessing, fcrit_preprocessing = results_analysis.fcrit(results, pre_processing_params, 0.01)
print('dof:', dof_preprocessing, '\nfcrit:', fcrit_preprocessing)

In [None]:
sns.set_theme(style="whitegrid", font_scale=1.8)
results_plot.plot_param_importance(df=results_fairness,
                      metrics={'all_mcc':'MCC','model_fairness':'bias'}, 
                      parameters="preprocessing", 
                      select_tables=['8000_','16000_'],
                      fcrit=fcrit_preprocessing,
                      save_fig=False,
                      plot_title = "Pre-processing parameter importance for reliability bias and accuracy (MCC)")
plt.savefig('figures/metric_param_importance.png',bbox_inches='tight')

In [None]:
results_plot.plot_param_importance(df=results_fairness,
                      metrics={'male_mcc':'male','female_mcc':'female'}, 
                      parameters="preprocessing", 
                      select_tables=['8000_','16000_'],
                      fcrit=fcrit_preprocessing,
                      save_fig=False,
                      palette=['fuchsia', 'green'],
                      plot_title = "Pre-processing parameter importance for subgroup accuracy (MCC)")

In [None]:
results_plot.plot_param_importance(df=results_fairness,
                      metrics={'male_fairness':'male','female_fairness':'female'}, 
                      parameters="preprocessing", 
                      select_tables=['8000_','16000_'],
                      fcrit=fcrit_preprocessing,
                      save_fig=False,
                      palette=['fuchsia', 'green'],
                      plot_title = "Pre-processing parameter importance for bias across groups")

#### Input to Table 2

In [None]:
results_analysis.generate_importance_tables(results_fairness[results_fairness.equal_weighted==True], 
                                            'model_fairness', parameters="preprocessing", model_arch='cnn')['16000_']

In [None]:
results_analysis.generate_importance_tables(results_fairness[results_fairness.equal_weighted==True], 
                                            'model_fairness', parameters="preprocessing", model_arch='low_latency_cnn')['16000_']

In [None]:
results_analysis.generate_importance_tables(results_fairness[results_fairness.equal_weighted==True], 
                                            'model_fairness', parameters="preprocessing", model_arch='cnn')['8000_']

In [None]:
results_analysis.generate_importance_tables(results_fairness[results_fairness.equal_weighted==True], 
                                            'model_fairness', parameters="preprocessing", model_arch='low_latency_cnn')['8000_']

#### Tables in Appendix

In [None]:
results_fairness[results_fairness['exp_name'].isin(['sc16_cnn','sc16_llcnn']) 
                 & (results_fairness['input_features']=='log_mel_spectrogram')
                ].groupby(['exp_name','mel_bins'])[['all_mcc','model_fairness']
                                                  ].agg(['mean','std']).T#.style.format('{:.1e}')

In [None]:
results_fairness[results_fairness['exp_name'].isin(['sc8_cnn','sc8_llcnn']) 
                 & (results_fairness['input_features']=='log_mel_spectrogram')
                ].groupby(['exp_name','mel_bins'])[['all_mcc','model_fairness']
                                                  ].agg(['mean','std']).T#.style.format('{:.1e}')

In [None]:
results_fairness[results_fairness['exp_name'].isin(['sc16_cnn','sc16_llcnn']) 
                 & (results_fairness['input_features']=='mfcc')
                ].groupby(['exp_name','mfccs'])[['all_mcc','model_fairness']
                                                  ].agg(['mean','std']).T#.style.format('{:.1e}')

In [None]:
results_fairness[results_fairness['exp_name'].isin(['sc8_cnn','sc8_llcnn']) 
                 & (results_fairness['input_features']=='mfcc')
                ].groupby(['exp_name','mfccs'])[['all_mcc','model_fairness']
                                                  ].agg(['mean','std']).T#.style.format('{:.1e}')

In [None]:
# sc16_cnn: mfcc / log Mel spec | fairest for each
0.016908/0.006664

In [None]:
# sc16_llcnn: mfcc / log Mel spec | fairest for each
0.019317/0.013510

In [None]:
# sc8_cnn: mfcc / log Mel spec | fairest for each
0.028245/0.010909

In [None]:
# sc8_llcnn: mfcc / log Mel spec | fairest for each
0.043740/0.017558

In [None]:
sns.set_theme(style="whitegrid", font_scale=2.2)

g = sns.relplot(data=results[results['exp_name'].isin(['sc16_cnn','sc8_cnn','sc16_llcnn','sc8_llcnn'])].sort_values(by=['model_arch','exp_name']), 
                x='male_mcc', y='female_mcc', hue='input_features', col="exp_name", col_wrap=4, palette=['purple','aqua'], height=6, aspect=1.1)
for ax in g.axes.flat:
    ax_title = ax.get_title().split(' = ')[-1]
    arch = ax_title.split('_')[-1]
    train_sr = ax_title.split('-')[0].split('_')[0].strip('sc')
    input_sr = ax_title.split('-')[-1].strip('sc')
    ax.axline((0.89,0.89), (0.9,0.9), c="black", ls='--', linewidth=0.8)
    ax.set_title('{}k {}{}'.format(train_sr, arch[:-3], arch[-3:].upper()), fontsize='large', pad=10)
#     ax.set_title('{}k {} model'.format(train_sr, arch), fontsize='large')
    ax.set_xlabel('Male accuracy \n $\mathregular{(MCC_{male})}$', fontsize='large', labelpad=25)
    ax.tick_params(axis='both', which='major', labelsize='medium')
    ax.set_xlim([0.45,0.89])
    ax.set_ylim([0.45,0.89])
    ax.set_ylabel('Female accuracy \n $\mathregular{(MCC_{female})}$', fontsize='large', labelpad=25)
    legend = g._legend
    legend.set_bbox_to_anchor((0.93, 0.51))
    legend.set_frame_on(True)
    legend.set_title('Feature type')
    legend.texts[0].set_text('MFCC')
    legend.texts[1].set_text('log Mel spectrogram')
    sns.despine(bottom=True, left=True)
    
plt.suptitle('Accuracy score distributions showing effect of feature type on accuracy for males and females', fontsize='x-large', ha='center', x=0.45, va='top', y=1.1)
plt.savefig('figures/subgroup_performance_featuretype.png', bbox_inches='tight')

#### Input to Table 3

In [None]:
the_best = results_analysis.get_top_result(results_fairness, 'all_mcc', best=1)
the_best[the_best['exp_name'].isin(['sc8_cnn','sc16_cnn','sc8_llcnn','sc16_llcnn'])][
    ['run_name','exp_id', 'exp_name','model_arch','input_features','frame_length','frame_step','mel_bins',
     'mfccs','window_fn','all_mcc','model_fairness']].sort_values(['exp_name','all_mcc'], ascending=False)

In [None]:
the_fairest = results_analysis.get_top_result(results_fairness, 'model_fairness', best=1)
the_fairest[the_fairest['exp_name'].isin(['sc8_cnn','sc16_cnn','sc8_llcnn','sc16_llcnn'])]

In [None]:
results_analysis.select_models_in_mcc_range(results_fairness, 'sc16_cnn', best=1, min_percentage_of_mcc=0.99)

In [None]:
results_analysis.select_models_in_mcc_range(results_fairness, 'sc16_llcnn', best=1, min_percentage_of_mcc=0.99)

In [None]:
results_analysis.select_models_in_mcc_range(results_fairness, 'sc8_cnn', best=1, min_percentage_of_mcc=0.99)

In [None]:
results_analysis.select_models_in_mcc_range(results_fairness, 'sc8_llcnn', best=1, min_percentage_of_mcc=0.99)

In [None]:
# Reduction in model bias when selecting for accuracy + fairness, rather than accuracy alone
# 16k CNN
print(0.012009 / 0.000765)

# 16k llCNN
print(0.000658 / 0.000658)

# 8k CNN
print(0.009773 / 0.005914)

# 8k llCNN
print(0.040717 / 0.001824)

In [None]:
# Cost to accuracy when selecting for fairness only
# 16k CNN
# print(0.877 / 0.849)
print(0.849/0.877)

# 16k llCNN
# print(0.868 / 0.815)
print(0.815 / 0.868)

# 8k CNN
# print(0.804 / 0.762)
print(0.762 / 0.804)

# 8k llCNN
# print(0.778 / 0.740)
print(0.740 / 0.778)

### Section 5.3: Impact of Pruning Hyperparameters

In [None]:
best_runs = ['run-1628708435','run-1628757640','run-1628732129','run-1628728028','run-1628733620','run-1628787124','run-1628745492','run-1628778770','run-1628729675','run-1628795284','run-1628762389','run-1628794609']
fairest_runs = ['run-1628769835','run-1628786232','run-1628785897','run-1628742549','run-1628765724','run-1628793147','run-1628763273','run-1628776241','run-1628796466','run-1628758822','run-1628753364','run-1628809058']
accurate_fair_runs = ['run-1628726178','run-1628729666','run-1628733199','run-1628743272','run-1628759090','run-1628715119','run-1628737906','run-1628782987','run-1628799888','run-1628735838','run-1628806790']

In [None]:
len(best_runs)+len(fairest_runs)+len(accurate_fair_runs)

In [None]:
results_compress = pd.merge(compression_results, results[results['exp_name'].isin(exp_names['sc_train'])].iloc[:,6:22], 
                            how='left',left_on='trained_model_path', right_on='run_name', suffixes=[None,'_trained'])
# results_compress['equal_weighted'].fillna(False, inplace=True)
results_compress['model_fairness'] = results_compress.apply(lambda x: model_unfairness([x['female_mcc'], x['male_mcc']],x['all_mcc']), axis=1)
results_compress['model_fairness_trained'] = results_compress.apply(lambda x: model_unfairness([x['female_mcc_trained'], x['male_mcc_trained']],x['all_mcc_trained']), axis=1)
results_compress['delta_all_mcc'] = results_compress['all_mcc'] - results_compress['all_mcc_trained']
results_compress['delta_male_mcc'] = results_compress['male_mcc'] - results_compress['male_mcc_trained']
results_compress['delta_female_mcc'] = results_compress['female_mcc'] - results_compress['female_mcc_trained']
results_compress['delta_model_fairness'] = results_compress['model_fairness'] - results_compress['model_fairness_trained']
results_compress['model_selected_because'] = np.where(results_compress['trained_model_path'].isin(best_runs), 'best', 
                                                    np.where(results_compress['trained_model_path'].isin(fairest_runs), 'fairest',
                                                             # np.where(results_compress['trained_model_path'].isin(ignorant_runs), 'ignorant', 
                                                                      np.where(results_compress['trained_model_path'].isin(accurate_fair_runs), 
                                                                               'accurate_fair_runs', np.nan)))#)

In [None]:
results_compress.drop_duplicates(subset=['exp_name','trained_model_path','pruning_schedule','pruning_learning_rate','pruning_frequency','pruning_final_sparsity','quantize','quantization_optimization'], 
    keep='last', inplace=True)

In [None]:
pruning_params = ['trained_model_path','pruning_learning_rate','pruning_schedule', 'pruning_frequency', 'pruning_final_sparsity','model_arch']
pruning_params_pretty = ['trained model','learning rate','schedule', 'frequency', 'final sparsity','architecture']
dof_pruning_all_arch, fcrit_pruning_all_arch = results_analysis.fcrit(results_compress, pruning_params, 0.01)
dof_pruning, fcrit_pruning = results_analysis.fcrit(results_compress, pruning_params[:-1], 0.01)
print('dof:', dof_pruning, '\nfcrit:', fcrit_pruning)

In [None]:
results_analysis.generate_importance_tables(results_compress[results_compress.equal_weighted==True], 
                                            'model_fairness', parameters="compression", model_arch='cnn')['16000_prune']

In [None]:
results_analysis.generate_importance_tables(results_compress[results_compress.equal_weighted==True], 
                                            'model_fairness', parameters="compression", model_arch='cnn')['8000_prune']

In [None]:
results_analysis.generate_importance_tables(results_compress[results_compress.equal_weighted==True], 
                                            'model_fairness', parameters="compression", model_arch='low_latency_cnn')['16000_prune']

In [None]:
results_analysis.generate_importance_tables(results_compress[results_compress.equal_weighted==True], 
                                            'model_fairness', parameters="compression", model_arch='low_latency_cnn')['8000_prune']

In [None]:
sns.set_theme(style="whitegrid", font_scale=1.8)
results_plot.plot_param_importance(df=results_compress[(results_compress.equal_weighted==True)
                                                       # & (results_compress.model_selected_because=='accurate_fair_runs') #select from: best, fairest, accurate_fair_runs
                                                      ],
                                   metrics={'all_mcc':'MCC','model_fairness':'bias'}, 
                                   parameters="compression", 
                                   select_tables=['8000_prune','16000_prune'],
                                   fcrit=fcrit_pruning, 
                                   save_fig=True,
                                   pretty_params=dict(zip(pruning_params, pruning_params_pretty)),
                                   plot_title = "Pruning parameter importance for reliability bias and accuracy (MCC)")
plt.savefig('figures/pruning_param_importance.png',bbox_inches='tight')

In [None]:
results_plot.plot_param_importance(df=results_compress[(results_compress.equal_weighted==True)
                                                       # & (results_compress.model_selected_because=='best')
                                                      ],
                                   metrics={'delta_all_mcc':'delta MCC','delta_model_fairness':'delta bias'}, 
                                   parameters="compression", 
                                   select_tables=['8000_prune','16000_prune'],
                                   fcrit=fcrit_pruning,  
                                   pretty_params=dict(zip(pruning_params, pruning_params_pretty)),
                                   plot_title = "Pruning parameter importance for change in accuracy and reliability bias")

In [None]:
sns.set_theme(style="whitegrid", font_scale=2.2)

g = sns.relplot(data=results_compress[(results_compress.equal_weighted==True)
                                      # & (results_compress.exp_name.isin(['sc8_cnn-compress_ew','sc8_llcnn-compress_ew']))
                                      # & (results_compress.model_selected_because=='fairest')
                                     ].sort_values(by=['model_arch','exp_name']), 
                x='male_mcc', y='female_mcc', hue='pruning_learning_rate', 
                col="exp_name", col_wrap=4, #, row="model_selected_because"
                palette=['red','purple','aqua'], 
                height=6, aspect=1.1)
for ax in g.axes.flat:
    ax_title = ax.get_title().split(' = ')[-1]
    arch = ax_title.split('-')[0].split('_')[-1]
    train_sr = ax_title.split('-')[0].split('_')[0].strip('sc')
    input_sr = ax_title.split('-')[-1].strip('sc')
    ax.axline((0.89,0.89), (0.9,0.9), c="black", ls='--', linewidth=0.8)
    ax.set_title('{}k {}{}'.format(train_sr, arch[:-3], arch[-3:].upper()), fontsize='large', pad=10)
#     ax.set_title('{}k {} model'.format(train_sr, arch), fontsize='large')
    ax.set_xlabel('Male accuracy \n $\mathregular{(MCC_{male})}$', fontsize='large', labelpad=25)
    ax.tick_params(axis='both', which='major', labelsize='medium')
    ax.set_xlim([0.73,0.91])
    ax.set_ylim([0.73,0.91])
    ax.set_ylabel('Female accuracy \n $\mathregular{(MCC_{female})}$', fontsize='large', labelpad=25)
    legend = g._legend
    legend.set_bbox_to_anchor((0.95, 0.56))
    legend.set_frame_on(True)
    legend.set_title('Pruning \nLearning Rate')
    # legend.texts[0].set_text('MFCC')
    # legend.texts[1].set_text('log Mel spectrogram')
    sns.despine(bottom=True, left=True)
    
plt.suptitle("Accuracy scores showing effect of pruning learning rate on males and females", fontsize='large', ha='center', x=0.42, va='top', y=1.15)
plt.savefig('figures/subgroup_performance_pruning_learning_rate.png', bbox_inches='tight')

In [None]:
sns.set_theme(style="whitegrid", font_scale=2.2)

g = sns.relplot(data=results_compress[(results_compress.equal_weighted==True)
                                      # & (results_compress.model_selected_because=='fairest')
                                     ].sort_values(by=['model_arch','exp_name']), 
                x='delta_all_mcc', y='delta_model_fairness', hue='all_mcc', size='all_mcc',
                col="exp_name", col_wrap=4, #, row="model_selected_because"
                palette='hsv',
                style='model_selected_because',
                markers=['o','*','^'],
                height=6, aspect=1.1)
for ax in g.axes.flat:
    ax_title = ax.get_title().split(' = ')[-1]
    arch = ax_title.split('-')[0].split('_')[-1]
    train_sr = ax_title.split('-')[0].split('_')[0].strip('sc')
    input_sr = ax_title.split('-')[-1].strip('sc')
    ax.axhline(y=0, color="black", ls='-', linewidth=0.5)
    ax.axvline(x=0, color="black", ls='-', linewidth=0.5)
    ax.set_title('{}k {}{}'.format(train_sr, arch[:-3], arch[-3:].upper()), fontsize='large', pad=10)
#     ax.set_title('{}k {} model'.format(train_sr, arch), fontsize='large')
    ax.set_xlabel('Male accuracy \n $\mathregular{(MCC_{male})}$', fontsize='large', labelpad=25)
    ax.tick_params(axis='both', which='major', labelsize='medium')
    # ax.set_xlim([0.73,0.91])
    # ax.set_ylim([0.73,0.91])
    ax.set_ylabel('Female accuracy \n $\mathregular{(MCC_{female})}$', fontsize='large', labelpad=25)
    legend = g._legend
    legend.set_bbox_to_anchor((0.98,0.56))
    legend.set_frame_on(True)
    legend.set_title('MCC score')
    # legend.texts[0].set_text('MFCC')
    # legend.texts[1].set_text('log Mel spectrogram')
    sns.despine(bottom=True, left=True)
    
plt.suptitle("Accuracy score distributions showing effect of pruning learning rate on accuracy for males and females", fontsize='x-large', ha='center', x=0.45, va='top', y=1.1)
plt.savefig('figures/subgroup_performance_pruning_learning_rate.png', bbox_inches='tight')

In [None]:
g = sns.FacetGrid(results_compress[(results_compress.equal_weighted==True) 
                                   # & (results_compress.model_selected_because=='best')
                                  ].sort_values(by=['model_arch','exp_name']), 
                  col="exp_name", col_wrap=4, height=5, aspect=1.2, hue="model_selected_because", palette=['red','blue','orange']
                 )
g.map(sns.kdeplot, "delta_all_mcc", cut=0.91, linewidth=2)
# g.map(sns.kdeplot, "model_fairness")
g.add_legend(title='Selection strategy', bbox_to_anchor=(0.96, 0.58), frameon=True, borderpad=0.6)
legend = g._legend
legend.texts[0].set_text('fairest')
legend.texts[1].set_text('accurate & fair')
legend.texts[2].set_text('best accuracy')
for ax in g.axes.flat:
    ax_title = ax.get_title().split(' = ')[-1]
    arch = ax_title.split('-')[0].split('_')[-1]
    train_sr = ax_title.split('-')[0].split('_')[0].strip('sc')
    input_sr = ax_title.split('-')[-1].strip('sc')
    # ax.set_xlim([0.7,0.91])
    ax.set_title('{}k {}{}'.format(train_sr, arch[:-3], arch[-3:].upper()), fontsize='large', pad=10)
    ax.set_xlabel('Change in accuracy (MCC)', labelpad=15)
    ax.set_ylabel('Density', labelpad=15)
    
    sns.despine(left=True)
    
title = plt.suptitle('Score distributions', fontsize='x-large', va='top', y=1.1)
plt.savefig('figures/model_selection_delta_mcc_density.png',bbox_inches='tight')

In [None]:
g = sns.FacetGrid(results_compress[(results_compress.equal_weighted==True) 
                                      # & (results_compress.exp_name.isin(['sc8_cnn-compress_ew','sc8_llcnn-compress_ew']))
                                   # & (results_compress.model_selected_because=='best')
                                  ].sort_values(by=['model_arch','exp_name','model_selected_because']), 
                  col="exp_name", col_wrap=4, height=6, aspect=1.1, hue="model_selected_because", palette=['red','blue','orange']
                 )
# g.map(sns.kdeplot, "male_mcc", cut=0.91, linewidth=2)
g.map(sns.kdeplot, "delta_model_fairness")
g.add_legend(title='Selection strategy', bbox_to_anchor=(0.95, 0.58), frameon=True, borderpad=0.6)
legend = g._legend
legend.texts[0].set_text('accurate + inclusive')
legend.texts[1].set_text('accuracy')
legend.texts[2].set_text('inclusion')
for ax in g.axes.flat:
    ax.vlines(x=0, ymin=0, ymax=80, color="black", ls='--', linewidth=2)
    ax_title = ax.get_title().split(' = ')[-1]
    arch = ax_title.split('-')[0].split('_')[-1]
    train_sr = ax_title.split('-')[0].split('_')[0].strip('sc')
    input_sr = ax_title.split('-')[-1].strip('sc')
    # ax.set_xlim([0.7,0.91])
    ax.set_title('{}k {}{}'.format(train_sr, arch[:-3], arch[-3:].upper()), fontsize='large', pad=10)
    ax.set_xlabel('$\Delta$ reliability bias \n', labelpad=15)
    ax.set_ylabel('Density', labelpad=15)
    
    sns.despine(left=True)
    
title = plt.suptitle('Density distributions of $\Delta$ reliability bias for model selection strategies', fontsize='large', ha='center', va='top', y=1.15)
plt.savefig('figures/model_selection_delta_fairness_density.png',bbox_inches='tight')

In [None]:
results_sparsity = results_compress.loc[:,['equal_weighted','exp_name','pruning_final_sparsity',
                                     'pruning_learning_rate','pruning_schedule','pruning_frequency',
                                     'trained_model_path','model_selected_because','all_mcc','model_fairness']]
results_sparsity.loc[:,'sparsity_max_mcc'] = results_sparsity.groupby(['equal_weighted','exp_name','pruning_final_sparsity'])['all_mcc'].transform('max')#agg({'all_mcc':'max', 'model_fairness':'min'})
results_sparsity.loc[:,'sparsity_min_fairness'] = results_sparsity.groupby(['equal_weighted','exp_name','pruning_final_sparsity'])['model_fairness'].transform('min')#agg({'all_mcc':'max', 'model_fairness':'min'})

In [None]:
def fairest_pruned_model_in_mcc_range(df, exp_name, pruning_final_sparsity, min_percentage_of_mcc=0.99):
        
    min_mcc = df.where((df.exp_name == exp_name)
                       & (df.pruning_final_sparsity == pruning_final_sparsity)
                      )['all_mcc'].dropna(how='all').nlargest(1).values[0]*min_percentage_of_mcc

    selected_models = df[(df.all_mcc >= min_mcc)
                         & (df.exp_name == exp_name)
                         & (df.pruning_final_sparsity == pruning_final_sparsity)
                        ].dropna(how='all').reset_index(drop=True)
    
    selected_models = selected_models.iloc[selected_models['model_fairness'].idxmin(),:]
        
    return selected_models

In [None]:
def pruned_model_selection(df, min_percentage_of_mcc):
    df_mod = pd.DataFrame()

    for a in ['sc8_', 'sc8_ll','sc16_','sc16_ll']:
        arch = a+'cnn-compress_ew'
        for s in [0.2, 0.5, 0.75, 0.8, 0.85, 0.9]:
            df_s = fairest_pruned_model_in_mcc_range(df[df.equal_weighted==True], arch, s, min_percentage_of_mcc)
            df_mod = df_mod.append(df_s)
                
    return df_mod

In [None]:
pruned_model_selection(results_sparsity, 1).groupby('exp_name').agg({'all_mcc':['mean','std'],'model_fairness':['mean','std']})

In [None]:
pruned_model_selection(results_sparsity, 0.995).groupby('exp_name').agg({'all_mcc':['mean','std'],'model_fairness':['mean','std']})

In [None]:
# pd.set_option('display.float_format', '{:.1e}'.format)
pruned_model_selection(results_sparsity, 0.99).groupby('exp_name').agg({'all_mcc':['mean','std'],'model_fairness':['mean','std']})

In [None]:
pd.set_option('display.float_format', '{:.1e}'.format)
pruned_model_selection(results_sparsity, 0).groupby('exp_name').agg({'all_mcc':['mean','std'],'model_fairness':['mean','std']})

In [None]:
pd.set_option('display.float_format', '{:.4}'.format)

acc_df = pruned_model_selection(results_sparsity, 0.99).sort_values(['exp_name','pruning_final_sparsity'])
acc_df

In [None]:
acc_df.groupby(['exp_name','pruning_learning_rate'])['all_mcc'].count()

In [None]:
acc_df.groupby(['exp_name','pruning_frequency'])['all_mcc'].count()

In [None]:
acc_df.groupby(['exp_name','pruning_schedule'])['all_mcc'].count()

In [None]:
acc_df.groupby(['exp_name','trained_model_path'])['all_mcc'].count()

In [None]:
acc_df.groupby(['model_selected_because'])['all_mcc'].count()