In [1]:
import tensorflow as tf
import pandas as pd
import numpy as np
from utils.df_loader import load_adult_df, load_compas_df, load_german_df, load_diabetes_df, load_breast_cancer_df
from sklearn.model_selection import train_test_split
from utils.preprocessing import preprocess_df
from utils.save import save_result_as_csv

#### Select dataset ####
dataset_name = 'compas' # [adult, german, compas]

if dataset_name == 'adult':
    dataset_loading_fn = load_adult_df
elif dataset_name == 'german':
    dataset_loading_fn = load_german_df
elif dataset_name == 'compas':
    dataset_loading_fn = load_compas_df
elif dataset_name == 'diabetes':
    dataset_loading_fn = load_diabetes_df
elif dataset_name == 'breast_cancer':
    dataset_loading_fn = load_breast_cancer_df
else:
    raise Exception("Unsupported dataset")

In [2]:
df_info = preprocess_df(dataset_loading_fn)

In [3]:
df_info.numerical_cols

['age', 'priors_count', 'days_b_screening_arrest', 'length_of_stay']

In [4]:
df_info.categorical_cols

['age_cat',
 'sex',
 'race',
 'c_charge_degree',
 'is_recid',
 'is_violent_recid',
 'two_year_recid',
 'class']

In [5]:
df_info.scaled_df.max()

age                                 1.0
age_cat                    Less than 25
sex                                Male
race                              Other
priors_count                        1.0
days_b_screening_arrest             1.0
c_charge_degree                       M
is_recid                              1
is_violent_recid                      1
two_year_recid                        1
length_of_stay                      1.0
class                        Medium-Low
dtype: object

In [6]:
seed = 123 

### Seperate to train and test set.
train_df, test_df = train_test_split(df_info.dummy_df, train_size=.8, random_state=seed, shuffle=True)

In [7]:
### Get training and testing array.
X_train = np.array(train_df[df_info.ohe_feature_names])
y_train = np.array(train_df[df_info.target_name])
X_test = np.array(test_df[df_info.ohe_feature_names])
y_test = np.array(test_df[df_info.target_name])

In [8]:
X_train.max()

1.0

In [9]:
X_test.max()

1.0