# Inference Notebook 

### Step 1: Import necessary libraries

In [None]:
from src.models.tabularNN import TabularNN
import os
from sklearn.naive_bayes import GaussianNB
import numpy as np
from src.dataset.eeg_dataset import EEGDataset
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
import pandas as pd
from catboost import CatBoostClassifier
from tabicl import TabICLClassifier
import json
from src.utils import Utils
import timeit
import time
import matplotlib.pyplot as plt
import numpy as np
from src.inference.inference import EEGInference
import matplotlib.gridspec as gridspec
from sklearn.metrics import f1_score
import pandas as pd
from pathlib import Path
import joblib
from lime import lime_tabular
from lime.lime_tabular import LimeTabularExplainer
import pickle

### Step 2: Define the Parameters the same way as they were defined in the training process
and define the paths accordingly

In [None]:
window_sizes = [1, 2, 5, 10, 30, 60]
step_sizes = [1, 2, 5, 10, 30, 60]
strategy = "FeatureBased"  # "rawEEG"
timing = window_sizes[5]
sampling_rate = 128
window_size = sampling_rate * timing
step_size = sampling_rate * timing
preprocessing = True
feature_selection = True
num_selected_features = 50
depth_of_anesthesia = True
random_seed = 42
for_majority = int(window_size / 2)
sampling_rate = 128
base_path = Path.cwd()
data_path = base_path / "src" / "dataset" / "saved_data"
volunteer_number = "5-3"  # "8-2"  # Choose a specific volunteer
case_number = "5_3"  # "8_2"
print(base_path)
data_path = base_path / "EEG_data" / "Session3" / f"Case_{case_number}"
file_path = base_path.parent.parent.parent / "Ablation" / "final"
selected_features_path = file_path / "selected_features"
scaler_path = file_path / "scaler"

### Step 4: Function for defining file paths

If you want to run the NN, you need to change the file extensions to .pt instead of .pkl

