# PyTorch ECG Model Tutorial - SignXAI2 Time Series Explainability

This tutorial demonstrates how to use SignXAI2 to explain PyTorch ECG classification models, following the structure from the official documentation.

## ⚠️ Data Requirements

**This tutorial requires ECG data from the repository:**

```bash
# Clone the repository to get the complete dataset
git clone https://github.com/your-repo/signxai2.git
```

**Required data files:**
- ECG records in `examples/data/timeseries/`
- Pre-trained models in `examples/data/models/pytorch/ECG/`
- Utility functions for ECG processing

The PyPI package alone doesn't include ECG data to keep the package size manageable.

## What you'll learn:

1. **PyTorch ECG Models**: Building CNN models for time series classification
2. **ECG Data Processing**: Loading and preprocessing 12-lead ECG data
3. **Time Series XAI**: Applying various explainability methods to sequential data
4. **12-Lead Visualization**: Creating professional medical visualizations
5. **Multiple Methods**: Comparing different XAI approaches
6. **Dynamic Method Parsing**: Using the new unified API with embedded parameters

## ECG Signal Components:
Understanding what the model should focus on:
- **P-wave**: Atrial depolarization (0.08-0.12 sec)
- **QRS complex**: Ventricular depolarization (0.06-0.10 sec)  
- **T-wave**: Ventricular repolarization (0.16 sec)
- **PR interval**: AV conduction time (0.12-0.20 sec)
- **QT interval**: Total ventricular activity (0.35-0.44 sec)

## 1. Import Libraries and Setup

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
import warnings
import os
import sys

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# SignXAI imports
from signxai import explain, list_methods

# Add project root to path for utility imports
current_dir = os.getcwd()
project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_dir)))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Import ECG utilities (from quickstart files)
try:
    from utils.ecg_data import load_and_preprocess_ecg
    from utils.ecg_visualization import plot_ecg
    from utils.ecg_explainability import normalize_ecg_relevancemap
    print("✅ ECG utilities loaded successfully!")
except ImportError as e:
    print(f"⚠️ ECG utilities not available: {e}")
    print("Please ensure you have the complete repository with utils/ directory")

print(f"PyTorch version: {torch.__version__}")
print(f"GPU available: {torch.cuda.is_available()}")

# ECG lead names for reference
LEAD_NAMES = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
print(f"\n12-Lead ECG configuration: {', '.join(LEAD_NAMES)}")

## 2. Generate Synthetic ECG Data

Following the documentation approach (lines 75-92), we'll generate synthetic ECG data with characteristic patterns.

In [None]:
# Generate synthetic data (in practice, you would use real ECG datasets)
def generate_synthetic_ecg_data(n_samples=1000, seq_length=1000, n_classes=2):
    X = np.random.randn(n_samples, seq_length, 1) * 0.1
    # Add synthetic patterns for different classes
    for i in range(n_samples):
        if i % n_classes == 0:  # Class 0: Normal
            # Add normal QRS complex
            X[i, 400:420, 0] += np.sin(np.linspace(0, np.pi, 20)) * 1.0
            X[i, 350:370, 0] += np.sin(np.linspace(0, np.pi, 20)) * 0.2  # P wave
            X[i, 450:480, 0] += np.sin(np.linspace(0, np.pi, 30)) * 0.3  # T wave
        else:  # Class 1: Abnormal
            # Add abnormal QRS complex
            X[i, 380:410, 0] += np.sin(np.linspace(0, np.pi, 30)) * 0.8
            X[i, 420:460, 0] -= np.sin(np.linspace(0, np.pi, 40)) * 0.4
            
    # Create labels
    y = np.array([i % n_classes for i in range(n_samples)])
    return X, y

# Generate data
X_train, y_train = generate_synthetic_ecg_data(800, 1000, 2)
X_test, y_test = generate_synthetic_ecg_data(200, 1000, 2)

print(f"Training data shape: {X_train.shape}")
print(f"Test data shape: {X_test.shape}")

# Visualize some samples
plt.figure(figsize=(15, 8))
for i in range(4):
    plt.subplot(2, 2, i+1)
    sample_idx = i * 50
    plt.plot(X_test[sample_idx, :, 0])
    plt.title(f'Sample {sample_idx}, Class: {y_test[sample_idx]}')
    plt.xlabel('Time (samples)')
    plt.ylabel('Amplitude')
    plt.grid(True)

