In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os
sys.path.insert(0, os.path.join(os.path.abspath('.'),'..', 'src'))
import tree_utils, ctree

import numpy as np
import scipy as sc
from scipy import stats
import pandas as pd
import pickle
import json
import tqdm

from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import LabelEncoder, LabelBinarizer
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import RepeatedStratifiedKFold

from sklearn.metrics import accuracy_score
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import SimpleImputer, IterativeImputer, KNNImputer


In [2]:
morphology_categories = {
    'only positive, no notch/acc': ['R'], 
    'only negative, no notch/acc': ['S'],
    'both positive and negative, no notch/acc': ['Q.R', 'Q.R.S', 'R.S'],
    'only positive with notch/accent': [
        'R.R_acc', 'R.Rn', 'R.Rn.R_acc', 'Rn.R', 'Rn.R.R_acc', 'Rn.R.Rn'
    ],
    'only negative with notch/accent': [
        'S.Sn', 'Sn.S', 'Sn.S.Sn'
    ],
    'both positive and negative with notch/accent': [
        'Q.R.R_acc', 'Q.R.R_acc.S', 'Q.R.Rn', 'Q.R.Rn.S', 'Q.R.S.R_acc',
        'Q.R.S.R_acc.S_acc', 'Q.R.S.Sn', 'Q.Rn.R', 'Q.Rn.R.S', 'R.R_acc.S',
        'R.R_acc.S.S_acc', 'R.Rn.S', 'R.S.R_acc', 'R.S.R_acc.S_acc', 'R.S.Rn',
        'R.S.Rn.Sn', 'R.S.S_acc', 'R.S.Sn', 'R.Sn.S', 'R.Sn.S.R_acc',
        'R.Sn.S.Sn', 'Rn.R.R_acc.S', 'Rn.R.Rn.S', 'Rn.R.S', 'Rn.R.S.R_acc',
        'Rn.R.S.R_acc.S_acc', 'Rn.R.S.Rn'
    ],
    'none': ['none']
}
inv_morpho_map = {_v:k  for k,v in morphology_categories.items() for _v in v}


In [3]:
data_dir = r'J:\Onderzoek\21-763_rvanes_MiniECG-2-Data\E_ResearchData\2_ResearchData\Parquet'
output_dir = r'J:\Onderzoek\21-763_rvanes_MiniECG-2-Data\G_Output\2_Data\CustomTree'

In [4]:
NameMap = pd.read_parquet(os.path.join(data_dir, '..', 'Name_toSimpleName.parquet'))

In [5]:
NameMapDict = {k:v for k,v in zip(NameMap['Old_Name'].values, NameMap['New_Name'].values)}

In [6]:
MIN_MORPHO_PRESENCE = 0.05 # %
MULTI_CLASS = False
num_splits = 10
num_repeats = 10
MISSINGNESS_INDICATOR = False
MORPHO_MAP = True

MULTI_CLASS_STRING = "_MultiClass" if MULTI_CLASS else ""
MISSINGNESS_INDICATOR_STRING = "_Missing" if MISSINGNESS_INDICATOR else ""
MORPHO_MAP_STRING = "_MorphoMap" if MORPHO_MAP else ""

In [7]:
DATA = pd.read_parquet(os.path.join(data_dir, f'DATA.parquet'))

DATA.columns = [NameMapDict[c] for c in DATA.columns]

morphology_columns = [c for c in DATA.columns if 'morphology' in c.lower()]
lead_columns = [c for c in DATA.columns if ('lead' in c.lower()) & ('morphology' not in c.lower())]

for c in morphology_columns:
    DATA.loc[:, c] = DATA[c].apply(lambda x: x[0].strip(",").strip(" "))
    DATA.loc[:, c] = DATA[c].apply(lambda x: x if x.strip()!="" else "none")
    

In [8]:
if MORPHO_MAP:
    for c in morphology_columns:
        DATA.loc[:, c] = DATA[c].map(inv_morpho_map)

