In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [2]:
df = pd.read_csv('heart_failure_records.csv')

In [3]:
df.head()

Unnamed: 0,age,anaemia,creatinine_phosphokinase,diabetes,ejection_fraction,high_blood_pressure,platelets,serum_creatinine,serum_sodium,sex,smoking,time,DEATH_EVENT
0,75.0,0,582,0,20,1,265000.0,1.9,130,1,0,4,1
1,55.0,0,7861,0,38,0,263358.03,1.1,136,1,0,6,1
2,65.0,0,146,0,20,0,162000.0,1.3,129,1,1,7,1
3,50.0,1,111,0,20,0,210000.0,1.9,137,1,0,7,1
4,65.0,1,160,1,20,0,327000.0,2.7,116,0,0,8,1


In [4]:
# Check the datatypes of each column
df.dtypes

age                         float64
anaemia                       int64
creatinine_phosphokinase      int64
diabetes                      int64
ejection_fraction             int64
high_blood_pressure           int64
platelets                   float64
serum_creatinine            float64
serum_sodium                  int64
sex                           int64
smoking                       int64
time                          int64
DEATH_EVENT                   int64
dtype: object

In [5]:
# No Null Values
df.isna().sum()

age                         0
anaemia                     0
creatinine_phosphokinase    0
diabetes                    0
ejection_fraction           0
high_blood_pressure         0
platelets                   0
serum_creatinine            0
serum_sodium                0
sex                         0
smoking                     0
time                        0
DEATH_EVENT                 0
dtype: int64

In [6]:
# I would prefer age to be an integer
df['age'] = df['age'].astype('int64')

In [7]:
# Recheck datatypes
df.dtypes

age                           int64
anaemia                       int64
creatinine_phosphokinase      int64
diabetes                      int64
ejection_fraction             int64
high_blood_pressure           int64
platelets                   float64
serum_creatinine            float64
serum_sodium                  int64
sex                           int64
smoking                       int64
time                          int64
DEATH_EVENT                   int64
dtype: object

In [8]:
# Take a look at the first 5 values
df.head()

Unnamed: 0,age,anaemia,creatinine_phosphokinase,diabetes,ejection_fraction,high_blood_pressure,platelets,serum_creatinine,serum_sodium,sex,smoking,time,DEATH_EVENT
0,75,0,582,0,20,1,265000.0,1.9,130,1,0,4,1
1,55,0,7861,0,38,0,263358.03,1.1,136,1,0,6,1
2,65,0,146,0,20,0,162000.0,1.3,129,1,1,7,1
3,50,1,111,0,20,0,210000.0,1.9,137,1,0,7,1
4,65,1,160,1,20,0,327000.0,2.7,116,0,0,8,1


For now, I would like to make each variable Binary to make things a little easier.  To do this in the best possible manner, we should check the distribution.

In [9]:
df[['age', 'creatinine_phosphokinase', 'ejection_fraction', 'platelets', 'serum_creatinine', 'serum_sodium', 'time']].describe()

Unnamed: 0,age,creatinine_phosphokinase,ejection_fraction,platelets,serum_creatinine,serum_sodium,time
count,299.0,299.0,299.0,299.0,299.0,299.0,299.0
mean,60.829431,581.839465,38.083612,263358.029264,1.39388,136.625418,130.26087
std,11.894997,970.287881,11.834841,97804.236869,1.03451,4.412477,77.614208
min,40.0,23.0,14.0,25100.0,0.5,113.0,4.0
25%,51.0,116.5,30.0,212500.0,0.9,134.0,73.0
50%,60.0,250.0,38.0,262000.0,1.1,137.0,115.0
75%,70.0,582.0,45.0,303500.0,1.4,140.0,203.0
max,95.0,7861.0,80.0,850000.0,9.4,148.0,285.0


In [10]:
# for now, let's use the mean of each variable
def make_binary(series):
    mean = round(series.mean())
    temp = []
    for obs in series:
        if obs < mean:
            temp.append(0)
        else:
            temp.append(1)
    return temp

df['age_over_60'] = make_binary(df['age'])
df['cpk_over_581'] = make_binary(df['creatinine_phosphokinase'])
df['ejection_fraction_over_38'] = make_binary(df['ejection_fraction'])
df['platelets_above_mean'] = make_binary(df['platelets'])
df['serum_creatinine_above_avg'] = make_binary(df['serum_creatinine'])
df['serum_sodium_above_137'] = make_binary(df['serum_sodium'])
df['follow_up_over_130_days'] = make_binary(df['time'])

In [11]:
final_binary = df[['anaemia', 'diabetes', 'high_blood_pressure', 'sex',
                   'smoking', 'age_over_60', 'cpk_over_581', 'ejection_fraction_over_38',
                   'platelets_above_mean', 'serum_creatinine_above_avg',
                   'serum_sodium_above_137', 'follow_up_over_130_days', 'DEATH_EVENT']]
final_binary.head()

Unnamed: 0,anaemia,diabetes,high_blood_pressure,sex,smoking,age_over_60,cpk_over_581,ejection_fraction_over_38,platelets_above_mean,serum_creatinine_above_avg,serum_sodium_above_137,follow_up_over_130_days,DEATH_EVENT
0,0,0,1,1,0,1,1,0,1,1,0,0,1
1,0,0,0,1,0,0,1,1,1,1,0,0,1
2,0,0,0,1,1,1,0,0,0,1,0,0,1
3,1,0,0,1,0,0,0,0,0,1,1,0,1
4,1,1,0,0,0,1,0,0,1,1,0,0,1


### Attribute Information:

Thirteen (13) clinical features:

