In [75]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [76]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn import metrics
from sklearn.cluster import KMeans
from sklearn.ensemble import RandomForestClassifier
from sklearn.impute import KNNImputer
from sklearn.feature_selection import mutual_info_classif
from sklearn.metrics import accuracy_score
from sklearn.model_selection import cross_val_score, train_test_split

from smoker_status.config import RAW_DATA_DIR
from smoker_status.features import (
    create_encoded_X,
    set_anemia,
    set_blood_pressure_class,
    set_cholesterol_class,
    set_HDL_class,
    set_LDL_class,
)

sns.set_style('whitegrid')
pd.set_option('display.max_columns', None)

df = pd.read_csv(RAW_DATA_DIR / 'train.csv')

In [77]:
X = df.drop(['id'], axis=1)
X.describe()
X.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 159256 entries, 0 to 159255
Data columns (total 23 columns):
 #   Column               Non-Null Count   Dtype  
---  ------               --------------   -----  
 0   age                  159256 non-null  int64  
 1   height(cm)           159256 non-null  int64  
 2   weight(kg)           159256 non-null  int64  
 3   waist(cm)            159256 non-null  float64
 4   eyesight(left)       159256 non-null  float64
 5   eyesight(right)      159256 non-null  float64
 6   hearing(left)        159256 non-null  int64  
 7   hearing(right)       159256 non-null  int64  
 8   systolic             159256 non-null  int64  
 9   relaxation           159256 non-null  int64  
 10  fasting blood sugar  159256 non-null  int64  
 11  Cholesterol          159256 non-null  int64  
 12  triglyceride         159256 non-null  int64  
 13  HDL                  159256 non-null  int64  
 14  LDL                  159256 non-null  int64  
 15  hemoglobin       

In [78]:
X

Unnamed: 0,age,height(cm),weight(kg),waist(cm),eyesight(left),eyesight(right),hearing(left),hearing(right),systolic,relaxation,fasting blood sugar,Cholesterol,triglyceride,HDL,LDL,hemoglobin,Urine protein,serum creatinine,AST,ALT,Gtp,dental caries,smoking
0,55,165,60,81.0,0.5,0.6,1,1,135,87,94,172,300,40,75,16.5,1,1.0,22,25,27,0,1
1,70,165,65,89.0,0.6,0.7,2,2,146,83,147,194,55,57,126,16.2,1,1.1,27,23,37,1,0
2,20,170,75,81.0,0.4,0.5,1,1,118,75,79,178,197,45,93,17.4,1,0.8,27,31,53,0,1
3,35,180,95,105.0,1.5,1.2,1,1,131,88,91,180,203,38,102,15.9,1,1.0,20,27,30,1,0
4,30,165,60,80.5,1.5,1.0,1,1,121,76,91,155,87,44,93,15.4,1,0.8,19,13,17,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
159251,40,155,45,69.0,1.5,2.0,1,1,127,80,64,238,47,72,159,14.5,1,0.8,25,26,13,0,0
159252,50,155,75,82.0,1.0,1.0,1,1,120,80,89,213,202,64,108,14.5,1,0.6,21,20,18,0,0
159253,40,160,50,66.0,1.5,1.0,1,1,114,70,84,189,45,87,93,10.9,1,0.6,15,9,12,0,0
159254,50,165,75,92.0,1.2,1.0,1,1,121,90,122,165,148,55,80,14.4,1,1.1,22,17,37,0,1


