# TensorFlow ECG Model Tutorial - SignXAI2 Time Series Explainability

This tutorial demonstrates how to use SignXAI2 to explain TensorFlow ECG classification models, following the structure from the official documentation.

## ‚ö†Ô∏è Data Requirements

**This tutorial requires ECG data from the repository:**

```bash
# Clone the repository to get the complete dataset
git clone https://github.com/your-repo/signxai2.git
```

**Required data files:**
- ECG records in `examples/data/timeseries/`
- Pre-trained models in `examples/data/models/tensorflow/ECG/`
- Utility functions for ECG processing

The PyPI package alone doesn't include ECG data to keep the package size manageable.

## What you'll learn:

1. **TensorFlow ECG Models**: Building CNN models for time series classification
2. **ECG Data Processing**: Loading and preprocessing 12-lead ECG data
3. **Time Series XAI**: Applying various explainability methods to sequential data
4. **12-Lead Visualization**: Creating professional medical visualizations
5. **Multiple Methods**: Comparing different XAI approaches
6. **Dynamic Method Parsing**: Using the new unified API with embedded parameters

## ECG Signal Components:
Understanding what the model should focus on:
- **P-wave**: Atrial depolarization (0.08-0.12 sec)
- **QRS complex**: Ventricular depolarization (0.06-0.10 sec)  
- **T-wave**: Ventricular repolarization (0.16 sec)
- **PR interval**: AV conduction time (0.12-0.20 sec)
- **QT interval**: Total ventricular activity (0.35-0.44 sec)


## 1. Import Libraries and Setup

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Dense, Flatten, Dropout
import matplotlib.pyplot as plt
import warnings
import os
import sys

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')
tf.get_logger().setLevel('ERROR')

# SignXAI imports
from signxai import explain, list_methods

# Add project root to path for utility imports
current_dir = os.getcwd()
project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_dir)))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Import ECG utilities (from quickstart files)
try:
    from utils.ecg_data import load_and_preprocess_ecg
    from utils.ecg_visualization import plot_ecg
    from utils.ecg_explainability import normalize_ecg_relevancemap
    print("‚úÖ ECG utilities loaded successfully!")
except ImportError as e:
    print(f"‚ö†Ô∏è ECG utilities not available: {e}")
    print("Please ensure you have the complete repository with utils/ directory")

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {len(tf.config.list_physical_devices('GPU')) > 0}")

# ECG lead names for reference
LEAD_NAMES = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
print(f"\n12-Lead ECG configuration: {', '.join(LEAD_NAMES)}")

## 2. Generate Synthetic ECG Data

Following the documentation approach (lines 75-92), we'll generate synthetic ECG data with characteristic patterns.

In [None]:
# Generate synthetic data (in practice, you would use real ECG datasets)
def generate_synthetic_ecg_data(n_samples=1000, seq_length=1000, n_classes=2):
    X = np.random.randn(n_samples, seq_length, 1) * 0.1
    # Add synthetic patterns for different classes
    for i in range(n_samples):
        if i % n_classes == 0:  # Class 0: Normal
            # Add normal QRS complex
            X[i, 400:420, 0] += np.sin(np.linspace(0, np.pi, 20)) * 1.0
            X[i, 350:370, 0] += np.sin(np.linspace(0, np.pi, 20)) * 0.2  # P wave
            X[i, 450:480, 0] += np.sin(np.linspace(0, np.pi, 30)) * 0.3  # T wave
        else:  # Class 1: Abnormal
            # Add abnormal QRS complex
            X[i, 380:410, 0] += np.sin(np.linspace(0, np.pi, 30)) * 0.8
            X[i, 420:460, 0] -= np.sin(np.linspace(0, np.pi, 40)) * 0.4
            
    # Create labels
    y = np.array([i % n_classes for i in range(n_samples)])
    return X, y

# Generate data
X_train, y_train = generate_synthetic_ecg_data(800, 1000, 2)
X_test, y_test = generate_synthetic_ecg_data(200, 1000, 2)