- age: age of the patient (years)                                                            - Numeric
- anaemia: decrease of red blood cells or hemoglobin (boolean)                               - Categorical
- high blood pressure: if the patient has hypertension (boolean)                             - Categorical
- creatinine phosphokinase (CPK): level of the CPK enzyme in the blood (mcg/L)               - Numeric
- diabetes: if the patient has diabetes (boolean)                                            - Categorical
- ejection fraction: percentage of blood leaving the heart at each contraction (percentage)  - Numeric
- platelets: platelets in the blood (kiloplatelets/mL)                                       - Numeric
- sex: woman or man (binary)                                                                 - Categorical
- serum creatinine: level of serum creatinine in the blood (mg/dL)                           - Numeric
- serum sodium: level of serum sodium in the blood (mEq/L)                                   - Numeric
- smoking: if the patient smokes or not (boolean)                                            - Categorical
- time: follow-up period (days)                                                              - Categorical/Numeric
- [target] death event: if the patient deceased during the follow-up period (boolean)        - Categorical-Y

In [12]:
# Make a class which will house Gini calculations

class Gini:
    '''
    A class to preform Gini calculations for a decision tree
    '''
    def __init__(self, total_observations):
        self.total_observations = int(total_observations)
    
    def gini_impurity_binary(self, class0_value, class1_value):
        '''
        Calculates and returns Gini Impurity score
        
        Gini Impurity = 1 - (class0_value/total_values(branch))**2 - (class1_value/total_values(branch))**2
        
        Params:
            - class0_value - Number of class0 values (used to compute probability) - class 0 implied as True
            - class1_value - Number of class1 values (used to compute probability) - class 1 implied as False
        
        Implicit parameters:
            - total_values - total number of class0_value + class1_value
        
        returns Gini Impurity Score
        '''
        
        # Get the number of total values of the node branch
        total_values = class0_value + class1_value
        
        # Calculate the probability of each class compared to the total values
        class0_probability = class0_value / total_values
        class1_probability = class1_value / total_values
        
        # Calculate the Gini impurity score
        gini_impurity = 1 - (class0_probability**2) - (class1_probability**2)
        
        return gini_impurity
    
    def total_gini_impurity_binary(self, gini_impurity_scores, left_sum, right_sum):
        '''
        Calculates and returns Total Gini Impurity of a binary node
        
        Total Gini Impurity = ((total_left / (total_left + total_right)) * gini_impurity_left) + ((total_right / (total_left + total_right)) * gini_impurity_right)
        
        Params:
            - gini_impurity_scores - list of computed gini_impurity scores for a binary class
            - left_sum - total number of observations on the left side of a binary node
            - right_sum - total number of observations on the right side of a binary node
        
        Implicit parameters:
            - total_values - total number of observations (left_sum + right_sum)       
        
        returns Total Gini Impurity
        '''
        
        # Separate the impurity scores for left and right
        left_impurity, right_impurity = gini_impurity_scores

        # Calculate weighted probability for left and right sides
        weighted_left = ((left_sum / self.total_observations) * left_impurity)
        weighted_right = ((right_sum / self.total_observations) * right_impurity)
        
        # Calculate the Total Gini Impurity
        total_gini_impurity = weighted_left + weighted_right
        
        return total_gini_impurity

In [13]:
# Testing the class

g = Gini(300)

# (t)arget - Heart Failure
# (d)ata - Sex (male/female)

t1_d1 = 31
t1_d0 = 76

t0_d1 = 92
t0_d0 = 27

left = g.gini_impurity_binary(t1_d1, t1_d0)
right = g.gini_impurity_binary(t0_d1, t0_d0)

print('Left: ', left)
print('Right: ', right)

left_sum = t1_d1 + t1_d0
right_sum = t0_d1 + t0_d0

print('Total Gini Impurity: ', g.total_gini_impurity_binary([left, right], left_sum, right_sum))

Left:  0.4115643287623374
Right:  0.350822682013982
Total Gini Impurity:  0.2859509411241132


# Your task now is to calculate the Gini Impurity of the Boolean/Binary data

Boolean/Binary Columns:
- anemia
- high blood pressure
- diabetes
- sex (1==Female, 0==Male)
- smoking

#### Anemia

In [14]:
# Grab the number of observations where patient is:
    # dead and anaemic
    # alive and anaemic
    # dead and nonanaemic
    # alive and nonanaemic

dead_anaemic = len(df['DEATH_EVENT'][(df['anaemia'] == 1) & (df['DEATH_EVENT'] == 1)])
dead_nonanaemic = len(df['DEATH_EVENT'][(df['anaemia'] == 0) & (df['DEATH_EVENT'] == 1)])
alive_anaemic = len(df['DEATH_EVENT'][(df['anaemia'] == 1) & (df['DEATH_EVENT'] == 0)])
alive_nonanaemic = len(df['DEATH_EVENT'][(df['anaemia'] == 0) & (df['DEATH_EVENT'] == 0)])

# Instantiate the Gini class with total observations
total_observations = len(df)
g = Gini(total_observations)

# Calculate the Gini Impurity score for Anemic and Non-Anaemic patients 
anaemic = g.gini_impurity_binary(dead_anaemic, alive_anaemic)
nonanaemic = g.gini_impurity_binary(dead_nonanaemic, alive_nonanaemic)

# Calculate the Total Gini Impurity for df['anaemia']
total_gini_anaemia = g.total_gini_impurity_binary([anaemic, nonanaemic], (dead_anaemic + alive_anaemic), (dead_nonanaemic + alive_nonanaemic))

# Print out the values
print('--------------------------------------------------------------------------------------')
print(f'Number of Anaemic patients who have died of Heart Failure: {dead_anaemic}')
print(f'Number of Anaemic patients who have not died of Heart Failure: {alive_anaemic}')
print(f'Number of Non-Anaemic patients who have died of Heart Failure: {dead_nonanaemic}')
print(f'Number of Non-Anaemic patients who have not died of Heart Failure: {alive_nonanaemic}')
print('--------------------------------------------------------------------------------------')
print(f'Gini Impurity for Anaemic patients: {anaemic}')
print(f'Gini Impurity for Non-Anaemic patients: {nonanaemic}')
print('--------------------------------------------------------------------------------------')
print(f'Total Gini Impurity for Anaemia column: {total_gini_anaemia}')
print('--------------------------------------------------------------------------------------')