plt.tight_layout()
plt.show()

## 3. Create CNN Model for ECG Classification

Following the documentation structure (lines 98-113), we'll create a CNN model for ECG classification.

In [None]:
# Create a PyTorch CNN model for ECG classification
class ECG_CNN(nn.Module):
    def __init__(self, seq_length=1000):
        super(ECG_CNN, self).__init__()
        self.conv1 = nn.Conv1d(1, 16, kernel_size=5)
        self.pool1 = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(16, 32, kernel_size=5)
        self.pool2 = nn.MaxPool1d(2)
        self.conv3 = nn.Conv1d(32, 64, kernel_size=5)  # Named for Grad-CAM compatibility
        self.pool3 = nn.MaxPool1d(2)
        
        # Calculate size after convolutions and pooling
        self.flat_size = 64 * (((seq_length - 4) // 2 - 4) // 2 - 4) // 2
        
        self.fc1 = nn.Linear(self.flat_size, 64)
        self.dropout = nn.Dropout(0.2)
        self.fc2 = nn.Linear(64, 2)  # No activation (logits)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        # Conv blocks
        x = self.pool1(self.relu(self.conv1(x)))
        x = self.pool2(self.relu(self.conv2(x)))
        x = self.pool3(self.relu(self.conv3(x)))
        
        # Flatten
        x = x.view(-1, self.flat_size)
        
        # Fully connected
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

# Create the model
model = ECG_CNN()
print("Model created successfully!")
print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")

## 4. Train the Model

Following the documentation training procedure (lines 115-127).

In [None]:
# Convert to PyTorch tensors and prepare data loaders
# PyTorch expects [batch, channels, time] format
X_train_pt = torch.tensor(X_train.transpose(0, 2, 1), dtype=torch.float32)
y_train_pt = torch.tensor(y_train, dtype=torch.long)
X_test_pt = torch.tensor(X_test.transpose(0, 2, 1), dtype=torch.float32)
y_test_pt = torch.tensor(y_test, dtype=torch.long)

train_dataset = TensorDataset(X_train_pt, y_train_pt)
test_dataset = TensorDataset(X_test_pt, y_test_pt)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

# Initialize loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

print("Training the model...")
print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

In [None]:
# Training loop
epochs = 10
train_losses = []
train_accuracies = []
val_accuracies = []

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()
    
    # Validation
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss/len(train_loader)
    epoch_acc = correct_train/total_train
    val_acc = correct/total
    
    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_acc)
    val_accuracies.append(val_acc)
    
    print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}, Val Acc: {val_acc:.4f}')

print(f'\n✅ Training completed!')
print(f'Test accuracy: {val_acc:.4f}')

# Plot training history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Training Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

## 5. Prepare Sample for XAI Analysis

Following the documentation approach (lines 129-149), we'll prepare a sample for explanation.

In [None]:
# Load the model and sample
ecg_sample = X_test[0, :, 0]  # First test sample, remove channel dimension for visualization

# Prepare input with batch dimension - PyTorch expects [batch, channels, time] format
x = torch.tensor(X_test[0:1].transpose(0, 2, 1), dtype=torch.float32)

print(f"Sample shape for XAI: {x.shape}")
print(f"Sample shape for visualization: {ecg_sample.shape}")

# Get prediction
model.eval()
with torch.no_grad():
    preds = model(x)
    predicted_class = torch.argmax(preds, dim=1).item()
    probabilities = torch.nn.functional.softmax(preds, dim=1)
    
class_names = ['Normal', 'Abnormal']

print(f"Predicted class: {predicted_class} ({class_names[predicted_class]})")
print(f"Confidence: {probabilities[0, predicted_class]:.4f}")

# Visualize the sample
plt.figure(figsize=(15, 4))
plt.plot(ecg_sample)
plt.title(f'ECG Sample for XAI Analysis\nPredicted: {class_names[predicted_class]} (confidence: {probabilities[0, predicted_class]:.3f})')
plt.xlabel('Time (samples)')
plt.ylabel('Amplitude')
plt.grid(True)
plt.show()

