# TCAV (Testing with Concept Activation Vectors) for ECG Models

TCAV is an interpretability (XAI) method that quantifies how much a deep learning model utilizes specific **concepts**.

## Core Idea
- Measure how much user-defined concepts (e.g., "atrial fibrillation", "normal sinus rhythm") influence model predictions
- Compare with random concepts to verify statistical significance
- Analyze concept importance across different layers

## Contents of This Notebook
1. Model loading and setup
2. TCAV initialization (concept dataset generation)
3. TCAV score computation
4. Results visualization

## 0. Setup

In [None]:
import os
import sys
import json
import numpy as np
import torch

# Path setup
sys.path.append("../")

from samples.models import get_model, MODEL_REGISTRY

## 1. Model Load

In [None]:
# Load AFib classification model
model = get_model(
    name="afib_binary",
    model_dir="../tmp/models/afib_binary/",
    download=True,
)

print(f"Model loaded: {model.__class__.__name__}")
print(f"Device: {next(model.parameters()).device}")

In [None]:
# Check model layer names for TCAV
from execg.models.wrapper import TorchModelWrapper

temp_wrapper = TorchModelWrapper(model)
layer_names = temp_wrapper.get_layer_names()

# Find conv layers (recommended for TCAV)
conv_layers = [name for name in layer_names if 'conv' in name.lower()]
print(f"Total layers: {len(layer_names)}")
print(f"\nConv layers (last 10):")
for layer in conv_layers[-10:]:
    print(f"  - {layer}")

## 2. TCAV Initialization

During TCAV initialization, configure the following:
- `target_concepts`: Concepts to analyze (based on PhysioNet2021 labels)
- `target_layers`: Layers to compute TCAV scores
- `num_random_concepts`: Number of random concepts for statistical testing
- `n_samples`: Number of samples per concept
- `download`: If True, automatically download concept data when not found

### Data Download
When `download=True`, the PhysioNet2021 numpy data will be automatically downloaded:
- **Download size**: ~5 GB (compressed tar.gz)
- **Extracted size**: ~10 GB
- **Total space needed**: ~10 GB (archive is removed after extraction)

### Available Concepts (PhysioNet2021)
```
- atrial fibrillation
- atrial flutter
- bundle branch block
- bradycardia
- 1st degree av block
- sinus rhythm
- sinus bradycardia
- sinus tachycardia
- t wave abnormal
- qwave abnormal
... and more
```

In [None]:
from execg.concept.tcav import TCAV

# Data path (PhysioNet2021 numpy data will be downloaded here if not found)
data_dir = "../data/physionet2021/"

# Target concepts to analyze
target_concepts = [
    "atrial fibrillation",
    "sinus rhythm",
    "t wave abnormal",
]

# Target layers for TCAV analysis
# Typically, later conv layers capture more high-level features
target_layers = [
    "blk1d.1.2.conv2",  # Mid-level features
    "blk1d.2.2.conv2",  # High-level features (last conv)
]

# Initialize TCAV
# Set download=True to automatically download concept data (~5GB compressed, ~10GB extracted)
tcav = TCAV(
    model=model,
    target_layers=target_layers,
    sampling_rate=250,  # Model input sampling rate
    duration=10,        # Model input duration (seconds)
    data_name="physionet2021",
    data_dir=data_dir,
    target_concepts=target_concepts,
    num_random_concepts=10,       # Number of random concepts for statistical test
    n_samples=100,                # Samples per concept (reduce for faster testing)
    random_seed=42,
    download=True,  # Download data if not found (requires ~10GB disk space)
)

print("\nTCAV initialized successfully!")
print(f"Target concepts: {list(tcav.target_concept_dict.keys())}")
print(f"Random concepts: {len(tcav.random_concept_dict)}")
print(f"Target layers: {target_layers}")

## 3. Prepare Test Input

Prepare test input data for running TCAV.

In [None]:
# Load sample ECG data
try:
    with open("../samples/data/sample.json", "r") as f:
        sample_data = json.load(f)
    ecg_signal = np.array(list(sample_data["data"].values()))
    print(f"Loaded sample ECG data: shape = {ecg_signal.shape}")
except:
    # Generate synthetic data if sample not available
    n_leads = 12
    seq_length = 2500  # 10sec @ 250Hz
    ecg_signal = np.random.randn(n_leads, seq_length).astype(np.float32)
    print(f"Generated synthetic ECG data: shape = {ecg_signal.shape}")

# Convert to tensor and prepare for TCAV
# TCAV expects batch input: (batch_size, n_leads, seq_length)
ecg_tensor = torch.tensor(ecg_signal, dtype=torch.float32)

# Resample if needed (sample data is 500Hz, model expects 250Hz)
if ecg_tensor.shape[-1] == 5000:
    ecg_tensor = ecg_tensor[:, ::2]  # 500Hz -> 250Hz

# Add batch dimension
ecg_tensor = ecg_tensor.unsqueeze(0)  # (1, 12, 2500)

