# Decision tree for PCR testing

In [None]:
import plotly.express as px
import seaborn as sns
import pandas as pd
import numpy as np
import seaborn as sns
from IPython.display import HTML
from sklearn import tree
import plotly.express as px

import sklearn
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import RidgeClassifier, LogisticRegression, RidgeClassifierCV
from sklearn.feature_selection import SelectKBest
from sklearn.metrics import make_scorer, average_precision_score
from sklearn.inspection import permutation_importance
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, plot_precision_recall_curve
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV, StratifiedShuffleSplit, cross_validate, cross_val_score, KFold, StratifiedKFold

from joblib import load

In [None]:
symptoms = ['tiredness',
 'fever',
 'shivers',
 'cough',
 'breathlessness',
 'aches',
 'chest_opression',
 'chest_pain',
 'diarrhea',
 'vomiting',
 'sensoriel', 
 'anosmia',
 'ageusia',
 'anorexia',
 'rash',
 'frostbites',
 'conjunctivitis',
 'other_sympt']

FEATURES = ['tiredness',
 'fever',
 'shivers',
 'cough',
 'breathlessness',
 'aches',
 'breath_thorac',
 'digest',
 'sensoriel', 
 'anorexia',
 'cutan', 
 'conjunctivitis', 
           ]

FEATURE_NAMES = {'tiredness':'Tiredness',
 'fever':'Fever',
 'shivers':'Shivers',
 'cough':'Cough',
 'breathlessness':'Short of breath',
 'aches':'Myalgia',
 'breath_thorac':'Cardiopulmonary',
 'digest':'Digestive symptoms',
 'sensoriel':'Anosmia or ageusia', 
 'anorexia': 'Anorexia',
 'cutan':'Cutaneous', 
 'conjunctivitis':'Conjunctivitis',
}

# Functions

In [None]:
def prepare(main, effect_col, cols, query_filter='', 
            replace_names = {}):
    # Filter
    df = main.copy()
    df = df.dropna(subset = [effect_col])
    if query_filter!= "":
        df = df.query(query_filter)
    print(f'dropped {len(main)-len(df)} rows')

    # Replace
    for col, names in replace_names.items():
        df[col] = df[col].replace(names)

    # X, y 
    X = df[cols].values
    y = df[effect_col].values
    sample_weight = df['sample_weight'].values
    
    return X, y, sample_weight


def gridsearch(X_train, y_train, sample_weight, hyper_params_grid) :

    clf = GridSearchCV(
        DecisionTreeClassifier(), hyper_params_grid, 
        scoring=make_scorer(average_precision_score, needs_proba=True), n_jobs=20
    )
    
    print('Fitting gridsearch ... ')
    clf.fit(X_train, y_train, sample_weight=sample_weight)
    print('Done, best params : ', clf.best_params_)
    
    # Describe
    means = clf.cv_results_['mean_test_score']
    stds = clf.cv_results_['std_test_score']
    print(pd.Series(means).describe())    
    return clf.best_params_, clf.best_estimator_


def select_variables(X_train, y_train, best_params, features, 
                    plot=False):
    
    # Compute permutation importance
    cv = StratifiedShuffleSplit(n_splits=100, test_size=.5)
    n_repeats = 10
    mAP = []
    fimp = np.zeros((X_train.shape[1], cv.get_n_splits(), n_repeats))
    for i, (train, test) in enumerate(cv.split(X_train, y_train)):
        tree = DecisionTreeClassifier(**best_params)
        tree.fit(X_train[train], y_train[train])
        y_pred = tree.predict_proba(X_train[test])[:,1]
        mAP.append(average_precision_score(y_train[test], y_pred))
        result = permutation_importance(tree, X_train[test], y_train[test], 
                                        n_repeats=n_repeats, n_jobs=5, random_state=seed, 
                                        scoring=make_scorer(average_precision_score, needs_proba=True))
        fimp[:, i] = result.importances
    mAP = np.mean(mAP)
    
    # Build df 
    ncols, nsplits, n_repeats = fimp.shape
    df = pd.DataFrame({
        'feature_id': np.repeat(np.arange(ncols), nsplits*n_repeats), 
        'feature':np.repeat(features, nsplits*n_repeats), 
        'fimp':fimp.reshape(-1),
        'split':np.repeat(range(nsplits), ncols*n_repeats),
        'permut':np.tile(range(n_repeats), ncols*nsplits),
    })
    df = df.groupby(['feature', 'split']).agg('mean').reset_index() # average on repeats 
    df['fimp_mean'] = df.groupby('feature')['fimp'].transform('mean')
    
    # Plot 
    if plot:
        fig, ax = plt.subplots()
        sns.barplot('fimp', 'feature', data=df.sort_values('fimp_mean', ascending=False), ax=ax)
        labels = ax.get_yticklabels()
        ax.set_yticklabels([FEATURE_NAMES[i.get_text()] for i in labels])
        plt.xlabel('Permutation importance on test sets')
        plt.ylabel('')
        plt.show()
    
    # Select variables (cumulated sum better than chance)
    cumsum = df.groupby('feature').agg('mean').sort_values('fimp_mean', ascending=False)['fimp_mean'].cumsum() 
    selected_variables = cumsum.index[(cumsum < (mAP-(y_train.sum()/len(y_train))))].values
    print('Selected variables : ', selected_variables)
                   
    return selected_variables, df