## 6. Generate Explanations with Multiple Methods

Following the documentation structure (lines 150-179), we'll calculate explanations with different methods including Grad-CAM.

In [None]:
# Calculate explanations with different methods
methods = [
    'gradient',
    'gradient_x_input',
    'integrated_gradients',
    'grad_cam',  # Works for time series too
    'lrp_z',
    'lrp_epsilon_0_1',
    'lrpsign_z'  # The SIGN method
]

explanations = {}
print("Calculating explanations...")

for method in methods:
    try:
        print(f"  Computing: {method}")
        if method == 'grad_cam':
            # For PyTorch grad_cam, we need to specify the target layer
            explanations[method] = explain(
                model=model,
                x=x,
                method_name=method,
                target_class=predicted_class,
                layer_name='conv3'  # Target the last conv layer
            )
        else:
            explanations[method] = explain(
                model=model,
                x=x,
                method_name=method,
                target_class=predicted_class
            )
        print(f"    ✅ Success")
    except Exception as e:
        print(f"    ❌ Failed: {e}")
        # Create dummy explanation for visualization
        explanations[method] = torch.zeros_like(x)

print(f"\n✅ Generated explanations for {len(explanations)} methods")

## 7. Visualize Separate Explanations

Following the documentation approach (lines 180-201), we'll create separate plots for each method.

In [None]:
# Visualize explanations
fig, axs = plt.subplots(len(methods) + 1, 1, figsize=(15, 3*(len(methods) + 1)))

# Original signal
axs[0].plot(ecg_sample)
axs[0].set_title('Original ECG Signal')
axs[0].set_ylabel('Amplitude')
axs[0].grid(True)

# Explanations
for i, method in enumerate(methods):
    if method in explanations:
        # Reshape explanation to 1D (PyTorch format is [batch, channel, time])
        if isinstance(explanations[method], torch.Tensor):
            expl = explanations[method].detach().cpu().numpy()[0, 0, :]
        else:
            expl = explanations[method][0, 0, :]
        
        # Plot explanation
        axs[i+1].plot(expl)
        axs[i+1].set_title(f'Method: {method}')
        axs[i+1].set_ylabel('Attribution')
        axs[i+1].grid(True)
    else:
        axs[i+1].text(0.5, 0.5, 'Method failed', transform=axs[i+1].transAxes, ha='center')

plt.tight_layout()
plt.show()

## 8. Overlay Visualizations

Following the documentation approach (lines 202-222), we'll create overlay visualizations.

In [None]:
# Alternative visualization: Overlay explanation on signal
plt.figure(figsize=(15, 10))

for i, method in enumerate(methods):
    if method not in explanations:
        continue
        
    plt.subplot(len(methods), 1, i+1)
    
    # Original signal
    plt.plot(ecg_sample, 'gray', alpha=0.5, label='ECG Signal')
    
    # Explanation
    if isinstance(explanations[method], torch.Tensor):
        expl = explanations[method].detach().cpu().numpy()[0, 0, :]
    else:
        expl = explanations[method][0, 0, :]
        
    expl_norm = (expl - expl.min()) / (expl.max() - expl.min()) if expl.max() > expl.min() else expl
    plt.plot(expl_norm, 'r', label='Attribution')
    
    plt.title(f'Method: {method}')
    plt.legend()
    plt.grid(True)

plt.tight_layout()
plt.show()

## 9. Advanced ECG Component Analysis

Following the documentation approach for advanced analysis, we'll analyze specific ECG components.

In [None]:
# Define characteristic ECG components (these would be expert-identified in real applications)
p_wave_region = slice(350, 370)
qrs_complex_region = slice(400, 420)
t_wave_region = slice(450, 480)

