In [1]:
from sklearn.tree import DecisionTreeClassifier
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import classification_report, roc_auc_score, accuracy_score
from sklearn.pipeline import Pipeline

In [2]:
df = pd.read_csv("cleaned_data.csv")

In [3]:
X = np.array(df.drop(columns=["h1n1_vaccine", "seasonal_vaccine"]))
y = np.array(df["h1n1_vaccine"])

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

model = DecisionTreeClassifier(
    criterion='gini',
    splitter='best',
    max_depth=5,
    min_samples_split=2,
    min_samples_leaf=1,
    min_weight_fraction_leaf=0.0,
    max_features=None,
    random_state=42,
    max_leaf_nodes=None,
    min_impurity_decrease=0.0,
    class_weight=None,
    ccp_alpha=0.0,
    monotonic_cst=None
)

model.fit(X_train, y_train)

print("Classification Report for Train set")
print(classification_report(y_train, model.predict(X_train)))
print("\nClassification Report for Test set")
print(classification_report(y_test, model.predict(X_test)))

Classification Report for Train set
              precision    recall  f1-score   support

         0.0       0.86      0.95      0.90     16821
         1.0       0.67      0.41      0.51      4544

    accuracy                           0.83     21365
   macro avg       0.76      0.68      0.70     21365
weighted avg       0.82      0.83      0.82     21365


Classification Report for Test set
              precision    recall  f1-score   support

         0.0       0.86      0.95      0.90      4212
         1.0       0.68      0.41      0.52      1130

    accuracy                           0.84      5342
   macro avg       0.77      0.68      0.71      5342
weighted avg       0.82      0.84      0.82      5342



In [4]:
pipeline = Pipeline([
    ('dt', DecisionTreeClassifier(random_state=42))
])

param_grid = {
    'dt__criterion': ['gini', 'entropy'],
    'dt__max_depth': [3, 4, 5, 6, 7, 10],
    'dt__min_samples_split': [2, 3, 4, 5, 7, 10],
    'dt__min_samples_leaf': [1, 2, 3, 4, 5]
}

grid_search = GridSearchCV(
    estimator=pipeline,
    param_grid=param_grid,
    scoring='roc_auc',
    cv=5,
    n_jobs=-1,
)

grid_search.fit(X_train, y_train)
print("Best parameters found: ", grid_search.best_params_)

Best parameters found:  {'dt__criterion': 'gini', 'dt__max_depth': 5, 'dt__min_samples_leaf': 1, 'dt__min_samples_split': 2}


In [5]:
model = DecisionTreeClassifier(
    criterion='gini',
    splitter='best',
    max_depth=5,
    min_samples_split=2,
    min_samples_leaf=1,
    min_weight_fraction_leaf=0.0,
    max_features=None,
    random_state=42,
    max_leaf_nodes=None,
    min_impurity_decrease=0.0,
    class_weight=None,
    ccp_alpha=0.0,
    monotonic_cst=None
)

model.fit(X_train, y_train)

print("Classification Report for Train set")
print(classification_report(y_train, model.predict(X_train)))
print("\nClassification Report for Test set")
print(classification_report(y_test, model.predict(X_test)))

Classification Report for Train set
              precision    recall  f1-score   support

         0.0       0.86      0.95      0.90     16821
         1.0       0.67      0.41      0.51      4544

    accuracy                           0.83     21365
   macro avg       0.76      0.68      0.70     21365
weighted avg       0.82      0.83      0.82     21365


Classification Report for Test set
              precision    recall  f1-score   support

         0.0       0.86      0.95      0.90      4212
         1.0       0.68      0.41      0.52      1130

    accuracy                           0.84      5342
   macro avg       0.77      0.68      0.71      5342
weighted avg       0.82      0.84      0.82      5342



In [6]:
roc_auc_score(y_test, model.predict_proba(X_test)[:, 1])

0.8192712141458454

# Testing the Pipeline

In [7]:
from config import SELECTED_FEATURES, TARGET

In [11]:
df_features = pd.read_csv("test_set_features.csv", index_col="respondent_id")
#df_features = df_features[SELECTED_FEATURES]

In [12]:
df_features

Unnamed: 0_level_0,h1n1_concern,h1n1_knowledge,behavioral_antiviral_meds,behavioral_avoidance,behavioral_face_mask,behavioral_wash_hands,behavioral_large_gatherings,behavioral_outside_home,behavioral_touch_face,doctor_recc_h1n1,...,income_poverty,marital_status,rent_or_own,employment_status,hhs_geo_region,census_msa,household_adults,household_children,employment_industry,employment_occupation
respondent_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
26707,2.0,2.0,0.0,1.0,0.0,1.0,1.0,0.0,1.0,0.0,...,"> $75,000",Not Married,Rent,Employed,mlyzmhmf,"MSA, Not Principle City",1.0,0.0,atmlpfrs,hfxkjkmi
26708,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,Below Poverty,Not Married,Rent,Employed,bhuqouqj,Non-MSA,3.0,0.0,atmlpfrs,xqwwgdyp
26709,2.0,2.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0,...,"> $75,000",Married,Own,Employed,lrircsnp,Non-MSA,1.0,0.0,nduyfdeo,pvmttkik
26710,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,"<= $75,000, Above Poverty",Married,Own,Not in Labor Force,lrircsnp,"MSA, Not Principle City",1.0,0.0,,
26711,3.0,1.0,1.0,1.0,0.0,1.0,1.0,1.0,1.0,0.0,...,"<= $75,000, Above Poverty",Not Married,Own,Employed,lzgpxyit,Non-MSA,0.0,1.0,fcxhlnwr,mxkfnird
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
53410,1.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,1.0,...,,,,,dqpwygqj,"MSA, Principle City",1.0,1.0,,
53411,3.0,1.0,0.0,1.0,0.0,1.0,1.0,1.0,1.0,0.0,...,Below Poverty,Married,Rent,Employed,qufhixun,Non-MSA,1.0,3.0,fcxhlnwr,vlluhbov
53412,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,Below Poverty,Not Married,Rent,Not in Labor Force,qufhixun,"MSA, Not Principle City",1.0,0.0,,
53413,3.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,...,"<= $75,000, Above Poverty",Married,Own,Not in Labor Force,bhuqouqj,"MSA, Not Principle City",1.0,0.0,,