# Load and preprocess data

In [None]:
merged = pd.read_csv('data/shuffled.csv')

In [None]:
merged.shape

In [None]:
from tools import enrich_survey

In [None]:
merged = merged.pipe(enrich_survey)

In [None]:
merged.shape

In [None]:
# Preprocess
na_labels = {'pcr_result':-1, }
main = merged.fillna(na_labels)
main = main.replace({'Non':0, 'Oui':1})
main['sensoriel'] = main[['anosmia', 'ageusia']].max(axis=1)
main['num_symp'] = main[symptoms].sum(axis=1)
main[symptoms] = main[symptoms].fillna(0).astype('int')

# Groupements 
main['chest'] = main[['chest_pain', 'chest_opression']].max(axis=1)
main['cutan'] = main[['rash', 'frostbites']].max(axis=1)
main['digest'] = main[['vomiting', 'diarrhea']].max(axis=1)
main['breath_thorac'] = main[['chest', 'breathlessness']].min(axis=1)

In [None]:
len(main)

# Stats

In [None]:
from datetime import timedelta
merged['delay_symp'] = (pd.to_datetime(merged['survey_time'])-pd.to_datetime(merged['symptom_start_time'])).dt.round('D').dt.days.astype(int)
merged['delay_start'] = (pd.to_datetime(merged['survey_time'])-pd.to_datetime(merged['start_time'])).dt.round('D').dt.days.astype(int)

df = merged.query('delay_start>=0')
fig = plt.figure()
sns.distplot(df['delay_start'], kde=False, bins=np.arange(100),  hist_kws={'alpha':.8})
plt.xlim(0, 40)
plt.xlabel('Number of days between covidom inscription and survey answer')
plt.ylabel('Number of patients')
plt.show()
fig.savefig('output/charlotte_delay_covidom_answer.pdf', format='pdf', bbox_inches='tight')


fig = plt.figure()
sns.distplot(df['delay_symp'], kde=False, bins=np.arange(100), hist_kws={'alpha':.8})
plt.xlim(0, 40)
plt.xlabel('Number of days between first symptoms and survey answer')
plt.ylabel('Number of patients')
plt.show()
fig.savefig('output/charlotte_delay_symp_answer.pdf', format='pdf', bbox_inches='tight')

merged[['delay_symp', 'delay_start']].describe().to_csv('output/charlotte_delay_answer.csv')

In [None]:
df[['delay_symp', 'delay_start']].describe()

In [None]:
df = merged.copy()
pcr_cols = [f'eds_pcr_{rank}' for rank in np.arange(1, 11)]
delay_cols = [f'delay_pcr_{rank}' for rank in np.arange(1, 11)]
for col in pcr_cols:
    df.loc[:, col] = pd.to_datetime(df[col]) #.dt.round('D')
df[delay_cols] = df[pcr_cols].sub(df['eds_pcr_1'], axis='index')
df = df[delay_cols].describe().T
for col in [i for i in df.columns if i!='count']:
    df[col] = df[col].dt.round('D')
df.to_csv('output/charlotte_pcr_delay.csv')

In [None]:
main_ = main.copy()

# Set params

In [None]:
WEIGHTING = False
ONLY_AMBULATORY = False