print(f"Training data shape: {X_train.shape}")
print(f"Test data shape: {X_test.shape}")

# Visualize some samples
plt.figure(figsize=(15, 8))
for i in range(4):
    plt.subplot(2, 2, i+1)
    sample_idx = i * 50
    plt.plot(X_test[sample_idx, :, 0])
    plt.title(f'Sample {sample_idx}, Class: {y_test[sample_idx]}')
    plt.xlabel('Time (samples)')
    plt.ylabel('Amplitude')
    plt.grid(True)

plt.tight_layout()
plt.show()

## 3. Create TensorFlow CNN Model for ECG Classification

Following the documentation structure (lines 98-113), we'll create a CNN model for ECG classification.

In [None]:
# Create a CNN model for ECG classification
def create_ecg_model(seq_length=1000):
    model = Sequential([
        Conv1D(16, kernel_size=5, activation='relu', input_shape=(seq_length, 1)),
        MaxPooling1D(pool_size=2),
        Conv1D(32, kernel_size=5, activation='relu'),
        MaxPooling1D(pool_size=2),
        Conv1D(64, kernel_size=5, activation='relu', name='conv1d_2'),  # Named for Grad-CAM
        MaxPooling1D(pool_size=2),
        Flatten(),
        Dense(64, activation='relu'),
        Dropout(0.2),
        Dense(2)  # No activation (logits)
    ])
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

# Create and train the model
model = create_ecg_model()
print("Model created successfully!")
model.summary()

## 4. Train the Model

Following the documentation training procedure (lines 115-127).

In [None]:
# Train the model
print("Training the model...")
history = model.fit(X_train, y_train, epochs=10, batch_size=32, validation_split=0.2, verbose=1)

# Evaluate the model
test_loss, test_acc = model.evaluate(X_test, y_test, verbose=0)
print(f'\n‚úÖ Training completed!')
print(f'Test accuracy: {test_acc:.4f}')

# Plot training history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

## 5. Prepare Sample for XAI Analysis

Following the documentation approach (lines 129-149), we'll prepare a sample for explanation.

In [None]:
# Load the model and sample
ecg_sample = X_test[0, :, 0]  # First test sample, remove channel dimension for visualization

# Prepare input with batch dimension
x = ecg_sample.reshape(1, 1000, 1)

print(f"Sample shape for XAI: {x.shape}")
print(f"Sample shape for visualization: {ecg_sample.shape}")

# Get prediction
preds = model.predict(x, verbose=0)
predicted_class = np.argmax(preds[0])
class_names = ['Normal', 'Abnormal']

print(f"Predicted class: {predicted_class} ({class_names[predicted_class]})")
print(f"Confidence: {tf.nn.softmax(preds)[0, predicted_class]:.4f}")

# Visualize the sample
plt.figure(figsize=(15, 4))
plt.plot(ecg_sample)
plt.title(f'ECG Sample for XAI Analysis\nPredicted: {class_names[predicted_class]} (confidence: {tf.nn.softmax(preds)[0, predicted_class]:.3f})')
plt.xlabel('Time (samples)')
plt.ylabel('Amplitude')
plt.grid(True)
plt.show()

## 6. Generate Explanations with Multiple Methods

Following the documentation structure (lines 150-179), we'll calculate explanations with different methods including Grad-CAM.

In [None]:
# Calculate explanations with different methods
methods = [
    'gradient',
    'gradient_x_input',
    'integrated_gradients',
    'grad_cam',  # Works for time series too
    'lrp_z',
    'lrp_epsilon_0_1',
    'lrpsign_z'  # The SIGN method
]

explanations = {}
print("Calculating explanations...")

for method in methods:
    try:
        print(f"  Computing: {method}")
        if method == 'grad_cam':
            explanations[method] = explain(
                model=model,
                x=x,
                method_name=method,
                target_class=predicted_class,
                last_conv_layer_name='conv1d_2'
            )
        else:
            explanations[method] = explain(
                model=model,
                x=x,
                method_name=method,
                target_class=predicted_class
            )
        print(f"    ‚úÖ Success")
    except Exception as e:
        print(f"    ‚ùå Failed: {e}")
        # Create dummy explanation for visualization
        explanations[method] = np.zeros_like(x)

