# Explainability

In this notebook we will see the feature importance of the models. 

In [None]:
import os
import joblib
import numpy as np
import pandas as pd

from CogniPredictAD.visualization import ModelExplainer

pd.set_option('display.max_rows', 116)
pd.set_option('display.max_columns', 40)
pd.set_option('display.max_info_columns', 40) 

train = pd.read_csv("../data/train.csv")

# Target column
y_train = train['DX']

# All other columns as features
X_train = train.drop(columns=['DX'])

## Dataset with `CDRSB`, `LDELTOTAL`, and `mPACCdigit` with Classification

In [None]:
# Load models from folder
models_dir = "../results/all_models/1"
models = []

for fname in os.listdir(models_dir):
    if fname.endswith(".pkl"):  
        model_path = os.path.join(models_dir, fname)
        try:
            model = joblib.load(model_path)
            model_name = os.path.splitext(fname)[0]
            models.append((model_name, model))
            print(f"Loaded model: {model_name}")
        except Exception as e:
            print(f"Could not load {fname}: {e}")

# Initialize ModelExplainer
explainer1 = ModelExplainer(
    models=models,
    X_train=X_train,
    y_train=y_train,
    feature_names=list(X_train.columns),
    class_names=np.unique(y_train).tolist()
)

# Generate SHAP summary plots
explainer1.shap_summary_plots()

## Dataset without `CDRSB`, `LDELTOTAL`, and `mPACCdigit` with Classification

In [None]:
X_train.drop(columns=['CDRSB', 'LDELTOTAL', 'mPACCdigit'], axis=1, inplace=True)

# Load models from folder
models_dir = "../results/all_models/2"
models = []

for fname in os.listdir(models_dir):
    if fname.endswith(".pkl"):  
        model_path = os.path.join(models_dir, fname)
        try:
            model = joblib.load(model_path)
            model_name = os.path.splitext(fname)[0]
            models.append((model_name, model))
            print(f"Loaded model: {model_name}")
        except Exception as e:
            print(f"Could not load {fname}: {e}")

# Initialize ModelExplainer
explainer2 = ModelExplainer(
    models=models,
    X_train=X_train,
    y_train=y_train,
    feature_names=list(X_train.columns),
    class_names=np.unique(y_train).tolist()
)

# Generate SHAP summary plots
explainer2.shap_summary_plots()