In [None]:
# ── ONNX Inference on Raspberry Pi 5 (static data, no PTB-XL needed) ─────────

import numpy as np
import time
import onnxruntime as ort

# ── Configuration ─────────────────────────────────────────────────────────────
ONNX_MODEL   = 'models/ecg_model_float.onnx'   # path to your .onnx file
THRESHOLD    = 0.5
SUPERCLASSES = ['NORM', 'MI', 'STTC', 'CD', 'HYP']
DESCRIPTIONS = {
    'NORM': 'Normal ECG',
    'MI':   'Myocardial Infarction',
    'STTC': 'ST/T-wave Change',
    'CD':   'Conduction Disturbance',
    'HYP':  'Hypertrophy',
}

# ── Static synthetic ECG signal (no dataset, no wfdb) ─────────────────────────
TARGET_CLASS = 'MI'   # change to: NORM / MI / STTC / CD / HYP

def make_synthetic_ecg(target_class='MI', seed=42):
    """Generates a normalized 12-lead synthetic ECG. Returns (12, 1000) float32."""
    np.random.seed(seed)
    n, fs = 1000, 100.0
    t = np.linspace(0, 10, n)
    signal = np.zeros((12, n), dtype=np.float32)
    hr_map = {'NORM': 72, 'MI': 85, 'STTC': 78, 'CD': 55, 'HYP': 68}
    f_hr = hr_map.get(target_class, 72) / 60.0
    lead_scales = [1.0, 1.4, 0.6, -0.5, 0.3, 0.8, 0.4, 0.9, 1.1, 1.2, 1.0, 0.7]

    for li in range(12):
        sc = lead_scales[li]
        beat = np.zeros(n)
        for bt in np.arange(0, 10, 1.0 / f_hr):
            b = int(bt * fs)
            if b >= n: break
            # P wave
            pc = b - 15
            if 0 <= pc < n:
                rng = np.arange(max(0, pc-8), min(n, pc+8))
                beat[rng] += sc * 0.15 * np.exp(-((rng - pc) ** 2) / 18.0)
            # QRS
            rng = np.arange(max(0, b-5), min(n, b+6))
            beat[rng] += sc * 1.0 * np.exp(-((rng - b) ** 2) / 3.0)
            if 0 <= b-4 < n: beat[b-4] -= sc * 0.2   # Q
            if 0 <= b+4 < n: beat[b+4] -= sc * 0.25  # S
            # T wave
            tc = b + 18
            if 0 <= tc < n:
                rng = np.arange(max(0, tc-12), min(n, tc+12))
                t_amp = 0.6 if target_class == 'STTC' else (
                        -0.2 if target_class == 'MI' and li in [0,1,6,7] else 0.35)
                beat[rng] += sc * t_amp * np.exp(-((rng - tc) ** 2) / 50.0)
            # CD: wide QRS
            if target_class == 'CD':
                rng = np.arange(max(0, b-10), min(n, b+11))
                beat[rng] += sc * 0.3 * np.exp(-((rng - b) ** 2) / 20.0)
            # HYP: tall precordial
            if target_class == 'HYP' and li in [9, 10, 11]:
                rng = np.arange(max(0, b-5), min(n, b+6))
                beat[rng] += sc * 0.8 * np.exp(-((rng - b) ** 2) / 3.0)

        noise  = np.random.randn(n).astype(np.float32) * 0.03
        wander = (0.05 * np.sin(2 * np.pi * 0.15 * t)).astype(np.float32)
        signal[li] = beat + noise + wander

    # Z-score normalize per lead
    for i in range(12):
        mu, sigma = signal[i].mean(), signal[i].std()
        signal[i] = (signal[i] - mu) / (sigma + 1e-8)
    return signal.astype(np.float32)

seed_map = {'NORM': 1, 'MI': 2, 'STTC': 3, 'CD': 4, 'HYP': 5}
ecg_signal = make_synthetic_ecg(target_class=TARGET_CLASS, seed=seed_map[TARGET_CLASS])
print(f'Signal ready: shape={ecg_signal.shape}  class={TARGET_CLASS}')

# ── Load ONNX model ────────────────────────────────────────────────────────────
opts = ort.SessionOptions()
opts.intra_op_num_threads = 4                                      # RPi 5: 4 cores
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL

session    = ort.InferenceSession(ONNX_MODEL, opts, providers=['CPUExecutionProvider'])
input_name = session.get_inputs()[0].name
print(f'Model loaded : {ONNX_MODEL}')
print(f'Input name   : {input_name}  {session.get_inputs()[0].shape}')

# ── Warmup (important for accurate latency on RPi) ────────────────────────────
inp = ecg_signal[np.newaxis].astype(np.float32)   # (1, 12, 1000)
for _ in range(5):
    session.run(None, {input_name: inp})

# ── Timed inference ───────────────────────────────────────────────────────────
N_RUNS = 50
times  = []
for _ in range(N_RUNS):
    t0     = time.perf_counter()
    output = session.run(None, {input_name: inp})
    times.append((time.perf_counter() - t0) * 1000)

logits = output[0][0]                              # (5,)
probs  = (1 / (1 + np.exp(-logits))).astype(np.float32)
preds  = (probs >= THRESHOLD).astype(int)
times  = np.array(times)

# ── Print results ─────────────────────────────────────────────────────────────
print()
print('┌' + '─'*57 + '┐')
print('│       ONNX INFERENCE — RASPBERRY PI 5 RESULT          │')
print('├────────┬──────────────────────────┬───────────┬───────┤')
print('│ Class  │ Description              │   Prob    │ Pred  │')
print('├────────┼──────────────────────────┼───────────┼───────┤')
for cls, prob, pred in zip(SUPERCLASSES, probs, preds):
    bar    = '█' * int(prob * 15) + '░' * (15 - int(prob * 15))
    flag   = ' DETECTED' if pred else '         '
    print(f'│ {cls:<6} │ {DESCRIPTIONS[cls]:<24} │ {prob:.4f}    │{flag}│')
print('├────────┴──────────────────────────┴───────────┴───────┤')

detected = [c for c, p in zip(SUPERCLASSES, preds) if p]
result_str = ', '.join(detected) if detected else 'No abnormality detected'
print(f'│  Result : {result_str:<46}│')
print('├─────────────────────────────────────────────────────────┤')
print(f'│  Latency  mean={times.mean():>6.2f}ms   p50={np.median(times):>6.2f}ms            │')
print(f'│           p95 ={np.percentile(times,95):>6.2f}ms   min={times.min():>6.2f}ms            │')
edge_ok = np.percentile(times, 95) < 200
print(f'│  Edge target (<200ms p95): {"PASS ✓" if edge_ok else "FAIL ✗"}                   │')
print('└─────────────────────────────────────────────────────────┘')