--------------------------------------------------------------------------------------
Number of Anaemic patients who have died of Heart Failure: 46
Number of Anaemic patients who have not died of Heart Failure: 83
Number of Non-Anaemic patients who have died of Heart Failure: 50
Number of Non-Anaemic patients who have not died of Heart Failure: 120
--------------------------------------------------------------------------------------
Gini Impurity for Anaemic patients: 0.4588666546481581
Gini Impurity for Non-Anaemic patients: 0.4152249134948096
--------------------------------------------------------------------------------------
Total Gini Impurity for Anaemia column: 0.43405362456096996
--------------------------------------------------------------------------------------


#### High Blood Pressure

In [15]:
# Grab the number of observations where patient is:
    # dead and has high blood pressure (dead_hbp)
    # alive and has high blood pressure (alive_hbp)
    # dead and does not have high blood pressure (dead_nohbp)
    # alive and does not have high blood pressure (alive_nohbp)

dead_hbp = len(df['DEATH_EVENT'][(df['high_blood_pressure'] == 1) & (df['DEATH_EVENT'] == 1)])
dead_nohbp = len(df['DEATH_EVENT'][(df['high_blood_pressure'] == 0) & (df['DEATH_EVENT'] == 1)])
alive_hbp = len(df['DEATH_EVENT'][(df['high_blood_pressure'] == 1) & (df['DEATH_EVENT'] == 0)])
alive_nohbp = len(df['DEATH_EVENT'][(df['high_blood_pressure'] == 0) & (df['DEATH_EVENT'] == 0)])

# Instantiate the Gini class with total observations
total_observations = len(df)
g = Gini(total_observations)

# Calculate the Gini Impurity score for High Blood Pressure and Non-High Blood Pressure patients 
hbp = g.gini_impurity_binary(dead_hbp, alive_hbp)
nohbp = g.gini_impurity_binary(dead_nohbp, alive_nohbp)

# Calculate the Total Gini Impurity for df['high_blood_pressure']
total_gini_hbp = g.total_gini_impurity_binary([hbp, nohbp], (dead_hbp + alive_hbp), (dead_nohbp + alive_nohbp))

# Print out the values
print('--------------------------------------------------------------------------------------')
print(f'Number of High Blood Pressure patients who have died of Heart Failure: {dead_hbp}')
print(f'Number of High Blood Pressure patients who have not died of Heart Failure: {alive_hbp}')
print(f'Number of Non-High Blood Pressure patients who have died of Heart Failure: {dead_nohbp}')
print(f'Number of Non-High Blood Pressure patients who have not died of Heart Failure: {alive_nohbp}')
print('--------------------------------------------------------------------------------------')
print(f'Gini Impurity for High Blood Pressure patients: {hbp}')
print(f'Gini Impurity for Non-High Blood Pressure patients: {nohbp}')
print('--------------------------------------------------------------------------------------')
print(f'Total Gini Impurity for High Blood Pressure column: {total_gini_hbp}')
print('--------------------------------------------------------------------------------------')

--------------------------------------------------------------------------------------
Number of High Blood Pressure patients who have died of Heart Failure: 39
Number of High Blood Pressure patients who have not died of Heart Failure: 66
Number of Non-High Blood Pressure patients who have died of Heart Failure: 57
Number of Non-High Blood Pressure patients who have not died of Heart Failure: 137
--------------------------------------------------------------------------------------
Gini Impurity for High Blood Pressure patients: 0.46693877551020413
Gini Impurity for Non-High Blood Pressure patients: 0.4149750239132744
--------------------------------------------------------------------------------------
Total Gini Impurity for High Blood Pressure column: 0.4332231641061762
--------------------------------------------------------------------------------------


#### Diabetes

In [16]:
# Grab the number of observations where patient is:
    # dead and has diabetes (dead_diabetes)
    # alive and has diabetes (alive_diabetes)
    # dead and does not have diabetes (dead_nodiabetes)
    # alive and does not have diabetes (alive_nodiabetes)

dead_diabetes = len(df['DEATH_EVENT'][(df['diabetes'] == 1) & (df['DEATH_EVENT'] == 1)])
dead_nodiabetes = len(df['DEATH_EVENT'][(df['diabetes'] == 0) & (df['DEATH_EVENT'] == 1)])
alive_diabetes = len(df['DEATH_EVENT'][(df['diabetes'] == 1) & (df['DEATH_EVENT'] == 0)])
alive_nodiabetes = len(df['DEATH_EVENT'][(df['diabetes'] == 0) & (df['DEATH_EVENT'] == 0)])

# Instantiate the Gini class with total observations
total_observations = len(df)
g = Gini(total_observations)

# Calculate the Gini Impurity score for Diabetic and Non-Diabetic patients 
diabetes = g.gini_impurity_binary(dead_diabetes, alive_diabetes)
nodiabetes = g.gini_impurity_binary(dead_nodiabetes, alive_nodiabetes)

# Calculate the Total Gini Impurity for df['diabetes']
total_gini_diabetes = g.total_gini_impurity_binary([diabetes, nodiabetes], (dead_diabetes + alive_diabetes), (dead_nodiabetes + alive_nodiabetes))

# Print out the values
print('--------------------------------------------------------------------------------------')
print(f'Number of Diabetic patients who have died of Heart Failure: {dead_diabetes}')
print(f'Number of Diabetic patients who have not died of Heart Failure: {alive_diabetes}')
print(f'Number of Non-Diabetic patients who have died of Heart Failure: {dead_nodiabetes}')
print(f'Number of Non-Diabetic patients who have not died of Heart Failure: {alive_nodiabetes}')
print('--------------------------------------------------------------------------------------')
print(f'Gini Impurity for Diabetic patients: {diabetes}')
print(f'Gini Impurity for Non-Diabetic patients: {nodiabetes}')
print('--------------------------------------------------------------------------------------')
print(f'Total Gini Impurity for Diabetes column: {total_gini_diabetes}')
print('--------------------------------------------------------------------------------------')

