In [1]:
import os
import warnings

warnings.filterwarnings('ignore')
os.environ["PYTHONWARNINGS"] = "ignore"

import gradio as gr
import pandas as pd

from virny.datasets import ACSIncomeDataset, GermanCreditDataset, LawSchoolDataset, CardiovascularDiseaseDataset
from virny.custom_classes.metrics_interactive_visualizer import MetricsInteractiveVisualizer
from source.custom_classes.database_client import DatabaseClient
from configs.constants import (EXP_COLLECTION_NAME, GERMAN_CREDIT_DATASET, BANK_MARKETING_DATASET, CARDIOVASCULAR_DISEASE_DATASET,
                               DIABETES_DATASET, LAW_SCHOOL_DATASET, ACS_INCOME_DATASET)

In [2]:
def read_metrics_from_db(dataset_names: list, null_imputers: list, db_collection_name: str):
    db = DatabaseClient()
    db.connect()
    
    dataset_metrics_dct = dict()
    for dataset_name in dataset_names:
        dataset_metrics_dct[dataset_name] = dict()
        for imputer_idx, null_imputer in enumerate(null_imputers):
            # Extract experimental data for the defined dataset from MongoDB
            # query = {'dataset_name': dataset_name, 'null_imputer_name': null_imputer,  'tag': 'demo_20240423'}
            query = {'dataset_name': dataset_name, 'null_imputer_name': null_imputer,  'tag': 'OK'}
            records = db.execute_read_query(db_collection_name, query)
            model_metric_df = pd.DataFrame(records)

            # Capitalize column names to be consistent across the whole library
            new_column_names = []
            for col in model_metric_df.columns:
                new_col_name = '_'.join([c.capitalize() for c in col.split('_')])
                new_column_names.append(new_col_name)

            model_metric_df.columns = new_column_names
            model_metric_df = model_metric_df.drop(columns=['Model_Params', 'Tag', 'Model_Init_Seed', 'Runtime_In_Mins'])
            model_metric_df['Model_Name'] = (model_metric_df['Model_Name'] + '__' + 
                                             model_metric_df['Null_Imputer_Name'] + '_' + 
                                             model_metric_df['Virny_Random_State'].astype(str))
            if imputer_idx == 0:
                dataset_metrics_dct[dataset_name][null_imputer] = model_metric_df
            else:
                dataset_metrics_dct[dataset_name][null_imputer] = (
                    pd.concat([dataset_metrics_dct[dataset_name][null_imputer], model_metric_df], axis=0))
                dataset_metrics_dct[dataset_name][null_imputer] = dataset_metrics_dct[dataset_name][null_imputer].reset_index(drop=True)

            print(f'Extracted metrics for {dataset_name} dataset and {null_imputer} imputer')

    db.close()

    return dataset_metrics_dct

In [3]:
# Define configs for sample datasets
demo_configs = {
    ACS_INCOME_DATASET: {
        'data_loader': ACSIncomeDataset(state=['GA'], year=2018, with_nulls=False,
                                        subsample_size=15_000, subsample_seed=42),
        'sensitive_attributes_dct': {'SEX': '2', 'RAC1P': ['2', '3', '4', '5', '6', '7', '8', '9'], 'SEX&RAC1P': None},
    },
    LAW_SCHOOL_DATASET: {
        'data_loader': LawSchoolDataset(),
        'sensitive_attributes_dct': {'male': '0', 'race': 'Non-White', 'male&race': None},
    },
    GERMAN_CREDIT_DATASET: {
        'data_loader': GermanCreditDataset(),
        'sensitive_attributes_dct': {'sex': 'female', 'age': [19, 20, 21, 22, 23, 24, 25], 'sex&age': None},
    },
    CARDIOVASCULAR_DISEASE_DATASET: {
        'data_loader': CardiovascularDiseaseDataset(),
        'sensitive_attributes_dct': {'gender': '1'},
    },
}

In [4]:
dataset_metrics_dct = read_metrics_from_db(dataset_names=list(demo_configs.keys()),
                                           null_imputers=['baseline'],
                                           db_collection_name=EXP_COLLECTION_NAME)

Extracted metrics for folk dataset and baseline imputer
Extracted metrics for law_school dataset and baseline imputer
Extracted metrics for german dataset and baseline imputer
Extracted metrics for heart dataset and baseline imputer


In [5]:
for dataset_name in dataset_metrics_dct.keys():
    model_metric_df = dataset_metrics_dct[dataset_name]['baseline']
    
    # Create columns based on values in the Subgroup column
    pivoted_model_metric_df = model_metric_df.pivot(columns='Subgroup', values='Metric_Value',
                                                     index=[col for col in model_metric_df.columns
                                                            if col not in ('Subgroup', 'Metric_Value')]).reset_index()
    pivoted_model_metric_df = pivoted_model_metric_df.rename_axis(None, axis=1)

    dataset_metrics_dct[dataset_name]['baseline'] = pivoted_model_metric_df

