In [11]:
# import OneHotEncoder from sklearn
from sklearn.preprocessing import OneHotEncoder

# import pandas
import pandas as pd

# function to one hot encode
def one_hot_encode(df, cat_cols, mode):

    # copy of dataframe
    df_copy = df.copy()[cat_cols]

    # empty df to store one hot encoded columns
    df_encoded = pd.DataFrame()
    # if mode is "develop"
    if mode == "develop":
        # save the fitted encoder for use during evaluate mode
        fitted_encoder = OneHotEncoder(handle_unknown='ignore').fit(df_copy)

    # transform the column
    col_encoded = fitted_encoder.transform(df_copy)

    # create a dataframe with the one hot encoded columns
    # and column names as the original column name + the category
    # e.g. if the original column name is "color" and the categories are "red" and "blue"
    # the new columns will be "color_red" and "color_blue"
    # the column names are accessed using the .categories_ attribute of the encoder
    # the .toarray() method is used to convert the sparse matrix to a dense matrix
    # the dense matrix is then converted to a dataframe
    # the column names are set to the categories
    # the index is set to the index of the original dataframe
    # the index is reset to ensure the index is sequential 


    # get the new column names
    new_cols = fitted_encoder.get_feature_names_out()

    print(new_cols)

    # create a dataframe with the one hot encoded columns
    col_encoded_df = pd.DataFrame(col_encoded.toarray(), columns=new_cols)

    # concat to the final dataframe
    df_encoded = pd.concat([df_encoded, col_encoded_df], axis=1)

    # return the dataframe
    return df_encoded, fitted_encoder


# read the df
df = pd.read_csv('./data/adult1.csv')

# get the features and target
X = df.drop(columns=['income'])
y = df['income']


# get the categorical columns
cat_cols = X.select_dtypes(include=['object']).columns

# get the numerical columns
num_cols = X.select_dtypes(exclude=['object']).columns


print("Categorical cols", cat_cols)
print("numerical cols", num_cols)


# call the function
df_encoded, fitted_encoders = one_hot_encode(X, cat_cols, "develop")

        
# view the head of the dataframe
df_encoded.head()




Categorical cols Index(['workclass', 'education', 'marital.status', 'occupation',
       'relationship', 'race', 'sex', 'native.country'],
      dtype='object')
numerical cols Index(['age', 'fnlwgt', 'education.num', 'capital.gain', 'capital.loss',
       'hours.per.week'],
      dtype='object')
['workclass_?' 'workclass_Federal-gov' 'workclass_Local-gov'
 'workclass_Never-worked' 'workclass_Private' 'workclass_Self-emp-inc'
 'workclass_Self-emp-not-inc' 'workclass_State-gov'
 'workclass_Without-pay' 'education_10th' 'education_11th'
 'education_12th' 'education_1st-4th' 'education_5th-6th'
 'education_7th-8th' 'education_9th' 'education_Assoc-acdm'
 'education_Assoc-voc' 'education_Bachelors' 'education_Doctorate'
 'education_HS-grad' 'education_Masters' 'education_Preschool'
 'education_Prof-school' 'education_Some-college'
 'marital.status_Divorced' 'marital.status_Married-AF-spouse'
 'marital.status_Married-civ-spouse'
 'marital.status_Married-spouse-absent' 'marital.status_Never-m

Unnamed: 0,workclass_?,workclass_Federal-gov,workclass_Local-gov,workclass_Never-worked,workclass_Private,workclass_Self-emp-inc,workclass_Self-emp-not-inc,workclass_State-gov,workclass_Without-pay,education_10th,...,native.country_Portugal,native.country_Puerto-Rico,native.country_Scotland,native.country_South,native.country_Taiwan,native.country_Thailand,native.country_Trinadad&Tobago,native.country_United-States,native.country_Vietnam,native.country_Yugoslavia
0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
1,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
2,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
3,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
4,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
