In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import os
import warnings
warnings.filterwarnings('ignore')
os.environ["PYTHONWARNINGS"] = "ignore"

# Visualizations

In [39]:
import os
import pandas as pd

from virny.utils.custom_initializers import create_models_metrics_dct_from_database_df

from source.utils.db_functions import read_model_metric_dfs_from_db
from source.custom_classes.experiments_composer import ExperimentsComposer
from source.custom_classes.experiments_visualizer import ExperimentsVisualizer

## Initialize Configs

In [4]:
EXPERIMENT_NAME = 'stress_testing_nulls'
DB_COLLECTION_NAME = f'{EXPERIMENT_NAME}_results'
DATASET_NAME = 'Folktables_GA_2018'
EXPERIMENT_SESSION_UUID = 'c53d250b-5ba9-4d91-a444-ed7eb7919de5'
SENSITIVE_ATTRS = ['SEX', 'RAC1P', 'AGEP', 'SEX&RAC1P&AGEP']

In [5]:
MODEL_NAMES = ['DecisionTreeClassifier', 'LogisticRegression', 'RandomForestClassifier',
               'XGBClassifier', 'KNeighborsClassifier', 'MLPClassifier']

In [6]:
from source.utils.db_functions import connect_to_mongodb

client, collection_obj, db_writer_func = connect_to_mongodb(DB_COLLECTION_NAME)

## Group Metrics Composition

In [7]:
model_metric_dfs = read_model_metric_dfs_from_db(collection_obj, EXPERIMENT_SESSION_UUID)
models_metrics_dct = create_models_metrics_dct_from_database_df(model_metric_dfs)
client.close()

In [9]:
models_metrics_dct[list(models_metrics_dct.keys())[0]].shape

(224, 27)

In [8]:
models_metrics_dct[list(models_metrics_dct.keys())[0]].head(20)

Unnamed: 0,Metric,Bootstrap_Model_Seed,Model_Name,Model_Params,Run_Number,Dataset_Name,Num_Estimators,Test_Set_Index,Tag,Record_Create_Date_Time,...,Injector_Config_Lst,AGEP_dis,AGEP_priv,RAC1P_dis,RAC1P_priv,SEX&RAC1P&AGEP_dis,SEX&RAC1P&AGEP_priv,SEX_dis,SEX_priv,overall
0,Accuracy,101,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_1,Folktables_GA_2018,10,0,OK,2023-04-22 14:09:26.705,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.770941,0.819535,0.771368,0.782867,0.737063,0.905405,0.746923,0.813958,0.7791
1,Accuracy,101,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_1,Folktables_GA_2018,10,1,OK,2023-04-22 14:09:26.716,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.768537,0.815962,0.76801,0.780637,0.735664,0.899614,0.745962,0.809583,0.7765
2,Accuracy,101,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_1,Folktables_GA_2018,10,2,OK,2023-04-22 14:09:26.728,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.767456,0.81358,0.766484,0.779447,0.734965,0.893822,0.746154,0.806667,0.7752
3,Accuracy,101,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_1,Folktables_GA_2018,10,3,OK,2023-04-22 14:09:26.745,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.762889,0.805241,0.7616,0.774093,0.734266,0.8861,0.743269,0.798958,0.77
4,Accuracy,101,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_1,Folktables_GA_2018,10,4,OK,2023-04-22 14:09:26.758,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.757481,0.794521,0.757021,0.766954,0.730769,0.874517,0.738654,0.790833,0.7637
5,Accuracy,101,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_1,Folktables_GA_2018,10,5,OK,2023-04-22 14:09:26.769,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.752073,0.787373,0.749695,0.762046,0.724476,0.862934,0.734423,0.783542,0.758
6,Accuracy,101,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_1,Folktables_GA_2018,10,6,OK,2023-04-22 14:09:26.782,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.747987,0.777844,0.741758,0.758477,0.720979,0.861004,0.730385,0.7775,0.753
14,Accuracy,102,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_2,Folktables_GA_2018,10,0,OK,2023-04-22 14:12:05.983,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.735128,0.799285,0.739316,0.749108,0.730769,0.880309,0.744615,0.747292,0.7459
15,Accuracy,102,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_2,Folktables_GA_2018,10,1,OK,2023-04-22 14:12:05.993,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.733806,0.79869,0.737179,0.748364,0.730769,0.882239,0.742885,0.746667,0.7447
16,Accuracy,102,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_2,Folktables_GA_2018,10,2,OK,2023-04-22 14:12:06.004,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.732604,0.796903,0.733822,0.748067,0.726573,0.880309,0.741154,0.745833,0.7434


