# Scatter Plots

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import os
import warnings
warnings.filterwarnings('ignore')
os.environ["PYTHONWARNINGS"] = "ignore"

In [3]:
cur_folder_name = os.getcwd().split('/')[-1]
if cur_folder_name != "data-cleaning-stability":
    os.chdir("../../../..")

print('Current location: ', os.getcwd())

Current location:  /Users/denys_herasymuk/Research/NYU/ML_Lifecycle_Project


In [4]:
from source.visualizations.scatter_plots import create_scatter_plot
from configs.constants import (ACS_INCOME_DATASET, ACS_EMPLOYMENT_DATASET, LAW_SCHOOL_DATASET, GERMAN_CREDIT_DATASET,
                               CARDIOVASCULAR_DISEASE_DATASET, BANK_MARKETING_DATASET, DIABETES_DATASET)

## Define global configs

In [5]:
from source.custom_classes.database_client import DatabaseClient, get_secrets_path

db_client_1 = DatabaseClient()
db_client_3 = DatabaseClient(secrets_path=get_secrets_path('secrets_3.env'))
db_client_1.connect()
db_client_3.connect()

In [6]:
DATASETS_ALL_SENSITIVE_ATTRS = {
    ACS_INCOME_DATASET: ['SEX', 'RAC1P', 'SEX&RAC1P'],
    LAW_SCHOOL_DATASET: ['male', 'race', 'male&race'],
    GERMAN_CREDIT_DATASET: ['sex', 'age', 'sex&age'],
    CARDIOVASCULAR_DISEASE_DATASET: ['gender'],
    BANK_MARKETING_DATASET: ['age'],
    DIABETES_DATASET: ['Gender'],
    ACS_EMPLOYMENT_DATASET: ['SEX', 'RAC1P', 'SEX&RAC1P'],
}

DATASETS_SENSITIVE_ATTRS = {
    ACS_INCOME_DATASET: 'SEX&RAC1P',
    LAW_SCHOOL_DATASET: 'male&race',
    GERMAN_CREDIT_DATASET: 'sex',
    CARDIOVASCULAR_DISEASE_DATASET: 'gender',
    BANK_MARKETING_DATASET: 'age',
    DIABETES_DATASET: 'Gender',
    ACS_EMPLOYMENT_DATASET: 'SEX&RAC1P',
}

DATASET_TO_COLUMN_NAME = {
    DIABETES_DATASET: {'cat': ['Family_Diabetes', 'PhysicallyActive', 'RegularMedicine'], 'num': ['SoundSleep']},
    GERMAN_CREDIT_DATASET: {'cat': ['checking-account', 'savings-account', 'employment-since'], 'num': ['duration', 'credit-amount']},
    ACS_INCOME_DATASET: {'cat': ['SCHL', 'MAR'], 'num': ['AGEP', 'WKHP']},
    LAW_SCHOOL_DATASET: {'cat': ['fam_inc', 'tier'], 'num': ['zfygpa', 'ugpa']},
    BANK_MARKETING_DATASET: {'cat': ['education', 'job'], 'num': ['balance', 'campaign']},
    CARDIOVASCULAR_DISEASE_DATASET: {'cat': ['cholesterol', 'gluc'], 'num': ['weight', 'height']},
    ACS_EMPLOYMENT_DATASET: {'cat': ['SCHL', 'DIS', 'MIL'], 'num': ['AGEP']},
}

## Metric Visualizations

In [7]:
create_scatter_plot(missingness_types=[{'train': 'MCAR3', 'test': 'MCAR3'},
                                       {'train': 'MCAR3', 'test': 'MAR3'},
                                       {'train': 'MCAR3', 'test': 'MNAR3'}],
                    imputation_quality_metric_name='f1_score_difference',
                    model_performance_metric_name='Equalized_Odds_TPR',
                    dataset_to_column_name=DATASET_TO_COLUMN_NAME,
                    dataset_to_group=DATASETS_SENSITIVE_ATTRS,
                    db_client_1=db_client_1,
                    db_client_3=db_client_3,
                    without_dummy=False)

Extracted data for german
Extracted data for bank
Extracted data for heart
Extracted data for diabetes
Extracted data for law_school
Extracted data for folk
Extracted data for folk_emp
Extracted data for german
Extracted data for bank
Extracted data for heart
Extracted data for diabetes
Extracted data for law_school
Extracted data for folk
Extracted data for folk_emp

Extracted data for german
Extracted data for bank
Extracted data for heart
Extracted data for diabetes
Extracted data for law_school
Extracted data for folk
Extracted data for folk_emp
Extracted data for german
Extracted data for bank
Extracted data for heart
Extracted data for diabetes
Extracted data for law_school
Extracted data for folk
Extracted data for folk_emp

Extracted data for german
Extracted data for bank
Extracted data for heart
Extracted data for diabetes
Extracted data for law_school
Extracted data for folk
Extracted data for folk_emp
Extracted data for german
Extracted data for bank
Extracted data for hear

In [8]:
create_scatter_plot(missingness_types=[{'train': 'MAR3', 'test': 'MCAR3'},
                                       {'train': 'MAR3', 'test': 'MAR3'},
                                       {'train': 'MAR3', 'test': 'MNAR3'}],
                    imputation_quality_metric_name='rmse_difference',
                    model_performance_metric_name='Disparate_Impact',
                    dataset_to_column_name=DATASET_TO_COLUMN_NAME,
                    dataset_to_group=DATASETS_SENSITIVE_ATTRS,
                    db_client_1=db_client_1,
                    db_client_3=db_client_3,
                    without_dummy=False)

Extracted data for german
Extracted data for bank
Extracted data for heart
Extracted data for diabetes
Extracted data for law_school
Extracted data for folk
Extracted data for folk_emp
Extracted data for german
Extracted data for bank
Extracted data for heart
Extracted data for diabetes
Extracted data for law_school
Extracted data for folk
Extracted data for folk_emp

Extracted data for german
Extracted data for bank
Extracted data for heart
Extracted data for diabetes
Extracted data for law_school
Extracted data for folk
Extracted data for folk_emp
Extracted data for german
Extracted data for bank
Extracted data for heart
Extracted data for diabetes
Extracted data for law_school
Extracted data for folk
Extracted data for folk_emp

Extracted data for german
Extracted data for bank
Extracted data for heart
Extracted data for diabetes
Extracted data for law_school
Extracted data for folk
Extracted data for folk_emp
Extracted data for german
Extracted data for bank
Extracted data for hear

In [9]:
create_scatter_plot(missingness_types=[{'train': 'mixed_exp', 'test': 'MCAR1 & MAR1 & MNAR1'}],
                    imputation_quality_metric_name='kl_divergence_pred_difference',
                    model_performance_metric_name='Equalized_Odds_TPR',
                    dataset_to_column_name=DATASET_TO_COLUMN_NAME,
                    dataset_to_group=DATASETS_SENSITIVE_ATTRS,
                    db_client_1=db_client_1,
                    db_client_3=db_client_3,
                    without_dummy=False)

Extracted data for german
Extracted data for bank
Extracted data for heart
Extracted data for diabetes
Extracted data for law_school
Extracted data for folk
Extracted data for folk_emp
Extracted data for german
Extracted data for bank
Extracted data for heart
Extracted data for diabetes
Extracted data for law_school
Extracted data for folk
Extracted data for folk_emp

