In [1]:
import os
import warnings

warnings.filterwarnings('ignore')
os.environ["PYTHONWARNINGS"] = "ignore"

import gradio as gr
import pandas as pd

from virny.datasets import ACSIncomeDataset, GermanCreditDataset, LawSchoolDataset, CardiovascularDiseaseDataset
from virny.custom_classes.metrics_interactive_visualizer import MetricsInteractiveVisualizer
from source.custom_classes.database_client import DatabaseClient
from configs.constants import (EXP_COLLECTION_NAME, GERMAN_CREDIT_DATASET, BANK_MARKETING_DATASET, CARDIOVASCULAR_DISEASE_DATASET,
                               DIABETES_DATASET, LAW_SCHOOL_DATASET, ACS_INCOME_DATASET)

In [9]:
def read_metrics_from_db(dataset_names: list, null_imputers: list, db_collection_name: str):
    db = DatabaseClient()
    db.connect()
    
    dataset_metrics_dct = dict()
    for dataset_name in dataset_names:
        dataset_metrics_dct[dataset_name] = dict()
        for imputer_idx, null_imputer in enumerate(null_imputers):
            # Extract experimental data for the defined dataset from MongoDB
            query = {'dataset_name': dataset_name, 'null_imputer_name': null_imputer,  'tag': 'demo_20240423'}
            records = db.execute_read_query(db_collection_name, query)
            model_metric_df = pd.DataFrame(records)

            # Capitalize column names to be consistent across the whole library
            new_column_names = []
            for col in model_metric_df.columns:
                new_col_name = '_'.join([c.capitalize() for c in col.split('_')])
                new_column_names.append(new_col_name)

            model_metric_df.columns = new_column_names
            model_metric_df = model_metric_df.drop(columns=['Model_Params', 'Tag', 'Model_Init_Seed', 'Runtime_In_Mins'])
            model_metric_df['Model_Name'] = (model_metric_df['Model_Name'] + '__' + 
                                             model_metric_df['Null_Imputer_Name'] + '_' + 
                                             model_metric_df['Virny_Random_State'].astype(str))
            if imputer_idx == 0:
                dataset_metrics_dct[dataset_name][null_imputer] = model_metric_df
            else:
                dataset_metrics_dct[dataset_name][null_imputer] = (
                    pd.concat([dataset_metrics_dct[dataset_name][null_imputer], model_metric_df], axis=0))
                dataset_metrics_dct[dataset_name][null_imputer] = dataset_metrics_dct[dataset_name][null_imputer].reset_index(drop=True)

            print(f'Extracted metrics for {dataset_name} dataset and {null_imputer} imputer')

    db.close()

    return dataset_metrics_dct

In [18]:
# Define configs for sample datasets
demo_configs = {
    ACS_INCOME_DATASET: {
        'data_loader': ACSIncomeDataset(state=['GA'], year=2018, with_nulls=False,
                                        subsample_size=15_000, subsample_seed=42),
        'sensitive_attributes_dct': {'SEX': '2', 'RAC1P': ['2', '3', '4', '5', '6', '7', '8', '9'], 'SEX&RAC1P': None},
    },
    LAW_SCHOOL_DATASET: {
        'data_loader': LawSchoolDataset(),
        'sensitive_attributes_dct': {'male': '0', 'race': 'Non-White', 'male&race': None},
    },
    GERMAN_CREDIT_DATASET: {
        'data_loader': GermanCreditDataset(),
        'sensitive_attributes_dct': {'sex': 'female', 'age': [19, 20, 21, 22, 23, 24, 25], 'sex&age': None},
    },
    CARDIOVASCULAR_DISEASE_DATASET: {
        'data_loader': CardiovascularDiseaseDataset(),
        'sensitive_attributes_dct': {'gender': '1'},
    },
}

In [19]:
dataset_metrics_dct = read_metrics_from_db(dataset_names=list(demo_configs.keys()),
                                           null_imputers=['baseline'],
                                           db_collection_name=EXP_COLLECTION_NAME)