In [9]:
vocab = set()
for lOl in [DATA[c].str.split(".").values for c in morphology_columns]:
    for l in lOl:
        for _s in l:
            vocab.add(_s)
Vocab = {k:v for k,v in enumerate(vocab)}

In [10]:
OneHot = OneHotEncoder(drop=None, 
                       sparse_output=False, 
                       min_frequency=MIN_MORPHO_PRESENCE,
                       handle_unknown='infrequent_if_exist')

MorphologyOneHot = pd.DataFrame(data=OneHot.fit_transform(DATA[morphology_columns]), 
                            columns=OneHot.get_feature_names_out(morphology_columns),
                            index=DATA.index)

In [11]:
DATA = DATA.drop(morphology_columns, axis=1)
DATA = pd.concat([DATA, MorphologyOneHot], axis=1)

In [12]:
DATA = DATA.assign(Diagnosis=DATA.Diagnosis.map({
                                                                'SR': 'SR',
                                                                'BF': 'BF',
                                                                'RBBB': 'RBBB',
                                                                'LBBB': 'LBBB',
                                                                'LAFB': 'LAFB',
                                                                'LAFB , LVH': 'LAFB',
                                                                'Microvoltages , BF': 'BF',
                                                                'Microvoltages , RBBB': 'RBBB',
                                                                'Microvoltages , LAFB': 'LAFB', 
                                                                'LVH , BF': 'BF',
                                                                'LVH , RBBB': 'RBBB',
                                                                'LVH , LBBB': 'LBBB'
                                                            }))
Reduction_map = {'BF': 'Disease', 
                 'LBBB': 'Disease', 
                 'RBBB': 'Disease',
                 'LAFB': 'Disease',
                 'SR': 'Normal'}

if MULTI_CLASS==False:
    DATA = DATA.assign(Diagnosis=DATA.Diagnosis.map(Reduction_map))
DATA = DATA.dropna(subset=['Diagnosis'])


In [13]:
Infreq_cat_dict = {morphology_columns[k]:list(inf_cats) for k, inf_cats in enumerate(OneHot.infrequent_categories_)}

In [14]:
json.dump(Infreq_cat_dict, open(os.path.join(output_dir, 'infrequent_categories_map.json'), 'w'))

# Make tree

In [15]:
rules_path = r'T:\laupodteam\AIOS\Bram\notebooks\code_dev\miniECG_interpretation\TreeBuilder\assets\conduction_tree.json'

TreeKwargs = {
    'criterion':'gini', 
    'splitter':'best', 
    'max_depth':5, 
    'min_samples_split':10, 
    'min_samples_leaf': 5, 
    'min_weight_fraction_leaf':0.05, 
    'max_features':None, 
    'random_state':7, 
    'max_leaf_nodes':50,
    'class_weight': 'balanced'
}

In [16]:
rules_loader = ctree.LoadRules(rules_path, name_map=NameMapDict)
processed_rules = rules_loader.get_processed_rules()

In [17]:
processed_rules.features_to_use_next

[]

In [18]:
SplitColumn = rules_loader.fold_split_col
TargetCol = rules_loader.target_col
IgnoreCols = rules_loader.ignore_cols
FeaturesToUse = rules_loader.features_to_use

In [19]:
if len(FeaturesToUse)>0:
    keep_columns = set(FeaturesToUse).difference(set(IgnoreCols))
else:
    keep_columns = [c for c in DATA.columns if c not in IgnoreCols]

Splitter = RepeatedStratifiedKFold(n_splits=num_splits, 
                                   n_repeats=num_repeats, 
                                   random_state=7)

In [20]:
X = DATA[[c for c in keep_columns if c not in [SplitColumn, TargetCol]]]
Y = DATA[TargetCol]

In [21]:
lb = LabelBinarizer()
lbe = LabelEncoder()
lbe.fit(Y)
TargetMap = {k:v for k,v in enumerate(lbe.classes_)}

In [22]:
TargetMap.values()

