# SignXAI2 with PyTorch - ECG Time Series Explanation

This tutorial demonstrates how to use SignXAI2 to explain ECG (electrocardiogram) predictions with PyTorch models. We'll be working with ECG data and a PyTorch model for detecting cardiac conditions.

## Setup Requirements

**Important**: SignXAI2 requires Python 3.9 or 3.10 (Python 3.11+ is not supported)

Since you're running this tutorial, you should already have cloned the signxai2 repository. From the repository root directory:

### Using conda:
```bash
# Create environment with Python 3.10
conda create -n signxai2 python=3.10
conda activate signxai2

# Install PyTorch dependencies only
pip install -r requirements/common.txt -r requirements/pytorch.txt

# Download models and example data
git lfs install
bash ./prepare.sh
```

### Using venv:
```bash
# Create virtual environment
python3.10 -m venv signxai2_env
source signxai2_env/bin/activate  # On Windows: signxai2_env\Scripts\activate

# Install PyTorch dependencies only
pip install -r requirements/common.txt -r requirements/pytorch.txt

# Download models and example data
git lfs install
bash ./prepare.sh
```

## Overview

We'll cover:
1. Loading ECG data and a PyTorch model
2. Pre-processing ECG signals
3. Applying explainability methods to time series data using the unified API
4. Visualizing the regions of the ECG that influenced model predictions

Let's get started!

## 1. Import Libraries

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

# SignXAI2 unified API imports
from signxai import explain, list_methods

# Import utilities for ECG data handling
# Ensure the utils directory is in the Python path
sys.path.append(os.path.join(os.path.dirname(os.path.abspath("__file__")), "..", "utils"))
try:
    from ecg.explainability import normalize_ecg_relevancemap
except ImportError:
    # Define a simple normalization function if the utility is not available
    def normalize_ecg_relevancemap(relevancemap):
        """Normalize relevance map to [0, 1] range."""
        if relevancemap.max() > relevancemap.min():
            return (relevancemap - relevancemap.min()) / (relevancemap.max() - relevancemap.min())
        return relevancemap

## 2. Set Up Data Paths

In [None]:
# Set up paths
_THIS_DIR = os.path.dirname(os.path.abspath("__file__"))
DATA_DIR = os.path.realpath(os.path.join(_THIS_DIR, "..", "..", "data"))
TIMESERIES_DIR = os.path.join(DATA_DIR, "timeseries")
ECG_MODEL_PATH = os.path.join(DATA_DIR, "models", "pytorch", "ECG", "ecg_ported_weights.pt")

# Check that model file exists
if not os.path.exists(ECG_MODEL_PATH):
    print(f"ECG model not found at {ECG_MODEL_PATH}")
    print("We'll create a demo model for this tutorial")
else:
    print(f"Found ECG model at {ECG_MODEL_PATH}")

## 3. Define a PyTorch ECG Model

Let's first define a PyTorch model architecture for ECG processing. This is a 1D convolutional neural network designed for time series data.

In [None]:
class ECGModel(nn.Module):
    """Demo ECG model for tutorial - matches the structure of ECG_PyTorch"""
    def __init__(self, input_channels=1, num_classes=1):
        super(ECGModel, self).__init__()
        self.features = nn.Sequential(
            # First convolutional block
            nn.Conv1d(input_channels, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2),  # 1000
            
            # Second convolutional block
            nn.Conv1d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=4, stride=4),  # 250
            
            # Third convolutional block
            nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2),  # 125
            
            # Fourth convolutional block
            nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=5, stride=5),  # 25
        )
        
        # Flatten and fully connected layers
        self.fc = nn.Sequential(
            nn.Linear(256 * 25, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes),  # Binary classification
            nn.Sigmoid()  # Use sigmoid for binary classification
        )
        
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)
        return x
    
    # Convenience method to get the last convolutional layer (for GradCAM)
    def get_last_conv_layer(self):
        return self.features[9]  # The 4th Conv1d layer

## 4. Load Model and ECG Data

In [None]:
# Import the actual ECG model if available
sys.path.append(os.path.join(DATA_DIR, "models", "pytorch", "ECG"))

# Load the PyTorch ECG model
try:
    from ecg_model import ECG_PyTorch
    # Initialize the model architecture
    ecg_model = ECG_PyTorch(input_channels=1, num_classes=3)
    # Load the pre-trained weights
    ecg_model.load_state_dict(torch.load(ECG_MODEL_PATH, map_location=torch.device('cpu')))
    print("Loaded pre-trained ECG model weights")
