# Inference Notebook 

### Step 1: Import necessary libraries

In [None]:
from src.models.tabularNN import TabularNN
import os
from sklearn.naive_bayes import GaussianNB
import numpy as np
from src.dataset.eeg_dataset import EEGDataset
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
import pandas as pd
from catboost import CatBoostClassifier
from tabicl import TabICLClassifier
import json
from src.utils import Utils
import timeit
import time
import matplotlib.pyplot as plt
import numpy as np
from src.inference.inference import EEGInference
import matplotlib.gridspec as gridspec
from sklearn.metrics import f1_score
import pandas as pd
from pathlib import Path
import joblib
from lime import lime_tabular
from lime.lime_tabular import LimeTabularExplainer
import pickle

### Step 2: Define the Parameters the same way as they were defined in the training process
and define the paths accordingly

In [None]:
window_sizes = [1, 2, 5, 10, 30, 60]
step_sizes = [1, 2, 5, 10, 30, 60]
strategy = "FeatureBased"  # "rawEEG"
timing = window_sizes[5]
sampling_rate = 128
window_size = sampling_rate * timing
step_size = sampling_rate * timing
preprocessing = True
feature_selection = True
num_selected_features = 50
depth_of_anesthesia = True
random_seed = 42
for_majority = int(window_size / 2)
sampling_rate = 128
base_path = Path.cwd()
data_path = base_path / "src" / "dataset" / "saved_data"
volunteer_number = "5-3"  #Choose a specific volunteer
case_number = "5_3"  
print(base_path)
data_path = base_path / "EEG_data" / "Session3" / f"Case_{case_number}"
file_path = base_path.parent.parent.parent / "Ablation" / "final"
selected_features_path = file_path / "selected_features"
scaler_path = file_path / "scaler"

### Step 4: Function for defining file paths

If you want to run the NN, you need to change the file extensions to .pt instead of .pkl