In [79]:
def set_sex(row: pd.Series):
    """Use with `pandas.DataFrame.apply`. Must use axis=1 to apply to
    each row.
    """
    # Male, increase score by 1. Female, decrease score by 1. If score == 0,
    # then neither had majority, so np.nan
    score = 0
    # Rough heuristic based on Table 2 from doi: 10.4178/epih.e2022024
    if row['age'] >= 19 and row['age'] <= 29:
        if row['height(cm)'] <= 165:
            score -= 1
        elif row['height(cm)'] >= 170:
            score += 1
        if row['weight(kg)'] <= 55:
            score -= 1
        elif row['weight(kg)'] >= 70:
            score += 1
        else:  # [60, 65]
            pass
    elif row['age'] >= 30 and row['age'] <= 39:
        if row['height(cm)'] <= 165:
            score -= 1
        elif row['height(cm)'] >= 170:
            score += 1
        if row['weight(kg)'] <= 60:
            score -= 1
        elif row['weight(kg)'] >= 75:
            score += 1
        else:  # [65, 70]
            pass
    elif row['age'] >= 40 and row['age'] <= 49:
        if row['height(cm)'] <= 160:
            score -= 1
        elif row['height(cm)'] >= 170:
            score += 1
        else:  # [165]
            pass
        if row['weight(kg)'] <= 60:
            score -= 1
        elif row['weight(kg)'] >= 70:
            score += 1
        else:  # [65]
            pass
    elif row['age'] >= 50 and row['age'] <= 59:
        if row['height(cm)'] <= 160:
            score -= 1
        elif row['height(cm)'] >= 165:
            score += 1
        if row['weight(kg)'] <= 60:
            score -= 1
        elif row['weight(kg)'] >= 70:
            score += 1
        else:  # [65]
            pass
    elif row['age'] >= 60 and row['age'] <= 69:
        if row['height(cm)'] <= 160:
            score -= 1
        elif row['height(cm)'] >= 165:
            score += 1
        if row['weight(kg)'] <= 55:
            score -= 1
        elif row['weight(kg)'] >= 70:
            score += 1
        else:  # [60, 65]
            pass
    elif row['age'] >= 70:
        if row['height(cm)'] <= 155:
            score -= 1
        elif row['height(cm)'] >= 160:
            score += 1
        if row['weight(kg)'] <= 55:
            score -= 1
        elif row['weight(kg)'] >= 65:
            score += 1
        else:  # [60]
            pass
    return score


X['sex'] = df.apply(set_sex, axis=1)
X

Unnamed: 0,age,height(cm),weight(kg),waist(cm),eyesight(left),eyesight(right),hearing(left),hearing(right),systolic,relaxation,fasting blood sugar,Cholesterol,triglyceride,HDL,LDL,hemoglobin,Urine protein,serum creatinine,AST,ALT,Gtp,dental caries,smoking,sex
0,55,165,60,81.0,0.5,0.6,1,1,135,87,94,172,300,40,75,16.5,1,1.0,22,25,27,0,1,0
1,70,165,65,89.0,0.6,0.7,2,2,146,83,147,194,55,57,126,16.2,1,1.1,27,23,37,1,0,2
2,20,170,75,81.0,0.4,0.5,1,1,118,75,79,178,197,45,93,17.4,1,0.8,27,31,53,0,1,2
3,35,180,95,105.0,1.5,1.2,1,1,131,88,91,180,203,38,102,15.9,1,1.0,20,27,30,1,0,2
4,30,165,60,80.5,1.5,1.0,1,1,121,76,91,155,87,44,93,15.4,1,0.8,19,13,17,0,1,-2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
159251,40,155,45,69.0,1.5,2.0,1,1,127,80,64,238,47,72,159,14.5,1,0.8,25,26,13,0,0,-2
159252,50,155,75,82.0,1.0,1.0,1,1,120,80,89,213,202,64,108,14.5,1,0.6,21,20,18,0,0,0
159253,40,160,50,66.0,1.5,1.0,1,1,114,70,84,189,45,87,93,10.9,1,0.6,15,9,12,0,0,-2
159254,50,165,75,92.0,1.2,1.0,1,1,121,90,122,165,148,55,80,14.4,1,1.1,22,17,37,0,1,2


In [80]:
def count_missing_values(dataframe: pd.DataFrame):
    """https://machinelearningmastery.com/knn-imputation-for-missing-values-in-machine-learning/"""
    for column in dataframe.columns:
        # count number of rows with missing values
        n_miss = dataframe[column].isnull().sum()
        perc = n_miss / dataframe.shape[0] * 100
        print(f'> {column}, Missing {n_miss:d} ({perc:f}%)')


count_missing_values(X)

> age, Missing 0 (0.000000%)
> height(cm), Missing 0 (0.000000%)
> weight(kg), Missing 0 (0.000000%)
> waist(cm), Missing 0 (0.000000%)
> eyesight(left), Missing 0 (0.000000%)
> eyesight(right), Missing 0 (0.000000%)
> hearing(left), Missing 0 (0.000000%)
> hearing(right), Missing 0 (0.000000%)
> systolic, Missing 0 (0.000000%)
> relaxation, Missing 0 (0.000000%)
> fasting blood sugar, Missing 0 (0.000000%)
> Cholesterol, Missing 0 (0.000000%)
> triglyceride, Missing 0 (0.000000%)
> HDL, Missing 0 (0.000000%)
> LDL, Missing 0 (0.000000%)
> hemoglobin, Missing 0 (0.000000%)
> Urine protein, Missing 0 (0.000000%)
> serum creatinine, Missing 0 (0.000000%)
> AST, Missing 0 (0.000000%)
> ALT, Missing 0 (0.000000%)
> Gtp, Missing 0 (0.000000%)
> dental caries, Missing 0 (0.000000%)
> smoking, Missing 0 (0.000000%)
> sex, Missing 0 (0.000000%)


