# SignXAI2 with TensorFlow - ECG Time Series Explanation

This tutorial demonstrates how to use SignXAI2 to explain ECG (electrocardiogram) predictions with TensorFlow models. We'll be working with ECG data and pre-trained models for detecting cardiac pathologies.

## Setup Requirements

**Important**: SignXAI2 requires Python 3.9 or 3.10 (Python 3.11+ is not supported)

Since you're running this tutorial, you should already have cloned the signxai2 repository. From the repository root directory:

### Using conda:
```bash
# Create environment with Python 3.10
conda create -n signxai2 python=3.10
conda activate signxai2

# Install SignXAI2 with TensorFlow support
pip install signxai2[tensorflow]

# Download models and example data
git lfs install
bash ./prepare.sh
```

### Using venv:
```bash
# Create virtual environment
python3.10 -m venv signxai2_env
source signxai2_env/bin/activate  # On Windows: signxai2_env\Scripts\activate

# Install SignXAI2 with TensorFlow support
pip install signxai2[tensorflow]

# Download models and example data
git lfs install
bash ./prepare.sh
```

## Overview

We'll cover:
1. Loading ECG data and models
2. Pre-processing ECG signals
3. Applying explainability methods specifically designed for time series data using the unified API
4. Visualizing the regions of the ECG that influenced model predictions

Let's get started!

## 1. Import Libraries

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# SignXAI2 unified API imports
from signxai import explain, list_methods

# Import utilities for ECG data handling
# Ensure the utils directory is in the Python path
sys.path.append(os.path.join(os.path.dirname(os.path.abspath("__file__")), "..", "utils"))
from ecg.data import load_and_preprocess_ecg
from ecg.model import load_models_from_paths
from ecg.explainability import normalize_ecg_relevancemap
from ecg.viz import plot_ecg

## 2. Set Up Data Paths

In [None]:
# Set up paths
_THIS_DIR = os.path.dirname(os.path.abspath("__file__"))
DATA_DIR = os.path.realpath(os.path.join(_THIS_DIR, "..", "..", "data"))
TIMESERIES_DIR = os.path.join(DATA_DIR, "timeseries")
MODELS_DIR = os.path.join(DATA_DIR, "models", "tensorflow", "ECG")

# Main ECG model path
ECG_MODEL_PATH = os.path.join(MODELS_DIR, "ecg_model.h5")

# We can also use one of the pathology-specific models
# Options: 'AVB' (AV Block), 'ISCH' (Ischemia), 'LBBB' (Left Bundle Branch Block), 'RBBB' (Right Bundle Branch Block)
MODEL_ID = 'AVB'  # You can change this to any of the above options
MODEL_DIR = os.path.join(MODELS_DIR, MODEL_ID)

# Specific record to analyze
RECORD_ID = '03509_hr'  # This is a record with AV block

# Define model paths for pathology-specific model
MODEL_JSON_PATH = os.path.join(MODEL_DIR, "model.json")
MODEL_WEIGHTS_PATH = os.path.join(MODEL_DIR, "weights.h5")

# Check if files exist
if os.path.exists(ECG_MODEL_PATH):
    print(f"Found general ECG model at {ECG_MODEL_PATH}")
    USE_GENERAL_MODEL = True
else:
    print(f"General ECG model not found, using pathology-specific model: {MODEL_ID}")
    USE_GENERAL_MODEL = False
    assert os.path.exists(MODEL_JSON_PATH), f"Model JSON not found at {MODEL_JSON_PATH}"
    assert os.path.exists(MODEL_WEIGHTS_PATH), f"Model weights not found at {MODEL_WEIGHTS_PATH}"

## 3. Load the ECG Model and Data

In [None]:
# Load the ECG model
if USE_GENERAL_MODEL:
    # Load the general ECG model directly
    model = tf.keras.models.load_model(ECG_MODEL_PATH)
    model_no_softmax = model  # Assuming the model doesn't have softmax
    print("Loaded general ECG model")
else:
    # Load the pathology-specific model
    model, model_no_softmax = load_models_from_paths(
        modelpath=MODEL_JSON_PATH,
        weightspath=MODEL_WEIGHTS_PATH
    )
    print(f"Loaded pathology-specific model: {MODEL_ID}")

# Display model architecture
model.summary()

In [None]:
# Load and preprocess the ECG signal
ecg = load_and_preprocess_ecg(
    record_id=RECORD_ID,
    ecg_filters=['BWR', 'BLA', 'AC50Hz', 'LP40Hz'],  # Standard ECG filters
    subsampling_window_size=2000,  # Window size to analyze
    subsample_start=0,  # Start from the beginning of the record
    data_dir=TIMESERIES_DIR,  # Directory with ECG data
    model_id=MODEL_ID  # Pass model ID for specific preprocessing
)

