# Lead-Minimal ECG Diagnosis â€” Demo

This notebook walks through the core functionality:
1. Loading and visualizing PTB-XL data
2. Lead subset extraction
3. Model inference
4. Lead-Robustness Score (LRS) computation

In [None]:
import sys
sys.path.append('../src')

import numpy as np
import matplotlib.pyplot as plt
import torch
from pathlib import Path

from dataset import PTBXLDataset, LEAD_NAMES, LEAD_TO_IDX
from model import get_model
from metrics import compute_lead_robustness_score

plt.style.use('seaborn-v0_8-whitegrid')
%matplotlib inline

## 1. Load PTB-XL Data

In [None]:
# Load test set (fold 10)
dataset = PTBXLDataset(
    h5_path='../data/processed/ptbxl_processed.h5',
    leads='all',
    folds=[10],
    augment=False
)

print(f"Test samples: {len(dataset)}")
print(f"Classes: {dataset.classes}")
print(f"Leads: {dataset.lead_names}")

## 2. Visualize a 12-Lead ECG

In [None]:
def plot_12lead_ecg(signal, title="12-Lead ECG"):
    """Plot all 12 leads of an ECG."""
    fig, axes = plt.subplots(6, 2, figsize=(14, 12))
    axes = axes.flatten()
    
    time = np.arange(signal.shape[1]) / 100  # 100 Hz sampling
    
    for i, lead in enumerate(LEAD_NAMES):
        axes[i].plot(time, signal[i], 'b-', linewidth=0.8)
        axes[i].set_title(lead, fontsize=12)
        axes[i].set_ylabel('mV (norm)')
        axes[i].set_xlim([0, 10])
        axes[i].grid(True, alpha=0.3)
    
    axes[-2].set_xlabel('Time (s)')
    axes[-1].set_xlabel('Time (s)')
    
    plt.suptitle(title, fontsize=14, y=1.02)
    plt.tight_layout()
    plt.show()

# Get a sample
signal, label = dataset[0]
print(f"Signal shape: {signal.shape}")
print(f"Label: {label}")
print(f"Diagnoses: {[dataset.classes[i] for i, l in enumerate(label) if l == 1]}")

plot_12lead_ecg(signal.numpy(), "Sample 12-Lead ECG")

## 3. Compare Lead Subsets

In [None]:
def plot_lead_comparison(signal, subsets):
    """Compare different lead subsets visually."""
    n_subsets = len(subsets)
    fig, axes = plt.subplots(n_subsets, 1, figsize=(14, 3 * n_subsets))
    
    if n_subsets == 1:
        axes = [axes]
    
    time = np.arange(signal.shape[1]) / 100
    
    for ax, (name, leads) in zip(axes, subsets.items()):
        for lead in leads:
            idx = LEAD_TO_IDX[lead]
            ax.plot(time, signal[idx], label=lead, linewidth=1)
        
        ax.set_title(f"{name}: {', '.join(leads)}")
        ax.set_ylabel('mV (norm)')
        ax.legend(loc='upper right', ncol=len(leads))
        ax.set_xlim([0, 10])
        ax.grid(True, alpha=0.3)
    
    axes[-1].set_xlabel('Time (s)')
    plt.tight_layout()
    plt.show()

# Compare common lead subsets
subsets = {
    "Single Lead (II)": ["II"],
    "2-Lead (I, II)": ["I", "II"],
    "3-Lead (I, II, V2)": ["I", "II", "V2"],
    "6-Lead (Limb)": ["I", "II", "III", "aVR", "aVL", "aVF"]
}

plot_lead_comparison(signal.numpy(), subsets)

## 4. Model Architecture

In [None]:
# Create models for different lead counts
model_12lead = get_model("resnet1d", n_leads=12, n_classes=5)
model_2lead = get_model("resnet1d", n_leads=2, n_classes=5)
model_1lead = get_model("resnet1d", n_leads=1, n_classes=5)

print("\nParameter counts (same architecture, different input channels):")
print(f"  12-lead: {sum(p.numel() for p in model_12lead.parameters()):,}")
print(f"  2-lead:  {sum(p.numel() for p in model_2lead.parameters()):,}")
print(f"  1-lead:  {sum(p.numel() for p in model_1lead.parameters()):,}")

## 5. Lead-Robustness Score (LRS)

In [None]:
# Simulated results (replace with actual results after training)
simulated_results = {
    "12-lead": {"auroc": 0.92, "brier": 0.08, "n_leads": 12},
    "6-lead":  {"auroc": 0.89, "brier": 0.09, "n_leads": 6},
    "3-lead":  {"auroc": 0.86, "brier": 0.11, "n_leads": 3},
    "2-lead":  {"auroc": 0.83, "brier": 0.12, "n_leads": 2},
    "1-lead":  {"auroc": 0.78, "brier": 0.14, "n_leads": 1},
}

baseline = simulated_results["12-lead"]

print("Lead-Robustness Score (LRS) Analysis")
print("=" * 50)
print(f"{'Config':<12} {'AUROC':<8} {'Retain%':<10} {'LRS':<8}")
print("-" * 50)

for name, res in simulated_results.items():
    lrs = compute_lead_robustness_score(
        baseline_auroc=baseline["auroc"],
        subset_auroc=res["auroc"],
        baseline_brier=baseline["brier"],
        subset_brier=res["brier"],
        n_leads_baseline=12,
        n_leads_subset=res["n_leads"]
    )
    
    retention = res["auroc"] / baseline["auroc"] * 100
    print(f"{name:<12} {res['auroc']:.3f}    {retention:.1f}%      {lrs:.3f}")

## 6. Load and Run Trained Model

Uncomment after training to test inference on real models.

In [None]:
# Uncomment after training models:
"""
model_path = Path('../outputs/models/')
model_dirs = list(model_path.glob('resnet1d_all_*'))

if model_dirs:
    latest_model = sorted(model_dirs)[-1]
    print(f"Loading model from: {latest_model}")
    
    checkpoint = torch.load(latest_model / 'best_model.pt')
    model = get_model("resnet1d", n_leads=12, n_classes=5)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # Inference on single sample
    with torch.no_grad():
        signal_batch = signal.unsqueeze(0)
        logits = model(signal_batch)
        probs = torch.sigmoid(logits)
    
    print("Predictions:")
    for i, cls in enumerate(dataset.classes):
        print(f"  {cls}: {probs[0, i].item():.3f}")
"""
print("Run training first, then uncomment this cell to test inference.")

## Summary

This notebook demonstrated:
- How 12-lead ECGs can be reduced to fewer leads
- The architecture adapts automatically to different lead counts  
- LRS quantifies how well performance is retained under lead reduction

**Next steps:**
1. Run `python run_experiments.py --full` to train all configurations
2. Run `python src/evaluate.py` to compute LRS
3. Run `python src/visualize.py` to generate figures