print(f"\n‚úÖ Generated explanations for {len(explanations)} methods")

## 7. Visualize Separate Explanations

Following the documentation approach (lines 180-201), we'll create separate plots for each method.

In [None]:
# Visualize explanations
fig, axs = plt.subplots(len(methods) + 1, 1, figsize=(15, 3*(len(methods) + 1)))

# Original signal
axs[0].plot(ecg_sample)
axs[0].set_title('Original ECG Signal')
axs[0].set_ylabel('Amplitude')
axs[0].grid(True)

# Explanations
for i, method in enumerate(methods):
    if method in explanations:
        # Reshape explanation to 1D
        if isinstance(explanations[method], tf.Tensor):
            expl = explanations[method].numpy()[0, :, 0]
        else:
            expl = explanations[method][0, :, 0]
        
        # Plot explanation
        axs[i+1].plot(expl)
        axs[i+1].set_title(f'Method: {method}')
        axs[i+1].set_ylabel('Attribution')
        axs[i+1].grid(True)
    else:
        axs[i+1].text(0.5, 0.5, 'Method failed', transform=axs[i+1].transAxes, ha='center')

plt.tight_layout()
plt.show()

## 8. Overlay Visualizations

Following the documentation approach (lines 202-222), we'll create overlay visualizations.

In [None]:
# Alternative visualization: Overlay explanation on signal
plt.figure(figsize=(15, 10))

for i, method in enumerate(methods):
    if method not in explanations:
        continue
        
    plt.subplot(len(methods), 1, i+1)
    
    # Original signal
    plt.plot(ecg_sample, 'gray', alpha=0.5, label='ECG Signal')
    
    # Explanation
    if isinstance(explanations[method], tf.Tensor):
        expl = explanations[method].numpy()[0, :, 0]
    else:
        expl = explanations[method][0, :, 0]
        
    expl_norm = (expl - expl.min()) / (expl.max() - expl.min()) if expl.max() > expl.min() else expl
    plt.plot(expl_norm, 'r', label='Attribution')
    
    plt.title(f'Method: {method}')
    plt.legend()
    plt.grid(True)

plt.tight_layout()
plt.show()

## 9. Advanced ECG Component Analysis

Following the documentation approach for advanced analysis, we'll analyze specific ECG components.

In [None]:
# Define characteristic ECG components (these would be expert-identified in real applications)
p_wave_region = slice(350, 370)
qrs_complex_region = slice(400, 420)
t_wave_region = slice(450, 480)

# Calculate the mean attribution for each region using LRP-SIGN method
if 'lrpsign_z' in explanations:
    if isinstance(explanations['lrpsign_z'], tf.Tensor):
        lrpsign_expl = explanations['lrpsign_z'].numpy()[0, :, 0]
    else:
        lrpsign_expl = explanations['lrpsign_z'][0, :, 0]
    
    p_wave_attr = np.mean(np.abs(lrpsign_expl[p_wave_region]))
    qrs_complex_attr = np.mean(np.abs(lrpsign_expl[qrs_complex_region]))
    t_wave_attr = np.mean(np.abs(lrpsign_expl[t_wave_region]))
    
    # Visualize with region highlighting
    plt.figure(figsize=(15, 6))
    
    # Plot original ECG
    plt.subplot(2, 1, 1)
    plt.plot(ecg_sample)
    
    # Highlight ECG components
    plt.axvspan(350, 370, color='blue', alpha=0.2, label='P-wave')
    plt.axvspan(400, 420, color='red', alpha=0.2, label='QRS Complex')
    plt.axvspan(450, 480, color='green', alpha=0.2, label='T-wave')
    
    plt.title('ECG Signal with Components')
    plt.legend()
    plt.grid(True)
    
    # Plot explanation with component attribution
    plt.subplot(2, 1, 2)
    plt.plot(lrpsign_expl)
    
    # Highlight attribution in ECG components
    plt.axvspan(350, 370, color='blue', alpha=0.2)
    plt.axvspan(400, 420, color='red', alpha=0.2)
    plt.axvspan(450, 480, color='green', alpha=0.2)
    
    # Add component attribution values
    plt.text(360, max(lrpsign_expl), f'P-wave: {p_wave_attr:.4f}', 
             horizontalalignment='center', backgroundcolor='white')
    plt.text(410, max(lrpsign_expl), f'QRS: {qrs_complex_attr:.4f}', 
             horizontalalignment='center', backgroundcolor='white')
    plt.text(465, max(lrpsign_expl), f'T-wave: {t_wave_attr:.4f}', 
             horizontalalignment='center', backgroundcolor='white')
    
    plt.title('LRP-SIGN Attribution')
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nüîç ECG Component Analysis:")
    print(f"  P-wave attribution: {p_wave_attr:.4f}")
    print(f"  QRS complex attribution: {qrs_complex_attr:.4f}")
    print(f"  T-wave attribution: {t_wave_attr:.4f}")