In [77]:
exp_composer = ExperimentsComposer(models_metrics_dct, SENSITIVE_ATTRS)
exp_subgroup_metrics_dct = exp_composer.create_exp_subgroup_metrics_dct()

In [42]:
exp_subgroup_metrics_dct['DecisionTreeClassifier']['Exp_iter_1'][0.5].head(100)

Unnamed: 0,Metric,Bootstrap_Model_Seed,Model_Name,Model_Params,Run_Number,Dataset_Name,Num_Estimators,Test_Set_Index,Tag,Record_Create_Date_Time,...,Injector_Config_Lst,AGEP_dis,AGEP_priv,RAC1P_dis,RAC1P_priv,SEX&RAC1P&AGEP_dis,SEX&RAC1P&AGEP_priv,SEX_dis,SEX_priv,overall
6,Accuracy,101,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_1,Folktables_GA_2018,10,6,OK,2023-04-22 14:09:26.782,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.747987,0.777844,0.741758,0.758477,0.720979,0.861004,0.730385,0.7775,0.753
20,Accuracy,102,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_2,Folktables_GA_2018,10,6,OK,2023-04-22 14:12:06.054,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.720707,0.762954,0.716117,0.733492,0.702797,0.874517,0.713462,0.743333,0.7278
34,Entropy,101,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_1,Folktables_GA_2018,10,6,OK,2023-04-22 14:09:26.782,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.14539,0.0,0.0,0.129821,0.0,0.0,0.103232,0.0,0.135128
48,Entropy,102,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_2,Folktables_GA_2018,10,6,OK,2023-04-22 14:12:06.054,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.0,0.0,0.229812,0.0,0.0,0.0,0.0,0.241828,0.0
62,F1,101,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_1,Folktables_GA_2018,10,6,OK,2023-04-22 14:09:26.782,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.760917,0.860038,0.773918,0.788926,0.737327,0.917051,0.759602,0.809422,0.784015
76,F1,102,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_2,Folktables_GA_2018,10,6,OK,2023-04-22 14:12:06.054,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.739345,0.848899,0.745345,0.773107,0.695341,0.927048,0.727006,0.797768,0.764329
90,FNR,101,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_1,Folktables_GA_2018,10,6,OK,2023-04-22 14:09:26.782,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.228082,0.116423,0.200442,0.203203,0.222222,0.099548,0.191606,0.2125,0.202313
104,FNR,102,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_2,Folktables_GA_2018,10,6,OK,2023-04-22 14:12:06.054,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.237567,0.138011,0.248482,0.198477,0.326389,0.065611,0.275912,0.15625,0.214591
118,FPR,101,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_1,Folktables_GA_2018,10,6,OK,2023-04-22 14:09:26.782,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.277889,0.581152,0.329693,0.291595,0.33662,0.368421,0.356504,0.2375,0.304338
132,FPR,102,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_2,Folktables_GA_2018,10,6,OK,2023-04-22 14:12:06.054,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.324412,0.573298,0.327645,0.355403,0.267606,0.473684,0.298374,0.407292,0.346119


In [43]:
exp_subgroup_metrics_dct['DecisionTreeClassifier']['Exp_iter_1'][0.05].head(100)