--------------------------------------------------------------------------------------
Number of Diabetic patients who have died of Heart Failure: 40
Number of Diabetic patients who have not died of Heart Failure: 85
Number of Non-Diabetic patients who have died of Heart Failure: 56
Number of Non-Diabetic patients who have not died of Heart Failure: 118
--------------------------------------------------------------------------------------
Gini Impurity for Diabetic patients: 0.43519999999999986
Gini Impurity for Non-Diabetic patients: 0.4365173734971596
--------------------------------------------------------------------------------------
Total Gini Impurity for Diabetes column: 0.4359666320685811
--------------------------------------------------------------------------------------


#### Sex

In [17]:
# Grab the number of observations where patient is:
    # dead and female (dead_female)
    # alive and female (alive_female)
    # dead and male (dead_male)
    # alive and male (alive_male)

dead_female = len(df['DEATH_EVENT'][(df['sex'] == 1) & (df['DEATH_EVENT'] == 1)])
dead_male = len(df['DEATH_EVENT'][(df['sex'] == 0) & (df['DEATH_EVENT'] == 1)])
alive_female = len(df['DEATH_EVENT'][(df['sex'] == 1) & (df['DEATH_EVENT'] == 0)])
alive_male = len(df['DEATH_EVENT'][(df['sex'] == 0) & (df['DEATH_EVENT'] == 0)])

# Instantiate the Gini class with total observations
total_observations = len(df)
g = Gini(total_observations)

# Calculate the Gini Impurity score for Female and Male patients 
female = g.gini_impurity_binary(dead_female, alive_female)
male = g.gini_impurity_binary(dead_male, alive_male)

# Calculate the Total Gini Impurity for df['sex']
total_gini_sex = g.total_gini_impurity_binary([female, male], (dead_female + alive_female), (dead_male + alive_male))

# Print out the values
print('--------------------------------------------------------------------------------------')
print(f'Number of Female patients who have died of Heart Failure: {dead_female}')
print(f'Number of Female patients who have not died of Heart Failure: {alive_female}')
print(f'Number of Male patients who have died of Heart Failure: {dead_male}')
print(f'Number of Male patients who have not died of Heart Failure: {alive_male}')
print('--------------------------------------------------------------------------------------')
print(f'Gini Impurity for Female patients: {female}')
print(f'Gini Impurity for Male patients: {male}')
print('--------------------------------------------------------------------------------------')
print(f'Total Gini Impurity for Sex column: {total_gini_sex}')
print('--------------------------------------------------------------------------------------')

--------------------------------------------------------------------------------------
Number of Female patients who have died of Heart Failure: 62
Number of Female patients who have not died of Heart Failure: 132
Number of Male patients who have died of Heart Failure: 34
Number of Male patients who have not died of Heart Failure: 71
--------------------------------------------------------------------------------------
Gini Impurity for Female patients: 0.43490275268360085
Gini Impurity for Male patients: 0.43791383219954644
--------------------------------------------------------------------------------------
Total Gini Impurity for Sex column: 0.43596015518920045
--------------------------------------------------------------------------------------


#### Smoking

In [18]:
# Grab the number of observations where patient is:
    # dead and is a smoker (dead_smoking)
    # alive and is a smoker (alive_smoking)
    # dead and is a non-smoker (dead_nonsmoking)
    # alive and is a non-smoker (alive_nonsmoking)

dead_smoking = len(df['DEATH_EVENT'][(df['smoking'] == 1) & (df['DEATH_EVENT'] == 1)])
dead_nonsmoking = len(df['DEATH_EVENT'][(df['smoking'] == 0) & (df['DEATH_EVENT'] == 1)])
alive_smoking = len(df['DEATH_EVENT'][(df['smoking'] == 1) & (df['DEATH_EVENT'] == 0)])
alive_nonsmoking = len(df['DEATH_EVENT'][(df['smoking'] == 0) & (df['DEATH_EVENT'] == 0)])

# Instantiate the Gini class with total observations
total_observations = len(df)
g = Gini(total_observations)

# Calculate the Gini Impurity score for Smoking and Non-Smoking patients 
smoking = g.gini_impurity_binary(dead_smoking, alive_smoking)
nonsmoking = g.gini_impurity_binary(dead_nonsmoking, alive_nonsmoking)

# Calculate the Total Gini Impurity for df['smoking']
total_gini_smoking = g.total_gini_impurity_binary([smoking, nonsmoking], (dead_smoking + alive_smoking), (dead_nonsmoking + alive_nonsmoking))

# Print out the values
print('--------------------------------------------------------------------------------------')
print(f'Number of Smoking patients who have died of Heart Failure: {dead_smoking}')
print(f'Number of Smoking patients who have not died of Heart Failure: {alive_smoking}')
print(f'Number of Non-Smoking patients who have died of Heart Failure: {dead_nonsmoking}')
print(f'Number of Non-Smoking patients who have not died of Heart Failure: {alive_nonsmoking}')
print('--------------------------------------------------------------------------------------')
print(f'Gini Impurity for Smoking patients: {smoking}')
print(f'Gini Impurity for Non-Smoking patients: {nonsmoking}')
print('--------------------------------------------------------------------------------------')
print(f'Total Gini Impurity for Smoking column: {total_gini_smoking}')
print('--------------------------------------------------------------------------------------')

--------------------------------------------------------------------------------------
Number of Smoking patients who have died of Heart Failure: 30
Number of Smoking patients who have not died of Heart Failure: 66
Number of Non-Smoking patients who have died of Heart Failure: 66
Number of Non-Smoking patients who have not died of Heart Failure: 137
--------------------------------------------------------------------------------------
Gini Impurity for Smoking patients: 0.4296875
Gini Impurity for Non-Smoking patients: 0.43883617656337215
--------------------------------------------------------------------------------------
Total Gini Impurity for Smoking column: 0.43589880883733956
--------------------------------------------------------------------------------------