dict_values(['Disease', 'Normal'])

In [23]:
X.to_parquet(os.path.join(output_dir, f'data{MULTI_CLASS_STRING}{MISSINGNESS_INDICATOR_STRING}{MORPHO_MAP_STRING}.parquet'))

In [None]:
results_list = []
for i, (train_index, test_index) in tqdm.tqdm(enumerate(Splitter.split(X, Y)),
                                              total=num_splits * num_repeats):
    result_df = pd.DataFrame()
    
    X_train, X_test = X.iloc[train_index], X.iloc[test_index]
    Y_train, Y_test = Y.iloc[train_index], Y.iloc[test_index]
    
    y_train_encoded = lbe.transform(Y_train)
    y_test_encoded = lbe.transform(Y_test)
        
    Imputer = IterativeImputer(LinearRegression(), 
                           add_indicator=MISSINGNESS_INDICATOR, max_iter=25, verbose=0)
    Imputer.fit(X_train)
    
    X_train_imputed = Imputer.transform(X_train)
    X_test_imputed = Imputer.transform(X_test)
    
    X_train_imputed = pd.DataFrame(data=X_train_imputed,
                                   columns=Imputer.get_feature_names_out())
    
    X_test_imputed = pd.DataFrame(data=X_test_imputed,
                                   columns=Imputer.get_feature_names_out())
    
    
    clf = ctree.CustomDecisionTreeV2(custom_rules=processed_rules,
                                 prune_threshold=None,
                                 Tree_kwargs=TreeKwargs,
                                 TargetMap = TargetMap, 
                                 tot_max_depth=5)

    clf_base = DecisionTreeClassifier(**TreeKwargs)
    
    print("Training classifiers...")
    clf_base.fit(X_train_imputed, y_train_encoded)
    clf.fit(X_train_imputed, y_train_encoded)
    enriched_rules = clf.get_enriched_rules()
    final_tree = clf.get_custom_rules_model()
    
    Fold = i % num_splits
    Repeat = i // num_splits
    
    
    json.dump(final_tree, 
              open(os.path.join(output_dir, f"tree_Fold{Fold}_{Repeat}{MULTI_CLASS_STRING}{MISSINGNESS_INDICATOR_STRING}{MORPHO_MAP_STRING}.json"), mode='w'))
    ctree.update_html(tree=final_tree, 
                      html_path="../src/treeTemplate.html", 
                      output_path=os.path.join(output_dir, f"tree_Fold{Fold}_{Repeat}{MULTI_CLASS_STRING}{MISSINGNESS_INDICATOR_STRING}{MORPHO_MAP_STRING}.html"))
    
    sklearn_tree = clf.load_from_sklearn_tree(clf_base, X_train_imputed, y_train_encoded)
    final_tree_sklearn = sklearn_tree.get_custom_rules_model()
    json.dump(final_tree_sklearn, 
              open(os.path.join(output_dir, f"sklearn_tree_Fold{Fold}_{Repeat}{MULTI_CLASS_STRING}{MISSINGNESS_INDICATOR_STRING}{MORPHO_MAP_STRING}.json"), mode='w'))
    ctree.update_html(tree=final_tree_sklearn, 
                     html_path="../src/treeTemplate.html", 
                     output_path=os.path.join(output_dir, f"sklearn_tree_Fold{Fold}_{Repeat}{MULTI_CLASS_STRING}{MISSINGNESS_INDICATOR_STRING}{MORPHO_MAP_STRING}.html"))
    
    cust_probas_train = clf.predict_proba(X_train_imputed)
    cust_probas_test = clf.predict_proba(X_test_imputed)
    
    base_probas_train = clf_base.predict_proba(X_train_imputed)
    base_probas_test = clf_base.predict_proba(X_test_imputed)
    
    result_df['indices'] = np.hstack([train_index, test_index])
    result_df['Fold'] = Fold
    result_df['Repeat'] = Repeat
    result_df['Y_true'] = np.hstack([Y_train.values, Y_test.values])
    result_df[[f'Y_pred_normal_DT_{cname}' for cname in TargetMap.values()]] = np.vstack([base_probas_train, base_probas_test])
    result_df[[f'Y_pred_custom_DT_{cname}' for cname in TargetMap.values()]] = np.vstack([cust_probas_train, cust_probas_test])
    result_df['Dataset'] = ['train' for _ in train_index]+['test' for _ in test_index]
    
    results_list.append(result_df)
    