NAME = f'{"_correct" if WEIGHTING else ""}{"_ambulatory" if ONLY_AMBULATORY else ""}'

exp ={
    'target':'pcr_result',
    'query':'pcr_result!=-1',
    'cols':FEATURES, 
    'label': 'PCR', 
}

hyper_params_grid = {
            'class_weight' : ['balanced', None], 
            'max_depth': [1, 2, 3, 4, 5, 6, 7, 8, 10, 20, None], 
            'min_impurity_decrease':[1e-3, 1e-4, 5e-4, 1e-5, 5e-5, 1e-6, 1e-7, 1e-8], 
            'min_samples_split': [5, 10, 50, 100, 500, 1000],
            'min_samples_leaf':[50, 100, 500, 1000],
            }

N_REPEATS=50

if ONLY_AMBULATORY:
    main = main_.loc[~main['hospitalized']].copy()
else:
    main = main_.copy()

from collections import defaultdict
SYMPTOM_DICT = {}
for s in ['tiredness', 'fever', 'cough', 'breathlessness', 'aches',
          'anorexia', 'anosmia', 'ageusia', 'headache', 
#           'upper_respiratory',
          'conjunctivitis']:
    SYMPTOM_DICT[s] = [s]
SYMPTOM_DICT['cutaneous'] = ['rash', 'frostbites']
SYMPTOM_DICT['digestive'] = ['diarrhea', 'vomiting', 'abdo_pain']
SYMPTOM_DICT['cardiopulmonary'] = ['breathlessness', 'chest_opression', 'chest_pain']
for k, v in SYMPTOM_DICT.items():
    main[k] = np.any(main[v], axis=1)
SYMPTOMS = list(SYMPTOM_DICT.keys())

SYMPTOMS = ['tiredness',
 'fever',
 'cough',
 'breathlessness',
 'aches',
 'anorexia',
 'anosmia',
 'ageusia',
 'headache',
 'conjunctivitis',
 'cutaneous',
 'digestive',
 'cardiopulmonary']

SEX = ['male', 'female', 'undetermined']
TOBACCO = ['smoker_current', 'no_smoker']
COMORBIDITIES = ['no_comorbidity', 'any_comorbidity', 'respiratory', 'cardio-vascular', 'diabetes', 'obesity']
HOSPITALIZED = ['hospitalized','non_hospitalized']
INCLUSION_REASONS = ['samu', 'urgence']
AGE = main['binned_age'].cat.categories.tolist()
X_weight = main[SEX + TOBACCO + COMORBIDITIES + AGE + SYMPTOMS + HOSPITALIZED]

y_weight = main['test_done'].astype(bool)

from sklearn.linear_model import LogisticRegression

lr = LogisticRegression(C=10000)

lr.fit(X_weight, y_weight)

main['sample_weight'] = 1
main['p_test'] = lr.predict_proba(X_weight)[:, 1]
main.loc[merged['test_done'], 'sample_weight'] = 1 / main.loc[main['test_done'], 'p_test']
main.loc[~merged['test_done'], 'sample_weight'] = 1 / (1 - main.loc[~main['test_done'], 'p_test'])
main['sample_weight'] /= main.loc[main['test_done'], 'sample_weight'].sum() / len(main.loc[main['test_done']])

if not WEIGHTING:
    main['sample_weight'] = 1

X, y, sample_weight = prepare(main, exp['target'], exp['cols'], 
               query_filter=exp['query'])

seed = 42
np.random.seed(seed)
cv = StratifiedKFold(5, shuffle=True, random_state=seed)
splits = list(cv.split(X, y))

In [None]:
NAME

# Run experiment on one split

In [None]:
train, test = splits[0]
best_params, best_model = gridsearch(X[train], y[train], sample_weight[train], hyper_params_grid)
tree = DecisionTreeClassifier(**best_params)
tree.fit(X[train], y[train], sample_weight=sample_weight[train])

train, test = splits[0]
best_params = {'class_weight': 'balanced', 'max_depth': 10, 'min_impurity_decrease': 1e-07, 'min_samples_leaf': 50, 'min_samples_split': 5}
#best_params = {'class_weight': 'balanced', 'max_depth': 10, 'min_impurity_decrease': 1e-08, 'min_samples_leaf': 100, 'min_samples_split': 5}
tree = DecisionTreeClassifier(**best_params)
tree.fit(X[train], y[train], sample_weight[train])

