# SignXAI with TensorFlow - ECG Time Series Explanation

This tutorial demonstrates how to use SignXAI to explain ECG (electrocardiogram) predictions with TensorFlow models. We'll be working with ECG data and pre-trained models for detecting cardiac pathologies.

## Setup Requirements

For this TensorFlow tutorial, you'll need to install SignXAI with TensorFlow dependencies:

```bash
# For conda users
conda create -n signxai-tensorflow python=3.10
conda activate signxai-tensorflow
pip install -r ../../requirements/common.txt
pip install -r ../../requirements/tensorflow.txt

# Or for pip users
python -m venv signxai_tensorflow_env
source signxai_tensorflow_env/bin/activate  # On Windows: signxai_tensorflow_env\Scripts\activate
pip install -r ../../requirements/common.txt
pip install -r ../../requirements/tensorflow.txt
```

## Overview

We'll cover:
1. Loading ECG data and models
2. Pre-processing ECG signals
3. Applying explainability methods specifically designed for time series data
4. Visualizing the regions of the ECG that influenced model predictions

Let's get started!

## 1. Import Libraries

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# SignXAI imports
from signxai.tf_signxai.methods import SIGN, GradCAMTimeseries, GuidedBackprop, LRPZ, LRPEpsilon, LRPSignedEpsilon
from signxai.common.visualization import visualize_attribution
from signxai.utils.utils import remove_softmax

# Import utilities for ECG data handling
# Ensure the utils directory is in the Python path
sys.path.append(os.path.join(os.path.dirname(os.path.abspath("__file__")), "..", "utils"))
from ecg.data import load_and_preprocess_ecg
from ecg.model import load_models_from_paths
from ecg.explainability import normalize_ecg_relevancemap
from ecg.viz import plot_ecg

## 2. Set Up Data Paths

In [None]:
# Set up paths
_THIS_DIR = os.path.dirname(os.path.abspath("__file__"))
DATA_DIR = os.path.realpath(os.path.join(_THIS_DIR, "..", "..", "data"))
TIMESERIES_DIR = os.path.join(DATA_DIR, "timeseries")
MODELS_DIR = os.path.join(DATA_DIR, "models", "tensorflow", "ECG_AIME")

# We'll use one of the ECG pathology models from AIME2024
# Options: 'AVB' (AV Block), 'ISCH' (Ischemia), 'LBBB' (Left Bundle Branch Block), 'RBBB' (Right Bundle Branch Block)
MODEL_ID = 'AVB'  # You can change this to any of the above options
MODEL_DIR = os.path.join(MODELS_DIR, MODEL_ID)

# Specific record to analyze
RECORD_ID = '03509_hr'  # This is a record with AV block

# Define model paths
MODEL_JSON_PATH = os.path.join(MODEL_DIR, "model.json")
MODEL_WEIGHTS_PATH = os.path.join(MODEL_DIR, "weights.h5")

# Verify files exist
assert os.path.exists(MODEL_JSON_PATH), f"Model JSON not found at {MODEL_JSON_PATH}"
assert os.path.exists(MODEL_WEIGHTS_PATH), f"Model weights not found at {MODEL_WEIGHTS_PATH}"

## 3. Load the ECG Model and Data

In [None]:
# Load the ECG model for the specified pathology
model, model_no_softmax = load_models_from_paths(
    modelpath=MODEL_JSON_PATH,
    weightspath=MODEL_WEIGHTS_PATH
)

# Display model architecture
model.summary()

In [None]:
# Load and preprocess the ECG signal
ecg = load_and_preprocess_ecg(
    record_id=RECORD_ID,
    ecg_filters=['BWR', 'BLA', 'AC50Hz', 'LP40Hz'],  # Standard ECG filters
    subsampling_window_size=2000,  # Window size to analyze
    subsample_start=0,  # Start from the beginning of the record
    data_dir=TIMESERIES_DIR,  # Directory with ECG data
    model_id=MODEL_ID  # Pass model ID for specific preprocessing
)

# Plot the ECG signal
plt.figure(figsize=(15, 3))
plt.plot(ecg)
plt.title(f"ECG Signal - {RECORD_ID}")
plt.xlabel("Time (samples)")
plt.ylabel("Amplitude")
plt.grid(True)
plt.show()