In [None]:
def define_file_paths(
    window_size,
    step_size,
    majority,
    data_path,
    file_path,
    model_names,
    volunteer_number,
    case_number,
    number_of_selected_feature=None,
    selected_features_path=None,
    random_seed=42,
):
    """Define file paths for models, data, and other resources.
    Args:
        window_size (int): Size of the window for analysis.
        step_size (int): Step size for moving the window.
        majority (int): Majority voting parameter.
        data_path (Path): Path to the data directory.
        file_path (Path): Path to the model directory.
        model_names (dict): Dictionary containing model names for different tasks.
        volunteer_number (str): Identifier for the volunteer.
        case_number (str): Identifier for the case of that volunteer.
        number_of_selected_feature (int, optional): Number of selected features. Defaults to None.
        selected_features_path (Path, optional): Path to selected features directory. Defaults to None.
        random_seed (int, optional): Random seed for reproducibility. Defaults to 42.
    Returns:
        tuple: Contains paths and filenames for models, data, and other resources.
    """

    csv_file = data_path / f"prop_{case_number}eeg_Fp1Fp2.csv"
    propofol_concentration_blood_plasma = pd.read_csv(
        base_path
        / "EEG_data"
        / "propofol_infusion"
        / f"prop{volunteer_number}Ce_Fp1Fp2.csv"
    )
    csv_files_groundtruth = {
        "sleep": data_path / f"prop{volunteer_number}sleep_Fp1Fp2.csv",
        "bs1sec": data_path / f"prop{volunteer_number}bs1sec_Fp1Fp2.csv",
        "bs3sec": data_path / f"prop{volunteer_number}bs3sec_Fp1Fp2.csv",
        "cr": data_path / f"prop{volunteer_number}cr_Fp1Fp2.csv",
        "sspl": data_path / f"prop{volunteer_number}sspl_Fp1Fp2.csv",
    }
    if number_of_selected_feature is not None:

        model_filenames = {
            "sleep": file_path
            / f"{strategy}_sleep_ws{window_size}_ss{step_size}_majority{majority}_type{model_names['sleep']}_preprocTrue_randomseed{random_seed}_{number_of_selected_feature}.pkl",
            "cr": file_path
            / f"{strategy}_cr_ws{window_size}_ss{step_size}_majority{majority}_type{model_names['cr']}_preprocTrue_randomseed{random_seed}_{number_of_selected_feature}.pkl",
            "sspl": file_path
            / f"{strategy}_sspl_ws{window_size}_ss{step_size}_majority{majority}_type{model_names['sspl']}_preprocTrue_randomseed{random_seed}_{number_of_selected_feature}.pkl",
            "burst_suppression": file_path
            / f"{strategy}_burst_suppression_ws{window_size}_ss{step_size}_majority{majority}_type{model_names['burst_suppression']}_preprocTrue_randomseed{random_seed}_{number_of_selected_feature}.pkl",
        }

        selected_features_filename = {
            "sleep": selected_features_path
            / f"Sleep_FeatureBased_ws{window_size}_ss{step_size}_majority{majority}_preprocTrue_randomseed42_numfeatures{number_of_selected_feature}.json",
            "cr": selected_features_path
            / f"cr_FeatureBased_ws{window_size}_ss{step_size}_majority{majority}_preprocTrue_randomseed42_numfeatures{number_of_selected_feature}.json",
            "sspl": selected_features_path
            / f"sspl_FeatureBased_ws{window_size}_ss{step_size}_majority{majority}_preprocTrue_randomseed42_numfeatures{number_of_selected_feature}.json",
            "burst_suppression": selected_features_path
            / f"burst_suppression_FeatureBased_ws{window_size}_ss{step_size}_majority{majority}_preprocTrue_randomseed42_numfeatures{number_of_selected_feature}.json",
        }
        scaler_filenames = {
            "sleep": scaler_path
            / f"FeatureBased_sleep_ws{window_size}_ss{step_size}_majority{majority}_preprocTrue_randomseed42_numfeatures{number_of_selected_feature}.joblib",
            "cr": scaler_path
            / f"FeatureBased_cr_ws{window_size}_ss{step_size}_majority{majority}_preprocTrue_randomseed42_numfeatures{number_of_selected_feature}.joblib",
            "sspl": scaler_path
            / f"FeatureBased_sspl_ws{window_size}_ss{step_size}_majority{majority}_preprocTrue_randomseed42_numfeatures{number_of_selected_feature}.joblib",
            "burst_suppression": scaler_path
            / f"FeatureBased_burst_suppression_ws{window_size}_ss{step_size}_majority{majority}_preprocTrue_randomseed42_numfeatures{number_of_selected_feature}.joblib",
        }

        return (
            file_path,
            csv_files_groundtruth,
            model_filenames,
            model_names,
            selected_features_filename,
            scaler_filenames,
            csv_file,
            propofol_concentration_blood_plasma,
        )
    else:
        model_filenames = {
            "sleep": file_path
            / f"{strategy}_sleep_ws{window_size}_ss{step_size}_majority{majority}_type{model_names['sleep']}_preprocTrue_randomseed42_all.pt",
            "cr": file_path
            / f"{strategy}_cr_ws{window_size}_ss{step_size}_majority{majority}_type{model_names['cr']}_preprocTrue_randomseed42_all.pt",
            "sspl": file_path
            / f"{strategy}_sspl_ws{window_size}_ss{step_size}_majority{majority}_type{model_names['sspl']}_preprocTrue_randomseed42_all.pt",
            "burst_suppression": file_path
            / f"{strategy}_burst_suppression_ws{window_size}_ss{step_size}_majority{majority}_type{model_names['burst_suppression']}_preprocTrue_randomseed42_all.pt",
        }

        return (
            file_path,
            csv_files_groundtruth,
            model_filenames,
            model_names,
            None,
            None,
            csv_file,
            propofol_concentration_blood_plasma,
        )

### Step 5: Define which model you would like to run the inference for

In [None]:
model_names = {
    "sleep": "RandomForestClassifier",
    "cr": "RandomForestClassifier",
    "sspl": "RandomForestClassifier",
    "burst_suppression": "RandomForestClassifier",
}

""" 
Further Examples:
model_names = {
    "sleep": "TabICLClassifier",
    "cr":  "TabICLClassifier",
    "sspl": "TabICLClassifier",
    "burst_suppression":  "TabICLClassifier",
}

model_names = {
    "sleep": "TabularNN",
    "cr":  "TabularNN",
    "sspl": "TabularNN",
    "burst_suppression":"TabularNN",
}
"""

### Step 6: Run the Inference:

- If you don't have any ground truths, you can run the inference, by putting inference_mode to True

In [None]:
start = time.perf_counter()
(
    file_path,
    csv_files_groundtruth,
    model_filenames,
    model_names,
    selected_features_filename,
    scaler_filenames,
    csv_file,
    propofol_concentration_blood_plasma,
) = define_file_paths(
    window_size=window_size,
    step_size=step_size,
    majority=for_majority,
    number_of_selected_feature=num_selected_features,
    data_path=data_path,
    file_path=file_path,
    model_names=model_names,
    selected_features_path=selected_features_path,
    volunteer_number=volunteer_number,
    case_number=case_number,
    random_seed=random_seed,
)
inference = EEGInference(
    csv_file=csv_file,
    sampling_rate=sampling_rate,
    window_size=window_size,
    step_size=step_size,
    for_majority=for_majority,
    preprocessing=preprocessing,
    majority_voting=True,
    strategy=strategy,
    random_seed=random_seed,
    feature_selection=feature_selection,
    scaler_applicable=False,
    device="cpu",
    inference_mode=False,
    depth_of_anesthesia=True,
    visualize_results=True,
    csv_files_groundtruth=csv_files_groundtruth,
    model_filenames=model_filenames,
    model_names=model_names,
    selected_features_filename=selected_features_filename,
    scaler_filenames=scaler_filenames,
    propofol_concentration_blood_plasma=propofol_concentration_blood_plasma,
    zoomed_in=True,
)
end = time.perf_counter()
print(f"Final Execution time: {end - start:.4f} seconds")

