# Global and Local Explanation for CNN model on MIT shifted data with multiple target variable

Input files:  
Mode: `model_mit_multiple_shift_cnn.h5` generated in notebook `modeling_mit_multiple_shift_CNN.ipynb`     

Dataset: Binary classification datate genereted in notebook `preprocessing_mit_shift_minmax_oversampling.ipynb` 


mitbih_multipleclass_train_shift_minmax_oversampling.csv \
mitbih_multipleclass_test_shift_minmax_oversampling.csv  

In [5]:
import sys
import os 

data_path = ''
model_output_path = ''
# check if the enviorment is Google Colab 

if 'google.colab' in sys.modules:
    print("Running on Google Colab")
    # Install required libraries
    !pip install scikit-learn -q
    !pip install pandas -q
    !pip install numpy -q
    !pip install imbalanced-learn -q

    # Mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')
    # set the path where the csv file stored in your google drive. 
    data_path = '/content/drive/MyDrive/Heartbeat_Project/'
    #model_output_path = data_path
    model_path = data_path + 'model_mit_dnn_shift.h5' # pkl'

else:
    print("Running on local environment")

    current_path = os.getcwd()
    print("Current working directory:", current_path)
    data_path = '../data/processed/'
    model_output_path = '../models/'
    model_path = '../models/' + 'model_mit_multiple_shift_cnn.h5' 

Running on local environment
Current working directory: g:\Meine Ablage\heartbeat-analysis-ai\notebooks


In [6]:
# Verify installation and import libraries
import sys
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import shap
from sklearn.metrics import f1_score, confusion_matrix, classification_report, roc_curve, auc
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers import Adam
import lime
import lime.lime_tabular

## load data

In [7]:
RawFiles = dict({
    'train': data_path + 'mitbih_multipleclass_train_shift_minmax_oversampling.csv', 
    'test': data_path + 'mitbih_multipleclass_test_shift_minmax_oversampling.csv'  
}) 

train = pd.read_csv(RawFiles.get('train'),sep=',',header=0)
test = pd.read_csv(RawFiles.get('test'),sep=',',header=0)

y_train = train['target']
X_train = train.drop('target', axis=1)

y_test = test['target']
X_test = test.drop('target', axis=1)

## load model

In [None]:
# Load and compile the model
try:
    loaded_model = load_model(model_path, compile=False)
except OSError as e:
    print(f"Error loading model: {e}")
    sys.exit("Check model path and try again.")

# Compile the model (ensuring metrics match expected inputs)
loaded_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Evaluate the model with multiple class target variable
evaluation_metrics = loaded_model.evaluate(X_test, y_test)

print(f"Model Evaluation - Loss: {evaluation_metrics[0]:.4f}, Accuracy: {evaluation_metrics[1]:.4f}")