# Calculate the mean attribution for each region using LRP-SIGN method
if 'lrpsign_z' in explanations:
    if isinstance(explanations['lrpsign_z'], torch.Tensor):
        lrpsign_expl = explanations['lrpsign_z'].detach().cpu().numpy()[0, 0, :]
    else:
        lrpsign_expl = explanations['lrpsign_z'][0, 0, :]
    
    p_wave_attr = np.mean(np.abs(lrpsign_expl[p_wave_region]))
    qrs_complex_attr = np.mean(np.abs(lrpsign_expl[qrs_complex_region]))
    t_wave_attr = np.mean(np.abs(lrpsign_expl[t_wave_region]))
    
    # Visualize with region highlighting
    plt.figure(figsize=(15, 6))
    
    # Plot original ECG
    plt.subplot(2, 1, 1)
    plt.plot(ecg_sample)
    
    # Highlight ECG components
    plt.axvspan(350, 370, color='blue', alpha=0.2, label='P-wave')
    plt.axvspan(400, 420, color='red', alpha=0.2, label='QRS Complex')
    plt.axvspan(450, 480, color='green', alpha=0.2, label='T-wave')
    
    plt.title('ECG Signal with Components')
    plt.legend()
    plt.grid(True)
    
    # Plot explanation with component attribution
    plt.subplot(2, 1, 2)
    plt.plot(lrpsign_expl)
    
    # Highlight attribution in ECG components
    plt.axvspan(350, 370, color='blue', alpha=0.2)
    plt.axvspan(400, 420, color='red', alpha=0.2)
    plt.axvspan(450, 480, color='green', alpha=0.2)
    
    # Add component attribution values
    plt.text(360, max(lrpsign_expl), f'P-wave: {p_wave_attr:.4f}', 
             horizontalalignment='center', backgroundcolor='white')
    plt.text(410, max(lrpsign_expl), f'QRS: {qrs_complex_attr:.4f}', 
             horizontalalignment='center', backgroundcolor='white')
    plt.text(465, max(lrpsign_expl), f'T-wave: {t_wave_attr:.4f}', 
             horizontalalignment='center', backgroundcolor='white')
    
    plt.title('LRP-SIGN Attribution')
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\n🔍 ECG Component Analysis:")
    print(f"  P-wave attribution: {p_wave_attr:.4f}")
    print(f"  QRS complex attribution: {qrs_complex_attr:.4f}")
    print(f"  T-wave attribution: {t_wave_attr:.4f}")
else:
    print("⚠️ LRP-SIGN method not available for component analysis")

In [None]:
# Try to load real ECG data and create 12-lead visualization
try:
    # Load real ECG data from repository
    record_id = '03509_hr'
    ecg_src_dir = os.path.join(project_root, 'examples', 'data', 'timeseries', '')
    
    print(f"Loading ECG data for record: {record_id}...")
    ecg_data = load_and_preprocess_ecg(
        record_id=record_id,
        src_dir=ecg_src_dir,
        ecg_filters=['BWR', 'BLA', 'AC50Hz', 'LP40Hz'],
        subsampling_window_size=3000,
        subsample_start=0
    )
    
    if ecg_data is not None:
        print(f"✅ ECG data loaded: {ecg_data.shape}")
        
        # Use single lead for model prediction
        ecg_single_lead = ecg_data[:, 0:1]  # Shape: (3000, 1)
        input_tensor = torch.from_numpy(ecg_single_lead).float().permute(1, 0).unsqueeze(0)
        
        # Get explanation for real ECG
        with torch.no_grad():
            output = model(input_tensor)
        predicted_idx = torch.argmax(output, dim=1)
        
        # Calculate explanation with one method
        explanation = explain(
            model,
            input_tensor,
            method_name="gradient_x_input",
            target_class=predicted_idx.item()
        )
        
        # Process for visualization
        if isinstance(explanation, torch.Tensor):
            explanation_np = explanation.detach().cpu().numpy()
        else:
            explanation_np = explanation
        
        # Handle shape and expand to 12 leads
        if explanation_np.ndim == 3:
            relevance_map = explanation_np[0].transpose()  # (1, 1, 3000) -> (3000, 1)
        else:
            relevance_map = explanation_np.reshape(-1, 1)
        
        # Expand to 12 leads for visualization
        if relevance_map.shape[1] == 1 and ecg_data.shape[1] == 12:
            relevance_map = np.tile(relevance_map, (1, 12))
        
        # Normalize
        normalized_relevance = normalize_ecg_relevancemap(relevance_map)
        
        # Format for visualization
        ecg_for_visual = ecg_data.transpose()
        expl_for_visual = normalized_relevance.transpose()
        
        # Create 12-lead visualization
        plot_ecg(
            ecg=ecg_for_visual,
            explanation=expl_for_visual,
            sampling_rate=500,
            title=f"PyTorch XAI: gradient_x_input on {record_id}",
            show_colorbar=True,
            cmap='seismic',
            bubble_size=30,
            line_width=1.0,
            style='fancy',
            save_to=None,
            clim_min=-1,
            clim_max=1,
            colorbar_label='Relevance',
            shape_switch=False
        )
        
        print("\n✅ 12-lead ECG visualization created!")
    
