In [1]:
import os
import time
import json
import yaml
import cv2
import numpy as np

from PIL import Image
from scipy.spatial.distance import euclidean
from scipy.stats import pearsonr
from fastdtw import fastdtw
from sklearn.metrics import mean_squared_error

import neurokit2 as nk

from scripts.grid_detection import get_grid_square_size
from scripts.digititze import process_ecg_mask
from scripts.lead_segmentation import init_model as init_lead_model, inference_and_label_and_crop
from scripts.extract_wave import WaveExtractor

# --- Load configs ---
with open('./configs/lead_segmentation.yaml') as f:
    lead_cfg = yaml.safe_load(f)
with open('./configs/wave_extraction.yaml') as f:
    wave_cfg = yaml.safe_load(f)
with open('./configs/grid_detection.yaml') as f:
    grid_cfg = yaml.safe_load(f)
with open('./configs/digitize.yaml') as f:
    digitize_cfg = yaml.safe_load(f)

# --- Paths & parameters ---
INPUT_ROOT        = '../data/digitization-dataset/digitization-dataset'
CROPPED_SAVE_DIR  = './data/test'
FINAL_OUTPUT_DIR  = './data/eval'
GRID_KERNEL       = grid_cfg.get('closing_kernel', 10)
GRID_LENGTH_FRAC  = grid_cfg.get('length_frac', 0.05)
WAVE_WEIGHTS_PATH = wave_cfg['weights_path']
WAVE_DEVICE       = wave_cfg.get('device', 'cpu')
YOLO_WEIGHTS_PATH = lead_cfg['model_path']
os.makedirs(CROPPED_SAVE_DIR, exist_ok=True)
os.makedirs(FINAL_OUTPUT_DIR, exist_ok=True)

# --- Lead names in 6×2 layout order ---
LEAD_NAMES = ["I","II","III","aVR","aVL","aVF","V1","V2","V3","V4","V5","V6"]


In [2]:

# --- Load models ---
lead_model     = init_lead_model(lead_cfg['model_path'])
wave_extractor = WaveExtractor(WAVE_WEIGHTS_PATH, device=WAVE_DEVICE)


initializing wave extractor...
Wave extractor initialized.


In [3]:

# --- Helper functions ---
def match_length(a, b):
    n = min(len(a), len(b))
    return a[:n], b[:n]

def scalar_euclid(x, y):
    return euclidean([x], [y])

def evaluate_signals(pred, gt):
    p = np.ravel(pred).astype(float)
    g = np.ravel(gt ).astype(float)
    p, g = match_length(p, g)
    mse  = mean_squared_error(g, p)
    corr = pearsonr(g, p)[0] if len(g)>2 else np.nan
    dtw, _ = fastdtw(g.tolist(), p.tolist(), dist=scalar_euclid)
    return {"MSE": mse, "Correlation": corr, "DTW": dtw}

def extract_pqrst_avg(signal, fs=400):
    """Return average P, Q, R, S, T amplitudes via neurokit2 peaks/delineation."""
    try:
        cleaned = nk.ecg_clean(signal, sampling_rate=fs)
        peaks, _ = nk.ecg_peaks(cleaned, sampling_rate=fs)
        dline, _= nk.ecg_delineate(cleaned, peaks["ECG_R_Peaks"], sampling_rate=fs, method="dwt")
        # gather mean amplitudes at annotated indices
        amp = {}
        for wave, key in [("ECG_P_Peaks","P"),("ECG_Q_Peaks","Q"),("ECG_R_Peaks","R"),
                          ("ECG_S_Peaks","S"),("ECG_T_Peaks","T")]:
            idx = dline.index[dline[wave]==1]
            amp[key] = cleaned[idx].mean() if len(idx)>0 else np.nan
        return amp
    except Exception:
        return {"P":np.nan,"Q":np.nan,"R":np.nan,"S":np.nan,"T":np.nan}

In [None]:


