In [1]:
import time
import os
import json
import cv2
import yaml
import numpy as np
from PIL import Image

from scipy.spatial.distance import euclidean
from scipy.stats import pearsonr
from fastdtw import fastdtw
from sklearn.metrics import mean_squared_error

from scripts.grid_detection import get_grid_square_size
from scripts.extract_wave_tflite import WaveExtractor
from scripts.digititze import process_ecg_mask
from scripts.lead_segmentation_tflite import init_model as init_lead_model, inference_and_label_and_crop


2025-07-29 14:52:26.254779: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753782747.295986   61387 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753782747.521055   61387 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1753782749.591810   61387 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1753782749.591918   61387 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1753782749.591931   61387 computation_placer.cc:177] computation placer alr

In [2]:
# --- Load configs ---
with open('./configs/lead_segmentation.yaml', 'r') as f:
    lead_cfg = yaml.safe_load(f)
with open('./configs/wave_extraction.yaml', 'r') as f:
    wave_cfg = yaml.safe_load(f)
with open('./configs/grid_detection.yaml', 'r') as f:
    grid_cfg = yaml.safe_load(f)
with open('./configs/digitize.yaml', 'r') as f:
    digitize_cfg = yaml.safe_load(f)

# --- Paths ---
INPUT_ROOT = '../data/digitization-dataset/digitization-dataset'
CROPPED_SAVE_DIR = lead_cfg['output_dir']
GRID_KERNEL = grid_cfg.get('closing_kernel', 10)
GRID_LENGTH_FRAC = grid_cfg.get('length_frac', 0.05)
WAVE_WEIGHTS_PATH = wave_cfg['weights_path']
WAVE_DEVICE = wave_cfg.get('device', 'cpu')
FINAL_OUTPUT_DIR = './data/test'
YOLO_WEIGHTS_PATH = lead_cfg['model_path']
os.makedirs(CROPPED_SAVE_DIR, exist_ok=True)
os.makedirs(FINAL_OUTPUT_DIR, exist_ok=True)

In [3]:

# --- Init models once ---
wave_extractor = WaveExtractor(WAVE_WEIGHTS_PATH, device=WAVE_DEVICE)
lead_model     = init_lead_model(YOLO_WEIGHTS_PATH)


initializing wave extractor...
Wave extractor initialized.


    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


In [4]:

# --- Evaluation helpers ---
def match_length(a, b):
    n = min(len(a), len(b))
    return a[:n], b[:n]

def scalar_euclid(a, b):
    # convert floats to 1‑element 1D arrays
    return euclidean([a], [b])


def evaluate_signals(pred, gt):
    p = np.ravel(pred).astype(float)
    g = np.ravel(gt).astype(float)
    p, g = match_length(p, g)
    # MSE & Pearson
    mse  = mean_squared_error(g, p)
    corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
    # DTW on plain lists
    g = g.tolist()
    p = p.tolist()

    dtw_dist, _ = fastdtw(g, p, dist=scalar_euclid)
    return {"MSE": mse, "Correlation": corr, "DTW": dtw_dist}


In [5]:
# --- Main evaluation loop ---
lead_names = ["I","II","III","aVR","aVL","aVF","V1","V2","V3","V4","V5","V6"]
folders = [d for d in os.listdir(INPUT_ROOT) if '6by2' in d and os.path.isdir(os.path.join(INPUT_ROOT,d))]
total_folders = len(folders)
total_slots   = total_folders * len(lead_names)
detected_count= 0
per_lead_detect = {ln:0 for ln in lead_names}
metrics_detected = []
metrics_all      = []

start = time.time()
for idx, fld in enumerate(folders,1):
    if idx%10==0:
        print(f"{idx}/{total_folders} folders processed…")
    fld_path = os.path.join(INPUT_ROOT, fld)
    img_path = os.path.join(fld_path, f"{fld}.jpg")
    img_bgr  = cv2.imread(img_path)
    if img_bgr is None:
        print(f"⚠️ Missing image: {img_path}")
        continue

    # 1) lead segmentation
    crops = inference_and_label_and_crop(lead_model, img_bgr, conf_threshold=lead_cfg['conf_threshold'])
    # crops: list of (crop_img, label)

    # 2) grid detect & digitize each crop
    lead_waveforms = {}
    lead_to_size = {}
    for crop_img, label in crops:
        # a) grid square
        sq = get_grid_square_size(crop_img, closing_kernel=GRID_KERNEL, length_frac=GRID_LENGTH_FRAC)
        lead_to_size[label] = sq

        # b) extract mask
        pil_crop = Image.fromarray(cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB))
        mask = wave_extractor.extract_wave(pil_crop)
        mask_np = np.array(mask)

        # c) resize mask back to original crop size
        h,w = crop_img.shape[:2]
        mask_rs = cv2.resize(mask_np, (w,h), interpolation=cv2.INTER_NEAREST)

        # d) digitize
        wf = process_ecg_mask(mask_rs, sq)
        lead_waveforms[label] = wf

    # 3) compare to JSON ground truth
    results_det = {}
    results_all = {}
    for i_lead, lead_label in enumerate(lead_names):
        gt_file = os.path.join(fld_path, f"{fld}_lead_{i_lead}.json")
        if not os.path.isfile(gt_file):
            continue
        gt_wave = np.array(json.load(open(gt_file)))

        # find matching digitized
        pred = None
        for lbl,wf in lead_waveforms.items():
            if lbl.lower()==lead_label.lower():
                pred = wf; break

        # detected-only
        if pred is not None:
            detected_count += 1
            per_lead_detect[lead_label] += 1
            m = evaluate_signals(pred, gt_wave)
            results_det[lead_label] = m
            metrics_detected.append(m)
        # zeros-for-missed
        p0 = pred if pred is not None else np.zeros_like(gt_wave,dtype=float)
        m_all = evaluate_signals(p0, gt_wave)
        results_all[lead_label] = m_all
        metrics_all.append(m_all)

    # print per-folder
    print(f"\nFolder {fld}: {len(results_det)}/{len(lead_names)} leads detected")
    for ln,m in results_det.items():
        print(f"  {ln}: MSE={m['MSE']:.4f}, Corr={m['Correlation']:.3f}, DTW={m['DTW']:.1f}")