# Convert to batch format for the model
ecg_batch = np.expand_dims(ecg, axis=0)

## 4. Get Model Prediction

In [None]:
# Get model prediction
prediction = model.predict(ecg_batch)
prediction_probability = prediction[0][0]

print(f"Model: {MODEL_ID}")
print(f"Record: {RECORD_ID}")
print(f"Prediction probability: {prediction_probability:.4f}")
print(f"Diagnosis: {'Positive' if prediction_probability > 0.5 else 'Negative'} for {MODEL_ID}")

## 5. Generate Explanations for the ECG Signal

In [None]:
# Define the explainability methods to use
methods = {
    'GradCAM-Timeseries': GradCAMTimeseries(model, layer_name='activation_4'),  # Layer specific to the ECG model
    'Gradient': SIGN(model_no_softmax),
    'Gradient × Input': SIGN(model_no_softmax),  # Will be processed separately
    'Gradient × Sign': SIGN(model_no_softmax),   # Will be processed separately
    'LRP-Alpha-1-Beta-0': LRPZ(model_no_softmax),
    'LRP-Epsilon': LRPEpsilon(model_no_softmax, epsilon=0.5),
    'LRP-Signed-Epsilon': LRPSignedEpsilon(model_no_softmax, epsilon=0.5)
}

# Storage for explanations
explanations = {}

In [None]:
# Convert data to TensorFlow tensor for explanation
ecg_tensor = tf.convert_to_tensor(ecg_batch, dtype=tf.float32)

# Generate explanations for each method
for method_name, explainer in methods.items():
    print(f"Generating {method_name} explanation...")
    
    if method_name == 'Gradient × Input':
        # Special case for gradient × input
        raw_explanation = explainer.attribute(ecg_tensor).numpy()
        explanation = raw_explanation * ecg_batch
    elif method_name == 'Gradient × Sign':
        # Special case for gradient × sign
        raw_explanation = explainer.attribute(ecg_tensor, vlow=-1, vhigh=1).numpy()
        sign = np.nan_to_num(ecg_batch / np.abs(ecg_batch), nan=1.0)
        explanation = raw_explanation * sign
    else:
        explanation = explainer.attribute(ecg_tensor).numpy()
    
    # Only use positive values for visualization
    explanation = np.maximum(explanation, 0)
    
    # Normalize the explanation for visualization
    explanation = normalize_ecg_relevancemap(explanation)
    
    # Add to explanations dictionary
    explanations[method_name] = explanation
    
print("All explanations generated!")

## 6. Visualize ECG Explanations

In [None]:
# Filter explanations for better visualization (optional)
posthresh = 0.2  # Threshold for filtering low values
cmap_adjust = 0.3  # Adjust colormap for better visibility

for method_name, explanation in explanations.items():
    # Filter low values
    filtered_explanation = explanation.copy()
    filtered_explanation[filtered_explanation <= posthresh] = 0
    filtered_explanation[filtered_explanation > posthresh] += cmap_adjust
    
    # Update the explanations
    explanations[method_name] = filtered_explanation

In [None]:
# Create a figure to visualize all explanations
fig, axes = plt.subplots(len(explanations), 1, figsize=(15, 4*len(explanations)))

for i, (method_name, explanation) in enumerate(explanations.items()):
    # Plot the explanation using the utility function
    plot_ecg(ecg=ecg, 
             explanation=explanation[0], # Remove batch dimension
             title=f"{method_name} - {MODEL_ID} - {RECORD_ID}",
             ax=axes[i])

plt.tight_layout()
plt.show()

## 7. Interpretation of ECG Explanations

Let's interpret what we're seeing in the ECG explanations:

- **GradCAM-Timeseries**: This method adapts GradCAM for time series data, highlighting temporal regions that contributed most to the classification decision.

- **Gradient**: Shows the sensitivity of the model output to changes in the input ECG signal. Peaks often correspond to key cardiac events (P-waves, QRS complexes, T-waves).

- **Gradient × Input**: Enhances the gradient by multiplying it with the input signal, emphasizing areas where both the signal and its importance are high.