Unnamed: 0,Metric,Bootstrap_Model_Seed,Model_Name,Model_Params,Run_Number,Dataset_Name,Num_Estimators,Test_Set_Index,Tag,Record_Create_Date_Time,...,Injector_Config_Lst,AGEP_dis,AGEP_priv,RAC1P_dis,RAC1P_priv,SEX&RAC1P&AGEP_dis,SEX&RAC1P&AGEP_priv,SEX_dis,SEX_priv,overall
1,Accuracy,101,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_1,Folktables_GA_2018,10,1,OK,2023-04-22 14:09:26.716,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.768537,0.815962,0.76801,0.780637,0.735664,0.899614,0.745962,0.809583,0.7765
15,Accuracy,102,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_2,Folktables_GA_2018,10,1,OK,2023-04-22 14:12:05.993,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.733806,0.79869,0.737179,0.748364,0.730769,0.882239,0.742885,0.746667,0.7447
29,Entropy,101,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_1,Folktables_GA_2018,10,1,OK,2023-04-22 14:09:26.716,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.136834,0.0,0.0,0.119452,0.0,0.0,0.088621,0.0,0.1242
43,Entropy,102,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_2,Folktables_GA_2018,10,1,OK,2023-04-22 14:12:05.993,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.0,0.0,0.171997,0.0,0.0,0.0,0.0,0.18547,0.0
57,F1,101,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_1,Folktables_GA_2018,10,1,OK,2023-04-22 14:09:26.716,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.788816,0.888003,0.805428,0.815001,0.761364,0.941834,0.781545,0.843278,0.811853
71,F1,102,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_2,Folktables_GA_2018,10,1,OK,2023-04-22 14:12:05.993,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.762058,0.876732,0.775956,0.79386,0.742475,0.931844,0.770787,0.804439,0.78815
85,FNR,101,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_1,Folktables_GA_2018,10,1,OK,2023-04-22 14:09:26.716,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.167939,0.055513,0.131419,0.14702,0.1625,0.047511,0.137591,0.146181,0.141993
99,FNR,102,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_2,Folktables_GA_2018,10,1,OK,2023-04-22 14:12:05.993,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.179505,0.073246,0.176698,0.144657,0.229167,0.056561,0.179562,0.131597,0.154982
113,FPR,101,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_1,Folktables_GA_2018,10,1,OK,2023-04-22 14:09:26.716,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.30015,0.620419,0.356314,0.313894,0.367606,0.407895,0.38374,0.256771,0.328082
127,FPR,102,DecisionTreeClassifier,"{'ccp_alpha': 0.0, 'class_weight': None, 'crit...",Run_2,Folktables_GA_2018,10,1,OK,2023-04-22 14:12:05.993,...,"[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]",0.35993,0.636126,0.369283,0.391424,0.309859,0.473684,0.343496,0.435937,0.384018


In [44]:
exp_avg_group_metrics_dct = exp_composer.compose_group_metrics(exp_subgroup_metrics_dct)

In [45]:
exp_avg_group_metrics_dct['DecisionTreeClassifier']['Exp_iter_1'][0.5].head(100)

Unnamed: 0,Metric,SEX,RAC1P,AGEP,SEX&RAC1P&AGEP,Model_Name
0,Equalized_Odds_TPR,-0.049384,-0.023622,-0.105608,-0.191726,DecisionTreeClassifier
1,Equalized_Odds_FPR,0.005043,0.00517,-0.276075,-0.11894,DecisionTreeClassifier
2,Disparate_Impact,1.028784,0.99492,1.002776,1.03414,DecisionTreeClassifier
3,Statistical_Parity_Difference,0.029663,-0.005318,0.002895,0.033792,DecisionTreeClassifier
4,Accuracy_Parity,-0.038494,-0.017047,-0.036052,-0.155873,DecisionTreeClassifier
5,Label_Stability_Ratio,1.068731,0.982722,0.935809,0.96099,DecisionTreeClassifier
6,IQR_Parity,-0.043025,0.001928,0.001826,-0.036516,DecisionTreeClassifier
7,Std_Parity,-0.027653,0.003487,0.005419,-0.013523,DecisionTreeClassifier
8,Std_Ratio,0.762326,1.034582,1.055606,0.86851,DecisionTreeClassifier
9,Jitter_Parity,-0.038243,0.00897,0.043678,0.028686,DecisionTreeClassifier