In [19]:
print('Total Gini Impurity scores for binary data')
print('------------------------------------------')
print('Anaemia: ', total_gini_anaemia)
print('High Blood Pressure: ', total_gini_hbp)
print('Diabetes: ', total_gini_diabetes)
print('Sex: ', total_gini_sex)
print('Smoking: ', total_gini_smoking)

Total Gini Impurity scores for binary data
------------------------------------------
Anaemia:  0.43405362456096996
High Blood Pressure:  0.4332231641061762
Diabetes:  0.4359666320685811
Sex:  0.43596015518920045
Smoking:  0.43589880883733956


### InOrder categories lowest-> highest

1. High Blood Pressure
2. Anaemia
3. Smoking
4. Sex
5. Diabetes

In [20]:
# Quickly make use of the new binary columns
# Age over 60
dead_over60 = len(final_binary['DEATH_EVENT'][(final_binary['age_over_60'] == 1) & (final_binary['DEATH_EVENT'] == 1)])
dead_under60 = len(final_binary['DEATH_EVENT'][(final_binary['age_over_60'] == 0) & (final_binary['DEATH_EVENT'] == 1)])
alive_over60 = len(final_binary['DEATH_EVENT'][(final_binary['age_over_60'] == 1) & (final_binary['DEATH_EVENT'] == 0)])
alive_under60 = len(final_binary['DEATH_EVENT'][(final_binary['age_over_60'] == 0) & (final_binary['DEATH_EVENT'] == 0)])

total_observations = len(df)
g = Gini(total_observations)

over60 = g.gini_impurity_binary(dead_over60, alive_over60)
under60 = g.gini_impurity_binary(dead_under60, alive_under60)
print((dead_over60, alive_over60), (dead_over60, alive_over60))
print('----------')

total_gini_over60 = g.total_gini_impurity_binary([over60, under60], (dead_over60 + alive_over60), (dead_under60 + alive_under60))
print(total_gini_over60)

(51, 84) (51, 84)
----------
0.4306740625934688


In [21]:
# CPK over 581
dead_cpk1 = len(final_binary['DEATH_EVENT'][(final_binary['cpk_over_581'] == 1) & (final_binary['DEATH_EVENT'] == 1)])
dead_cpk0 = len(final_binary['DEATH_EVENT'][(final_binary['cpk_over_581'] == 0) & (final_binary['DEATH_EVENT'] == 1)])
alive_cpk1 = len(final_binary['DEATH_EVENT'][(final_binary['cpk_over_581'] == 1) & (final_binary['DEATH_EVENT'] == 0)])
alive_cpk0 = len(final_binary['DEATH_EVENT'][(final_binary['cpk_over_581'] == 0) & (final_binary['DEATH_EVENT'] == 0)])

total_observations = len(df)
g = Gini(total_observations)

cpk1 = g.gini_impurity_binary(dead_cpk1, alive_cpk1)
cpk0 = g.gini_impurity_binary(dead_cpk0, alive_cpk0)
print((dead_cpk1, alive_cpk1), (dead_cpk0, alive_cpk0))
print('----------')

total_gini_cpk1 = g.total_gini_impurity_binary([cpk1, cpk0], (dead_cpk1 + alive_cpk1), (dead_cpk0 + alive_cpk0))
print(total_gini_cpk1)

(34, 77) (62, 126)
----------
0.43571088344446346


In [22]:
# Ejection Fraction over 38
dead_ef1 = len(final_binary['DEATH_EVENT'][(final_binary['ejection_fraction_over_38'] == 1) & (final_binary['DEATH_EVENT'] == 1)])
dead_ef0 = len(final_binary['DEATH_EVENT'][(final_binary['ejection_fraction_over_38'] == 0) & (final_binary['DEATH_EVENT'] == 1)])
alive_ef1 = len(final_binary['DEATH_EVENT'][(final_binary['ejection_fraction_over_38'] == 1) & (final_binary['DEATH_EVENT'] == 0)])
alive_ef0 = len(final_binary['DEATH_EVENT'][(final_binary['ejection_fraction_over_38'] == 0) & (final_binary['DEATH_EVENT'] == 0)])

total_observations = len(df)
g = Gini(total_observations)

ef1 = g.gini_impurity_binary(dead_ef1, alive_ef1)
ef0 = g.gini_impurity_binary(dead_ef0, alive_ef0)
print((dead_ef1, alive_ef1), (dead_ef0, alive_ef0))
print('----------')

total_gini_ef1 = g.total_gini_impurity_binary([ef1, ef0], (dead_ef1 + alive_ef1), (dead_ef0 + alive_ef0))
print(total_gini_ef1)

(38, 119) (58, 84)
----------
0.42215656806441615


In [23]:
# Platelets_above_mean
dead_p1 = len(final_binary['DEATH_EVENT'][(final_binary['platelets_above_mean'] == 1) & (final_binary['DEATH_EVENT'] == 1)])
dead_p0 = len(final_binary['DEATH_EVENT'][(final_binary['platelets_above_mean'] == 0) & (final_binary['DEATH_EVENT'] == 1)])
alive_p1 = len(final_binary['DEATH_EVENT'][(final_binary['platelets_above_mean'] == 1) & (final_binary['DEATH_EVENT'] == 0)])
alive_p0 = len(final_binary['DEATH_EVENT'][(final_binary['platelets_above_mean'] == 0) & (final_binary['DEATH_EVENT'] == 0)])

total_observations = len(df)
g = Gini(total_observations)

p1 = g.gini_impurity_binary(dead_p1, alive_p1)
p0 = g.gini_impurity_binary(dead_p0, alive_p0)
print((dead_p1, alive_p1), (dead_p0, alive_p0))
print('----------')

