# Data Analysis and Predictive Modeling with SHAP and XGBoost

## Title  
**Integrating Machine Learning with Metabolic Models for Precision Trauma Care: Personalized ENDOTYPE Stratification and Metabolic Target Identification**

## Authors  
- **Igor Marin de Mas** (Copenhagen University Hospital, Rigshospitalet)  
- **Lincoln Moura** (Universidade Federal do Ceará)  
- **Fernando Luiz Marcelo Antunes** (Universidade Federal do Ceará)  
- **Josep Maria Guerrero** (Aalborg University)  
- **Pär Ingemar Johansson** (Copenhagen University Hospital, Rigshospitalet)  

## Description  
This notebook performs a data analysis based on patient features, leveraging machine learning techniques and interpretability tools for classification and pattern identification. The workflow includes:  

1. **Data Loading:** Importing preprocessed patient data and explainability variables.  
2. **Preprocessing:** Combining individual patient data into a single DataFrame and handling missing values.  
3. **Modeling:** Training a classification model using the XGBoost algorithm.  
4. **Interpretability:** Utilizing the SHAP library to understand the impact of each feature on the predictive outcomes.  
5. **Visualization:** Generating histograms of explained variance and visualizing the top 20 most important features for each group.  
6. **Evaluation:** Creating confusion matrices to assess model performance on training and test datasets.  


In [None]:
# Standard libraries
import os
import glob

# Third-party libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score
from xgboost import XGBClassifier
import scikitplot as skplt
import warnings
import shap

# print the JS visualization code to the notebook
shap.initjs()

# Suppress warnings
warnings.filterwarnings("ignore")

In [None]:
# -----------------------------
# Utility Functions
# -----------------------------

def extract_patient_index(file_path):
    """
    Extract the patient index from the file name using regex.

    Args:
        file_path (str): Path to the file.

    Returns:
        int: Patient index extracted from the file name.

    Raises:
        ValueError: If the index cannot be extracted from the file name.
    """
    import re
    match = re.search(r'(\d+)', os.path.basename(file_path))
    if match:
        return int(match.group(1)) - 1
    else:
        raise ValueError(f"Cannot extract index from file name: {file_path}")

In [None]:
def load_preprocessed_data(preprocessed_path, num_patients):
    """
    Load preprocessed patient data from CSV files.

    Args:
        preprocessed_path (str): Path to the directory with preprocessed data.
        num_patients (int): Number of patients.

    Returns:
        dict: Dictionary of patient DataFrames indexed by patient number.
        list: List of indices for successfully loaded patients.
    """
    patients = ["patient_" + str(x) for x in range(num_patients)]
    test_indices = []

    for file_path in glob.glob(preprocessed_path + '*.csv*'):
        try:
            index = extract_patient_index(file_path)
            print(f"Index = {index}")
            test_indices.append(index)
            patients[index] = pd.read_csv(file_path, index_col=0)
        except ValueError as e:
            print(f"[ERROR] {e}")
            continue
        except Exception as e:
            print(f"[ERROR] Could not load file: {file_path}. Error: {e}")
            continue

    return patients, test_indices

In [None]:
def combine_patient_data(patients, test_indices, target):
    """
    Combine individual patient data into a single DataFrame with target labels.

    Args:
        patients (list): List of patient DataFrames.
        test_indices (list): List of indices for patients.
        target (list): Target labels for patients.

    Returns:
        pd.DataFrame: Combined DataFrame with all patient data and target labels.
    """
    dataframes = []

    for i in test_indices:
        temp = patients[i].T
        temp["target"] = target[i]
        dataframes.append(temp)

    df = pd.concat(dataframes, axis=0).reset_index(drop=True)

    # Check for missing values
    if df.isna().sum().sum() > 0:
        print("[WARNING] Missing values detected. Filling with column means.")
        df.fillna(df.mean(), inplace=True)

    return df

In [None]:
def plot_explainability_histogram(explainability):
    """
    Plot a histogram of explained variance from PCA data.

    Args:
        explainability (pd.DataFrame): DataFrame containing explained variance.
    """
    ax = explainability.hist(figsize=(10, 5))
    plt.title("")
    plt.xlabel("Explained Variance Using PCA with 600 Components (%)", fontsize=12)
    plt.ylabel("Number of patients", fontsize=12)
    plt.tick_params(axis='both', which='major', labelsize=12)
    plt.gca().set_xticklabels([f'{x:.2%}' for x in plt.gca().get_xticks()])
    plt.show()