In [46]:
exp_avg_group_metrics_dct['DecisionTreeClassifier']['Exp_iter_1'][0.3].head(100)

Unnamed: 0,Metric,SEX,RAC1P,AGEP,SEX&RAC1P&AGEP,Model_Name
0,Equalized_Odds_TPR,-0.033765,-0.010951,-0.108343,-0.163276,DecisionTreeClassifier
1,Equalized_Odds_FPR,0.010855,0.007839,-0.282018,-0.102039,DecisionTreeClassifier
2,Disparate_Impact,1.049886,1.009298,1.005764,1.078998,DecisionTreeClassifier
3,Statistical_Parity_Difference,0.052727,0.010036,0.006213,0.078909,DecisionTreeClassifier
4,Accuracy_Parity,-0.035176,-0.011721,-0.046761,-0.153105,DecisionTreeClassifier
5,Label_Stability_Ratio,1.0685,0.988394,0.924832,0.959175,DecisionTreeClassifier
6,IQR_Parity,-0.03971,0.001162,0.004776,-0.031439,DecisionTreeClassifier
7,Std_Parity,-0.027532,0.002593,0.008413,-0.012046,DecisionTreeClassifier
8,Std_Ratio,0.756403,1.026503,1.091746,0.877404,DecisionTreeClassifier
9,Jitter_Parity,-0.039716,0.005379,0.052325,0.029466,DecisionTreeClassifier


## Metrics Visualization and Reporting

In [92]:
visualizer = ExperimentsVisualizer(exp_subgroup_metrics_dct=exp_subgroup_metrics_dct,
                                   exp_avg_runs_group_metrics_dct=exp_avg_group_metrics_dct,
                                   dataset_name=DATASET_NAME,
                                   model_names=MODEL_NAMES,
                                   sensitive_attrs=SENSITIVE_ATTRS)

### Subgroup metrics per dataset, experiment iteration, and model

In [93]:
visualizer.create_subgroups_grid_pct_lines_plot(model_name=MODEL_NAMES[0],
                                                exp_iter='Exp_iter_1',
                                                subgroup_metrics_type='variance')

In [94]:
visualizer.create_subgroups_grid_pct_lines_plot(model_name=MODEL_NAMES[0],
                                                exp_iter='Exp_iter_1',
                                                subgroup_metrics_type='error')

### Group metrics per dataset, experiment iteration, and model

In [95]:
visualizer.create_groups_grid_pct_lines_plot(model_name=MODEL_NAMES[0],
                                             exp_iter='Exp_iter_1',
                                             group_metrics_type='variance')

In [96]:
visualizer.create_groups_grid_pct_lines_plot(model_name=MODEL_NAMES[0],
                                             exp_iter='Exp_iter_1',
                                             group_metrics_type='fairness')

### Specific subgroup metric per dataset, experiment iteration, and multiple models

In [97]:
visualizer.create_subgroups_grid_pct_lines_per_model_plot(subgroup_metric='Jitter',
                                                          model_names=MODEL_NAMES[:2],
                                                          exp_iter='Exp_iter_1')

In [100]:
visualizer.create_subgroups_grid_pct_lines_per_model_plot(subgroup_metric='F1',
                                                          model_names=MODEL_NAMES[:2],
                                                          exp_iter='Exp_iter_1')

### Specific group metric per dataset, experiment iteration, and multiple models

In [101]:
visualizer.create_groups_grid_pct_lines_per_model_plot(group_metric='Label_Stability_Ratio',
                                                       model_names=MODEL_NAMES[:2],
                                                       exp_iter='Exp_iter_1')

In [102]:
visualizer.create_groups_grid_pct_lines_per_model_plot(group_metric='Disparate_Impact',
                                                       model_names=MODEL_NAMES[:2],
                                                       exp_iter='Exp_iter_1')