total_gini_p1 = g.total_gini_impurity_binary([p1, p0], (dead_p1 + alive_p1), (dead_p0 + alive_p0))
print(total_gini_p1)

(47, 101) (49, 102)
----------
0.4359442279108274


In [24]:
# Serum Creatinine above avg
dead_cre1 = len(final_binary['DEATH_EVENT'][(final_binary['serum_creatinine_above_avg'] == 1) & (final_binary['DEATH_EVENT'] == 1)])
dead_cre0 = len(final_binary['DEATH_EVENT'][(final_binary['serum_creatinine_above_avg'] == 0) & (final_binary['DEATH_EVENT'] == 1)])
alive_cre1 = len(final_binary['DEATH_EVENT'][(final_binary['serum_creatinine_above_avg'] == 1) & (final_binary['DEATH_EVENT'] == 0)])
alive_cre0 = len(final_binary['DEATH_EVENT'][(final_binary['serum_creatinine_above_avg'] == 0) & (final_binary['DEATH_EVENT'] == 0)])

total_observations = len(df)
g = Gini(total_observations)

cre1 = g.gini_impurity_binary(dead_cre1, alive_cre1)
cre0 = g.gini_impurity_binary(dead_cre0, alive_cre0)
print((dead_cre1, alive_cre1), (dead_cre0, alive_cre0))
print('----------')

total_gini_cre1 = g.total_gini_impurity_binary([cre1, cre0], (dead_cre1 + alive_cre1), (dead_cre0 + alive_cre0))
print(total_gini_cre1)

(87, 131) (9, 72)
----------
0.40320947500843796


In [25]:
# Serum sodium above 137
dead_na1 = len(final_binary['DEATH_EVENT'][(final_binary['serum_sodium_above_137'] == 1) & (final_binary['DEATH_EVENT'] == 1)])
dead_na0 = len(final_binary['DEATH_EVENT'][(final_binary['serum_sodium_above_137'] == 0) & (final_binary['DEATH_EVENT'] == 1)])
alive_na1 = len(final_binary['DEATH_EVENT'][(final_binary['serum_sodium_above_137'] == 1) & (final_binary['DEATH_EVENT'] == 0)])
alive_na0 = len(final_binary['DEATH_EVENT'][(final_binary['serum_sodium_above_137'] == 0) & (final_binary['DEATH_EVENT'] == 0)])

total_observations = len(df)
g = Gini(total_observations)

na1 = g.gini_impurity_binary(dead_na1, alive_na1)
na0 = g.gini_impurity_binary(dead_na0, alive_na0)
print((dead_na1, alive_na1), (dead_na0, alive_na0))
print('----------')

total_gini_na1 = g.total_gini_impurity_binary([na1, na0], (dead_na1 + alive_na1), (dead_na0 + alive_na0))
print(total_gini_na1)

(37, 123) (59, 80)
----------
0.4173952142633719


In [26]:
# follow up over 130 days
dead_fu1 = len(final_binary['DEATH_EVENT'][(final_binary['follow_up_over_130_days'] == 1) & (final_binary['DEATH_EVENT'] == 1)])
dead_fu0 = len(final_binary['DEATH_EVENT'][(final_binary['follow_up_over_130_days'] == 0) & (final_binary['DEATH_EVENT'] == 1)])
alive_fu1 = len(final_binary['DEATH_EVENT'][(final_binary['follow_up_over_130_days'] == 1) & (final_binary['DEATH_EVENT'] == 0)])
alive_fu0 = len(final_binary['DEATH_EVENT'][(final_binary['follow_up_over_130_days'] == 0) & (final_binary['DEATH_EVENT'] == 0)])

total_observations = len(df)
g = Gini(total_observations)

fu1 = g.gini_impurity_binary(dead_fu1, alive_fu1)
fu0 = g.gini_impurity_binary(dead_fu0, alive_fu0)
print((dead_fu1, alive_fu1), (dead_fu0, alive_fu0))
print('----------')

total_gini_fu1 = g.total_gini_impurity_binary([fu1, fu0], (dead_fu1 + alive_fu1), (dead_fu0 + alive_fu0))
print(total_gini_fu1)

(18, 116) (78, 87)
----------
0.37932683799004374


In [27]:
# New Categories
gini_scores = {'anaemia': total_gini_anaemia,
               'high_blood_pressure': total_gini_hbp,
               'diabetes': total_gini_diabetes,
               'sex': total_gini_sex,
               'smoking': total_gini_smoking,
               'age_over_60': total_gini_over60,
               'cpk_over_581': total_gini_cpk1,
               'ejection_fraction_over_38': total_gini_ef1,
               'platelets_above_mean': total_gini_p1,
               'serum_creatinine_above_avg': total_gini_cre1,
               'serum_sodium_above_137': total_gini_na1,
               'follow_up_over_130_days': total_gini_fu1}
sorted_scores = {k: v for k, v in sorted(gini_scores.items(), key=lambda x: x[1])}
for k, v in sorted_scores.items():
    print('--------------------------')
    print(f'Column: {k} - Gini Impurity: {v}')

--------------------------
Column: follow_up_over_130_days - Gini Impurity: 0.37932683799004374
--------------------------
Column: serum_creatinine_above_avg - Gini Impurity: 0.40320947500843796
--------------------------
Column: serum_sodium_above_137 - Gini Impurity: 0.4173952142633719
--------------------------
Column: ejection_fraction_over_38 - Gini Impurity: 0.42215656806441615
--------------------------
Column: age_over_60 - Gini Impurity: 0.4306740625934688
--------------------------
Column: high_blood_pressure - Gini Impurity: 0.4332231641061762
--------------------------
Column: anaemia - Gini Impurity: 0.43405362456096996
--------------------------
Column: cpk_over_581 - Gini Impurity: 0.43571088344446346
--------------------------
Column: smoking - Gini Impurity: 0.43589880883733956
--------------------------
Column: platelets_above_mean - Gini Impurity: 0.4359442279108274
--------------------------
Column: sex - Gini Impurity: 0.43596015518920045
--------------------------