# Feature importance on the test set 
sns.set_style('white')
fig = plt.figure()
result = permutation_importance(tree, X[test], y[test], 
                                n_repeats=N_REPEATS, n_jobs=5, random_state=seed, 
                                scoring=make_scorer(average_precision_score, needs_proba=True, sample_weight=sample_weight[test]))

df = pd.DataFrame({'feats':np.repeat(exp['cols'], N_REPEATS), 'fimp':result.importances.reshape(-1)})
df['fimp_mean'] = df.groupby('feats')['fimp'].transform('mean')

# Save as csv 
df.to_csv(f'output/charlotte_fimp{NAME}.csv')

# Plot 
sns.barplot('fimp', 'feats', data=df.replace(FEATURE_NAMES).sort_values('fimp_mean', ascending=False), ci='sd')
sns.despine()
plt.xlabel('Decrease in average precision \n when randomly shuffling the feature')
plt.ylabel('')
plt.show()

fig.savefig(f'output/charlotte_fimp{NAME}.pdf', format='pdf', bbox_inches='tight')

# Plot performance 
fig = plt.figure()

pred = tree.predict_proba(X[test])[:,1]
precision, recall, _ = precision_recall_curve(y[test], pred, sample_weight=sample_weight[test])
AP = average_precision_score(y[test], pred, sample_weight=sample_weight[test])
sns.lineplot(recall, precision, label='Mean average precision: {:.2}'.format(AP))
sns.despine()
plt.xlabel('Recall (sensitivity)')
plt.ylabel('Precision \n (positive predictive value)')
baseline = (sample_weight[test] * y[test]).sum()/sample_weight[test].sum()
plt.axhline(baseline, label='Chance level: {:.2}'.format(baseline),
            linestyle='--', linewidth=1, color='grey')                             
plt.legend()
plt.show()
fig.savefig(f'output/charlotte_pr{NAME}.pdf', format='pdf', bbox_inches='tight')

In [None]:
# Plot on train
plt.figure(figsize=(15, 8))
feature_names=[FEATURE_NAMES[i] for i in np.array(exp['cols'])]
t = plot_tree(tree, feature_names=feature_names, filled=True, 
              fontsize=8, impurity=True, label=None, 
                   proportion=True, rounded=True, precision=2, max_depth=4)

# Display tree

In [None]:
# Plot 
from IPython.display import HTML
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.io as pio
import plotly.express as px
import matplotlib
import plotly
import copy
import statsmodels.api as sm
import re

def renumber(df_flux, nodes):
    r = re.compile(r'\(#[0-9]*\)')
    mapping = {}
    mapping_label = {}
    mapped_nodes = {}
    for i, (k,v) in enumerate(nodes.items()):
        mapping[k] = i
        mapped_nodes[i] = copy.copy(v)
        mapped_nodes[i]['id'] = i
        mapped_nodes[i]['label'] = r.sub(f'(#{i:d})', v['label'])
        mapping_label[v['label']] = mapped_nodes[i]['label']
    df_flux = df_flux.replace(mapping)
    nodes = mapped_nodes
    return df_flux, nodes

def build_flux(tree, cols, values=None, max_depth_plot=None, 
              use_hashtag=False, only_hashtag=False, add_stats=True):
    t = tree.tree_
    if values is None:
        values = t.value.squeeze()
    if max_depth_plot is None:
        max_depth_plot = tree.max_depth
    col_dico = {i:c for i,c in enumerate(cols)}

    # Compute depth and is_leave
    node_depth = np.zeros(shape=t.node_count, dtype=np.int64)
    is_leaves = np.zeros(shape=t.node_count, dtype=bool)
    stack = [(0, -1)]  # seed is the root node id and its parent depth
    while len(stack) > 0:
        node_id, parent_depth = stack.pop()
        node_depth[node_id] = parent_depth + 1

        # If we have a test node
        if (t.children_left[node_id] != t.children_right[node_id]): #and parent_depth < MAX_DEPTH_PLOT:
            stack.append((t.children_left[node_id], parent_depth + 1))
            stack.append((t.children_right[node_id], parent_depth + 1))
        else:
            is_leaves[node_id] = True

    # Build flux and nodes 
    flux = {}
    x = {0:0, 1:.2, 2:.45, 3:.55, 4:1}
    nodes = {i:{} for i in range(t.node_count)}
    nodes[0]['y'] = .5
    for i in range(t.node_count):
        if node_depth[i] > max_depth_plot:
            nodes[i] = None
        else:
            nodes[i]['x']= x[node_depth[i]] if max_depth_plot==4 else node_depth[i]/t.max_depth
            #nodes[i]['x']= 
            nodes[i]['values']=values[i]
