<span style="color:red">**WARNING: The results are uploaded to Weights and Biases, so an API key will be requested.**</span>

## Imports

In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
np.int = int
import pandas as pd
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier
from sklearn.impute import SimpleImputer
from sklearn.multioutput import MultiOutputClassifier
import yaml
import wandb 
import random
import warnings
# import BaseEstimator
from sklearn.base import BaseEstimator
from sklearn.metrics import roc_auc_score
import pandas as pd
warnings.filterwarnings("ignore")

# Python scripts
from PTBXLModel import PTBXLModel
    
ID = "ecg_id"
GROUP = "patient_id"
FOLD = "strat_fold"
RANDOM_STATE = 2024 
N_JOBS = 8
HYPERPARAM_ITERATIONS = 2
FEATURE_METHOD = "12sl" # Choose between "12sl", "unig", "ecgdeli" and "3dfmmecg"
PROBLEM = "rhythm" # Choose between "all", "diag", "superdiag", "subdiag", "form" and "rhythm"

## Dataset loading

In [2]:
X = pd.read_csv("../data/3dfmmecg_features.csv").sort_values(by="ecg_id")
train_idx = X.loc[X[FOLD] <= 8, "ecg_id"].values
val_idx = X.loc[X[FOLD] == 9, "ecg_id"].values
test_idx = X.loc[X[FOLD] == 10, "ecg_id"].values
extra_ids = X["ecg_id"]
print(X.shape)

DATASET = "../data/" + FEATURE_METHOD + "_features.csv"
X = pd.read_csv(DATASET).sort_values(by="ecg_id")
X = X[["ecg_id"] + [col for col in X.columns if col != "ecg_id"]]

# Fill missing ECG with NaN
new_ids = extra_ids[~extra_ids.isin(X["ecg_id"])]
new_rows = pd.DataFrame(new_ids, columns=["ecg_id"])
new_rows = new_rows.reindex(columns=X.columns, fill_value=pd.NA)
X = pd.concat([X, new_rows], ignore_index=True)

print(X.shape)
X.head()

(21799, 318)
(21799, 783)


Unnamed: 0,ecg_id,P_Area_I,P_PeakTime_I,Q_Area_I,Q_PeakTime_I,R_Area_I,R_PeakTime_I,S_Area_I,S_PeakTime_I,QRS_Balance_I,...,T+_Dur_aVF,T-_Dur_aVF,T+_Amp_aVF,T-_Amp_aVF,T_Morph_aVF,T_DurFull_aVF,P_Dur_Global,P_Found_Global,HR__Global,P_Term_V1
0,1,0.264,64.0,0.0,0.0,0.737,30.0,0.0,0.0,629.0,...,208.0,0.0,0.151,0.0,1,208.0,112.0,1,64.0,0.0
1,2,0.256,48.0,0.0,0.0,0.702,38.0,0.12,56.0,420.0,...,166.0,0.0,0.249,0.0,1,166.0,108.0,1,47.0,4.002
2,3,0.223,44.0,0.013,8.0,0.913,40.0,0.0,0.0,835.0,...,240.0,0.0,0.063,0.0,1,240.0,92.0,1,64.0,0.0
3,4,0.329,66.0,0.0,0.0,0.576,36.0,0.508,62.0,25.0,...,161.0,0.0,0.4,0.0,1,161.0,114.0,1,75.0,2.88
4,5,0.13,50.0,0.0,0.0,0.503,34.0,0.17,54.0,318.0,...,194.0,0.0,0.38,0.0,1,194.0,114.0,1,66.0,0.0


### Predictors

In [3]:
X_train = X.loc[X["ecg_id"].isin(train_idx),:]
X_train = X_train.reset_index(drop=True)
X_train = X_train.loc[:, ~X.columns.isin([ID, GROUP, FOLD])]
print(X_train.shape)

X_val = X.loc[X["ecg_id"].isin(val_idx),:]
X_val = X_val.reset_index(drop=True)
X_val = X_val.loc[:, ~X.columns.isin([ID, GROUP, FOLD])]
print(X_val.shape)

X_test = X.loc[X["ecg_id"].isin(test_idx),:]
X_test = X_test.reset_index(drop=True)
X_test = X_test.loc[:, ~X.columns.isin([ID, GROUP, FOLD])]
print(X_test.shape)

X_test.head()

(17418, 782)
(2183, 782)
(2198, 782)