In [10]:
dataset_metrics_dct[GERMAN_CREDIT_DATASET]['baseline']['Model_Name'].unique()

array(['dt_clf__baseline_100', 'dt_clf__baseline_200',
       'dt_clf__baseline_300', 'dt_clf__baseline_400',
       'dt_clf__baseline_500', 'dt_clf__baseline_600',
       'lr_clf__baseline_100', 'lr_clf__baseline_200',
       'lr_clf__baseline_300', 'lr_clf__baseline_400',
       'lr_clf__baseline_500', 'lr_clf__baseline_600'], dtype=object)

In [9]:
dataset_metrics_dct[GERMAN_CREDIT_DATASET]['baseline'].head()

Unnamed: 0,Metric,Model_Name,Virny_Random_State,Dataset_Name,Num_Estimators,Record_Create_Date_Time,Session_Uuid,Null_Imputer_Name,Evaluation_Scenario,Experiment_Iteration,...,sex&age_dis_incorrect,sex&age_priv,sex&age_priv_correct,sex&age_priv_incorrect,sex_dis,sex_dis_correct,sex_dis_incorrect,sex_priv,sex_priv_correct,sex_priv_incorrect
0,Accuracy,dt_clf__baseline_100,100,german,50,2024-04-23 13:33:04.028,00a7a284-0176-11ef-831e-3cfdfe6139a0,baseline,baseline,exp_iter_1,...,0.0,0.723247,1.0,0.0,0.722892,1.0,0.0,0.709677,1.0,0.0
1,Accuracy,dt_clf__baseline_200,200,german,50,2024-04-23 13:33:09.186,00a7a284-0176-11ef-831e-3cfdfe6139a0,baseline,baseline,exp_iter_2,...,0.0,0.77037,1.0,0.0,0.689655,1.0,0.0,0.784038,1.0,0.0
2,Accuracy,dt_clf__baseline_300,300,german,50,2024-04-23 13:33:14.137,00a7a284-0176-11ef-831e-3cfdfe6139a0,baseline,baseline,exp_iter_3,...,0.0,0.781955,1.0,0.0,0.773196,1.0,0.0,0.773399,1.0,0.0
3,Accuracy,dt_clf__baseline_400,400,german,50,2024-04-23 13:33:18.439,00a7a284-0176-11ef-831e-3cfdfe6139a0,baseline,baseline,exp_iter_4,...,0.0,0.778195,1.0,0.0,0.776596,1.0,0.0,0.762136,1.0,0.0
4,Accuracy,dt_clf__baseline_500,500,german,50,2024-04-23 13:33:21.790,00a7a284-0176-11ef-831e-3cfdfe6139a0,baseline,baseline,exp_iter_5,...,0.0,0.750929,1.0,0.0,0.730769,1.0,0.0,0.739796,1.0,0.0


In [7]:
# Create gradio demo objects for each sample dataset
dataset_names = list(demo_configs.keys())
sample_demos = []
for dataset_name in dataset_names:
    sample_demo = MetricsInteractiveVisualizer(
        X_data=demo_configs[dataset_name]['data_loader'].X_data,
        y_data=demo_configs[dataset_name]['data_loader'].y_data,
        model_metrics=dataset_metrics_dct[dataset_name]['baseline'],
        sensitive_attributes_dct=demo_configs[dataset_name]['sensitive_attributes_dct']
    ).create_web_app(start_app=False)
    sample_demos.append(sample_demo)

In [8]:
# Build a web application with tabs for each sample dataset
demo = gr.TabbedInterface(sample_demos, [name.replace('_', ' ') for name in dataset_names], theme=gr.themes.Soft())
demo.launch(inline=False, debug=True, show_error=True)

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


Traceback (most recent call last):
  File "/Users/denys_herasymuk/Research/NYU/ML_Lifecycle_Project/Code/data-cleaning-stability/vldb_venv/lib/python3.9/site-packages/gradio/queueing.py", line 459, in call_prediction
    output = await route_utils.call_process_api(
  File "/Users/denys_herasymuk/Research/NYU/ML_Lifecycle_Project/Code/data-cleaning-stability/vldb_venv/lib/python3.9/site-packages/gradio/route_utils.py", line 232, in call_process_api
    output = await app.get_blocks().process_api(
  File "/Users/denys_herasymuk/Research/NYU/ML_Lifecycle_Project/Code/data-cleaning-stability/vldb_venv/lib/python3.9/site-packages/gradio/blocks.py", line 1533, in process_api
    result = await self.call_function(
  File "/Users/denys_herasymuk/Research/NYU/ML_Lifecycle_Project/Code/data-cleaning-stability/vldb_venv/lib/python3.9/site-packages/gradio/blocks.py", line 1151, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "/Users/denys_herasymuk/Research/NYU/ML_Lifecycl

Keyboard interruption in main thread... closing server.