except Exception as e:
    print(f"Could not load saved model: {e}\nInitializing demo model instead")
    # Use the demo model defined above
    ecg_model = ECGModel()

# Set model to evaluation mode
ecg_model.eval()

# Print model architecture
print(f"Model type: {type(ecg_model).__name__}")

In [None]:
# Load a sample ECG
try:
    # Try to load from saved numpy file
    ecg_sample_path = os.path.join(TIMESERIES_DIR, "ecg_sample.npy")
    ecg_data = np.load(ecg_sample_path)
    print(f"Loaded ECG sample from {ecg_sample_path}")
except Exception as e:
    print(f"Could not load ECG sample: {e}\nGenerating synthetic data instead")
    # Generate synthetic ECG data if loading fails
    # This is simplified and not realistic, just for illustration
    from scipy import signal
    t = np.linspace(0, 10, 2000)
    ecg_data = signal.square(2 * np.pi * 1.2 * t, duty=0.3) + np.sin(2 * np.pi * 3 * t)
    ecg_data += np.random.normal(0, 0.1, size=ecg_data.shape)

# Plot the ECG data
plt.figure(figsize=(15, 3))
plt.plot(ecg_data)
plt.title("ECG Sample")
plt.xlabel("Time (samples)")
plt.ylabel("Amplitude")
plt.grid(True)
plt.show()

# Preprocess for PyTorch model (add channel dimension and batch dimension)
ecg_tensor = torch.tensor(ecg_data, dtype=torch.float32).view(1, 1, -1)

## 5. Get Model Prediction

In [None]:
# Get model prediction
with torch.no_grad():
    prediction = ecg_model(ecg_tensor)
    prediction_probability = prediction.item()

print(f"Model prediction: {prediction_probability:.4f}")
print(f"Diagnosis: {'Positive' if prediction_probability > 0.5 else 'Negative'} for abnormal ECG")

## 6. Generate Explanations for the ECG Signal Using the Unified API

In [None]:
# Define the explainability methods to use with the unified API
methods_to_test = [
    'gradient',
    'gradient_x_input',
    'grad_cam',
    'guided_backprop',
    'lrp_z',
    'lrp_epsilon_0_5'
]

# Method-specific parameters
method_params = {
    'grad_cam': {'target_layer': 'features.9'}  # Last conv layer in our ECG model
}

# Storage for explanations
explanations = {}

In [None]:
# Generate explanations using the unified API
# For binary classification, we use the predicted class as target
target_class = 0 if prediction_probability < 0.5 else 1

for method_name in methods_to_test:
    print(f"Generating {method_name} explanation...")
    
    # Get method-specific parameters
    params = method_params.get(method_name, {})
    
    # Use the unified explain API
    explanation = explain(
        model=ecg_model,
        x=ecg_tensor,
        method_name=method_name,
        target_class=target_class,
        **params
    )
    
    # Convert to numpy if it's a tensor
    if torch.is_tensor(explanation):
        explanation = explanation.detach().cpu().numpy()
    
    # Only use positive values for visualization
    explanation = np.maximum(explanation, 0)
    
    # Normalize the explanation for visualization
    explanation = normalize_ecg_relevancemap(explanation)
    
    # Add to explanations dictionary
    explanations[method_name] = explanation
    
print("All explanations generated!")

## 7. Visualize ECG Explanations

In [None]:
# Convert explanations to proper shape for visualization
def reshape_for_visualization(explanation):
    # For 1D time series data, extract the values and flatten
    if explanation.ndim == 3:  # [batch, channel, time]
        return explanation[0, 0, :]
    elif explanation.ndim == 4:  # [batch, channel, height, width]
        # For GradCAM, the result might need to be resized
        from scipy.interpolate import interp1d
        orig_size = ecg_data.shape[0]
        expl = explanation[0, 0]
        if expl.shape[0] != orig_size:
            x_orig = np.linspace(0, 1, expl.shape[0])
            x_new = np.linspace(0, 1, orig_size)
            f = interp1d(x_orig, expl, kind='linear')
            return f(x_new)
        return expl
    return explanation

# Process explanations
processed_explanations = {}
for method_name, explanation in explanations.items():
    processed_explanations[method_name] = reshape_for_visualization(explanation)

In [None]:
# Filter explanations for better visualization (optional)
posthresh = 0.2  # Threshold for filtering low values
cmap_adjust = 0.3  # Adjust colormap for better visibility