If we were to make a tree based on these values, the tree would probably overfit so, we can set a threshold to act as feature exclusion

In [28]:
# Let's set the threshold for the mean of the Gini impurity scores

mean_gini_impurity = sum([score for _, score in sorted_scores.items()]) / len(sorted_scores)
print(mean_gini_impurity)

0.42495997116977474


In [29]:
# Finally, let's revise the dictionary
final_sorted_gini = {k: v for k, v in sorted_scores.items() if v <= mean_gini_impurity}
for k, v in final_sorted_gini.items():
    print('--------------------------')
    print(f'Column: {k} - Gini Impurity: {v}')

--------------------------
Column: follow_up_over_130_days - Gini Impurity: 0.37932683799004374
--------------------------
Column: serum_creatinine_above_avg - Gini Impurity: 0.40320947500843796
--------------------------
Column: serum_sodium_above_137 - Gini Impurity: 0.4173952142633719
--------------------------
Column: ejection_fraction_over_38 - Gini Impurity: 0.42215656806441615


## Decision Tree (so far)

           - Follow-Up over 130 days -
                 /            \
                /              \
           (18/116)         (78/87)

In [30]:
# Next, we calculate Gini Impurity on the left side of the tree against Serum Creatinine, Serum Sodium and Ejection Fraction

# Make a copy of a DataFrame where all values for follow_up_over_130_days is True (1)
followup1 = final_binary[final_binary['follow_up_over_130_days']==1].copy()

# Calculate Gini Impurity for the remaining columns to determine which will go on the left tree side first
g = Gini(len(followup1))  # Should have a length of 18 + 116 == 134

In [31]:
# Serum Creatinine
print('=====================================================')
print('SERUM CREATININE ABOVE AVERAGE - 1.39 milligrams/decilitre (mg/dL)')
dead_fu1_cre1 = len(followup1['DEATH_EVENT'][(followup1['serum_creatinine_above_avg']==1) & (followup1['DEATH_EVENT']==1)])
alive_fu1_cre1 = len(followup1['DEATH_EVENT'][(followup1['serum_creatinine_above_avg']==1) & (followup1['DEATH_EVENT']==0)])
dead_fu1_cre0 = len(followup1['DEATH_EVENT'][(followup1['serum_creatinine_above_avg']==0) & (followup1['DEATH_EVENT']==1)])
alive_fu1_cre0 = len(followup1['DEATH_EVENT'][(followup1['serum_creatinine_above_avg']==0) & (followup1['DEATH_EVENT']==0)])

print('-----------------------------------------------------')
print((dead_fu1_cre1, alive_fu1_cre1), (dead_fu1_cre0, alive_fu1_cre0))
print('-----------------------------------------------------')

fu1_cre1 = g.gini_impurity_binary(dead_fu1_cre1, alive_fu1_cre1)
fu1_cre0 = g.gini_impurity_binary(dead_fu1_cre0, alive_fu1_cre0)

print('Total Gini Impurity: ', g.total_gini_impurity_binary([fu1_cre1, fu1_cre0], (dead_fu1_cre1 + alive_fu1_cre1), (dead_fu1_cre0 + alive_fu1_cre0)))


# Serum Sodium
print('=====================================================')
print('SERUM SODIUM ABOVE 137 milliequivalents/litre (mEq/L)')
dead_fu1_na1 = len(followup1['DEATH_EVENT'][(followup1['serum_sodium_above_137']==1) & (followup1['DEATH_EVENT']==1)])
alive_fu1_na1 = len(followup1['DEATH_EVENT'][(followup1['serum_sodium_above_137']==1) & (followup1['DEATH_EVENT']==0)])
dead_fu1_na0 = len(followup1['DEATH_EVENT'][(followup1['serum_sodium_above_137']==0) & (followup1['DEATH_EVENT']==1)])
alive_fu1_na0 = len(followup1['DEATH_EVENT'][(followup1['serum_sodium_above_137']==0) & (followup1['DEATH_EVENT']==0)])

print('-----------------------------------------------------')
print((dead_fu1_na1, alive_fu1_na1), (dead_fu1_na0, alive_fu1_na0))
print('-----------------------------------------------------')

fu1_na1 = g.gini_impurity_binary(dead_fu1_na1, alive_fu1_na1)
fu1_na0 = g.gini_impurity_binary(dead_fu1_na0, alive_fu1_na0)

print('Total Gini Impurity: ', g.total_gini_impurity_binary([fu1_na1, fu1_na0], (dead_fu1_na1 + alive_fu1_na1), (dead_fu1_na0 + alive_fu1_na0)))


# Ejection Fraction
print('=====================================================')
print('EJECTION FRACTION OVER 38%')
dead_fu1_ef1 = len(followup1['DEATH_EVENT'][(followup1['ejection_fraction_over_38']==1) & (followup1['DEATH_EVENT']==1)])
alive_fu1_ef1 = len(followup1['DEATH_EVENT'][(followup1['ejection_fraction_over_38']==1) & (followup1['DEATH_EVENT']==0)])
dead_fu1_ef0 = len(followup1['DEATH_EVENT'][(followup1['ejection_fraction_over_38']==0) & (followup1['DEATH_EVENT']==1)])
alive_fu1_ef0 = len(followup1['DEATH_EVENT'][(followup1['ejection_fraction_over_38']==0) & (followup1['DEATH_EVENT']==0)])

print('-----------------------------------------------------')
print((dead_fu1_ef1, alive_fu1_ef1), (dead_fu1_ef0, alive_fu1_ef0))
print('-----------------------------------------------------')

fu1_ef1 = g.gini_impurity_binary(dead_fu1_ef1, alive_fu1_ef1)
fu1_ef0 = g.gini_impurity_binary(dead_fu1_ef0, alive_fu1_ef0)

