In [1]:
import pandas as pd
import joblib
from sklearn.model_selection import train_test_split

In [2]:
def get_stratify_col(y, stratify_col):
    if stratify_col is None:
        stratification = None
    else:
        stratification = y[stratify_col]
    
    return stratification

def read_data(path,
            target_column,
            set_index = None):

    df_default = pd.read_csv(path, index_col = set_index)
    #rename the target name
    df_default = df_default.rename(columns={'default.payment.next.month': 'TARGET'})

    #create input(all of x) and output(all of y)
    output_df = df_default[target_column].reset_index(drop=True)
    input_df = df_default.drop([target_column], axis = 1)
    
    return output_df, input_df

def run_split_data(x, y,
                    stratify_col=None,
                    TEST_SIZE=0.2,
                    save_file = True,
                    return_file = True):
    
    strat_train = get_stratify_col(y, stratify_col)
    x_train, x_test, y_train, y_test = train_test_split(x, y,
                                       stratify = strat_train,
                                       test_size= TEST_SIZE*2,
                                       random_state= 42)
    
    strat_test = get_stratify_col(y_test, stratify_col)
    x_valid, x_test, y_valid, y_test = train_test_split(x_test, y_test,
                                       stratify = strat_test,
                                       test_size= 0.5,
                                       random_state= 42)

    if save_file:
        joblib.dump(x_train, "output/x_train.pkl")
        joblib.dump(y_train, "output/y_train.pkl")
        joblib.dump(x_valid, "output/x_valid.pkl")
        joblib.dump(y_valid, "output/y_valid.pkl")
        joblib.dump(x_test, "output/x_test.pkl")
        joblib.dump(y_test, "output/y_test.pkl")
        
    if return_file:
        return x_train, y_train, x_valid, y_valid, x_test, y_test

In [5]:
PATH = 'data/UCI_Credit_Card.csv'
output_df, input_df = read_data(PATH, target_column='TARGET', set_index='ID')
x_train, y_train, x_valid, y_valid, x_test, y_test = run_split_data(input_df, output_df)

In [4]:
print(x_train.shape)
print(y_train.shape)
print(x_valid.shape)
print(y_valid.shape)
print(x_test.shape)
print(y_test.shape)

(18000, 23)
(18000,)
(6000, 23)
(6000,)
(6000, 23)
(6000,)