Final_results = pd.concat(results_list, axis=0)
Final_results.to_csv(os.path.join(output_dir, f"results{MULTI_CLASS_STRING}{MISSINGNESS_INDICATOR_STRING}{MORPHO_MAP_STRING}.csv"), index=False, sep=";")
Final_results.to_parquet(os.path.join(output_dir, f"results{MULTI_CLASS_STRING}{MISSINGNESS_INDICATOR_STRING}{MORPHO_MAP_STRING}.parquet"))


  0%|          | 0/100 [00:00<?, ?it/s]INFO:ctree:Starting fit method
INFO:ctree:Number of features: 82
INFO:ctree:Number of classes: 2
INFO:ctree:Features to consider: ['QRS vector Lead A1', 'P vector Lead A1', 'T vector Lead A1', 'QRS amplitude Lead A1', 'QRS vector Lead L1', 'P vector Lead L1', 'T vector Lead L1', 'QRS amplitude Lead L1', 'QRS vector Lead S1', 'P vector Lead S1', 'T vector Lead S1', 'QRS amplitude Lead S1', 'QRS vector Lead A2', 'P vector Lead A2', 'T vector Lead A2', 'QRS amplitude Lead A2', 'QRS vector Lead L2', 'P vector Lead L2', 'T vector Lead L2', 'QRS amplitude Lead L2', 'QRS vector Lead I1', 'P vector Lead I1', 'T vector Lead I1', 'QRS amplitude Lead I1', 'QRS vector Lead I2', 'P vector Lead I2', 'T vector Lead I2', 'QRS amplitude Lead I2', 'QRS vector Lead I3', 'P vector Lead I3', 'T vector Lead I3', 'QRS amplitude Lead I3', 'QQ intervals', 'Heart rate', 'PQ duration', 'QT duration', 'QTc duration', 'T duration', 'P duration', 'QRS duration', 'Morphology Le

Training classifiers...


INFO:ctree:Tree building completed
INFO:ctree:Tree enrichment completed
  1%|          | 1/100 [00:22<37:41, 22.84s/it]INFO:ctree:Starting fit method
INFO:ctree:Number of features: 82
INFO:ctree:Number of classes: 2
INFO:ctree:Features to consider: ['QRS vector Lead A1', 'P vector Lead A1', 'T vector Lead A1', 'QRS amplitude Lead A1', 'QRS vector Lead L1', 'P vector Lead L1', 'T vector Lead L1', 'QRS amplitude Lead L1', 'QRS vector Lead S1', 'P vector Lead S1', 'T vector Lead S1', 'QRS amplitude Lead S1', 'QRS vector Lead A2', 'P vector Lead A2', 'T vector Lead A2', 'QRS amplitude Lead A2', 'QRS vector Lead L2', 'P vector Lead L2', 'T vector Lead L2', 'QRS amplitude Lead L2', 'QRS vector Lead I1', 'P vector Lead I1', 'T vector Lead I1', 'QRS amplitude Lead I1', 'QRS vector Lead I2', 'P vector Lead I2', 'T vector Lead I2', 'QRS amplitude Lead I2', 'QRS vector Lead I3', 'P vector Lead I3', 'T vector Lead I3', 'QRS amplitude Lead I3', 'QQ intervals', 'Heart rate', 'PQ duration', 'QT durat

Training classifiers...


INFO:ctree:Tree building completed
INFO:ctree:Tree enrichment completed
  2%|▏         | 2/100 [00:48<39:49, 24.39s/it]