print('Total Gini Impurity: ', g.total_gini_impurity_binary([fu1_ef1, fu1_ef0], (dead_fu1_ef1 + alive_fu1_ef1), (dead_fu1_ef0 + alive_fu1_ef0)))

SERUM CREATININE ABOVE AVERAGE - 1.39 milligrams/decilitre (mg/dL)
-----------------------------------------------------
(17, 77) (1, 39)
-----------------------------------------------------
Total Gini Impurity:  0.22239599872975546
SERUM SODIUM ABOVE 137 milliequivalents/litre (mEq/L)
-----------------------------------------------------
(5, 66) (13, 50)
-----------------------------------------------------
Total Gini Impurity:  0.22336339763289534
EJECTION FRACTION OVER 38%
-----------------------------------------------------
(5, 68) (13, 48)
-----------------------------------------------------
Total Gini Impurity:  0.22219466333278595


In [32]:
# Serum Creatinine goes next

## Decision Tree (so far part 2)

                      - Follow-Up >= 130 days -
                           /            \
                          /              \
                 Ejection Fraction     (78/87)
                     >= 38%
                    /      \
                   /        \
               (5/68)    (13/48)

In [33]:
# Keep going left...
# Make a copy of a DataFrame where all values for follow_up_over_130_days is True (1) AND serum_creatinine_above_avg is True (1)
fu1_ef1 = followup1[followup1['ejection_fraction_over_38']==1].copy()

g = Gini(len(fu1_ef1))  # Should be len == 5 + 68 == 94

In [35]:
# Serum Creatinine
print('=====================================================')
print('SERUM CREATININE ABOVE 1.39 mg/dL')
dead_fu1_ef1_cre1 = len(fu1_ef1['DEATH_EVENT'][(fu1_ef1['serum_creatinine_above_avg']==1) & (fu1_ef1['DEATH_EVENT']==1)])
alive_fu1_ef1_cre1 = len(fu1_ef1['DEATH_EVENT'][(fu1_ef1['serum_creatinine_above_avg']==1) & (fu1_ef1['DEATH_EVENT']==0)])
dead_fu1_ef1_cre0 = len(fu1_ef1['DEATH_EVENT'][(fu1_ef1['serum_creatinine_above_avg']==0) & (fu1_ef1['DEATH_EVENT']==1)])
alive_fu1_ef1_cre0 = len(fu1_ef1['DEATH_EVENT'][(fu1_ef1['serum_creatinine_above_avg']==0) & (fu1_ef1['DEATH_EVENT']==0)])

print('-----------------------------------------------------')
print((dead_fu1_ef1_cre1, alive_fu1_ef1_cre1), (dead_fu1_ef1_cre0, alive_fu1_ef1_cre0))
print('-----------------------------------------------------')

fu1_ef1_cre1 = g.gini_impurity_binary(dead_fu1_ef1_cre1, alive_fu1_ef1_cre1)
fu1_ef1_cre0 = g.gini_impurity_binary(dead_fu1_ef1_cre0, alive_fu1_ef1_cre0)

print('Total Gini Impurity: ', g.total_gini_impurity_binary([fu1_ef1_cre1, fu1_ef1_cre0], (dead_fu1_ef1_cre1 + alive_fu1_ef1_cre1), (dead_fu1_ef1_cre0 + alive_fu1_ef1_cre0)))

# Serum Sodium
print('=====================================================')
print('SERUM SODIUM ABOVE 137 milliequivalents/litre (mEq/L)')
dead_fu1_ef1_na1 = len(fu1_ef1['DEATH_EVENT'][(fu1_ef1['serum_sodium_above_137']==1) & (fu1_ef1['DEATH_EVENT']==1)])
alive_fu1_ef1_na1 = len(fu1_ef1['DEATH_EVENT'][(fu1_ef1['serum_sodium_above_137']==1) & (fu1_ef1['DEATH_EVENT']==0)])
dead_fu1_ef1_na0 = len(fu1_ef1['DEATH_EVENT'][(fu1_ef1['serum_sodium_above_137']==0) & (fu1_ef1['DEATH_EVENT']==1)])
alive_fu1_ef1_na0 = len(fu1_ef1['DEATH_EVENT'][(fu1_ef1['serum_sodium_above_137']==0) & (fu1_ef1['DEATH_EVENT']==0)])

print('-----------------------------------------------------')
print((dead_fu1_ef1_na1, alive_fu1_ef1_na1), (dead_fu1_ef1_na0, alive_fu1_ef1_na0))
print('-----------------------------------------------------')

fu1_ef1_na1 = g.gini_impurity_binary(dead_fu1_ef1_na1, alive_fu1_ef1_na1)
fu1_ef1_na0 = g.gini_impurity_binary(dead_fu1_ef1_na0, alive_fu1_ef1_na0)

print('Total Gini Impurity: ', g.total_gini_impurity_binary([fu1_ef1_na1, fu1_ef1_na0], (dead_fu1_ef1_na1 + alive_fu1_ef1_na1), (dead_fu1_ef1_na0 + alive_fu1_ef1_na0)))

SERUM CREATININE ABOVE 1.39 mg/dL
-----------------------------------------------------
(4, 45) (1, 23)
-----------------------------------------------------
Total Gini Impurity:  0.1268987046873543
SERUM SODIUM ABOVE 137 milliequivalents/litre (mEq/L)
-----------------------------------------------------
(2, 46) (3, 22)
-----------------------------------------------------
Total Gini Impurity:  0.12484018264840173


In [36]:
# Serum Sodium goes next

## Decision Tree (so far part 3)

                      - Follow-Up >= 130 days -
                           /            \
                          /              \
                 Ejection Fraction     (78/87)
                     >= 38%
                    /      \
                   /        \
            Serum Sodium  (13/48)
            >= 137 mEq/L
              /   \
             /     \
          (2/46)  (3/22)