# Visualizations for Law School Imputation

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import warnings
from pathlib import Path
warnings.filterwarnings('ignore')
os.environ["PYTHONWARNINGS"] = "ignore"

# Define a correct root path
sys.path.append(str(Path(os.getcwd()).parent.parent.parent))

In [3]:
from source.visualizations.imputers_viz import create_box_plots_for_diff_imputers
from configs.constants import (ACS_INCOME_DATASET, LAW_SCHOOL_DATASET, GERMAN_CREDIT_DATASET,
                               CARDIOVASCULAR_DISEASE_DATASET, BANK_MARKETING_DATASET, DIABETES_DATASET)

## Initialize Configs

In [4]:
DATASET_NAME = LAW_SCHOOL_DATASET
DATASETS_SENSITIVE_ATTRS = {
    ACS_INCOME_DATASET: ['SEX', 'RAC1P', 'SEX&RAC1P'],
    LAW_SCHOOL_DATASET: ['male', 'race', 'male&race'],
    GERMAN_CREDIT_DATASET: ['sex', 'age', 'sex&age'],
    CARDIOVASCULAR_DISEASE_DATASET: ['gender'],
    BANK_MARKETING_DATASET: ['age'],
    DIABETES_DATASET: ['Gender'],
}
SENSITIVE_ATTR_FOR_DISPARITY_METRICS = DATASETS_SENSITIVE_ATTRS[DATASET_NAME][-1]

In [5]:
from source.custom_classes.database_client import DatabaseClient

db_client = DatabaseClient()
db_client.connect()

## Metric Visualizations

### Overall Metrics

In [9]:
# Numerical
create_box_plots_for_diff_imputers(dataset_name=DATASET_NAME,
                                   column_name='ugpa',
                                   metric_name='rmse', 
                                   db_client=db_client,
                                   ylim=[0.88, 1.15])

In [16]:
# Categorical
create_box_plots_for_diff_imputers(dataset_name=DATASET_NAME,
                                   column_name='tier',
                                   metric_name='f1_score',
                                   db_client=db_client,
                                   without_dummy=True,
                                   ylim=[0.34, 0.54])

In [21]:
# Categorical
create_box_plots_for_diff_imputers(dataset_name=DATASET_NAME,
                                   column_name='tier',
                                   metric_name='kl_divergence_pred',
                                   db_client=db_client,
                                   without_dummy=True)

In [None]:
db_client.close()