# ECG-RAMBA: Demo Inference Notebook

This notebook demonstrates end-to-end inference using a pre-trained ECG-RAMBA model.

**Pipeline:**
1. Load a sample ECG signal
2. Preprocess and extract features (MiniRocket + HRV)
3. Load the pre-trained model
4. Run inference
5. Visualize results

## 1. Setup & Imports

In [None]:
import sys
import os
import glob
import numpy as np
import torch
import matplotlib.pyplot as plt

# Add project root to path
sys.path.append(os.path.abspath(".."))

from configs.config import CONFIG, PATHS, DEVICE, CLASSES
from src.model import ECGRambaV7Advanced
from src.features import MiniRocketNative, extract_hrv_features, extract_amplitude_features, extract_global_record_stats
from src.utils import normalize_signal

print(f"Device: {DEVICE}")
print(f"Classes: {len(CLASSES)}")

## 2. Load Sample ECG Signal

We'll create a synthetic ECG signal for demonstration. In practice, replace this with actual data loaded via `wfdb` or from a `.mat` file.

In [None]:
# Create synthetic ECG (12-lead, 5000 samples at 500Hz = 10 seconds)
np.random.seed(42)
SEQ_LEN = 5000
N_LEADS = 12

# Simulate a simple ECG pattern with QRS complexes
t = np.linspace(0, 10, SEQ_LEN)
ecg_signal = np.zeros((N_LEADS, SEQ_LEN))

for lead in range(N_LEADS):
    # Base signal with noise
    signal = np.sin(2 * np.pi * 1.2 * t) * 0.3  # ~72 BPM rhythm
    # Add QRS spikes
    for peak_t in np.arange(0.5, 10, 0.83):  # ~72 BPM
        idx = int(peak_t * 500)
        if idx < SEQ_LEN - 50:
            signal[idx:idx+20] += np.exp(-0.5 * ((np.arange(20) - 10) / 3) ** 2) * (1.5 + 0.2 * lead)
    ecg_signal[lead] = signal + np.random.randn(SEQ_LEN) * 0.05

# Normalize
ecg_signal = normalize_signal(ecg_signal).astype(np.float32)

print(f"ECG Signal Shape: {ecg_signal.shape}")

# Visualize Lead II
plt.figure(figsize=(14, 3))
plt.plot(t[:1000], ecg_signal[1, :1000], 'b-', linewidth=0.8)
plt.title("Lead II (First 2 seconds)")
plt.xlabel("Time (s)")
plt.ylabel("Amplitude (normalized)")
plt.grid(True, alpha=0.3)
plt.show()

## 3. Feature Extraction

Extract:
- **MiniRocket features** (20,000 dim → PCA to `hydra_dim`)
- **HRV + Amplitude + Global Stats** (36 dim)

In [None]:
# MiniRocket
rocket = MiniRocketNative(c_in=12, seq_len=SEQ_LEN, num_kernels=10000, seed=42).cpu().eval()

with torch.no_grad():
    x_tensor = torch.tensor(ecg_signal[np.newaxis, ...], dtype=torch.float32)
    rocket_feats = rocket(x_tensor).numpy()

print(f"Raw MiniRocket features: {rocket_feats.shape}")

# For demo, we skip PCA and use zero-padding to match hydra_dim
# In production, load the fitted PCA from training and apply it
hydra_dim = CONFIG['hydra_dim']
if rocket_feats.shape[1] > hydra_dim:
    hydra_feats = rocket_feats[:, :hydra_dim]
else:
    hydra_feats = np.pad(rocket_feats, ((0, 0), (0, hydra_dim - rocket_feats.shape[1])))

print(f"Hydra features (after dim match): {hydra_feats.shape}")

In [None]:
# HRV Features (36 dim)
hrv = extract_hrv_features(ecg_signal, fs=500)
amp = extract_amplitude_features(ecg_signal)
gstat = extract_global_record_stats(ecg_signal)

hrv_feats = np.concatenate([hrv, amp, gstat])[np.newaxis, ...]
print(f"HRV features: {hrv_feats.shape}")

## 4. Load Pre-trained Model

In [None]:
# Find checkpoint
model_dir = PATHS.get('model_dir', '../models')
ckpts = sorted(glob.glob(os.path.join(model_dir, 'fold*_best.pt')))

if ckpts:
    print(f"Found {len(ckpts)} checkpoints")
    ckpt_path = ckpts[0]  # Use first fold for demo
else:
    print("No checkpoints found. Using randomly initialized model for demo.")
    ckpt_path = None

# Initialize model
model = ECGRambaV7Advanced(cfg=CONFIG).to(DEVICE)

if ckpt_path:
    state = torch.load(ckpt_path, map_location=DEVICE)
    if isinstance(state, dict) and 'model' in state:
        model.load_state_dict(state['model'], strict=False)
    else:
        model.load_state_dict(state, strict=False)
    print(f"Loaded checkpoint: {ckpt_path}")

model.eval()
print(f"Model Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

## 5. Run Inference

In [None]:
# Prepare tensors
x = torch.tensor(ecg_signal[np.newaxis, ...], dtype=torch.float32, device=DEVICE)
xh = torch.tensor(hydra_feats, dtype=torch.float32, device=DEVICE)
xhr = torch.tensor(hrv_feats, dtype=torch.float32, device=DEVICE)

# Inference
with torch.no_grad():
    if DEVICE == 'cuda':
        with torch.amp.autocast('cuda'):
            logits = model(x, xh, xhr)
    else:
        logits = model(x, xh, xhr)
    probs = torch.sigmoid(logits).cpu().numpy()[0]

print(f"Output probabilities shape: {probs.shape}")

## 6. Visualize Results

In [None]:
# Get top predictions
threshold = 0.5
top_k = 5

# Sort by probability
sorted_indices = np.argsort(probs)[::-1]

print("\n" + "="*50)
print("TOP PREDICTIONS")
print("="*50)

for i in range(min(top_k, len(CLASSES))):
    idx = sorted_indices[i]
    class_name = CLASSES[idx] if idx < len(CLASSES) else f"Class_{idx}"
    prob = probs[idx]
    status = "✅" if prob >= threshold else "  "
    print(f"{status} {class_name:12s}: {prob:.4f}")

print("="*50)

# Bar chart
plt.figure(figsize=(12, 6))
top_n = 10
top_indices = sorted_indices[:top_n]
top_probs = probs[top_indices]
top_classes = [CLASSES[i] if i < len(CLASSES) else f"C{i}" for i in top_indices]

colors = ['#2ECC71' if p >= threshold else '#E74C3C' for p in top_probs]
plt.barh(range(top_n), top_probs, color=colors)
plt.yticks(range(top_n), top_classes)
plt.xlabel('Probability')
plt.title('ECG-RAMBA Predictions (Top 10 Classes)')
plt.axvline(x=threshold, color='k', linestyle='--', label=f'Threshold={threshold}')
plt.legend()
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

---

**Note:** This demo uses a synthetic signal. For real inference:
1. Load actual ECG data using `wfdb.rdrecord()` or from `.mat` files
2. Apply the PCA fitted during training (saved as `global_pca_zeroshot.pkl`)
3. Use ensemble inference across all folds for best performance