-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
134 lines (118 loc) · 6.07 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
'''
@info covid-19 predictability tool
@version 1.0
'''
import dash, gui_utils, numpy as np, pickle
import dash_html_components as html
from dash.dependencies import Input, Output
''' =========================== '''
''' ====== A: PARAMETERS ====== '''
''' =========================== '''
pagetitle = 'COVID-19: predictability tool'
stage = ['positive test','hospitalized','intensive care']
comorbidities = {
'Asthma':'ASTH',
'Cancer':'CANC',
'Cardiovascular disease':'CARDIACDIS',
'Chronic hematologic disease':'CHD',
'Diabetes':'DIAB',
'HIV':'HIV',
'Kidney disease':'KIDNEY',
'Hepatic disease':'LIVER',
'Lung disease':'LUNG',
'Neuromuscular disease':'NEUROMUS',
'Dyslipidemia':'DISLIP',
'Pregnant 1st quarter':'PREG1',
'Pregnant 2nd quarter':'PREG2',
'Pregnant 3rd quarter':'PREG3',
'Post-pregancy':'PREGPOST'
}
result = [html.Br(),html.B('Result:'),html.Div('Input the patient profile...')]
parameters = [
("Patient profile",[
('stage',stage,gui_utils.Button.radio,'positive test'),
('age','67',gui_utils.Button.input,None),
('gender',['male','female'],gui_utils.Button.radio,'male')]),
("Clinical history",[('comorbidities',comorbidities.keys(),gui_utils.Button.multidrop,['Cancer','Diabetes','Hepatic disease'])])]
layout = gui_utils.get_layout(pagetitle,parameters,[('result',result,gui_utils.Button.html)])
models_path = "./calc_models/"
files = ["calc_hosp_recall", "calc_hosp_f1", "calc_ic1_recall", "calc_ic1_f1", "calc_ic2_recall", "calc_ic2_f1",
"calc_out1_recall", "calc_out1_f1", "calc_out2_recall", "calc_out2_f1", "calc_out3_recall", "calc_out3_f1",
"calc_rs_recall", "calc_rs_f1"]
classifiers = []
for file in files:
filename = models_path + file + ".sav"
clf = pickle.load(open(filename, 'rb'))
classifiers.append(clf)
target_dict = {"Hosp": {"phrase": ("need hospitalization","needing"), "classifiers": (classifiers[0], classifiers[1])},
"IC1": {"phrase": ("need intensive care","needing"), "classifiers": (classifiers[2], classifiers[3])},
"IC2": {"phrase": ("need intensive care","needing"), "classifiers": (classifiers[4], classifiers[5])},
"Outcome1": {"phrase": ("pass away","passing"), "classifiers": (classifiers[6], classifiers[7])},
"Outcome2": {"phrase": ("pass away","passing"), "classifiers": (classifiers[8], classifiers[9])},
"Outcome3": {"phrase": ("pass away","passing"), "classifiers": (classifiers[10], classifiers[11])},
"RespSupport": {"phrase": ("need respiratory support","needing"), "classifiers": (classifiers[12], classifiers[13])},
}
to_use = {"positive test": ["Hosp", "IC1", "Outcome1"], "hospitalized": ["IC2", "RespSupport","Outcome2"],
"intensive care": ["Outcome3"]}
translations = {"Hosp": "Hospitalization", "IC1": "Intensive Care", "IC2": "Intensive Care", "Outcome1": "Outcome",
"Outcome2": "Outcome", "Outcome3": "Outcome","RespSupport": "Respiratory Support"}
age_mean= 48.026216
age_std = 24.804093
''' ==========================='''
''' ====== C: UPDATE GUI ====== '''
''' =========================== '''
app = dash.Dash(__name__, assets_folder = 'assets', include_assets_files = True)
@app.callback(Output('result','children'),[Input('button','n_clicks')],gui_utils.get_states(parameters)) #[State('stage','value'),State('age','value'),State('gender','value'),State('comorbidities','value')])
def update_map(inp,*args):
states = dash.callback_context.states
print("states",states)
if inp is None: return result
stage = states['stage.value']
age = (int(states['age.value']) - age_mean) / age_std
gender = 0 if states['gender.value'] == "female" else 1
data = [age, gender]
sel_comorbidities = states['comorbidities.value']
#################################################
if stage == "hospitalized" or stage == "intensive care": data.append(1)
if stage == "intensive care": data.append(1)
for comorb in comorbidities:
if comorb in sel_comorbidities: data.append(1)
else: data.append(0)
data = np.array(data).reshape(1,-1)
#################################################
to_return = [html.Br(),html.B('Result:'),html.Br(), html.Br()]
for target in to_use[stage]:
print("Target: ",target)
clf1 = target_dict[target]["classifiers"][0]
clf2 = target_dict[target]["classifiers"][1]
if target == "RespSupport":
print("resp")
pred1 = clf1.predict_proba(data)[0]
pred2 = clf2.predict_proba(data)[0]
result1 = [pred1[0]] + [pred1[1]+pred1[2]]
result2 = [pred2[0]] + [pred2[1]+pred2[2]]
else:
result1 = clf1.predict_proba(data)[0]
result2 = clf2.predict_proba(data)[0]
print("Results for target: {}\n{}\n{}".format(target,result1,result2))
yes1, yes2 = 'will' if result1[1]>0.5 else 'won\'t', 'will' if result2[1]>0.5 else 'won\'t'
phrase = target_dict[target]["phrase"][0]
verb = target_dict[target]["phrase"][1]
to_return.append(html.B(translations[target]))
to_return.append(html.Div('According to classifier that maximizes recall:'))
to_return.append(html.Div('The patient ' + yes1 +" "+ phrase + ' (' + str(int(result1[1] * 100)) + '% of ' + verb +')'))
to_return.append(html.Br(),)
to_return.append(html.Div('According to classifier that maximizes F1-score:'))
to_return.append(html.Div('The patient '+yes2+" "+ phrase +' ('+str(int(result2[1]*100))+'% of '+ verb+')'))
to_return.append(html.Br())
to_return.append(html.Div(style={"border":"1px black solid"}))
to_return.append(html.Br())
to_return.append(html.Div('Disclaimer: This is a clinical decision support tool and cannot be used to make a final decision without medical advice.'))
return to_return
''' ===================== '''
''' ====== D: MAIN ====== '''
''' ====================== '''
if __name__ == '__main__':
app.config.suppress_callback_exceptions = True
app.layout = layout
app.run_server()