# Plot the ECG signal
plt.figure(figsize=(15, 3))
plt.plot(ecg)
plt.title(f"ECG Signal - {RECORD_ID}")
plt.xlabel("Time (samples)")
plt.ylabel("Amplitude")
plt.grid(True)
plt.show()

# Convert to batch format for the model
ecg_batch = np.expand_dims(ecg, axis=0)

## 4. Get Model Prediction

In [None]:
# Get model prediction
prediction = model.predict(ecg_batch)
prediction_probability = prediction[0][0]

print(f"Model: {MODEL_ID}")
print(f"Record: {RECORD_ID}")
print(f"Prediction probability: {prediction_probability:.4f}")
print(f"Diagnosis: {'Positive' if prediction_probability > 0.5 else 'Negative'} for {MODEL_ID}")

## 5. Generate Explanations for the ECG Signal

In [None]:
# Define the explainability methods to use with the unified API
methods_to_test = [
    'grad_cam_timeseries',
    'gradient',
    'gradient_x_input',
    'gradient_x_sign',
    'lrp_alpha_1_beta_0',
    'lrp_epsilon_0_5',
    'lrpsign_epsilon_0_5'
]

# Method-specific parameters
# For GradCAM, we need to find the last activation layer
# This will depend on which model we're using
if USE_GENERAL_MODEL:
    # For the general model, we'll need to check its architecture
    # Let's find the last activation layer
    last_activation = None
    for layer in model.layers[::-1]:  # Iterate backwards
        if 'activation' in layer.name or isinstance(layer, tf.keras.layers.Activation):
            last_activation = layer.name
            break
    if not last_activation:
        # Try to find a conv layer as fallback
        for layer in model.layers[::-1]:
            if 'conv' in layer.name.lower():
                last_activation = layer.name
                break
    gradcam_layer = last_activation or 'activation'  # Default fallback
else:
    # For pathology-specific models, use the known layer
    gradcam_layer = 'activation_4'

method_params = {
    'grad_cam_timeseries': {'layer_name': gradcam_layer}
}

print(f"Using layer '{gradcam_layer}' for GradCAM")

# Storage for explanations
explanations = {}

In [None]:
# Generate explanations using the unified API
for method_name in methods_to_test:
    print(f"Generating {method_name} explanation...")
    
    # Get method-specific parameters
    params = method_params.get(method_name, {})
    
    # Use the unified explain API
    # Note: For ECG models, we typically use the class 0 (positive diagnosis)
    explanation = explain(
        model=model,
        x=ecg_batch,
        method_name=method_name,
        target_class=0,  # ECG models typically have binary output
        **params
    )
    
    # Only use positive values for visualization
    explanation = np.maximum(explanation, 0)
    
    # Normalize the explanation for visualization
    explanation = normalize_ecg_relevancemap(explanation)
    
    # Add to explanations dictionary
    explanations[method_name] = explanation
    
print("All explanations generated!")

## 6. Visualize ECG Explanations

In [None]:
# Filter explanations for better visualization (optional)
posthresh = 0.2  # Threshold for filtering low values
cmap_adjust = 0.3  # Adjust colormap for better visibility

for method_name, explanation in explanations.items():
    # Filter low values
    filtered_explanation = explanation.copy()
    filtered_explanation[filtered_explanation <= posthresh] = 0
    filtered_explanation[filtered_explanation > posthresh] += cmap_adjust
    
    # Update the explanations
    explanations[method_name] = filtered_explanation

In [None]:
# Create a figure to visualize all explanations
fig, axes = plt.subplots(len(explanations), 1, figsize=(15, 4*len(explanations)))

# Ensure axes is always a list
if len(explanations) == 1:
    axes = [axes]

for i, (method_name, explanation) in enumerate(explanations.items()):
    # Plot the explanation using the utility function
    plot_ecg(ecg=ecg, 
             explanation=explanation[0], # Remove batch dimension
             title=f"{method_name.replace('_', ' ').upper()} - {MODEL_ID} - {RECORD_ID}",
             ax=axes[i])

plt.tight_layout()
plt.show()

## 7. Interpretation of ECG Explanations

Let's interpret what we're seeing in the ECG explanations:

- **grad_cam_timeseries**: This method adapts GradCAM for time series data, highlighting temporal regions that contributed most to the classification decision.

- **gradient**: Shows the sensitivity of the model output to changes in the input ECG signal. Peaks often correspond to key cardiac events (P-waves, QRS complexes, T-waves).

- **gradient_x_input**: Enhances the gradient by multiplying it with the input signal, emphasizing areas where both the signal and its importance are high.

- **gradient_x_sign**: Uses the sign of the input signal to modulate gradient importance, highlighting direction-dependent features.