#             prob = values[i][1]/values[i].sum()
            
#             odd = (values[i][1]/values[i][0])/(values[:,1].sum()/values[:,0].sum())
#             nodes[i]['odd']=odd  
#             odd = 
            name = np.array(cols)[t.feature[i]]
            nodes[i]['name']=name
            if use_hashtag:
                label = '(#{}) '.format(i) 
            else:
                label = ''
            if not (is_leaves[i] or node_depth[i] == max_depth_plot):
                label += '{}'.format(name)                 
                flux[(i, t.children_left[i])] = values[t.children_left[i]].sum()
                flux[(i, t.children_right[i])] = values[t.children_right[i]].sum()
                step = 1/(2**node_depth[i])
                nodes[t.children_left[i]]['y'] = nodes[i]['y'] + step/3
                nodes[t.children_right[i]]['y'] = nodes[i]['y'] - step/3
#             elif add_stats:
#                 label += '{:.0%} pcr+, odds x{:.1f}'.format(prob, odd) 
            if only_hashtag : 
                label = '(#{}) '.format(i)
            nodes[i]['label']=label
            nodes[i]['id']=i
            nodes[i]['final'] = is_leaves[i] or node_depth[i] == max_depth_plot

    nodes = {k:v for k,v in nodes.items() if v is not None}

    # Convert flux to dataframe 
    df = pd.DataFrame(flux.values(), index=flux.keys()).reset_index()
    df.columns = ['source', 'target', 'value']
    return renumber(df, nodes)


def plot_tree_flow_chart(df, nodes):
    cmap = matplotlib.cm.get_cmap('coolwarm_r')
    norm = matplotlib.colors.DivergingNorm(vmin=0.1, vcenter=1., vmax=6)

    save=True
    fig = go.Figure(
        data=[go.Sankey(
            node=dict(
                pad=10,
                thickness=20,
                line=dict(color="black", width=0.5),
                label = [v['name'] for k, v in nodes.iterrows()], 
                x = [v['x'] for k,v in nodes.iterrows()],
                y = [v['y'] for k,v in nodes.iterrows()],
                color = [f"rgba{cmap(norm(v['odds']))}" for k, v in nodes.iterrows()]
            ),
            link = dict(
                source = [nodes.iloc[k]['id_num'] for k in df.source],
                target = [nodes.iloc[k]['id_num'] for k in df.target], 
                value = df.value.values,
            ),
        )])
    return fig

def tune_nodes(nodes):
    df = pd.DataFrame(nodes).T
    df['Splitting criteria'] = df['name']
    df['PCR+ patients'] = [int(np.round(i[1], 0)) for i in df['values']]
    df['PCR- patients'] = [int(np.round(i[0], 0)) for i in df['values']]
    df['Number of patients'] = df[['PCR+ patients', 'PCR- patients']].sum(axis=1).round(0).astype(int)
    df['PCR+ probability'] = (df['PCR+ patients']/df['Number of patients']*100).round(0).astype(int).astype(str) + '%'
    
    total_patient = df['Number of patients'].iloc[0]
    total_pcrp = df['PCR+ patients'].iloc[0]
    y = np.zeros(total_patient)
    y[:total_pcrp] = 1
    
    
