In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import os
import warnings
warnings.filterwarnings('ignore')
os.environ["PYTHONWARNINGS"] = "ignore"

In [3]:
cur_folder_name = os.getcwd().split('/')[-1]
if cur_folder_name != "Virny":
    os.chdir("../..")

print('Current location: ', os.getcwd())

Current location:  /Users/denys_herasymuk/UCU/4course_2term/Bachelor_Thesis/Code/Virny


# Multiple Models Interface Usage

In [4]:
import os

from virny.utils.custom_initializers import read_model_metric_dfs, create_config_obj
from virny.custom_classes.metrics_interactive_visualizer import MetricsInteractiveVisualizer
from virny.custom_classes.metrics_composer import MetricsComposer

In [5]:
ROOT_DIR = os.path.join('docs', 'examples')
config_yaml_path = os.path.join(ROOT_DIR, 'experiment_config.yaml')
config_yaml_content = """
dataset_name: COMPAS_Without_Sensitive_Attributes
bootstrap_fraction: 0.8
n_estimators: 50  # Better to input the higher number of estimators than 100; this is only for this use case example
sensitive_attributes_dct: {'sex': 1, 'race': 'African-American', 'sex&race': None}
"""
with open(config_yaml_path, 'w', encoding='utf-8') as f:
    f.write(config_yaml_content)

config = create_config_obj(config_yaml_path=config_yaml_path)
model_names = ['DecisionTreeClassifier', 'LogisticRegression', 'RandomForestClassifier', 'XGBClassifier']
SAVE_RESULTS_DIR_PATH = os.path.join(ROOT_DIR, 'results', 'COMPAS_Without_Sensitive_Attributes_Metrics_20230812__224136')

In [6]:
models_metrics_dct = read_model_metric_dfs(SAVE_RESULTS_DIR_PATH, model_names=model_names)

In [7]:
metrics_composer = MetricsComposer(models_metrics_dct, config.sensitive_attributes_dct)

In [8]:
# Compute composed metrics
models_composed_metrics_df = metrics_composer.compose_metrics()

In [185]:
models_metrics_dct['RandomForestClassifier'].head(100)

Unnamed: 0,Metric,overall,sex_priv,sex_priv_correct,sex_priv_incorrect,sex_dis,sex_dis_correct,sex_dis_incorrect,race_priv,race_priv_correct,...,race_dis_correct,race_dis_incorrect,sex&race_priv,sex&race_priv_correct,sex&race_priv_incorrect,sex&race_dis,sex&race_dis_correct,sex&race_dis_incorrect,Model_Name,Model_Params
0,Mean,0.52427,0.578645,0.60079,0.517352,0.510692,0.514399,0.501767,0.597526,0.618185,...,0.473863,0.484344,0.586391,0.60729,0.529874,0.462617,0.453857,0.482517,RandomForestClassifier,"{'bootstrap': True, 'ccp_alpha': 0.0, 'class_w..."
1,Std,0.067963,0.073618,0.072201,0.077539,0.066551,0.064791,0.070788,0.069162,0.066865,...,0.065947,0.07006,0.068718,0.066018,0.076019,0.067213,0.066631,0.068536,RandomForestClassifier,"{'bootstrap': True, 'ccp_alpha': 0.0, 'class_w..."
2,IQR,0.090596,0.099782,0.098402,0.1036,0.088303,0.085977,0.0939,0.093184,0.089451,...,0.087919,0.091258,0.09202,0.088338,0.101975,0.089184,0.088747,0.090175,RandomForestClassifier,"{'bootstrap': True, 'ccp_alpha': 0.0, 'class_w..."
3,Aleatoric_Uncertainty,0.834874,0.846689,0.826891,0.901488,0.831924,0.81717,0.86744,0.821672,0.807043,...,0.827404,0.880296,0.832383,0.817398,0.872906,0.837346,0.821026,0.874418,RandomForestClassifier,"{'bootstrap': True, 'ccp_alpha': 0.0, 'class_w..."
4,Overall_Uncertainty,0.859083,0.876581,0.856843,0.931213,0.854713,0.839203,0.892051,0.847778,0.832001,...,0.850193,0.903737,0.857995,0.84179,0.901818,0.860162,0.843933,0.897027,RandomForestClassifier,"{'bootstrap': True, 'ccp_alpha': 0.0, 'class_w..."
5,Statistical_Bias,0.405041,0.395811,0.314809,0.620012,0.407346,0.301656,0.661771,0.393484,0.296788,...,0.30951,0.650314,0.396398,0.30252,0.650263,0.41362,0.306294,0.657422,RandomForestClassifier,"{'bootstrap': True, 'ccp_alpha': 0.0, 'class_w..."
6,Jitter,0.106917,0.13209,0.112864,0.185306,0.100631,0.091351,0.122972,0.107225,0.097218,...,0.094812,0.134214,0.108871,0.095304,0.145559,0.104978,0.096287,0.124722,RandomForestClassifier,"{'bootstrap': True, 'ccp_alpha': 0.0, 'class_w..."
7,Per_Sample_Accuracy,0.691061,0.71109,0.918452,0.137143,0.686059,0.936918,0.082177,0.708261,0.930526,...,0.934866,0.09134,0.708783,0.933073,0.102254,0.673472,0.933152,0.08358,RandomForestClassifier,"{'bootstrap': True, 'ccp_alpha': 0.0, 'class_w..."
8,Label_Stability,0.851667,0.807393,0.836903,0.725714,0.862722,0.87397,0.835645,0.848213,0.861316,...,0.869732,0.81732,0.847224,0.866354,0.795493,0.856075,0.866304,0.83284,RandomForestClassifier,"{'bootstrap': True, 'ccp_alpha': 0.0, 'class_w..."
9,TPR,0.679406,0.613333,1.0,0.0,0.691919,1.0,0.0,0.585034,1.0,...,1.0,0.0,0.595745,1.0,0.0,0.734982,1.0,0.0,RandomForestClassifier,"{'bootstrap': True, 'ccp_alpha': 0.0, 'class_w..."


In [135]:
models_composed_metrics_df.head(20)

Unnamed: 0,Metric,sex,race,sex&race,Model_Name
0,Equalized_Odds_TPR,0.211919,0.195326,0.183576,DecisionTreeClassifier
1,Equalized_Odds_FPR,0.098356,0.104728,0.141078,DecisionTreeClassifier
2,Equalized_Odds_FNR,-0.211919,-0.195326,-0.183576,DecisionTreeClassifier
3,Disparate_Impact,1.234115,1.135965,1.125105,DecisionTreeClassifier
4,Statistical_Parity_Difference,0.193535,0.123016,0.115123,DecisionTreeClassifier
5,Accuracy_Parity,0.009832,0.00684,-0.010984,DecisionTreeClassifier
6,Label_Stability_Ratio,1.02474,0.997454,0.995869,DecisionTreeClassifier
7,IQR_Parity,0.000768,-0.004804,-0.003282,DecisionTreeClassifier
8,Std_Parity,-0.005106,-0.000927,-0.001976,DecisionTreeClassifier
9,Std_Ratio,0.931699,0.986984,0.972422,DecisionTreeClassifier


## Metrics Visualization and Reporting

In [322]:
visualizer = MetricsInteractiveVisualizer(models_metrics_dct, models_composed_metrics_df,
                                          sensitive_attributes_dct=config.sensitive_attributes_dct)

In [323]:
visualizer.start_web_app()

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.
Keyboard interruption in main thread... closing server.


In [17]:
visualizer.stop_web_app()

Closing server running on port: 7860