- **lrp_alpha_1_beta_0**: A Layer-wise Relevance Propagation variant that emphasizes positive contributions to the prediction.

- **lrp_epsilon_0_5**: Standard LRP with a stabilization term (epsilon=0.5), useful for highlighting relevant signal components.

- **lrpsign_epsilon_0_5**: Combines LRP with a sign operator to provide directional relevance information.

The highlighted regions in the ECG indicate which parts of the signal were most important for detecting the cardiac condition. For AV block, look for irregularities in the PR interval or missing QRS complexes after P-waves.

## 8. Compare Multiple Pathologies (Optional)

For a more comprehensive understanding, we can analyze explanations across different cardiac pathology models.

In [None]:
# Function to run the explanation process for different models using the unified API
def run_explanation(model_id, record_id, method='grad_cam_timeseries'):
    # Check if general model exists first
    general_model_path = os.path.join(MODELS_DIR, "ecg_model.h5")
    if os.path.exists(general_model_path) and model_id == 'GENERAL':
        # Use general model
        model = tf.keras.models.load_model(general_model_path)
        model_no_softmax = model
    else:
        # Load pathology-specific model
        model_path = os.path.join(MODELS_DIR, model_id, "model.json")
        weights_path = os.path.join(MODELS_DIR, model_id, "weights.h5")
        
        if not os.path.exists(model_path) or not os.path.exists(weights_path):
            print(f"Model files for {model_id} not found, skipping.")
            return None, None, None
        
        model, model_no_softmax = load_models_from_paths(model_path, weights_path)
    
    # Load the ECG data
    ecg = load_and_preprocess_ecg(
        record_id=record_id,
        ecg_filters=['BWR', 'BLA', 'AC50Hz', 'LP40Hz'],
        subsampling_window_size=2000,
        subsample_start=0,
        data_dir=TIMESERIES_DIR,
        model_id=model_id if model_id != 'GENERAL' else 'AVB'  # Use AVB preprocessing for general model
    )
    
    # Get prediction
    ecg_batch = np.expand_dims(ecg, axis=0)
    prediction = model.predict(ecg_batch)[0][0]
    
    # Generate explanation using the unified API
    params = {}
    if method == 'grad_cam_timeseries':
        # Find appropriate layer for this model
        if model_id == 'GENERAL':
            # Find last activation layer
            for layer in model.layers[::-1]:
                if 'activation' in layer.name:
                    params['layer_name'] = layer.name
                    break
        else:
            params['layer_name'] = 'activation_4'
    
    explanation = explain(
        model=model,
        x=ecg_batch,
        method_name=method,
        target_class=0,  # ECG models typically have binary output
        **params
    )
    
    explanation = np.maximum(explanation, 0)
    explanation = normalize_ecg_relevancemap(explanation)
    
    # Apply filtering
    explanation[explanation <= 0.2] = 0
    explanation[explanation > 0.2] += 0.3
    
    return ecg, explanation, prediction

In [None]:
# Define pathology models to compare
pathologies = ['AVB', 'ISCH', 'LBBB', 'RBBB']
selected_method = 'grad_cam_timeseries'  # Choose one method for comparison

# Create plot
fig, axes = plt.subplots(len(pathologies), 1, figsize=(15, 4*len(pathologies)))

for i, model_id in enumerate(pathologies):
    print(f"Processing {model_id} model...")
    ecg, explanation, prediction = run_explanation(model_id, RECORD_ID, selected_method)
    
    if ecg is not None:
        plot_ecg(ecg=ecg, 
                 explanation=explanation[0],
                 title=f"{model_id} - Prediction: {prediction:.4f} - {'Positive' if prediction > 0.5 else 'Negative'}",
                 ax=axes[i])
    else:
        axes[i].text(0.5, 0.5, f"Model for {model_id} not available", ha='center', va='center')
        axes[i].set_title(f"{model_id} - Not Available")

plt.tight_layout()
plt.show()

## 9. Conclusion

In this tutorial, we've demonstrated how to use SignXAI2's unified API to explain predictions of ECG models in TensorFlow. We've covered:

- Loading and preprocessing ECG time series data
- Working with specialized models for cardiac pathology detection
- Applying various explainability methods to ECG signals using the unified `explain()` function
- Visualizing and interpreting explanations

Key advantages of using SignXAI2's unified API:
- Consistent interface across all explanation methods
- Easy switching between different XAI techniques
- Automatic parameter handling for time series data
- Support for specialized methods like `grad_cam_timeseries`

These techniques help medical professionals understand what parts of the ECG signal are contributing to the model's predictions, potentially enhancing trust and enabling model validation for clinical use.

The AIME2024 ECG models used in this tutorial are specifically designed for cardiac pathology detection, demonstrating SignXAI2's versatility in working with domain-specific time series data beyond image classification.