In [81]:
imputer = KNNImputer(n_neighbors=5, missing_values=0)
X_trans = imputer.fit_transform(X[['age', 'height(cm)', 'weight(kg)', 'sex']])
sum(np.isnan(X_trans).flatten())

np.int64(0)

In [82]:
X_full_sex = pd.DataFrame(X_trans, columns=['age', 'height(cm)', 'weight(kg)', 'sex'])
X['sex'] = X_full_sex['sex']
count_missing_values(X)

> age, Missing 0 (0.000000%)
> height(cm), Missing 0 (0.000000%)
> weight(kg), Missing 0 (0.000000%)
> waist(cm), Missing 0 (0.000000%)
> eyesight(left), Missing 0 (0.000000%)
> eyesight(right), Missing 0 (0.000000%)
> hearing(left), Missing 0 (0.000000%)
> hearing(right), Missing 0 (0.000000%)
> systolic, Missing 0 (0.000000%)
> relaxation, Missing 0 (0.000000%)
> fasting blood sugar, Missing 0 (0.000000%)
> Cholesterol, Missing 0 (0.000000%)
> triglyceride, Missing 0 (0.000000%)
> HDL, Missing 0 (0.000000%)
> LDL, Missing 0 (0.000000%)
> hemoglobin, Missing 0 (0.000000%)
> Urine protein, Missing 0 (0.000000%)
> serum creatinine, Missing 0 (0.000000%)
> AST, Missing 0 (0.000000%)
> ALT, Missing 0 (0.000000%)
> Gtp, Missing 0 (0.000000%)
> dental caries, Missing 0 (0.000000%)
> smoking, Missing 0 (0.000000%)
> sex, Missing 0 (0.000000%)


In [83]:
X[X['sex'] > 0].describe()  # male
X[X['sex'] < 0].describe()  # female
# TODO: first use sex to figure out healthy ranges for BMI and other
# health stuff
# TODO: then convert sex to categorical M or F, and do one hot encoding

Unnamed: 0,age,height(cm),weight(kg),waist(cm),eyesight(left),eyesight(right),hearing(left),hearing(right),systolic,relaxation,fasting blood sugar,Cholesterol,triglyceride,HDL,LDL,hemoglobin,Urine protein,serum creatinine,AST,ALT,Gtp,dental caries,smoking,sex
count,66812.0,66812.0,66812.0,66812.0,66812.0,66812.0,66812.0,66812.0,66812.0,66812.0,66812.0,66812.0,66812.0,66812.0,66812.0,66812.0,66812.0,66812.0,66812.0,66812.0,66812.0,66812.0,66812.0,66812.0
mean,48.102137,157.44049,56.333518,77.039575,0.93927,0.931628,1.030848,1.030518,119.983596,74.856193,96.542268,197.399569,102.808492,62.05945,114.86462,13.869009,1.084072,0.798015,24.346555,21.052236,26.936239,0.161288,0.205756,-1.582333
std,10.618427,6.194094,6.863181,7.430205,0.402633,0.388671,0.172906,0.17201,13.586715,9.134896,14.600432,29.730029,55.540888,14.736449,29.56016,1.396359,0.367897,0.169646,9.022559,11.945094,26.104331,0.367799,0.404256,0.567977
min,20.0,135.0,30.0,51.0,0.1,0.1,1.0,1.0,77.0,44.0,46.0,91.0,8.0,9.0,10.0,4.9,1.0,0.1,6.0,1.0,3.0,0.0,0.0,-2.0
25%,40.0,155.0,50.0,72.0,0.7,0.7,1.0,1.0,110.0,69.0,88.0,175.0,63.0,51.0,93.0,13.0,1.0,0.7,19.0,14.0,15.0,0.0,0.0,-2.0
50%,45.0,155.0,55.0,77.0,1.0,1.0,1.0,1.0,119.0,75.0,94.0,198.0,88.0,61.0,114.0,13.7,1.0,0.8,23.0,18.0,20.0,0.0,0.0,-2.0
75%,55.0,160.0,60.0,82.0,1.2,1.2,1.0,1.0,130.0,80.0,102.0,221.0,132.0,72.0,136.0,14.7,1.0,0.9,27.0,25.0,30.0,0.0,0.0,-1.0
max,85.0,175.0,85.0,125.8,9.9,9.9,2.0,2.0,213.0,120.0,365.0,351.0,466.0,136.0,1660.0,21.0,6.0,9.9,778.0,745.0,816.0,1.0,1.0,-0.2