else:
    print("‚ö†Ô∏è LRP-SIGN method not available for component analysis")

## 10. 12-Lead ECG Visualization (if ECG utilities available)

Using the utilities from the quickstart files to create professional ECG visualizations.

In [None]:
# Try to load real ECG data and create 12-lead visualization
try:
    # Load real ECG data from repository
    record_id = '03509_hr'
    ecg_src_dir = os.path.join(project_root, 'examples', 'data', 'timeseries', '')
    
    print(f"Loading ECG data for record: {record_id}...")
    ecg_data = load_and_preprocess_ecg(
        record_id=record_id,
        src_dir=ecg_src_dir,
        ecg_filters=['BWR', 'BLA', 'AC50Hz', 'LP40Hz'],
        subsampling_window_size=3000,
        subsample_start=0
    )
    
    if ecg_data is not None:
        print(f"‚úÖ ECG data loaded: {ecg_data.shape}")
        
        # Use single lead for model prediction
        ecg_single_lead = ecg_data[:, 0:1]  # Shape: (3000, 1)
        
        # TensorFlow expects shape: (timesteps, channels)
        input_data = ecg_single_lead.astype(np.float32)
        
        # Get prediction and explanation for real ECG
        predictions = model.predict(np.expand_dims(input_data, 0), verbose=0)
        predicted_idx = np.argmax(predictions[0])
        
        # Calculate explanation with one method
        explanation = explain(
            model,
            input_data,  # No batch dimension needed for SignXAI
            method_name="gradient_x_input",
            target_class=predicted_idx
        )
        
        # Process for visualization
        if isinstance(explanation, tf.Tensor):
            explanation_np = explanation.numpy()
        else:
            explanation_np = explanation
        
        # Handle shape processing for relevance map
        if explanation_np.ndim == 1:
            relevance_map = explanation_np.reshape(-1, 1)
        elif explanation_np.ndim == 2:
            if explanation_np.shape[0] == 1:
                relevance_map = explanation_np.transpose()
            else:
                relevance_map = explanation_np
        else:
            relevance_map = explanation_np.reshape(-1, 1)
        
        # Expand to 12 leads for visualization
        if relevance_map.shape[1] == 1 and ecg_data.shape[1] == 12:
            relevance_map = np.tile(relevance_map, (1, 12))
        
        # Normalize
        normalized_relevance = normalize_ecg_relevancemap(relevance_map)
        
        # Format for visualization
        ecg_for_visual = ecg_data.transpose()
        expl_for_visual = normalized_relevance.transpose()
        
        # Create 12-lead visualization
        plot_ecg(
            ecg=ecg_for_visual,
            explanation=expl_for_visual,
            sampling_rate=500,
            title=f"TensorFlow XAI: gradient_x_input on {record_id}",
            show_colorbar=True,
            cmap='seismic',
            bubble_size=30,
            line_width=1.0,
            style='fancy',
            save_to=None,
            clim_min=-1,
            clim_max=1,
            colorbar_label='Relevance',
            shape_switch=False
        )
        
        print("\n‚úÖ 12-lead ECG visualization created!")
    
