In [4]:
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [17]:
dataset = pd.read_csv('fullEDfullmodel.csv')
dataset.head()

Unnamed: 0,TRAUMATYPE,SYSBP,RR,GCS,EDMOTOR,SI,SIRANK,AGE,SEX,RTS,...,PhysiologicalAMPT,LungAMPT,AMPT,Mechanism,RTSCode,AgeGroups,SBPCode,MotorCode,AMPT2,class
0,0,100,21,15,6,0.72,2,80.0,0,7.8408,...,0,0,0,0,0,5,0,0,0,T
1,0,103,22,14,6,0.650485,1,80.0,1,7.8408,...,0,0,0,1,0,5,0,0,0,T
2,0,96,12,15,6,0.791667,2,67.0,0,7.8408,...,0,0,0,1,0,5,0,0,0,T
3,1,141,20,15,6,0.553191,1,63.0,1,7.8408,...,0,0,0,0,0,5,0,0,0,T
4,0,105,24,15,6,0.695238,1,66.0,0,7.8408,...,0,1,2,0,0,5,0,0,1,T


In [18]:
X = dataset.iloc[:, :-1] #removes ['class']
y = dataset['class']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify = y)

In [19]:
from sklearn import tree
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import StratifiedKFold

tre = tree.DecisionTreeClassifier(random_state = 42)

sfk = StratifiedKFold(n_splits = 10)

#Parameters can be tuned more. I found max_depth impacted sensitivity the most in my testing.
tre_param = {'criterion':['gini'], 'max_depth': [4], 'max_leaf_nodes': [3, 5, 7]}

gs_tre = GridSearchCV(tre, tre_param, cv=sfk, n_jobs=-1)

gs_tre.fit(X_train, y_train)

#predict probability of classification with threshold of >= 0.25. Increased sensitivity by around 12%.
#This could definitely be tuned more.
tre_pred = (gs_tre.predict_proba(X_test) >= 0.25)

#Returns probability of classification as a tuple.
pred_values = gs_tre.predict_proba(X_test)
print(pred_values)

#predict_proba returns a 2 length boolean tuple. The remainder of this block makes it comparable to ['class']
tre_pred = [item[1] for item in tre_pred]
for i in range(len(tre_pred)):
    if tre_pred[i] == False:
        tre_pred[i] = 'F'
    else:
        tre_pred[i] = 'T'

[[0.93949384 0.06050616]
 [0.93949384 0.06050616]
 [0.12681951 0.87318049]
 ...
 [0.72767648 0.27232352]
 [0.93949384 0.06050616]
 [0.79601048 0.20398952]]


In [20]:
print(f'Best parameters: {gs_tre.best_params_}')

Best parameters: {'criterion': 'gini', 'max_depth': 4, 'max_leaf_nodes': 7}


In [21]:
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.model_selection import cross_validate

rsfk = RepeatedStratifiedKFold(n_splits = 10, n_repeats = 10, random_state = 1)

tre_score = cross_validate(tre, X, y, scoring = 'accuracy', cv = rsfk, n_jobs = -1)

print("Accuracy: " + str(tre_score['test_score'].mean()))

Accuracy: 0.7804862512574541


In [24]:
from sklearn.metrics import confusion_matrix

tn, fp, fn, tp = confusion_matrix(y_test, tre_pred).ravel()

print(f'Sensitivity: {tp/(tp+fn)}')
print(f'Specificity: {tn/(tn+fp)}')

print(f'Confusion Matrix:\n{confusion_matrix(y_test, tre_pred)}')

Sensitivity: 0.8140602582496413
Specificity: 0.7663752692258964
Confusion Matrix:
[[6049 1844]
 [ 648 2837]]


In [25]:
#Prints out a long text form of the tree

tre.fit(X_train, y_train)
tree.plot_tree(tre, feature_names = (dataset.columns)[:-1], class_names = dataset['class'])

#Saves a png of the tree. This takes a superl long time
# plt.savefig('tree_visualization.png')

[Text(0.4802384345054791, 0.9857142857142858, 'AMPT <= 0.5\ngini = 0.425\nsamples = 45510\nvalue = [31573, 13937]\nclass = T'),
 Text(0.24123283201876855, 0.9571428571428572, 'HEADAMPT <= 0.5\ngini = 0.176\nsamples = 26585\nvalue = [23991, 2594]\nclass = T'),
 Text(0.13253976809381005, 0.9285714285714286, 'ABDAMPT <= 0.5\ngini = 0.114\nsamples = 19717\nvalue = [18524, 1193]\nclass = T'),
 Text(0.07783125739398392, 0.9, 'NECKAMPT <= 0.5\ngini = 0.07\nsamples = 16300\nvalue = [15710, 590]\nclass = T'),
 Text(0.05251879524487137, 0.8714285714285714, 'CHESTAMPT <= 0.5\ngini = 0.049\nsamples = 15289\nvalue = [14901, 388]\nclass = T'),
 Text(0.030759828446461083, 0.8428571428571429, 'SYSBP <= 262.5\ngini = 0.021\nsamples = 13233\nvalue = [13094, 139]\nclass = T'),
 Text(0.030591993821260058, 0.8142857142857143, 'LEGAMPT <= 0.5\ngini = 0.021\nsamples = 13232\nvalue = [13094, 138]\nclass = T'),
 Text(0.026705773035794012, 0.7857142857142857, 'Mechanism <= 0.5\ngini = 0.034\nsamples = 6294\nval

Error in callback <function flush_figures at 0x0000014A19FB1430> (for post_execute):


KeyboardInterrupt: 