In [84]:
X['anemia'] = X.apply(set_anemia, axis=1)
X[X['sex'] > 0].describe()  # male, anemia mean: 0.01, std: 0.15
# X[X['sex'] < 0].describe()  # femmale, anemia mean: 0.08, std: 0.35

Unnamed: 0,age,height(cm),weight(kg),waist(cm),eyesight(left),eyesight(right),hearing(left),hearing(right),systolic,relaxation,fasting blood sugar,Cholesterol,triglyceride,HDL,LDL,hemoglobin,Urine protein,serum creatinine,AST,ALT,Gtp,dental caries,smoking,sex,anemia
count,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0
mean,41.563498,170.923327,74.956471,87.311202,1.05388,1.051119,1.019006,1.018292,124.324964,78.33245,99.660897,194.637337,145.545195,51.366871,114.421985,15.467627,1.067122,0.961242,26.362663,30.523906,42.922764,0.224525,0.604755,1.635522,0.017275
std,11.923572,5.476381,9.66235,7.353712,0.39482,0.387243,0.136547,0.134007,11.739804,8.60215,15.706657,27.334829,67.465405,11.446143,27.099804,1.023463,0.332441,0.153083,9.684511,20.054878,32.829402,0.417271,0.488906,0.512388,0.145365
min,20.0,150.0,50.0,51.0,0.1,0.1,1.0,1.0,80.0,47.0,48.0,77.0,11.0,18.0,1.0,5.8,1.0,0.1,9.0,1.0,2.0,0.0,0.0,0.2,0.0
25%,35.0,165.0,70.0,82.0,0.9,0.9,1.0,1.0,117.0,72.0,91.0,176.0,94.0,43.0,96.0,14.9,1.0,0.9,20.0,19.0,23.0,0.0,0.0,1.0,0.0
50%,40.0,170.0,75.0,87.0,1.0,1.0,1.0,1.0,124.0,79.0,97.0,196.0,138.0,49.0,114.0,15.5,1.0,1.0,24.0,26.0,34.0,0.0,1.0,2.0,0.0
75%,50.0,175.0,80.0,92.0,1.2,1.2,1.0,1.0,132.0,84.0,105.0,214.0,184.0,58.0,132.0,16.1,1.0,1.1,30.0,38.0,52.0,0.0,1.0,2.0,0.0
max,85.0,190.0,130.0,127.0,9.9,9.9,2.0,2.0,203.0,133.0,375.0,393.0,766.0,135.0,1860.0,20.4,5.0,5.9,656.0,2914.0,999.0,1.0,1.0,2.0,3.0


In [85]:
# Adding cholesterol classes
X['HDL class'] = X.apply(set_HDL_class, axis=1)
X['LDL class'] = X.apply(set_LDL_class, axis=1)
X['Cholesterol class'] = X.apply(set_cholesterol_class, axis=1)
X.describe()