### Step 7: If you want to run local error analysis with LIME for the TabICLClassifier:

### Step 7.1 Define/adjust your file paths



In [None]:
import os, json, pickle, joblib
import numpy as np
import pandas as pd
from lime.lime_tabular import LimeTabularExplainer
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import font_manager as fm
import textwrap
import matplotlib as mpl
import textwrap

utils = Utils(
    for_majority=for_majority,
    window_size=window_size,
    step_size=step_size,
    random_seed=random_seed,
    preprocessing=preprocessing,
    sampling_rate=sampling_rate,
    results_validation_csv_path=base_path
    / "doA_classification"
    / "ml_models"
    / f"{strategy}_validation_results_df.csv",
    results_test_csv_path=base_path
    / "doA_classification"
    / "ml_models"
    / f"{strategy}_test_results_df.csv",
    model_dir=base_path / "doA_classification" / "ml_models",
)
label = "burst_suppression"
model_name = "TabICLClassifier"
features = inference.features_processed
predictions = inference.prediction

base_path = Path.cwd()
my_path = base_path.parent.parent.parent / "Ablation" / "final"
model_path = (
    my_path
    / f"FeatureBased_{label}_ws7680_ss7680_majority3840_typeTabICLClassifier_preprocTrue_randomseed42_50.pkl"
)
selected_features_path = (
    my_path
    / f"selected_features/{label}_FeatureBased_ws{window_size}_ss{step_size}_majority{for_majority}_preprocTrue_randomseed42_numfeatures50.json"
)
training_data = pd.read_csv(
    base_path / "EEG_data" / "dataset" / f"training_data_{window_size}_{step_size}.csv"
)

# --- 0) Load artifacts safely ---
if not os.path.exists(model_path):
    raise FileNotFoundError(f"Model file not found: {model_path}")

if not os.path.exists(selected_features_path):
    raise FileNotFoundError(
        f"Selected-features file not found: {selected_features_path}"
    )

with open(model_path, "rb") as f:
    model = pickle.load(f)

with open(selected_features_path, "r") as f:
    selected_feature_names = json.load(f)

### 7.3 Run LIME

In [None]:
# --- 4) Predict-proba wrapper (returns probs IN model.classes_ order) ---
def predict_proba_lime(X, model=model):
    X = np.asarray(X)
    if X.ndim == 1:
        X = X.reshape(1, -1)
    probs = model.predict_proba(X)  # assumed to be in model.classes_ order
    if probs.ndim != 2:
        raise ValueError(
            f"predict_proba returned shape {probs.shape}, expected (n, n_classes)"
        )
    if probs.shape[1] == 1:
        # Binary edge-case: ensure two columns [p0, p1]
        probs = np.hstack([1 - probs, probs])
    return probs


def preprocessing_lime(training_data, selected_feature_names, utils, predictions):
    missing_cols = [c for c in selected_feature_names if c not in training_data.columns]
    if missing_cols:
        raise KeyError(f"Missing features in training_data: {missing_cols}")

    X_train_raw = training_data.loc[:, selected_feature_names].copy()
    X_test_raw = predictions.loc[:, selected_feature_names].copy()

    # LIME/model should see the same preprocessing as training:
    # Fill NaNs BEFORE scaling to avoid propagating nans through the scaler
    X_train_raw = X_train_raw.fillna(0.0)
    X_test_raw = X_test_raw.fillna(0.0)

    # Apply the EXACT scaler used in training
    X_train = X_train_raw.values
    X_test = X_test_raw.values

    # --- 2) Pretty feature names (must align with column order above) ---
    pretty_feature_names = [
        utils.feature_name_map.get(f, f) for f in selected_feature_names
    ]

    # --- 3) Class names MUST follow model.classes_ order ---
    # Example mapping from numeric to display labels; adjust to your labels if needed
    names_by_class = {0: "no Burst Suppression", 1: "Burst Suppression"}
    try:
        class_names_in_model_order = [names_by_class[c] for c in model.classes_]
    except Exception as e:
        raise RuntimeError(f"Could not build class names from model.classes_: {e}")

    # --- 5) Build LIME explainer on the EXACT representation the model sees ---
    explainer = LimeTabularExplainer(
        training_data=X_train,  # scaled data (same space as model)
        mode="classification",
        feature_names=pretty_feature_names,  # aligned with selected_feature_names order
        class_names=class_names_in_model_order,
        discretize_continuous=True,
        random_state=42,
    )
    return (
        explainer,
        X_test,
        pretty_feature_names,
        class_names_in_model_order
    )