print(f"Input tensor shape: {ecg_tensor.shape}")
print(f"Expected shape: (batch_size, 12, 2500)")

In [None]:
# Check model prediction for the input
with torch.no_grad():
    prediction = model(ecg_tensor)
    if isinstance(prediction, tuple):
        prediction = prediction[0]

print(f"Model output shape: {prediction.shape}")
print(f"Model output: {prediction.squeeze().cpu().numpy()}")

if prediction.shape[-1] > 1:
    print(f"Predicted class: {prediction.argmax(dim=-1).item()}")

## 4. Run TCAV Interpretation

TCAV measures how much each concept influences the prediction of a specific class.

- **TCAV score > 0.5**: The concept has positive influence on target class prediction
- **TCAV score < 0.5**: The concept has negative influence on target class prediction
- **TCAV score ≈ 0.5**: The concept has no significant influence on prediction

In [None]:
# Run TCAV interpretation
# target: class index to analyze (0 for binary, or specific class for multi-class)
target_class = 1

print(f"Running TCAV for target class: {target_class}")
print("This may take a few minutes...\n")

results = tcav.explain(
    inputs=ecg_tensor,
    target=target_class,
    n_steps=50,              # Integration steps (higher = more accurate)
    score_type="sign_count"  # "sign_count" or "magnitude"
)

print("TCAV computation completed!")

In [None]:
# Print TCAV results
print("=" * 60)
print("TCAV Results")
print("=" * 60)

for concept_name, layer_results in results.items():
    print(f"\nConcept: {concept_name}")
    print("-" * 40)
    
    for layer_name, (mean_score, (ci_lower, ci_upper)) in layer_results.items():
        significance = ""
        if ci_lower > 0.5:
            significance = " [Positive influence]"
        elif ci_upper < 0.5:
            significance = " [Negative influence]"
        else:
            significance = " [Not significant]"
        
        print(f"  {layer_name}:")
        print(f"    Score: {mean_score:.3f} (95% CI: [{ci_lower:.3f}, {ci_upper:.3f}]){significance}")

## 5. Visualization

Visualize TCAV results as a heatmap.

In [None]:
import matplotlib.pyplot as plt
from execg.visualizer import plot_tcav_scores

# Plot TCAV scores as heatmap
fig, ax = plot_tcav_scores(results, cmap="RdYlGn")

ax.set_title(f"TCAV Scores (Target Class: {target_class})")
ax.set_xlabel("Concepts")
ax.set_ylabel("Layers")

plt.tight_layout()
plt.show()

In [None]:
# Bar plot for easier comparison using the visualization module
from execg.visualizer import plot_tcav_comparison

# Plot TCAV scores comparison across layers
fig, axes = plot_tcav_comparison(results, layers=target_layers)
plt.suptitle(f'TCAV Scores by Layer (Target Class: {target_class})', fontsize=14)
plt.tight_layout()
plt.show()

## 6. Results Interpretation

### TCAV Score Interpretation

| TCAV Score | Interpretation |
|------------|----------------|
| > 0.5 (entire 95% CI above 0.5) | The concept has **significantly positive** influence on target class prediction |
| < 0.5 (entire 95% CI below 0.5) | The concept has **significantly negative** influence on target class prediction |
| ≈ 0.5 (95% CI includes 0.5) | The concept has **no significant** influence on prediction |

### Example Interpretations
- **"atrial fibrillation" concept TCAV > 0.5**: When the model recognizes atrial fibrillation features, it's more likely to predict that class
- **"sinus rhythm" concept TCAV < 0.5**: When normal sinus rhythm features are present, the model is less likely to predict that class

### Layer-wise Differences
- **Early layers**: Low-level features (basic waveform shapes)
- **Later layers**: High-level features (disease-specific patterns)

## 7. Advanced: Multiple Classes Analysis

In [None]:
# Uncomment and run for multi-class analysis
# This will take longer but provides comprehensive results

# num_classes = prediction.shape[-1] if len(prediction.shape) > 1 else 1
# 
# if num_classes > 1:
#     all_results = {}
#     
#     for class_idx in range(num_classes):
#         print(f"\nAnalyzing class {class_idx}...")
#         class_results = tcav.explain(
#             inputs=ecg_tensor,
#             target=class_idx,
#             n_steps=50,
#             score_type="sign_count"
#         )
#         all_results[f"Class {class_idx}"] = class_results
#     
#     print("\nMulti-class analysis completed!")

## Summary

In this notebook, we performed concept-based interpretation of ECG models using TCAV.

### Key Takeaways
1. **TCAV** quantifies the influence of user-defined concepts on model predictions
2. Statistical significance is verified through comparison with **random concepts**
3. Changes in concept importance can be observed across **different layers**
4. In the ECG domain, **medically meaningful concepts** can be analyzed

### Next Steps
- Add more concepts for analysis
- Compare results across different layers
- Compute TCAV for various input samples