# --- Main evaluation loop ---
folders       = [d for d in os.listdir(INPUT_ROOT)
                 if '6by2' in d and os.path.isdir(os.path.join(INPUT_ROOT,d))]
total_folders = len(folders)
total_slots   = total_folders * len(LEAD_NAMES)

detected_count  = 0
per_lead_detect = {ln:0 for ln in LEAD_NAMES}
metrics_detected = []
metrics_all      = []
pqrst_errors     = []

start = time.time()
for idx, fld in enumerate(folders,1):
    if idx%10==0:
        print(f"{idx}/{total_folders} processed…")
    fld_path = os.path.join(INPUT_ROOT, fld)
    img_path = os.path.join(fld_path, f"{fld}.jpg")
    if not os.path.isfile(img_path):
        print(f"⚠️ Missing image for {fld}")
        continue

    # try:
    # 1) Lead segmentation
    crops,_ = inference_and_label_and_crop(
        lead_model, img_path, CROPPED_SAVE_DIR,
        conf_threshold=lead_cfg['conf_threshold']
    )
    # save and collect
    all_crops = []
    for crop_img,label in crops:
        crop_path = os.path.join(CROPPED_SAVE_DIR, f"{fld}_{label}.jpg")
        cv2.imwrite(crop_path, crop_img)
        all_crops.append((crop_path,label))

    # 2) Grid detection
    lead_to_sq   = {}
    for cp,label in all_crops:
        img = cv2.imread(cp)
        lead_to_sq[label] = get_grid_square_size(
            img, closing_kernel=GRID_KERNEL,
            length_frac=GRID_LENGTH_FRAC
        )

    # 3) Wave extraction
    lead_to_mask = {}
    for cp,label in all_crops:
        lead_to_mask[label] = wave_extractor.extract_wave(cp)

    # 4) Digitization
    lead_waveforms = {}
    for cp,label in all_crops:
        sq = lead_to_sq.get(label)
        mask = lead_to_mask.get(label)
        if sq is None or mask is None: continue
        wf = process_ecg_mask(mask, sq)
        lead_waveforms[label] = wf

    # 5) Evaluation
    fs_target = 400
    duration_s = 2.0
    n_target = int(fs_target * duration_s)
    t_new = np.linspace(0, duration_s, n_target)

    for i, lead_label in enumerate(LEAD_NAMES):
        gt_file = os.path.join(fld_path, f"{fld}_lead_{i}.json")
        if not os.path.isfile(gt_file):
            continue

        # Load and resample GT to 400 Hz over 2 s
        raw = np.array(json.load(open(gt_file)))
        t_orig = np.linspace(0, duration_s, raw.size)
        gt_wave = np.interp(t_new, t_orig, raw)

        pred = lead_waveforms.get(lead_label)
        # Resample pred if present
        if pred is not None:
            t_pred = np.linspace(0, duration_s, pred.size)
            pred_rs = np.interp(t_new, t_pred, pred)
        else:
            pred_rs = None

        # 2) Waveform metrics
        try:
            if pred_rs is not None:
                m = evaluate_signals(pred_rs, gt_wave)
                metrics_detected.append(m)
                detected_count += 1
                per_lead_detect[lead_label] += 1
            p0 = pred_rs if pred_rs is not None else np.zeros_like(gt_wave, dtype=float)
            m_all = evaluate_signals(p0, gt_wave)
            metrics_all.append(m_all)
        except Exception as e:
            print(f"⚠️ Skipping metrics for {fld}/{lead_label}: {e}")

    # after loop
    print(f"📁 {fld}: detected {detected_count} leads so far")

    # except Exception as e:
    #     print(f"❌ Error on folder {fld}: {e}")
    #     continue

end = time.time()
print(f"\nDone in {end-start:.1f}s")
print(f"Total slots: {total_slots}, detected: {detected_count}")