except Exception as e:
    print(f"⚠️ Could not create 12-lead visualization: {e}")
    print("This requires the complete repository with ECG data and utilities.")

## 10. 12-Lead ECG Visualization (if ECG utilities available)

Using the utilities from the quickstart files to create professional ECG visualizations.

## 11. Method Comparison Analysis

Let's compare attribution across different methods for ECG components.

In [None]:
# Compare attribution across methods
methods_to_compare = ['gradient', 'gradient_x_input', 'lrp_z', 'lrpsign_z']
components = ['P-wave', 'QRS Complex', 'T-wave']
regions = [p_wave_region, qrs_complex_region, t_wave_region]

# Calculate attribution for each method and component
component_attribution = {}
for method in methods_to_compare:
    if method in explanations:
        if isinstance(explanations[method], torch.Tensor):
            expl = explanations[method].detach().cpu().numpy()[0, 0, :]
        else:
            expl = explanations[method][0, 0, :]
        component_attribution[method] = [np.mean(np.abs(expl[region])) for region in regions]
    else:
        component_attribution[method] = [0, 0, 0]

# Visualize component attribution comparison
plt.figure(figsize=(12, 6))

x = np.arange(len(components))
width = 0.2
offsets = np.linspace(-0.3, 0.3, len(methods_to_compare))

for i, method in enumerate(methods_to_compare):
    plt.bar(x + offsets[i], component_attribution[method], width, label=method)

plt.xlabel('ECG Component')
plt.ylabel('Mean Absolute Attribution')
plt.title('Attribution Comparison Across Methods')
plt.xticks(x, components)
plt.legend()
plt.grid(True, axis='y')

plt.tight_layout()
plt.show()

print("\n📊 Method Comparison Summary:")
for method in methods_to_compare:
    if method in component_attribution:
        attrs = component_attribution[method]
        print(f"  {method:20} - P: {attrs[0]:.3f}, QRS: {attrs[1]:.3f}, T: {attrs[2]:.3f}")

## 12. Summary and Key Insights

### What we've learned:

1. **PyTorch Time Series Models**: How to build and train CNN models for ECG classification
2. **Multiple XAI Methods**: Different approaches provide different insights:
   - **Gradient**: Shows instantaneous importance
   - **Gradient × Input**: Emphasizes strong signal regions  
   - **Integrated Gradients**: Provides theoretically grounded attributions
   - **Grad-CAM**: Adapts convolutional attention for time series
   - **LRP methods**: Show layer-wise relevance propagation
   - **LRP-SIGN**: The SIGN method for enhanced attribution

3. **Visualization Techniques**: 
   - Separate plots for clear comparison
   - Overlay visualizations for intuitive understanding
   - Component-specific analysis for medical insights
   - 12-lead medical visualizations for clinical applications

4. **Time Series XAI**: Understanding temporal dependencies and pattern recognition in sequential data

### Clinical Insights:
- Most methods correctly highlight QRS complexes, which is clinically appropriate
- Different methods show varying sensitivity to P-waves and T-waves
- Component-specific analysis helps validate model focus against medical knowledge

### Next Steps:
- Try different ECG records to see consistency across samples
- Experiment with different model architectures
- Compare with TensorFlow implementation
- Apply to your own time series classification problems

### Method Selection Guide:
- **`gradient`**: Fast, good for real-time analysis
- **`gradient_x_input`**: Better for highlighting strong features
- **`integrated_gradients`**: Most theoretically sound, but slower
- **`grad_cam`**: Good for understanding convolutional focus
- **`lrp_*`**: Good for layer-wise understanding
- **`lrpsign_z`**: Enhanced attribution with sign information

This tutorial demonstrates how SignXAI2 can provide valuable insights into time series model decisions, particularly important in medical applications where understanding AI reasoning is critical for clinical adoption.