for method_name, explanation in processed_explanations.items():
    # Filter low values
    filtered_explanation = explanation.copy()
    filtered_explanation[filtered_explanation <= posthresh] = 0
    filtered_explanation[filtered_explanation > posthresh] += cmap_adjust
    
    # Update the explanations
    processed_explanations[method_name] = filtered_explanation

In [None]:
# Create a figure to visualize all explanations
fig, axes = plt.subplots(len(processed_explanations), 1, figsize=(15, 4*len(processed_explanations)))

# Ensure axes is always a list
if len(processed_explanations) == 1:
    axes = [axes]

for i, (method_name, explanation) in enumerate(processed_explanations.items()):
    ax = axes[i]
    
    # Plot the ECG
    ax.plot(ecg_data, color='blue', alpha=0.7)
    
    # Create a colormap for the explanation
    ax_twin = ax.twinx()
    ax_twin.fill_between(range(len(ecg_data)), 0, explanation, 
                         color='red', alpha=0.5, label='Explanation')
    ax_twin.set_ylim(0, 1.1)
    ax_twin.set_ylabel('Relevance', color='red')
    ax_twin.tick_params(axis='y', labelcolor='red')
    
    # Set title and labels
    ax.set_title(f"{method_name.replace('_', ' ').upper()} Explanation")
    ax.set_xlabel('Time (samples)')
    ax.set_ylabel('ECG Amplitude', color='blue')
    ax.tick_params(axis='y', labelcolor='blue')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 8. Interpretation of ECG Explanations

Let's interpret what we're seeing in the ECG explanations:

- **gradient**: Shows the sensitivity of the model output to changes in the input ECG signal. Peaks often correspond to key cardiac events (P-waves, QRS complexes, T-waves).

- **gradient_x_input**: Enhances the gradient by multiplying it with the input signal, emphasizing areas where both the signal and its importance are high.

- **grad_cam**: Adapts the Grad-CAM technique for 1D time series data, highlighting temporal regions that contributed most to the classification decision.

- **guided_backprop**: Creates sharper feature visualizations by modifying the backpropagation signal through ReLU layers.

- **lrp_z**: Layer-wise Relevance Propagation with the Z-rule propagates the prediction backward through the network to identify relevant input features.

- **lrp_epsilon_0_5**: A variant of LRP that adds a stabilizing term (epsilon=0.5) to avoid division by zero, producing slightly different attribution maps.

The highlighted regions in the ECG indicate which parts of the signal were most important for the model's prediction. In cardiology, these regions often correspond to specific cardiac events like abnormal QRS complexes or ST-segment abnormalities.

In [None]:
# Import the TensorFlow-style API
from signxai.torch_signxai import tf_calculate_relevancemap

# Use the TensorFlow-style API with PyTorch model
methods_tf_style = ['gradient', 'gradient_x_input', 'guided_backprop', 'lrp_z', 'lrp_epsilon_0_1']
explanations_tf_style = {}

for method in methods_tf_style:
    print(f"Generating explanation using TF-style API: {method}...")
    
    # Using the TensorFlow-style API with PyTorch model
    explanation = tf_calculate_relevancemap(method, ecg_tensor, ecg_model_no_softmax)
    
    # Process for visualization
    explanation = np.maximum(explanation, 0)  # Only use positive values
    explanation = normalize_ecg_relevancemap(explanation)
    explanation = reshape_for_visualization(explanation)
    
    # Filter low values
    explanation[explanation <= posthresh] = 0
    explanation[explanation > posthresh] += cmap_adjust
    
    # Store explanation
    explanations_tf_style[method] = explanation

## 9. Conclusion

In this tutorial, we've demonstrated how to use SignXAI2's unified API to explain predictions of ECG models in PyTorch. We've covered:

- Creating a PyTorch model for ECG classification
- Loading and preprocessing ECG time series data
- Applying various explainability methods to ECG signals using the unified `explain()` function
- Visualizing and interpreting explanations

Key advantages of using SignXAI2's unified API:
- Consistent interface across all explanation methods
- Easy switching between different XAI techniques
- Automatic parameter handling for time series data
- Framework-agnostic approach (same API for TensorFlow and PyTorch)

These techniques help medical professionals understand what parts of the ECG signal are contributing to the model's predictions, potentially enhancing trust and enabling model validation for clinical use.

The ability to use the same API for both PyTorch and TensorFlow models makes SignXAI2 a valuable tool for researchers and practitioners working with either framework.