Input shape to TFLite model: (1, 640, 640, 3)
[INFO] processed signal length: 1712 samples
[INFO] processed signal length: 1716 samples
[INFO] processed signal length: 1931 samples
[INFO] processed signal length: 2013 samples
[INFO] processed signal length: 2104 samples
[INFO] processed signal length: 1700 samples


  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan



Folder 1670_6by2: 6/12 leads detected
  I: MSE=0.0265, Corr=-0.135, DTW=24.0
  II: MSE=0.0577, Corr=0.029, DTW=60.2
  III: MSE=0.0843, Corr=-0.131, DTW=78.1
  aVR: MSE=0.0468, Corr=0.036, DTW=57.0
  aVL: MSE=0.0302, Corr=-0.029, DTW=32.6
  aVF: MSE=0.0306, Corr=0.049, DTW=27.7
Input shape to TFLite model: (1, 640, 640, 3)
[INFO] processed signal length: 1935 samples
[INFO] processed signal length: 1969 samples
[INFO] processed signal length: 1950 samples
[INFO] processed signal length: 1952 samples
[INFO] processed signal length: 1922 samples


  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan



Folder 2_6by2: 5/12 leads detected
  I: MSE=0.0281, Corr=0.030, DTW=24.9
  II: MSE=0.0233, Corr=0.242, DTW=39.5
  aVR: MSE=0.0300, Corr=-0.005, DTW=36.4
  aVL: MSE=0.0303, Corr=-0.115, DTW=35.5
  aVF: MSE=0.0149, Corr=-0.132, DTW=29.0
Input shape to TFLite model: (1, 640, 640, 3)
[INFO] processed signal length: 1888 samples
[INFO] processed signal length: 1795 samples
[INFO] processed signal length: 1888 samples
[INFO] processed signal length: 1879 samples
[INFO] processed signal length: 1886 samples
[INFO] processed signal length: 2031 samples


  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan



Folder 1234_6by2: 6/12 leads detected
  I: MSE=0.0086, Corr=0.088, DTW=20.3
  II: MSE=0.0420, Corr=0.011, DTW=55.5
  III: MSE=0.0297, Corr=0.120, DTW=36.5
  aVR: MSE=0.0249, Corr=-0.224, DTW=41.5
  aVL: MSE=0.1345, Corr=0.071, DTW=110.3
  aVF: MSE=0.0349, Corr=0.072, DTW=43.5
Input shape to TFLite model: (1, 640, 640, 3)
[INFO] processed signal length: 1909 samples
[INFO] processed signal length: 1891 samples
[INFO] processed signal length: 1985 samples
[INFO] processed signal length: 1914 samples
[INFO] processed signal length: 1983 samples


  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan
  corr = pearsonr(g, p)[0] if len(g) > 2 else np.nan



Folder 1738_6by2: 5/12 leads detected
  II: MSE=0.0274, Corr=0.133, DTW=32.3
  III: MSE=0.0314, Corr=0.115, DTW=37.9
  aVR: MSE=0.0359, Corr=0.009, DTW=42.8
  aVL: MSE=0.0334, Corr=-0.031, DTW=35.4
  aVF: MSE=0.0408, Corr=0.083, DTW=39.2
Input shape to TFLite model: (1, 640, 640, 3)


RuntimeError: Detected only 11 boxes, expected 12.

In [None]:

# 4) final summary
print("\n=== Detection Rates ===")
print(f"Overall: {detected_count}/{total_slots} = {detected_count/total_slots:.1%}")
for ln in lead_names:
    r = per_lead_detect[ln]/total_folders
    print(f"  {ln}: {per_lead_detect[ln]}/{total_folders} = {r:.1%}")

def summarize(lst):
    keys = lst[0].keys()
    s={}
    for k in keys:
        arr = np.array([m[k] for m in lst])
        s[k] = (np.nanmean(arr), np.nanstd(arr))
    return s

sd = summarize(metrics_detected)
sa = summarize(metrics_all)
print("\n=== Metrics (detected-only) ===")
for k,(mu,sdv) in sd.items():
    print(f"{k}: {mu:.4f} ± {sdv:.4f}")
print("\n=== Metrics (all leads, zeros) ===")
for k,(mu,sdv) in sa.items():
    print(f"{k}: {mu:.4f} ± {sdv:.4f}")

print(f"\nTotal time: {time.time()-start:.1f}s")