Error loading model: [Errno 2] Unable to open file (unable to open file: name = '../models/model_mit_multiple_shift_cnn.h5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)


AttributeError: 'tuple' object has no attribute 'tb_frame'

### In the following we apply explanation strategies (SHAP, LIME) on the resampled (Oversampling) and rescaled (MinMax Scaler) and shifted dataset to explain the most important features within the CNN model due to to given strategies:

## SHAP Values

In [None]:
import shap
# # Prepare background data for SHAP and calculate SHAP values
# background = X_train.sample(1000, random_state=42).values  # Sample background data for SHAP
# X_test_array = X_test.values  # Convert X_test to a numpy array

# Reshape X_train and X_test to match the expected input shape of the model
background = X_train.sample(1000, random_state=42).values.reshape(-1, X_train.shape[1], 1)
X_test_array = X_test.values.reshape(-1, X_test.shape[1], 1)


In [None]:
# Initialize DeepExplainer with reshaped background data
explainer = shap.DeepExplainer(loaded_model, background)

# Compute SHAP values for test data
shap_values = explainer.shap_values(X_test_array)

# Visualize SHAP values (e.g., summary plot)
shap.summary_plot(shap_values[0], X_test_array.squeeze(), feature_names=X_test.columns)


In [None]:
# # Initialize DeepExplainer with background data
# explainer = shap.DeepExplainer(loaded_model, background)

In [None]:
# # Compute SHAP values for each feature and sample
# shap_values = explainer.shap_values(X_test_array)

In [None]:
if isinstance(shap_values, list):
    shap_values = shap_values[0]

# Calculate mean absolute SHAP values across all test samples for each feature
mean_shap_values = np.mean(np.abs(shap_values), axis=0)

# Verify that mean_shap_values is 1-dimensional
print("Shape of mean_shap_values before flattening:", mean_shap_values.shape)
if mean_shap_values.ndim > 1:
    mean_shap_values = mean_shap_values.flatten()  # Flatten if it's multi-dimensional
print("Shape of mean_shap_values after flattening:", mean_shap_values.shape)

# Verify the number of features matches
print("Number of features (X_test.columns):", len(X_test.columns))
print("Length of mean_shap_values:", len(mean_shap_values))

# Ensure lengths match before creating DataFrame
if len(mean_shap_values) == len(X_test.columns):
    # Create DataFrame for feature importance
    shap_importance_df = pd.DataFrame({
        'feature': X_test.columns,
        'importance': mean_shap_values
    })

    # Sort by importance and display top features
    shap_importance_df = shap_importance_df.sort_values(by='importance', ascending=False)
    print("Top 10 most important features based on SHAP values:\n", shap_importance_df.head(10))
    

In [None]:
# Plot the top 10 most important features
plt.figure(figsize=(12, 8))
sns.barplot(
    x='importance', 
    y='feature', 
    data=shap_importance_df.head(10), 
    palette='viridis'
)
plt.title('SHAP: Top 10 Most Important Features for DNN Model')
plt.xlabel('Mean Absolute SHAP Value')
plt.ylabel('Features')
plt.show()


## LIME


In [None]:
# Wrapper function for predict_proba behavior
def predict_proba_wrapper(data):
    # Predict using the model (outputs probabilities for binary classification)
    predictions = loaded_model.predict(data)
    # Reshape predictions to have two columns for binary classification: [1 - prediction, prediction]
    return np.column_stack([1 - predictions, predictions])

# Initialize the LIME explainer
explainer = lime.lime_tabular.LimeTabularExplainer(
    training_data=np.array(X_train),  # Features from the training data
    training_labels=np.array(y_train),  # Target labels for training data
    mode="classification",  # Set to 'classification' for binary classification
    feature_names=X_train.columns,  # Feature names
    class_names=['Class 0', 'Class 1'],  # Class names for binary classification
    discretize_continuous=True  # Discretizes continuous features
)

# Select a random instance from the test set
idx = 200  # You can change this index to select a different instance
instance = X_test.iloc[idx]  # The input instance
true_label = y_test.iloc[idx]

print("True Label for selected instance:", true_label)
print("Instance features:\n", instance)

# Explain the instance using the LIME explainer
exp = explainer.explain_instance(
    data_row=instance,  # Instance to explain
    predict_fn=predict_proba_wrapper,  # Use the wrapper function here
    num_features=10  # Number of features to include in the explanation
)

# Extract feature contributions from the explanation object
feature_importance = exp.as_list()  # Returns a list of (feature, contribution) tuples

# Convert the feature importance to a DataFrame for easy visualization
lime_df = pd.DataFrame(feature_importance, columns=['Feature', 'Contribution'])

# Plot the LIME explanation using Seaborn
plt.figure(figsize=(10, 6))
sns.barplot(
    x='Contribution', 
    y='Feature', 
    data=lime_df, 
    palette='viridis', 
    orient='h'
)
plt.title(f'LIME Explanation for Instance {idx} (True Label: {true_label})')
plt.xlabel('Feature Contribution')
plt.ylabel('Feature')
plt.grid(axis='x', linestyle='--', alpha=0.6)
plt.show()


In [None]:
from datetime import datetime
# Display the running time
print("Current time:", datetime.now())