In [None]:
def conf_matrix(clf, X_train, X_test, y_train, y_test):
    """
    Plot confusion matrices for training and testing predictions.

    Args:
        clf: Trained classifier.
        X_train: Training data.
        X_test: Testing data.
        y_train: Training labels.
        y_test: Testing labels.
    """
    Y_train_pred = clf.predict(X_train)
    Y_test_pred = clf.predict(X_test)

    fig = plt.figure(figsize=(15, 6))
    ax1 = fig.add_subplot(121)
    skplt.metrics.plot_confusion_matrix(Y_train_pred, y_train, normalize=False, title="Confusion Matrix", cmap="Oranges", ax=ax1)
    ax2 = fig.add_subplot(122)
    skplt.metrics.plot_confusion_matrix(Y_test_pred, y_test, normalize=False, title="Confusion Matrix", cmap="Purples", ax=ax2)
    plt.show()

In [None]:
# Load target data
target_df = pd.read_csv("Patient_Trauma_Groups.csv", delimiter=";")
target = target_df["Metabo-group"].values
target

In [None]:
# Load explainability data
explainability = pd.read_csv("explainability.csv", delimiter=",", index_col=0).T
plot_explainability_histogram(explainability)

In [None]:
# Load preprocessed patient data
preprocessed_path = "preprocess_PCA/"
num_patients = 95
patients, test_indices = load_preprocessed_data(preprocessed_path, num_patients)

# Combine patient data into a single DataFrame
df = combine_patient_data(patients, test_indices, target)
df

In [None]:
# Prepare features and target
X, y = df.drop('target', axis=1), df['target'] - 1

In [None]:
# Initialize variables
acc = []  # List to store accuracy scores
sv_c0, sv_c1, sv_c2, sv_c3 = pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame()

# Run the model training and SHAP analysis 100 times
for i in range(100):
    print("Iteration:", i)
    
    # Split features and target variable
    X, y = df.drop('target', axis=1), df['target'] - 1
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)
    
    # Initialize and train the XGBClassifier
    model = XGBClassifier(objective='multi:softmax')
    model.fit(X_train, y_train)
    
    # Make predictions and compute accuracy
    predictions = model.predict(X_test)
    accuracy_model = accuracy_score(predictions, y_test)
    acc.append(accuracy_model)
    print("Accuracy:", accuracy_model)
    
    # Compute SHAP values to explain the model's predictions
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(X_train)
    
    # Aggregate SHAP values by taking the mean absolute value across all samples
    aggs = np.abs(shap_values).mean(axis=1)
    sv_df = pd.DataFrame(aggs.T, index=X.columns)
    
    # Define column names dynamically for each class and iteration
    col_names = [f"class{c}_run_{i}" for c in range(4)]
    sv_df.columns = col_names
    
    # Concatenate SHAP value data for each class across iterations
    sv_c0 = pd.concat([sv_df[[col_names[0]]], sv_c0], axis=1, ignore_index=True)
    sv_c1 = pd.concat([sv_df[[col_names[1]]], sv_c1], axis=1, ignore_index=True)
    sv_c2 = pd.concat([sv_df[[col_names[2]]], sv_c2], axis=1, ignore_index=True)
    sv_c3 = pd.concat([sv_df[[col_names[3]]], sv_c3], axis=1, ignore_index=True)

In [None]:
# Adjusting indices to match MATLAB-style indexing (starting from 1)
groups = [sv_c0, sv_c1, sv_c2, sv_c3]
for sv in groups:
    sv.index += 1  # Increment index by 1
    sv["mean"] = sv.iloc[:, :-1].mean(axis=1)  # Compute mean SHAP value per feature

# Creating the figure and subplots
fig, axes = plt.subplots(figsize=(15, 12), nrows=4, ncols=1, sharex=False, constrained_layout=True)

# Titles for each group
group_titles = [
    "Top 20 Most Important Features for Group 1 by Average",
    "Top 20 Most Important Features for Group 2 by Average",
    "Top 20 Most Important Features for Group 3 by Average",
    "Top 20 Most Important Features for Group 4 by Average"
]

# Generating bar plots for each group
for i, (sv, title) in enumerate(zip(groups, group_titles)):
    ax = sv.sort_values(by='mean', ascending=False).head(20).plot.bar(ax=axes[i], legend=False)
    ax.set_xlabel(title, fontsize=12)  # Set x-axis label
    ax.set_ylabel("Mean Value for 100 Models", fontsize=12)  # Set y-axis label
    ax.tick_params(axis='both', which='major', labelsize=12)  # Adjust tick label size

# Display the plots
plt.show()