Extracted metrics for folk dataset and baseline imputer
Extracted metrics for law_school dataset and baseline imputer
Extracted metrics for german dataset and baseline imputer
Extracted metrics for heart dataset and baseline imputer


In [20]:
for dataset_name in dataset_metrics_dct.keys():
    model_metric_df = dataset_metrics_dct[dataset_name]['baseline']
    
    # Create columns based on values in the Subgroup column
    pivoted_model_metric_df = model_metric_df.pivot(columns='Subgroup', values='Metric_Value',
                                                     index=[col for col in model_metric_df.columns
                                                            if col not in ('Subgroup', 'Metric_Value')]).reset_index()
    pivoted_model_metric_df = pivoted_model_metric_df.rename_axis(None, axis=1)

    dataset_metrics_dct[dataset_name]['baseline'] = pivoted_model_metric_df

In [21]:
dataset_metrics_dct[ACS_INCOME_DATASET]['baseline'].head()

Unnamed: 0,Metric,Model_Name,Virny_Random_State,Dataset_Name,Num_Estimators,Record_Create_Date_Time,Session_Uuid,Null_Imputer_Name,Evaluation_Scenario,Experiment_Iteration,...,SEX&RAC1P_priv,SEX&RAC1P_priv_correct,SEX&RAC1P_priv_incorrect,SEX_dis,SEX_dis_correct,SEX_dis_incorrect,SEX_priv,SEX_priv_correct,SEX_priv_incorrect,overall
0,Accuracy,lr_clf__baseline_100,100,folk,50,2024-04-23 11:39:36.338,be84ae66-0165-11ef-a016-ae7d8bf09115,baseline,baseline,exp_iter_1,...,0.807987,1.0,0.0,0.833797,1.0,0.0,0.795134,1.0,0.0,0.813667
1,Accuracy,lr_clf__baseline_200,200,folk,50,2024-04-23 12:24:26.037,be84ae66-0165-11ef-a016-ae7d8bf09115,baseline,baseline,exp_iter_2,...,0.8116,1.0,0.0,0.826531,1.0,0.0,0.807843,1.0,0.0,0.817
2,Aleatoric_Uncertainty,lr_clf__baseline_100,100,folk,50,2024-04-23 11:39:36.338,be84ae66-0165-11ef-a016-ae7d8bf09115,baseline,baseline,exp_iter_1,...,0.590708,0.535772,0.821874,0.56557,0.515599,0.816258,0.591517,0.534126,0.814268,0.57908
3,Aleatoric_Uncertainty,lr_clf__baseline_200,200,folk,50,2024-04-23 12:24:26.037,be84ae66-0165-11ef-a016-ae7d8bf09115,baseline,baseline,exp_iter_2,...,0.572155,0.516852,0.810392,0.555837,0.500843,0.817869,0.573127,0.519962,0.796637,0.564655
4,Epistemic_Uncertainty,lr_clf__baseline_100,100,folk,50,2024-04-23 11:39:36.338,be84ae66-0165-11ef-a016-ae7d8bf09115,baseline,baseline,exp_iter_1,...,0.015859,0.014343,0.022235,0.015082,0.013374,0.023649,0.016275,0.014807,0.02197,0.015703


In [22]:
# Create gradio demo objects for each sample dataset
dataset_names = list(demo_configs.keys())
sample_demos = []
for dataset_name in dataset_names:
    sample_demo = MetricsInteractiveVisualizer(
        X_data=demo_configs[dataset_name]['data_loader'].X_data,
        y_data=demo_configs[dataset_name]['data_loader'].y_data,
        model_metrics=dataset_metrics_dct[dataset_name]['baseline'],
        sensitive_attributes_dct=demo_configs[dataset_name]['sensitive_attributes_dct']
    ).create_web_app(start_app=False)
    sample_demos.append(sample_demo)

In [23]:
# Build a web application with tabs for each sample dataset
demo = gr.TabbedInterface(sample_demos, [name.replace('_', ' ') for name in dataset_names], theme=gr.themes.Soft())
demo.launch(inline=False, debug=True, show_error=True)

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.
Keyboard interruption in main thread... closing server.