Unnamed: 0,age,height(cm),weight(kg),waist(cm),eyesight(left),eyesight(right),hearing(left),hearing(right),systolic,relaxation,fasting blood sugar,Cholesterol,triglyceride,HDL,LDL,hemoglobin,Urine protein,serum creatinine,AST,ALT,Gtp,dental caries,smoking,sex,anemia,HDL class,LDL class,Cholesterol class
count,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0,159256.0
mean,44.306626,165.266929,67.143662,83.00199,1.005798,1.000989,1.023974,1.023421,122.503648,76.874071,98.352552,195.796165,127.616046,55.852684,114.607682,14.796965,1.074233,0.892764,25.516853,26.550296,36.216004,0.197996,0.437365,0.285549,0.044381,1.252844,1.03122,0.509149
std,11.842286,8.81897,12.586198,8.957937,0.402113,0.392299,0.152969,0.151238,12.729315,8.994642,15.32974,28.396959,66.188989,13.964141,28.158931,1.431213,0.347856,0.179346,9.464882,17.75307,31.204643,0.39849,0.496063,1.676109,0.258191,0.610643,0.856905,0.595055
min,20.0,135.0,30.0,51.0,0.1,0.1,1.0,1.0,77.0,44.0,46.0,77.0,8.0,9.0,1.0,4.9,1.0,0.1,6.0,1.0,2.0,0.0,0.0,-2.0,0.0,0.0,0.0,0.0
25%,40.0,160.0,60.0,77.0,0.8,0.8,1.0,1.0,114.0,70.0,90.0,175.0,77.0,45.0,95.0,13.8,1.0,0.8,20.0,16.0,18.0,0.0,0.0,-2.0,0.0,1.0,0.0,0.0
50%,40.0,165.0,65.0,83.0,1.0,1.0,1.0,1.0,121.0,78.0,96.0,196.0,115.0,54.0,114.0,15.0,1.0,0.9,24.0,22.0,27.0,0.0,0.0,1.0,0.0,1.0,1.0,0.0
75%,55.0,170.0,75.0,89.0,1.2,1.2,1.0,1.0,130.0,82.0,103.0,217.0,165.0,64.0,133.0,15.8,1.0,1.0,29.0,32.0,44.0,0.0,1.0,2.0,0.0,2.0,2.0,1.0
max,85.0,190.0,130.0,127.0,9.9,9.9,2.0,2.0,213.0,133.0,375.0,393.0,766.0,136.0,1860.0,21.0,6.0,9.9,778.0,2914.0,999.0,1.0,1.0,2.0,3.0,2.0,4.0,2.0


In [92]:
X['blood pressure class'] = X.apply(set_blood_pressure_class, axis=1)
X[X['sex'] > 0].describe()
# X[X['sex'] < 0].describe()

Unnamed: 0,age,height(cm),weight(kg),waist(cm),eyesight(left),eyesight(right),hearing(left),hearing(right),systolic,relaxation,fasting blood sugar,Cholesterol,triglyceride,HDL,LDL,hemoglobin,Urine protein,serum creatinine,AST,ALT,Gtp,dental caries,smoking,sex,anemia,HDL class,LDL class,Cholesterol class,blood pressure class
count,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0,92444.0
mean,41.563498,170.923327,74.956471,87.311202,1.05388,1.051119,1.019006,1.018292,124.324964,78.33245,99.660897,194.637337,145.545195,51.366871,114.421985,15.467627,1.067122,0.961242,26.362663,30.523906,42.922764,0.224525,0.604755,1.635522,0.017275,1.082688,1.020867,0.48491,1.291961
std,11.923572,5.476381,9.66235,7.353712,0.39482,0.387243,0.136547,0.134007,11.739804,8.60215,15.706657,27.334829,67.465405,11.446143,27.099804,1.023463,0.332441,0.153083,9.684511,20.054878,32.829402,0.417271,0.488906,0.512388,0.145365,0.57759,0.824828,0.57789,0.998468
min,20.0,150.0,50.0,51.0,0.1,0.1,1.0,1.0,80.0,47.0,48.0,77.0,11.0,18.0,1.0,5.8,1.0,0.1,9.0,1.0,2.0,0.0,0.0,0.2,0.0,0.0,0.0,0.0,0.0
25%,35.0,165.0,70.0,82.0,0.9,0.9,1.0,1.0,117.0,72.0,91.0,176.0,94.0,43.0,96.0,14.9,1.0,0.9,20.0,19.0,23.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0
50%,40.0,170.0,75.0,87.0,1.0,1.0,1.0,1.0,124.0,79.0,97.0,196.0,138.0,49.0,114.0,15.5,1.0,1.0,24.0,26.0,34.0,0.0,1.0,2.0,0.0,1.0,1.0,0.0,2.0
75%,50.0,175.0,80.0,92.0,1.2,1.2,1.0,1.0,132.0,84.0,105.0,214.0,184.0,58.0,132.0,16.1,1.0,1.1,30.0,38.0,52.0,0.0,1.0,2.0,0.0,1.0,2.0,1.0,2.0
max,85.0,190.0,130.0,127.0,9.9,9.9,2.0,2.0,203.0,133.0,375.0,393.0,766.0,135.0,1860.0,20.4,5.0,5.9,656.0,2914.0,999.0,1.0,1.0,2.0,3.0,2.0,4.0,2.0,4.0