In [None]:
def define_file_paths(
    window_size,
    step_size,
    majority,
    data_path,
    file_path,
    model_names,
    volunteer_number,
    case_number,
    number_of_selected_feature=None,
    selected_features_path=None,
    random_seed=42,
):
    """Define file paths for models, data, and other resources.
    Args:
        window_size (int): Size of the window for analysis.
        step_size (int): Step size for moving the window.
        majority (int): Majority voting parameter.
        data_path (Path): Path to the data directory.
        file_path (Path): Path to the model directory.
        model_names (dict): Dictionary containing model names for different tasks.
        volunteer_number (str): Identifier for the volunteer.
        case_number (str): Identifier for the case of that volunteer.
        number_of_selected_feature (int, optional): Number of selected features. Defaults to None.
        selected_features_path (Path, optional): Path to selected features directory. Defaults to None.
        random_seed (int, optional): Random seed for reproducibility. Defaults to 42.
    Returns:
        tuple: Contains paths and filenames for models, data, and other resources.
    """

    csv_file = data_path / f"prop_{case_number}eeg_Fp1Fp2.csv"
    propofol_concentration_blood_plasma = pd.read_csv(
        base_path
        / "EEG_data"
        / "propofol_infusion"
        / f"prop{volunteer_number}Ce_Fp1Fp2.csv"
    )
    csv_files_groundtruth = {
        "sleep": data_path / f"prop{volunteer_number}sleep_Fp1Fp2.csv",
        "bs1sec": data_path / f"prop{volunteer_number}bs1sec_Fp1Fp2.csv",
        "bs3sec": data_path / f"prop{volunteer_number}bs3sec_Fp1Fp2.csv",
        "cr": data_path / f"prop{volunteer_number}cr_Fp1Fp2.csv",
        "sspl": data_path / f"prop{volunteer_number}sspl_Fp1Fp2.csv",
    }
    if number_of_selected_feature is not None:

        model_filenames = {
            "sleep": file_path
            / f"{strategy}_sleep_ws{window_size}_ss{step_size}_majority{majority}_type{model_names['sleep']}_preprocTrue_randomseed{random_seed}_{number_of_selected_feature}.pkl",
            "cr": file_path
            / f"{strategy}_cr_ws{window_size}_ss{step_size}_majority{majority}_type{model_names['cr']}_preprocTrue_randomseed{random_seed}_{number_of_selected_feature}.pkl",
            "sspl": file_path
            / f"{strategy}_sspl_ws{window_size}_ss{step_size}_majority{majority}_type{model_names['sspl']}_preprocTrue_randomseed{random_seed}_{number_of_selected_feature}.pkl",
            "burst_suppression": file_path
            / f"{strategy}_burst_suppression_ws{window_size}_ss{step_size}_majority{majority}_type{model_names['burst_suppression']}_preprocTrue_randomseed{random_seed}_{number_of_selected_feature}.pkl",
        }

        selected_features_filename = {
            "sleep": selected_features_path
            / f"Sleep_FeatureBased_ws{window_size}_ss{step_size}_majority{majority}_preprocTrue_randomseed42_numfeatures{number_of_selected_feature}.json",
            "cr": selected_features_path
            / f"cr_FeatureBased_ws{window_size}_ss{step_size}_majority{majority}_preprocTrue_randomseed42_numfeatures{number_of_selected_feature}.json",
            "sspl": selected_features_path
            / f"sspl_FeatureBased_ws{window_size}_ss{step_size}_majority{majority}_preprocTrue_randomseed42_numfeatures{number_of_selected_feature}.json",
            "burst_suppression": selected_features_path
            / f"burst_suppression_FeatureBased_ws{window_size}_ss{step_size}_majority{majority}_preprocTrue_randomseed42_numfeatures{number_of_selected_feature}.json",
        }
        scaler_filenames = {
            "sleep": scaler_path
            / f"FeatureBased_sleep_ws{window_size}_ss{step_size}_majority{majority}_preprocTrue_randomseed42_numfeatures{number_of_selected_feature}.joblib",
            "cr": scaler_path
            / f"FeatureBased_cr_ws{window_size}_ss{step_size}_majority{majority}_preprocTrue_randomseed42_numfeatures{number_of_selected_feature}.joblib",
            "sspl": scaler_path
            / f"FeatureBased_sspl_ws{window_size}_ss{step_size}_majority{majority}_preprocTrue_randomseed42_numfeatures{number_of_selected_feature}.joblib",
            "burst_suppression": scaler_path
            / f"FeatureBased_burst_suppression_ws{window_size}_ss{step_size}_majority{majority}_preprocTrue_randomseed42_numfeatures{number_of_selected_feature}.joblib",
        }

        return (
            file_path,
            csv_files_groundtruth,
            model_filenames,
            model_names,
            selected_features_filename,
            scaler_filenames,
            csv_file,
            propofol_concentration_blood_plasma,
        )
    else:
        model_filenames = {
            "sleep": file_path
            / f"{strategy}_sleep_ws{window_size}_ss{step_size}_majority{majority}_type{model_names['sleep']}_preprocTrue_randomseed42_all.pt",
            "cr": file_path
            / f"{strategy}_cr_ws{window_size}_ss{step_size}_majority{majority}_type{model_names['cr']}_preprocTrue_randomseed42_all.pt",
            "sspl": file_path
            / f"{strategy}_sspl_ws{window_size}_ss{step_size}_majority{majority}_type{model_names['sspl']}_preprocTrue_randomseed42_all.pt",
            "burst_suppression": file_path
            / f"{strategy}_burst_suppression_ws{window_size}_ss{step_size}_majority{majority}_type{model_names['burst_suppression']}_preprocTrue_randomseed42_all.pt",
        }

        return (
            file_path,
            csv_files_groundtruth,
            model_filenames,
            model_names,
            None,
            None,
            csv_file,
            propofol_concentration_blood_plasma,
        )

### Step 5: Define which model you would like to run the inference for

In [None]:
model_names = {
    "sleep": "RandomForestClassifier",
    "cr": "RandomForestClassifier",
    "sspl": "RandomForestClassifier",
    "burst_suppression": "RandomForestClassifier",
}

""" 
Further Examples:
model_names = {
    "sleep": "TabICLClassifier",
    "cr":  "TabICLClassifier",
    "sspl": "TabICLClassifier",
    "burst_suppression":  "TabICLClassifier",
}

model_names = {
    "sleep": "TabularNN",
    "cr":  "TabularNN",
    "sspl": "TabularNN",
    "burst_suppression":"TabularNN",
}
"""

### Step 6: Run the Inference:

- If you don't have any ground truths, you can run the inference, by putting inference_mode to True