- **Gradient × Sign**: Uses the sign of the input signal to modulate gradient importance, highlighting direction-dependent features.

- **LRP-Alpha-1-Beta-0**: A Layer-wise Relevance Propagation variant that emphasizes positive contributions to the prediction.

- **LRP-Epsilon**: Standard LRP with a stabilization term, useful for highlighting relevant signal components.

- **LRP-Signed-Epsilon**: Combines LRP with a sign operator to provide directional relevance information.

The highlighted regions in the ECG indicate which parts of the signal were most important for detecting the cardiac condition. For AV block, look for irregularities in the PR interval or missing QRS complexes after P-waves.

## 8. Compare Multiple Pathologies (Optional)

For a more comprehensive understanding, we can analyze explanations across different cardiac pathology models.

In [None]:
# Function to run the explanation process for different models
def run_explanation(model_id, record_id, method='GradCAM-Timeseries'):
    # Load the model
    model_path = os.path.join(MODELS_DIR, model_id, "model.json")
    weights_path = os.path.join(MODELS_DIR, model_id, "weights.h5")
    
    if not os.path.exists(model_path) or not os.path.exists(weights_path):
        print(f"Model files for {model_id} not found, skipping.")
        return None, None, None
    
    model, model_no_softmax = load_models_from_paths(model_path, weights_path)
    
    # Load the ECG data
    ecg = load_and_preprocess_ecg(
        record_id=record_id,
        ecg_filters=['BWR', 'BLA', 'AC50Hz', 'LP40Hz'],
        subsampling_window_size=2000,
        subsample_start=0,
        data_dir=TIMESERIES_DIR,
        model_id=model_id
    )
    
    # Get prediction
    ecg_batch = np.expand_dims(ecg, axis=0)
    prediction = model.predict(ecg_batch)[0][0]
    
    # Generate explanation
    ecg_tensor = tf.convert_to_tensor(ecg_batch, dtype=tf.float32)
    
    if method == 'GradCAM-Timeseries':
        explainer = GradCAMTimeseries(model, layer_name='activation_4')
    elif method == 'Gradient':
        explainer = SIGN(model_no_softmax)
    elif method == 'LRP-Z':
        explainer = LRPZ(model_no_softmax)
    else:
        raise ValueError(f"Method {method} not supported")
    
    explanation = explainer.attribute(ecg_tensor).numpy()
    explanation = np.maximum(explanation, 0)
    explanation = normalize_ecg_relevancemap(explanation)
    
    # Apply filtering
    explanation[explanation <= 0.2] = 0
    explanation[explanation > 0.2] += 0.3
    
    return ecg, explanation, prediction

In [None]:
# Define pathology models to compare
pathologies = ['AVB', 'ISCH', 'LBBB', 'RBBB']
selected_method = 'GradCAM-Timeseries'  # Choose one method for comparison

# Create plot
fig, axes = plt.subplots(len(pathologies), 1, figsize=(15, 4*len(pathologies)))

for i, model_id in enumerate(pathologies):
    print(f"Processing {model_id} model...")
    ecg, explanation, prediction = run_explanation(model_id, RECORD_ID, selected_method)
    
    if ecg is not None:
        plot_ecg(ecg=ecg, 
                 explanation=explanation[0],
                 title=f"{model_id} - Prediction: {prediction:.4f} - {'Positive' if prediction > 0.5 else 'Negative'}",
                 ax=axes[i])
    else:
        axes[i].text(0.5, 0.5, f"Model for {model_id} not available", ha='center', va='center')
        axes[i].set_title(f"{model_id} - Not Available")

plt.tight_layout()
plt.show()

## 9. Conclusion

In this tutorial, we've demonstrated how to use SignXAI to explain predictions of ECG models in TensorFlow. We've covered:

- Loading and preprocessing ECG time series data
- Working with specialized models for cardiac pathology detection
- Applying various explainability methods to ECG signals
- Visualizing and interpreting explanations

These techniques help medical professionals understand what parts of the ECG signal are contributing to the model's predictions, potentially enhancing trust and enabling model validation for clinical use.

The AIME2024 ECG models used in this tutorial are specifically designed for cardiac pathology detection, demonstrating SignXAI's versatility in working with domain-specific time series data beyond image classification.