def plot_lime_topk(
    exp,
    label_idx,
    pos_class_name,
    neg_class_name,
    top_k=7,
    pred_proba=None,  # if None, will use exp.predict_proba[label_idx]
    wrap_width=30,
):
    # --- sanity guards ---
    avail = set(exp.available_labels())
    if label_idx not in avail:
        raise ValueError(
            f"label_idx {label_idx} not in exp.available_labels()={sorted(avail)}"
        )
    if pred_proba is None:
        pred_proba = float(exp.predict_proba[label_idx])
    # Optional: warn if top label differs
    top_from_exp = int(np.argmax(exp.predict_proba))
    if top_from_exp != label_idx:
        print(
            f"[warn] label_idx={label_idx} != argmax(exp.predict_proba)={top_from_exp}"
        )

    # ---- LIME pairs for this label ----
    pairs = exp.as_list(label=label_idx)
    feat_labels = np.array([p[0] for p in pairs])
    weights = np.array([p[1] for p in pairs], dtype=float)

    # sort by |w|, keep top_k, strongest at top
    order = np.argsort(np.abs(weights))[::-1][:top_k]
    feat_labels = feat_labels[order][::-1]
    weights = weights[order][::-1]

    # optional wrap
    if wrap_width:
        feat_labels = np.array(
            ["\n".join(textwrap.wrap(lbl, width=wrap_width)) for lbl in feat_labels]
        )

    # ---- figure sizing ----
    SINGLE_COL_W_IN = 10
    HEIGHT = 5

    mpl.rcParams.update(
        {
            "font.family": "serif",
            "font.serif": [
                "Times New Roman",
                "Times",
                "Nimbus Roman No9 L",
                "DejaVu Serif",
            ],
            "font.size": 17,
            "axes.labelsize": 20,
            "xtick.labelsize": 20,
            "ytick.labelsize": 20,
            "legend.fontsize": 20,
            "axes.linewidth": 0.8,
            "pdf.fonttype": 42,
            "ps.fonttype": 42,
            "text.usetex": False,
        }
    )

    fig, ax = plt.subplots(figsize=(SINGLE_COL_W_IN, HEIGHT))
    y = np.arange(len(feat_labels))
    pos_mask = weights >= 0  # supports explained class
    neg_mask = ~pos_mask

    ax.barh(
        y[pos_mask],
        weights[pos_mask],
        color="0.7",
        edgecolor="grey",
        linewidth=0.3,
        label=pos_class_name,
    )
    ax.barh(
        y[neg_mask],
        weights[neg_mask],
        color="0.5",
        edgecolor="grey",
        linewidth=0.3,
        hatch="//",
        label=neg_class_name,
    )

    ax.axvline(0, linewidth=0.7)
    ax.set_yticks(y)
    ax.set_yticklabels(feat_labels)
    ax.set_xlabel("Contribution (LIME weight)")
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    plt.tight_layout(rect=[0, 0, 1, 0.98])


    return fig, ax

In [None]:

explainer, X_test,pretty_feature_names,class_names_in_model_order = preprocessing_lime(training_data, selected_feature_names, utils, predictions)

features_new = pd.DataFrame(X_test, columns=pretty_feature_names)

In [None]:
sample = 156 # Decide which window you would like to look at
probs_row = model.predict_proba(X_test[sample].reshape(1, -1))[0]
label_idx = int(np.argmax(probs_row)) 
exp = explainer.explain_instance(
    data_row=features_new.iloc[
        sample
    ].values,  
    predict_fn=predict_proba_lime,
    num_features=5,
    labels=[label_idx],  
)

pairs = exp.as_list(label=label_idx)

print("model.classes_:", model.classes_)
print("Explaining label idx:", label_idx, "->", class_names_in_model_order[label_idx])
for f, w in pairs:
    print(f"{f:>40s}: {w:+.3f}")

exp.show_in_notebook(show_table=True)


In [None]:
fig, ax = plot_lime_topk(
    exp,
    label_idx=label_idx,
    top_k=5,  # how many features to show
    pos_class_name="Burst Suppression",
    neg_class_name="No Burst Suppression",
)

plt.show()