In [None]:
start = time.perf_counter()
(
    file_path,
    csv_files_groundtruth,
    model_filenames,
    model_names,
    selected_features_filename,
    scaler_filenames,
    csv_file,
    propofol_concentration_blood_plasma,
) = define_file_paths(
    window_size=window_size,
    step_size=step_size,
    majority=for_majority,
    number_of_selected_feature=num_selected_features,
    data_path=data_path,
    file_path=file_path,
    model_names=model_names,
    selected_features_path=selected_features_path,
    volunteer_number=volunteer_number,
    case_number=case_number,
    random_seed=random_seed,
)
inference = EEGInference(
    csv_file=csv_file,
    sampling_rate=sampling_rate,
    window_size=window_size,
    step_size=step_size,
    for_majority=for_majority,
    preprocessing=preprocessing,
    majority_voting=True,
    strategy=strategy,
    random_seed=random_seed,
    feature_selection=feature_selection,
    scaler_applicable=True,
    device="cpu",
    inference_mode=False,
    depth_of_anesthesia=True,
    visualize_results=True,
    csv_files_groundtruth=csv_files_groundtruth,
    model_filenames=model_filenames,
    model_names=model_names,
    selected_features_filename=selected_features_filename,
    scaler_filenames=scaler_filenames,
    propofol_concentration_blood_plasma=propofol_concentration_blood_plasma,
    zoomed_in=True,
)
end = time.perf_counter()
print(f"Final Execution time: {end - start:.4f} seconds")

### Step 7: If you want to run local error analysis with LIME:

### Step 7.1 Define/adjust your file paths



In [None]:
label = "burst_suppression"
model_name = "TabICLClassifier"
sample = 156 #Which window segment to choose.
features = inference.features_processed
predictions = inference.prediction
utils = Utils(
    for_majority=for_majority,
    window_size=window_size,
    step_size=step_size,
    random_seed=random_seed,
    preprocessing=preprocessing,
    sampling_rate=sampling_rate,
    results_validation_csv_path=base_path
    / "doA_classification"
    / "ml_models"
    / f"{strategy}_validation_results_df.csv",
    results_test_csv_path=base_path
    / "doA_classification"
    / "ml_models"
    / f"{strategy}_test_results_df.csv",
    model_dir=base_path / "doA_classification" / "ml_models",
)

base_path = Path.cwd()
my_path = base_path.parent.parent.parent / "Ablation" / "final"
model_path = (
    my_path
    / f"FeatureBased_{label}_ws7680_ss7680_majority3840_typeTabICLClassifier_preprocTrue_randomseed42_50.pkl"
)
selected_features_path = (
    my_path
    / f"selected_features/{label}_FeatureBased_ws{window_size}_ss{step_size}_majority{for_majority}_preprocTrue_randomseed42_numfeatures50.json"
)
scaler_filename = (
    my_path
    / "scaler"
    / f"FeatureBased_{label}_ws{window_size}_ss{step_size}_majority{for_majority}_preprocTrue_randomseed42_numfeatures50.joblib"
)
training_data = pd.read_csv(
    base_path / "EEG_data" / "dataset" / f"training_data_{window_size}_{step_size}.csv"
)

### 7.2 Run LIME

In [None]:
if not os.path.exists(model_path):
    raise FileNotFoundError(f"Model file for {model_path} not found.")

with open(model_path, "rb") as f:
    model = pickle.load(f)

with open(
    selected_features_path,
    "r",
) as f:
    selected_feature_names = json.load(f)
scaler = joblib.load(scaler_filename)

X_train = training_data.loc[:, selected_feature_names]
X_train = scaler.transform(X_train)
pred = predictions.loc[:, selected_feature_names]
X_test = scaler.transform(pred)
X_train = np.nan_to_num(X_train, nan=0)

np.random.seed(42)


def predict_proba_lime(input_numpy, model):
    """Predict function for LIME that returns class probabilities.
    Args:
        input_numpy (np.ndarray): Input data for prediction.
        model: Trained model
    """
 
    if input_numpy.ndim == 1:
        input_numpy = input_numpy.reshape(1, -1)

    probs = model.predict_proba(input_numpy)

    if probs.shape[1] == 1:
        probs = np.hstack([1 - probs, probs])  

    return probs


pretty_feature_names = [
    utils.feature_name_map.get(f, f) for f in selected_feature_names
]
explainer = LimeTabularExplainer(
    training_data=X_train, 
    mode="classification",
    feature_names=pretty_feature_names,
    class_names=["no Burst Suppression", "Burst Suppression"],
    discretize_continuous=True,
)
features_new = pd.DataFrame(X_test)
exp = explainer.explain_instance(
    data_row=features_new.iloc[sample].values,
    predict_fn=lambda x: predict_proba_lime(x, model),
    num_features=7,  # Number of features to display in explanation
)

exp.show_in_notebook(show_table=True)