#     df['Odds ratio'] = (df['PCR+ patients']/df['PCR- patients'])/(df['PCR+ patients'].sum()/df['PCR- patients'].sum())
#     
#     df['Odds ratio'] = df['Odds ratio'].round(2)

    
    
    df['id_num'] = df['id']
    df['id'] =  '(#' + df['id'].astype(str) + ')'
    df['Odds ratio'] = 0
    df['Odds ratio CI'] = 0
    for i, row in df.iterrows():
        X = np.zeros(total_patient)
        X[:row['PCR+ patients']] = 1
        X[total_pcrp:total_pcrp + row['PCR- patients']] = 1
        X = sm.add_constant(X)
        res = sm.Logit(y, X).fit(disp=0)
        if len(res.params) == 1:
            odds_l, odds, odds_u = 1, 1, 1
        else:
            odds_l, odds, odds_u = np.exp(res.params[1] - res.bse[1]), np.exp(res.params[1]), np.exp(res.params[1] + res.bse[1])
        
        df.loc[i, "odds"] = odds
        df.loc[i, "Odds ratio"] = np.round(odds, 2)
        df.loc[i, "Odds ratio CI"] = f"[{odds_l:.2f}-{odds_u:.2f}]"
        if row['final']:
#             print(f'{row["label"]} {row["PCR+ probability"]} - odds {row["Odds ratio"]:.1f}x')
            df.loc[i, 'name'] = f'{row["label"]}{row["PCR+ probability"]} PCR+, odds {odds_l:.1f}-{odds_u:.1f}x'
        else:
            df.loc[i, 'name'] = row['label']
    return df 

import matplotlib
import seaborn as sns
cmap = matplotlib.cm.get_cmap('coolwarm_r')
# norm = matplotlib.colors.TwoSlopeNorm(vmin=0.1, vcenter=1., vmax=6)

sns.set_style('whitegrid')
fig, ax = plt.subplots(figsize=(.3, 1.3))

cb1 = matplotlib.colorbar.ColorbarBase(ax, spacing='proportional',drawedges=False, 
                                       cmap=cmap, orientation='vertical')
sns.despine(left=True)

ax.set_title('Odds ratio')
cb1.set_ticks([0.1, 1, 6])
#cb1.set_ticklabels([ 1, 4, 6])
fig.savefig(f'output/charlotte_cmap.pdf', format='pdf', bbox_inches='tight')

### Plot on test

In [None]:
# Set parameters
cols=[FEATURE_NAMES[i] for i in np.array(exp['cols'])]

# Compute values for test (LONG)
mat = tree.decision_path(X[test]).toarray()
values = np.zeros((tree.tree_.node_count, 2))
for sample_id, nodes in enumerate(mat):
    target = int(y[test][sample_id])
    values[np.nonzero(nodes), target] += sample_weight[test][sample_id]

columns = ['label', 'Number of patients', 'PCR+ patients', 'PCR- patients', 'PCR+ probability', 'Odds ratio', 'Odds ratio CI']

# Build flux
df_flux, nodes = build_flux(tree, cols, values=values.copy(), max_depth_plot=4, use_hashtag=False)

# Save nodes 
df_nodes = tune_nodes(nodes)
df_nodes[columns].to_csv(f'output/charlotte_tree{NAME}.csv')

# Plot 
fig = plot_tree_flow_chart(df_flux, df_nodes)

fig.update_layout(
    title_text=f"Decision tree to predict PCR",
    width=1000,
    height=450,
    showlegend=True,
)

s = fig.to_html('')

plotly.offline.plot(fig, filename = f'output/charlotte_tree{NAME}.html', auto_open=False)
HTML(s)

### Plot on test with hasthag 

In [None]:
# Plot on test with hashtag (annexe)

# Build flux
df_flux, nodes = build_flux(tree, cols, values=values.copy(), max_depth_plot=4, 
                            use_hashtag=True, add_stats=False)

# Save nodes 
df_nodes = tune_nodes(nodes)
df_nodes[columns].to_csv(f'output/charlotte_tree_hashtag{NAME}.csv')

# Plot 
fig = plot_tree_flow_chart(df_flux, df_nodes)

fig.update_layout(
    title_text=f"Full decision tree to predict PCR on held-out patients",
    width=1000,
    height=450,
    showlegend=True,
)

s = fig.to_html('')

plotly.offline.plot(fig, filename = f'output/charlotte_tree_hashtag{NAME}.html', auto_open=False)
HTML(s)

### Plot on whole data

In [None]:
# Plot on whole data with hashtag (annexe)

# Compute values for test (LONG)
mat = tree.decision_path(X).toarray()
values = np.zeros((tree.tree_.node_count, 2))
for sample_id, nodes in enumerate(mat):
    target = int(y[sample_id])
    values[np.nonzero(nodes), target] += sample_weight[sample_id]