Unnamed: 0,P_Area_I,P_PeakTime_I,Q_Area_I,Q_PeakTime_I,R_Area_I,R_PeakTime_I,S_Area_I,S_PeakTime_I,QRS_Balance_I,T_Area_I,...,T+_Dur_aVF,T-_Dur_aVF,T+_Amp_aVF,T-_Amp_aVF,T_Morph_aVF,T_DurFull_aVF,P_Dur_Global,P_Found_Global,HR__Global,P_Term_V1
0,0.227,44.0,0.0,0.0,0.336,34.0,0.105,54.0,254.0,1.346,...,176.0,0.0,0.19,0.0,1,176.0,124.0,1,61.0,0.0
1,0.147,92.0,0.0,0.0,0.255,36.0,0.08,62.0,146.0,0.525,...,180.0,0.0,0.214,0.0,1,180.0,122.0,1,73.0,0.0
2,0.243,70.0,0.009,8.0,1.817,48.0,0.0,0.0,1186.0,0.939,...,152.0,0.0,0.136,0.0,1,152.0,116.0,1,64.0,1.104
3,0.331,64.0,0.054,12.0,0.988,38.0,0.0,0.0,718.0,2.002,...,206.0,0.0,0.19,0.0,1,206.0,110.0,1,65.0,0.0
4,0.128,54.0,0.0,0.0,0.36,34.0,0.369,66.0,98.0,0.918,...,168.0,0.0,0.156,0.0,1,168.0,112.0,1,71.0,0.0


### Labels

In [4]:
# Load labels
y = pd.read_csv(f"../data/y_{PROBLEM}.csv")
lh = y.copy().fillna(0)
# Because of versioning
y.iloc[:, 1:] = y.iloc[:, 1:].map(lambda x: 1 if not pd.isna(x) else 0)
y = y.reset_index(drop=True)

y_train = (
    y.loc[y["ecg_id"].isin(train_idx), :]
    .drop(columns="ecg_id")
    .reset_index(drop=True)
)
lh_train = (
    lh.loc[lh["ecg_id"].isin(train_idx), :]
    .drop(columns="ecg_id")
    .reset_index(drop=True)
)

y_val = (
    y.loc[y["ecg_id"].isin(val_idx), :]
    .drop(columns="ecg_id")
    .reset_index(drop=True)
)

lh_val = (
    lh.loc[lh["ecg_id"].isin(val_idx), :]
    .drop(columns="ecg_id")
    .reset_index(drop=True)
)

y_test = ( 
    y.loc[y["ecg_id"].isin(test_idx), :]
    .drop(columns="ecg_id")
    .reset_index(drop=True)
)
lh_test = (
    lh.loc[lh["ecg_id"].isin(test_idx), :]
    .drop(columns="ecg_id")
    .reset_index(drop=True)
)

### Models

In [5]:
def replace_none_strings(d):
    if isinstance(d, dict):
        return {k: replace_none_strings(v) for k, v in d.items()}
    elif isinstance( d, list):
        return [replace_none_strings(i) for i in d]
    elif d == "None":
        return None
    else:
        return d

pipe = Pipeline(
    steps=[
        ("encoder", SimpleImputer()),
        (
            "model",
            MultiOutputClassifier(
                RandomForestClassifier(verbose=0, random_state=RANDOM_STATE),
               n_jobs=N_JOBS,
            ),
        ),
    ]
)
# Load the sweep configuration from the YAML file
with open("../models/sweep.yaml", "r") as file:
    sweep_config = yaml.safe_load(file)

# Replace "None" strings with Python None
sweep_config = replace_none_strings(sweep_config)

### Training and WANDB logger

In [None]:
def hyperparameter_tuning(
        estimator: BaseEstimator,
        X_train: pd.DataFrame,
        y_train: pd.DataFrame,
        X_val: pd.DataFrame,
        y_val: pd.DataFrame,
        X_test: pd.DataFrame,
        y_test: pd.DataFrame,
        lh_test: pd.DataFrame,
        name_prefix: str,
    ) -> None:
        
    wandb.init()
    config = wandb.config

    estimator.set_params(
        model__estimator__n_estimators=config.n_estimators,
        model__estimator__max_depth=config.max_depth,
        model__estimator__min_samples_split=config.min_samples_split,
        model__estimator__min_samples_leaf=config.min_samples_leaf,
        model__estimator__max_features=config.max_features,
    )

    ptbxl_model = PTBXLModel(estimator)
    ptbxl_model.fit(X_train, y_train)

    y_pred = ptbxl_model.predict_proba(X_val)
    macro_auc = roc_auc_score(y_val, y_pred)
    del y_pred
    wandb.log({"macro_auc": macro_auc})

    ptbxl_model.save(name_prefix, X_test, y_test, lh_test)
    del ptbxl_model
    

In [None]:
random.seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)

name_prefix = f"{PROBLEM}_{FEATURE_METHOD}"
sweep_config["name"] = name_prefix
sweep_id = wandb.sweep(sweep_config, project="PTB-XL")


wandb.agent(
    sweep_id,
    function=lambda: hyperparameter_tuning(
        estimator=pipe,
        X_train=X_train,
        y_train=y_train,
        X_val=X_val,
        y_val=y_val,
        X_test=X_test,
        y_test=y_test,
        lh_test=lh_test,
        name_prefix=name_prefix,
    ),
    count=HYPERPARAM_ITERATIONS,
)