except Exception as e:
    print(f"‚ö†Ô∏è Could not create 12-lead visualization: {e}")
    print("This requires the complete repository with ECG data and utilities.")

## 11. Method Comparison Analysis

Let's compare attribution across different methods for ECG components.

In [None]:
# Compare attribution across methods
methods_to_compare = ['gradient', 'gradient_x_input', 'lrp_z', 'lrpsign_z']
components = ['P-wave', 'QRS Complex', 'T-wave']
regions = [p_wave_region, qrs_complex_region, t_wave_region]

# Calculate attribution for each method and component
component_attribution = {}
for method in methods_to_compare:
    if method in explanations:
        if isinstance(explanations[method], tf.Tensor):
            expl = explanations[method].numpy()[0, :, 0]
        else:
            expl = explanations[method][0, :, 0]
        component_attribution[method] = [np.mean(np.abs(expl[region])) for region in regions]
    else:
        component_attribution[method] = [0, 0, 0]

# Visualize component attribution comparison
plt.figure(figsize=(12, 6))

x = np.arange(len(components))
width = 0.2
offsets = np.linspace(-0.3, 0.3, len(methods_to_compare))

for i, method in enumerate(methods_to_compare):
    plt.bar(x + offsets[i], component_attribution[method], width, label=method)

plt.xlabel('ECG Component')
plt.ylabel('Mean Absolute Attribution')
plt.title('Attribution Comparison Across Methods')
plt.xticks(x, components)
plt.legend()
plt.grid(True, axis='y')

plt.tight_layout()
plt.show()

print("\nüìä Method Comparison Summary:")
for method in methods_to_compare:
    if method in component_attribution:
        attrs = component_attribution[method]
        print(f"  {method:20} - P: {attrs[0]:.3f}, QRS: {attrs[1]:.3f}, T: {attrs[2]:.3f}")

## 12. Summary and Key Insights

### What we've learned:

1. **TensorFlow Time Series Models**: How to build and train CNN models for ECG classification
2. **Multiple XAI Methods**: Different approaches provide different insights:
   - **Gradient**: Shows instantaneous importance
   - **Gradient √ó Input**: Emphasizes strong signal regions  
   - **Integrated Gradients**: Provides theoretically grounded attributions
   - **Grad-CAM**: Adapts convolutional attention for time series
   - **LRP methods**: Show layer-wise relevance propagation
   - **LRP-SIGN**: The SIGN method for enhanced attribution

3. **Visualization Techniques**: 
   - Separate plots for clear comparison
   - Overlay visualizations for intuitive understanding
   - Component-specific analysis for medical insights
   - 12-lead medical visualizations for clinical applications

4. **Time Series XAI**: Understanding temporal dependencies and pattern recognition in sequential data

### Clinical Insights:
- Most methods correctly highlight QRS complexes, which is clinically appropriate
- Different methods show varying sensitivity to P-waves and T-waves
- Component-specific analysis helps validate model focus against medical knowledge

### Next Steps:
- Try different ECG records to see consistency across samples
- Experiment with different model architectures
- Compare with PyTorch implementation
- Apply to your own time series classification problems

### Method Selection Guide:
- **`gradient`**: Fast, good for real-time analysis
- **`gradient_x_input`**: Better for highlighting strong features
- **`integrated_gradients`**: Most theoretically sound, but slower
- **`grad_cam`**: Good for understanding convolutional focus
- **`lrp_*`**: Good for layer-wise understanding
- **`lrpsign_z`**: Enhanced attribution with sign information

This tutorial demonstrates how SignXAI2 can provide valuable insights into time series model decisions, particularly important in medical applications where understanding AI reasoning is critical for clinical adoption.