# Build flux
df_flux, nodes = build_flux(tree, cols, values=values, max_depth_plot=4, add_stats=False, 
                            use_hashtag=True)

# Save nodes 
df_nodes = tune_nodes(nodes)
df_nodes[columns].to_csv(f'output/charlotte_tree_whole_data{NAME}.csv')

# Plot 
fig = plot_tree_flow_chart(df_flux, df_nodes)

fig.update_layout(
    title_text=f"Decision tree to predict PCR on the entire tested cohort",
    width=1000,
    height=450,
    showlegend=True,
)

s = fig.to_html('')

plotly.offline.plot(fig, filename = f'output/charlotte_tree_whole_data{NAME}.html', auto_open=False)
HTML(s)

# Cross val

In [None]:
all_fimp = []
pr = []
mAP = []
PLOT=True
for i, (train, test) in enumerate(splits):
    best_params, best_model = gridsearch(X[train], y[train], sample_weight[train], hyper_params_grid)
    tree = DecisionTreeClassifier(**best_params)
    tree.fit(X[train], y[train], sample_weight=sample_weight[train])
    
    # Precision, recall 
    pred = tree.predict_proba(X[test])[:,1]
    precision, recall, _ = precision_recall_curve(y[test], pred, sample_weight=sample_weight[test])
    pr.append((precision, recall))
    AP = average_precision_score(y[test], pred, sample_weight=sample_weight[test]) 
    mAP.append(AP)

    # Permutation importance 
    result = permutation_importance(tree, X[test], y[test],
                                n_repeats=N_REPEATS, n_jobs=5, random_state=seed, 
                                scoring=make_scorer(average_precision_score, needs_proba=True, sample_weight=sample_weight[test]))
    fimp_df = pd.DataFrame({'fimp':result.importances.reshape(-1), 
                            'feats':np.repeat(exp['cols'], N_REPEATS), 
                           })
    fimp_df['cross_val_split'] = i
    all_fimp.append(fimp_df)
    
    # Plot tree 
    if PLOT:
        plt.figure(figsize=(15, 8))
        feature_names=[FEATURE_NAMES[i] for i in np.array(exp['cols'])]
        t = plot_tree(tree, feature_names=feature_names, filled=True, 
                      fontsize=8, impurity=True, label=None, 
                           proportion=True, rounded=True, precision=2, max_depth=4)

all_fimp = pd.concat(all_fimp, axis=0)

# Save as csv 
all_fimp.to_csv('output/charlotte_fimp_cross_val.csv')
np.save('output/charlotte_pr_cross_val.npy', pr)

# Plot RF performance on test sets 
sns.set_style()
fig = plt.figure()
for i, (precision, recall) in enumerate(pr):
    sns.lineplot(recall, precision, label=str(i))
sns.despine()
plt.xlabel('Recall (sensitivity)')
plt.ylabel('Precision \n (positive predictive value)')
plt.axhline((sample_weight[test] * y[test]).sum() / sample_weight[test].sum(), label='Chance level', linestyle='--', linewidth=1, color='grey')
plt.title('Precision recall curve on test set \n for 5 cross validation splits')
plt.legend(title='Cross validation splits', loc='upper right')
plt.show()
fig.savefig(f'output/charlotte_pr_cross_val{NAME}.pdf', format='pdf', bbox_inches='tight')

# Plot variables 
all_fimp = pd.read_csv('output/charlotte_fimp_cross_val.csv')
fig, ax = plt.subplots()
all_fimp['feature_mean'] = all_fimp.groupby('feats')['fimp'].transform('mean')
sns.barplot('fimp', 'feats', data=all_fimp.sort_values('feature_mean', ascending=False), 
            ax=ax, hue='cross_val_split', ci='sd')
sns.despine()
labels = ax.get_yticklabels()
ax.set_yticklabels([FEATURE_NAMES[i.get_text()] for i in labels])
plt.ylabel('')
plt.xlabel('Decrease in average precision \n when randomly shuffling the feature')
plt.title('Feature permutation importance on test sets \n for 5 cross validation splits')
plt.legend(title='Cross validation splits')
plt.show()
fig.savefig(f'output/charlotte_fimp_cross_val{NAME}.pdf', format='pdf', bbox_inches='tight')