# Random Forest Baseline Evaluation

In [1]:
import numpy as np
import pandas as pd
import statistics
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay, precision_score, recall_score, f1_score, roc_auc_score, roc_curve, auc, RocCurveDisplay, classification_report
from sklearn.preprocessing import LabelBinarizer
from scipy.stats import mode
import matplotlib.pyplot as plt
from itertools import cycle

pd.options.mode.copy_on_write = True

### Load tabular data

In [2]:
synthetic_data = pd.read_csv('../../data/syn_data.csv', index_col=0)
synthetic_data.head()

Unnamed: 0,hospital_stay_length,gcs,nb_acte,gender,entry,outcome,entry_code,ica,ttt,ica_therapy,...,ivh,age,nimodipine,paracetamol,nad,corotrop,morphine,dve,atl,iot
0,41.089445,17.086233,34.307297,0,0,0.0,0,0,0,0,...,0,38.712762,-1,-1,-1,-1,-1,-1,-1,25
1,21.702298,18.805639,133.523169,0,1,1.0,2,2,1,0,...,0,58.565461,89,58,26,-1,-1,116,-1,-1
2,4.627752,19.516216,85.648533,0,2,0.0,1,2,2,0,...,1,76.432889,12,-1,-1,-1,-1,40,-1,-1
3,12.830087,19.940518,17.982208,1,1,2.0,3,4,2,0,...,1,87.351874,29,-1,-1,-1,-1,-1,-1,53
4,75.675201,21.665547,132.859962,0,3,1.0,4,5,2,0,...,0,75.440254,26,-1,-1,-1,-1,79,-1,52


### Encode events

In [3]:
events = [
    "nimodipine",
    "paracetamol",
    "nad",
    "corotrop",
    "morphine",
    "dve",
    "atl",
    "iot",
]

# for each row, get pairs of ordered events
def get_sequence_pairs(row):
    pairs = []
    
    # get a subset of the columns that are events
    sub_df = row[events].copy()

    # sort the events in the row
    sub_df = sub_df.sort_values(ascending=True)
    # print(sub_df)

    # iterate over the sorted events and get pairs
    for i in range(len(sub_df) - 1):
        e1 = sub_df.index[i]
        e2 = sub_df.index[i+1]
        if sub_df[e1] > 0 and sub_df[e2] > 0: 
            # print(sub_df[e1])
            pairs.append((e1, e2))
            row[e1+"_before_"+e2] = 1

    return row

synthetic_data = synthetic_data.apply(get_sequence_pairs, axis=1)

# drop the events columns
synthetic_data = synthetic_data.drop(columns=events)
# replace all NaN with 0
synthetic_data = synthetic_data.replace(
    np.nan, 0
)

In [5]:
X = synthetic_data.drop(['outcome'], axis=1).to_numpy()
y = synthetic_data['outcome'].to_numpy()

### Train and evaluation

In [6]:
def k_fold(X, y, folds):
    skf = StratifiedKFold(folds, shuffle=True, random_state=77)
    train_indices, val_indices, test_indices,  = [], [], []
    for (non_test_idx, test_idx) in skf.split(X, y):
        test_indices.append(test_idx)
        train_idx, val_idx, _, _ = train_test_split(non_test_idx, y[non_test_idx], test_size=1/9, random_state=77)
        train_indices.append(train_idx)
        val_indices.append(val_idx)
    return train_indices, val_indices, test_indices


accs, aucs = [], []
for fold, (train_idx, val_idx, test_idx) in enumerate(zip(*k_fold(X, y, 10))):

    X_train, X_val, X_test = X[train_idx], X[val_idx], X[test_idx]
    y_train, y_val, y_test = y[train_idx], y[val_idx], y[test_idx]

    model = RandomForestClassifier()
    
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    y_prob = model.predict_proba(X_test)

    accuracy = accuracy_score(y_test, y_pred)
    auc_score = roc_auc_score(y_test, y_prob, average='weighted', multi_class='ovr')

    accs.append(accuracy)
    aucs.append(auc_score)
    
    print(f"FOLD {fold}")
    print(f"Accuracy: {accuracy} | AUC: {auc_score}")
    

mean_acc = statistics.mean(accs) 
mean_auc = statistics.mean(aucs)
print(f"AVG. accuracy: {mean_acc} | AVG. AUC: {mean_auc}")

FOLD 0
Accuracy: 0.564 | AUC: 0.7095702968652217
FOLD 1
Accuracy: 0.568 | AUC: 0.7190451342094784
FOLD 2
Accuracy: 0.574 | AUC: 0.7258405265972577
FOLD 3
Accuracy: 0.546 | AUC: 0.7012839660920522
FOLD 4
Accuracy: 0.566 | AUC: 0.7154155565799967
FOLD 5
Accuracy: 0.582 | AUC: 0.7072345180629217
FOLD 6
Accuracy: 0.553 | AUC: 0.6940204711051953
FOLD 7
Accuracy: 0.564 | AUC: 0.6973973152585963
FOLD 8
Accuracy: 0.554 | AUC: 0.6977965539177422
FOLD 9
Accuracy: 0.546 | AUC: 0.6950150900402395
AVG. accuracy: 0.5617 | AVG. AUC: 0.7062619428728701