# Optionally: aggregate and print global stats
avg_mse  = np.nanmean([m["MSE"] for m in metrics_all])
avg_corr = np.nanmean([m["Correlation"] for m in metrics_all])
print(f"Overall MSE={avg_mse:.4f}, Corr={avg_corr:.3f}")



image 1/1 /home/abdullah-bin-mansoor/Desktop/ECG Project/full-pipeline/../data/digitization-dataset/digitization-dataset/1670_6by2/1670_6by2.jpg: 320x640 12 lead_containers, 1 label_II, 1 label_III, 1 label_aVR, 1 label_aVL, 1 label_aVF, 1 label_V1, 1 label_V2, 1 label_V3, 1 label_V4, 1 label_V5, 1 label_V6, 733.7ms
Speed: 62.5ms preprocess, 733.7ms inference, 158.2ms postprocess per image at shape (1, 3, 320, 640)
📁 1670_6by2: detected 12 leads so far

image 1/1 /home/abdullah-bin-mansoor/Desktop/ECG Project/full-pipeline/../data/digitization-dataset/digitization-dataset/2_6by2/2_6by2.jpg: 320x640 13 lead_containers, 1 label_II, 1 label_III, 1 label_aVR, 1 label_aVL, 1 label_aVF, 1 label_V1, 1 label_V2, 1 label_V3, 1 label_V4, 1 label_V5, 1 label_V6, 10048.6ms
Speed: 709.3ms preprocess, 10048.6ms inference, 626.1ms postprocess per image at shape (1, 3, 320, 640)
📁 2_6by2: detected 24 leads so far

image 1/1 /home/abdullah-bin-mansoor/Desktop/ECG Project/full-pipeline/../data/digitiza

  corr = pearsonr(g, p)[0] if len(g)>2 else np.nan
  corr = pearsonr(g, p)[0] if len(g)>2 else np.nan
  corr = pearsonr(g, p)[0] if len(g)>2 else np.nan
  corr = pearsonr(g, p)[0] if len(g)>2 else np.nan
  corr = pearsonr(g, p)[0] if len(g)>2 else np.nan
  corr = pearsonr(g, p)[0] if len(g)>2 else np.nan
  corr = pearsonr(g, p)[0] if len(g)>2 else np.nan
  corr = pearsonr(g, p)[0] if len(g)>2 else np.nan
  corr = pearsonr(g, p)[0] if len(g)>2 else np.nan
  corr = pearsonr(g, p)[0] if len(g)>2 else np.nan
  corr = pearsonr(g, p)[0] if len(g)>2 else np.nan
  corr = pearsonr(g, p)[0] if len(g)>2 else np.nan


📁 66_6by2: detected 48 leads so far

image 1/1 /home/abdullah-bin-mansoor/Desktop/ECG Project/full-pipeline/../data/digitization-dataset/digitization-dataset/1630_6by2/1630_6by2.jpg: 320x640 12 lead_containers, 1 label_II, 1 label_III, 1 label_aVR, 1 label_aVL, 1 label_aVF, 1 label_V1, 1 label_V2, 1 label_V3, 1 label_V4, 1 label_V5, 1 label_V6, 255.2ms
Speed: 3.5ms preprocess, 255.2ms inference, 1.7ms postprocess per image at shape (1, 3, 320, 640)
📁 1630_6by2: detected 60 leads so far

image 1/1 /home/abdullah-bin-mansoor/Desktop/ECG Project/full-pipeline/../data/digitization-dataset/digitization-dataset/578_6by2/578_6by2.jpg: 320x640 13 lead_containers, 1 label_II, 1 label_III, 1 label_aVR, 1 label_aVL, 1 label_aVF, 1 label_V1, 1 label_V2, 1 label_V3, 1 label_V4, 1 label_V5, 1 label_V6, 286.6ms
Speed: 5.7ms preprocess, 286.6ms inference, 1.3ms postprocess per image at shape (1, 3, 320, 640)
📁 578_6by2: detected 72 leads so far

image 1/1 /home/abdullah-bin-mansoor/Desktop/ECG Project