In [None]:
# =======================================================================
# ECG Image Digitization — FULL CORRECTED CELL (process *all* images)
# - New: CFG.SCOPE = "all" | "test" | "train"  (default "all")
# - Scans the entire dataset root recursively for any image extensions
# - Processes TRAIN + TEST + any stray folders when SCOPE="all"
# - Saves artifacts to /kaggle/working/out
# - ALWAYS writes /kaggle/working/submission.csv
# =======================================================================
import os, sys, glob, json, math, re, warnings
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional, Any

import numpy as np
import pandas as pd
import cv2
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from scipy.ndimage import gaussian_filter1d, median_filter
from scipy.signal import butter, filtfilt, find_peaks, welch
from scipy.interpolate import PchipInterpolator

warnings.filterwarnings("ignore")

# --------------------------
# Config
# --------------------------
@dataclass
class CFG:
    ROOT: str = "/kaggle/input/physionet-ecg-image-digitization"  # auto-detected if empty/missing
    OUT_DIR: str = "/kaggle/working/out"
    SUBMISSION_PARQUET: str = "/kaggle/working/submission.parquet"
    SUBMISSION_CSV: str = "/kaggle/working/submission.csv"

    # ECG paper / signal
    SWEEP_MM_PER_S: float = 25.0
    GAIN_MM_PER_MV: float = 10.0
    TARGET_FS: int = 500

    # Layout & names
    LEAD_NAMES_12: List[str] = None

    # Batch behavior
    OVERWRITE: bool = False           # set True to recompute even if outputs exist
    SCOPE: str = "all"                # "all" | "test" | "train"

    # Strict redo thresholds
    STRICT_REDO_TRIGGER_POOR_COUNT: int = 8
    STRICT_REDO_TRIGGER_NO_GOOD: bool = True

    # Seam refinement
    STRICT_BAND_EXTRA: int = 4

    # Scale search
    MM_PER_PX_Y_GUESS: float = 0.1
    SCALES_MM_PER_PX_X_BASE: Tuple[float, ...] = (1/12, 1/11, 1/10, 1/9, 1/8)

    def __post_init__(self):
        if self.LEAD_NAMES_12 is None:
            self.LEAD_NAMES_12 = ["I","II","III","aVR","aVL","aVF","V1","V2","V3","V4","V5","V6"]

CFG = CFG()
os.makedirs(CFG.OUT_DIR, exist_ok=True)

# --------------------------
# ROOT auto-detection & diagnostics
# --------------------------
def auto_detect_root(preferred: str) -> str:
    """If the preferred ROOT doesn't exist or has no data, scan /kaggle/input for a likely dataset."""
    def has_data(root):
        if not os.path.isdir(root): return False
        # any images?
        exts = ("*.png","*.jpg","*.jpeg","*.tif","*.tiff","*.bmp")
        for e in exts:
            if glob.glob(os.path.join(root, "**", e), recursive=True):
                return True
        # or a CSV schema
        if os.path.exists(os.path.join(root, "test.csv")): return True
        if os.path.exists(os.path.join(root, "sample_submission.csv")): return True
        return False

    if has_data(preferred):
        return preferred

    candidates = []
    for d in sorted(glob.glob("/kaggle/input/*")):
        if os.path.isdir(d) and has_data(d):
            candidates.append(d)

    # Prefer names that match physionet / ecg / digitization
    def score_path(p):
        s = 0
        name = os.path.basename(p).lower()
        for token in ["physionet","ecg","digit","image","cardio"]:
            if token in name: s += 2
        if os.path.exists(os.path.join(p, "test.csv")): s += 3
        return -s  # sort ascending => most negative (highest score) first

    if candidates:
        best = sorted(candidates, key=score_path)[0]
        print(f"[auto-detect] Using dataset root: {best}")
        return best

    print("[auto-detect] No dataset found under /kaggle/input — proceeding without images.")
    return preferred  # will produce dummy submission if empty

CFG.ROOT = auto_detect_root(CFG.ROOT)

def list_ext_counts(root: str) -> Dict[str,int]:
    counts = {}
    for ext in ("png","jpg","jpeg","tif","tiff","bmp"):
        n = len(glob.glob(os.path.join(root, "**", f"*.{ext}"), recursive=True))
        counts[ext] = n
    return counts

print("Dataset root:", CFG.ROOT)
print("Image counts by extension:", list_ext_counts(CFG.ROOT))
print("Processing scope:", CFG.SCOPE)

# --------------------------
# Helpers (IO & listing)
# --------------------------
def find_all_images(root: str) -> List[str]:
    """Find ALL images recursively under root regardless of train/test placement."""
    files: List[str] = []
    for ext in ("*.png","*.jpg","*.jpeg","*.tif","*.tiff","*.bmp"):
        files += glob.glob(os.path.join(root, "**", ext), recursive=True)
    # Deduplicate but keep stable order
    seen=set(); out=[]
    for f in sorted(files):
        if f not in seen:
            seen.add(f); out.append(f)
    return out

def collect_images_from(dirpath: str) -> List[str]:
    """Collect only top-level images from dirpath (no recursion) for explicit train/test dirs."""
    out=[]
    for ext in ("*.png","*.jpg","*.jpeg","*.tif","*.tiff","*.bmp"):
        out += glob.glob(os.path.join(dirpath, ext))
    return sorted(list(dict.fromkeys(out)))

def save_preview_grid(leads: Dict[str, np.ndarray], fs: int, png_path: str):
    names = list(leads.keys())
    if len(names) == 1:
        y = leads[names[0]]
        t = np.arange(len(y))/fs if len(y)>0 else np.arange(1)
        plt.figure(figsize=(10,3))
        plt.plot(t, y, linewidth=0.9)
        plt.title(f"{names[0]}  len={len(y)}  fs={fs}Hz")
        plt.grid(True, alpha=0.3)
        if len(t)>0:
            plt.xlim(0, min(11, t[-1]))
        plt.savefig(png_path, dpi=160, bbox_inches="tight"); plt.close()
        return
    rows, cols = 3, 4
    fig, axes = plt.subplots(rows, cols, figsize=(12, 6), constrained_layout=True)
    axes = axes.ravel()
    for i, nm in enumerate(CFG.LEAD_NAMES_12):
        ax = axes[i]
        if nm in leads and len(leads[nm])>0:
            y = leads[nm]; t = np.arange(len(y))/fs
            ax.plot(t, y, linewidth=0.8); ax.grid(True, alpha=0.3)
            ax.set_xlim(0, min(11, t[-1] if len(t)>0 else 11))
            ax.set_title(nm, fontsize=9)
        else:
            ax.axis("off")
    plt.suptitle("Recovered ECG", fontsize=12)
    plt.savefig(png_path, dpi=160, bbox_inches="tight"); plt.close()

def npz_exists(prefix: str) -> bool:
    return all(os.path.exists(prefix+ext) for ext in (".npz",".json",".png"))

# ==========================================
# DFEE, RDSE, BioShed core
# ==========================================
_PATCH = {"strict": False, "row_idx": None}

def _auto_contrast(gray: np.ndarray) -> np.ndarray:
    p1,p99 = np.percentile(gray, [1,99])
    return np.clip((gray - p1) * (255.0/max(1.0,(p99-p1))), 0,255).astype(np.uint8)

def _rotate_bound(img: np.ndarray, angle_deg: float) -> np.ndarray:
    h, w = img.shape[:2]
    cX, cY = (w//2, h//2)
    M = cv2.getRotationMatrix2D((cX, cY), angle_deg, 1.0)
    cos = abs(M[0,0]); sin = abs(M[0,1])
    nW = int((h*sin) + (w*cos)); nH = int((h*cos) + (w*sin))
    M[0,2] += (nW/2) - cX; M[1,2] += (nH/2) - cY
    return cv2.warpAffine(img, M, (nW, nH), flags=cv2.INTER_LINEAR, borderValue=255)

def dfee_angle_and_pitch(gray: np.ndarray) -> Tuple[float, float, float]:
    g = _auto_contrast(gray)
    edges = cv2.Canny(g, 50, 150)
    lines = cv2.HoughLines(edges, 1, np.pi/180, threshold=max(150, int(0.0005*g.size)))
    angle_deg = 0.0
    if lines is not None:
        angs=[]
        for rho,theta in lines[:,0,:]:
            deg=(theta*180/np.pi)%180
            if deg<2 or deg>178: angs.append(deg if deg<2 else deg-180)
            elif 88<deg<92:      angs.append(deg-90)
        if angs:
            m = float(np.median(angs))
            if abs(m) < 5.0: angle_deg = m

    F = np.fft.fft2(g.astype(np.float32))
    P = np.fft.fftshift(np.abs(F)); P /= (P.max()+1e-6)
    vx = P[P.shape[0]//2, :]; vy = P[:, P.shape[1]//2]

    def dominant_period(sig, total_len):
        sg = gaussian_filter1d(sig, 5)
        mid = len(sg)//2
        idx = np.argmax(sg[mid+10:]) + (mid+10)
        k = abs(idx - mid)
        if k <= 0: return None
        return total_len/max(1, k)

    px_per_mm_x = dominant_period(vx, gray.shape[1])
    px_per_mm_y = dominant_period(vy, gray.shape[0])
    if px_per_mm_x is None or not (4 <= px_per_mm_x <= 20): px_per_mm_x = 10.0
    if px_per_mm_y is None or not (4 <= px_per_mm_y <= 20): px_per_mm_y = 10.0

    return angle_deg, 1.0/px_per_mm_x, 1.0/px_per_mm_y

def degrid_color_to_gray(bgr: np.ndarray) -> np.ndarray:
    if bgr is None or bgr.ndim != 3:
        return bgr if bgr is not None else None
    B,G,R = cv2.split(bgr.astype(np.float32))
    base = np.minimum(np.minimum(R, G), B)
    red_excess = np.clip(R - 0.5*(G + B), 0, None)
    alpha = 0.5 + (0.05 if _PATCH["strict"] else 0.0)
    red_mask = red_excess / (red_excess.max() + 1e-6)
    suppr = base - alpha * red_mask * (base - base.min())
    suppr = cv2.GaussianBlur(suppr, (3,3), 0)
    suppr -= suppr.min(); suppr /= (suppr.max() + 1e-6)
    return (suppr*255.0).astype(np.uint8)

def preprocess(gray_or_bgr):
    g = degrid_color_to_gray(gray_or_bgr) if (gray_or_bgr.ndim==3) else gray_or_bgr
    g = _auto_contrast(g)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    g = clahe.apply(g)
    g = median_filter(g, size=3).astype(np.uint8)
    return g

def _seam_dp(cost, lam_smooth=0.0025, lam_curve=0.003):
    H,W = cost.shape
    INF = 1e9
    C = np.full((H,W), INF, dtype=np.float32)
    P = np.full((H,W), 0, dtype=np.int8)
    C[:,0] = cost[:,0]
    for x in range(1,W):
        ccol = cost[:,x]
        up = np.r_[C[0:1,x-1]+1e6, C[:-1,x-1]]
        mi = C[:,x-1]
        dn = np.r_[C[1:,x-1], C[-1:,x-1]+1e6]
        stack = np.stack([up,mi,dn], axis=0)
        idx = np.argmin(stack, axis=0)
        base = stack[idx, np.arange(H)]
        curve_pen = lam_curve * (idx!=1).astype(np.float32)
        C[:,x] = ccol + base + lam_smooth*np.abs(idx-1) + curve_pen
        P[:,x] = (idx-1).astype(np.int8)
    y = int(np.argmin(C[:,-1]))
    path = np.empty(W, dtype=np.float32); path[-1]=y
    for x in range(W-1,0,-1):
        dy = int(P[y,x]); y = max(0, min(H-1, y+dy)); path[x-1]=y
    return gaussian_filter1d(path, sigma=1.0).astype(np.float32)

def seam_centerline_multi(g):
    g_norm = g.astype(np.float32)
    g_norm = (g_norm - g_norm.min())/max(1e-6,(g_norm.max()-g_norm.min()))
    edges = cv2.Canny((g_norm*255).astype(np.uint8), 40, 120).astype(np.float32)/255.0
    sobel = cv2.Sobel(g_norm, cv2.CV_32F, 1, 0, ksize=3)
    grad = np.abs(sobel); grad /= max(1e-6, grad.max())
    base_cost = 0.5*g_norm + 0.45*(1.0-edges) + 0.05*(1.0-grad)
    p1 = _seam_dp(base_cost, lam_smooth=0.0015, lam_curve=0.002)
    p2 = _seam_dp(base_cost, lam_smooth=0.0030, lam_curve=0.003)
    p3 = _seam_dp(base_cost, lam_smooth=0.0060, lam_curve=0.004)
    path = np.median(np.vstack([p1,p2,p3]), axis=0).astype(np.float32)
    return gaussian_filter1d(path, sigma=1.2).astype(np.float32)

def seam_centerline_simple(g):
    gb = cv2.GaussianBlur(g.astype(np.float32), (3,3), 0)
    k = cv2.getStructuringElement(cv2.MORPH_RECT, (9,9))
    tophat = cv2.morphologyEx(gb.astype(np.uint8), cv2.MORPH_TOPHAT, k).astype(np.float32)
    g2 = gb - 0.4*tophat
    y = np.argmin(g2, axis=0).astype(np.float32)
    return gaussian_filter1d(y, sigma=2.0).astype(np.float32)

def seam_centerline_ridge(g):
    k = cv2.getStructuringElement(cv2.MORPH_RECT, (7,7))
    bh = cv2.morphologyEx(g.astype(np.uint8), cv2.MORPH_BLACKHAT, k).astype(np.float32)
    medc = np.median(bh, axis=0, keepdims=True).astype(np.float32)
    hp = cv2.GaussianBlur(bh - medc, (3,3), 0)
    y = np.argmin(hp, axis=0).astype(np.float32)
    return gaussian_filter1d(y, sigma=1.5).astype(np.float32)

def is_flat(y): return np.std(gaussian_filter1d(y, 20)) < 0.35

def _adaptive_band_px(panel_gray: np.ndarray) -> int:
    H,W = panel_gray.shape
    contrast = float(np.std(panel_gray) / (np.mean(panel_gray)+1e-6))
    e = cv2.Canny(panel_gray, 40, 120)
    edged = float(np.mean(e>0))
    base = 20.0
    if contrast < 0.35: base += 4.0
    if edged < 0.05:    base += 2.0
    if contrast > 0.6:  base -= 2.0
    if edged > 0.15:    base -= 2.0
    if _PATCH["strict"]: base += CFG.STRICT_BAND_EXTRA
    return int(np.clip(base, 12, 32))

def refine_seam_around_path(g, y_path, band_px=18):
    g_norm = g.astype(np.float32)
    g_norm = (g_norm - g_norm.min())/max(1e-6,(g_norm.max()-g_norm.min()))
    H, W = g_norm.shape
    edges = cv2.Canny((g_norm*255).astype(np.uint8), 40, 120).astype(np.float32)/255.0
    sobel = cv2.Sobel(g_norm, cv2.CV_32F, 1, 0, ksize=3)
    grad = np.abs(sobel); grad /= max(1e-6, grad.max())
    base_cost = 0.5*g_norm + 0.45*(1.0-edges) + 0.05*(1.0-grad)
    yy = np.arange(H, dtype=np.float32)[:, None]
    ribbon = np.exp(-0.5*((yy - y_path[None, :])/(band_px))**2)
    ribbon = (ribbon - ribbon.min())/(ribbon.max()-ribbon.min()+1e-6)
    cost = base_cost*0.7 + 0.3*(1.0 - ribbon)
    return _seam_dp(cost, lam_smooth=0.0025, lam_curve=0.003)

def to_timeseries_from_path(y_path: np.ndarray, H: int,
                            mm_per_px_x: float, mm_per_px_y: float) -> Tuple[np.ndarray, np.ndarray]:
    mv_per_px = (1.0 / CFG.GAIN_MM_PER_MV) / mm_per_px_y
    trace_mV = ((H/2.0 - y_path) * mv_per_px).astype(np.float32)
    s_per_px = (mm_per_px_x / CFG.SWEEP_MM_PER_S)
    t = np.arange(len(y_path), dtype=np.float32) * s_per_px
    return trace_mV, t

def butter_bandpass(sig, fs, low=0.67, high=35.0, order=4):
    nyq = 0.5*fs
    b, a = butter(order, [low/nyq, high/nyq], btype='band')
    return filtfilt(b, a, sig)

def hp_detrend(sig, fs, fc=0.5, order=2):
    nyq = 0.5*fs
    b,a = butter(order, fc/nyq, btype='highpass')
    try:    return filtfilt(b,a,sig).astype(np.float32)
    except: return sig

def resample_to_fs(y_mV, t, fs):
    t = np.asarray(t, dtype=np.float32); y = np.asarray(y_mV, dtype=np.float32)
    if len(t) < 2: return y, 0.0
    t_target = np.arange(0, max(1e-6, t[-1]), 1.0/fs, dtype=np.float32)
    if len(t_target) < 5:
        t_target = np.linspace(0, max(1e-3, t[-1]), num=min(4096, max(32, len(y))), dtype=np.float32)
    y_res = PchipInterpolator(t, y, extrapolate=False)(t_target).astype(np.float32)
    if np.isnan(y_res).any():
        m = np.isnan(y_res)
        if (~m).any():
            first = np.argmax(~m); last = len(y_res)-1 - np.argmax((~m)[::-1])
            y_res[:first] = y_res[first]; y_res[last+1:] = y_res[last]
            y_res[m] = np.interp(np.flatnonzero(m), np.flatnonzero(~m), y_res[~m])
        else:
            y_res = np.zeros_like(y_res)
    try:
        y_bp = butter_bandpass(y_res, fs=fs, low=0.67, high=35.0).astype(np.float32)
    except Exception:
        y_bp = y_res
    y_bp = hp_detrend(y_bp, fs, fc=0.5)
    return y_bp, float(t_target[-1])

def snr_estimate(x, fs):
    if len(x) < 16: return 0.0
    f, Pxx = welch(x, fs=fs, nperseg=min(1024, len(x)))
    signal = (f>=5) & (f<=15)
    noise  = (f<0.5) | (f>45)
    Ps = float(np.mean(Pxx[signal])) if np.any(signal) else 1e-9
    Pn = float(np.mean(Pxx[~signal & ~noise])) if np.any(~signal & ~noise) else 1e-9
    return 10.0*np.log10((Ps+1e-12)/(Pn+1e-12))

def rpeaks_adaptive(x, fs):
    def _peaks(sig):
        try: xb = butter_bandpass(sig, fs, 5, 15)
        except Exception: xb = sig
        dx = np.diff(xb, prepend=xb[0]); sq = dx*dx
        win = max(1, int(0.12*fs))
        env = np.convolve(sq, np.ones(win)/win, mode='same')
        thr = np.median(env) + 1.8*np.std(env)
        peaks, _ = find_peaks(env, height=thr, distance=int(0.28*fs))
        if len(peaks) < 2:
            thr = np.median(env) + 1.2*np.std(env)
            peaks, _ = find_peaks(env, height=thr, distance=int(0.25*fs))
        return peaks
    pk = _peaks(x)
    if len(pk) < 2: pk = _peaks(-x)
    return pk

def bioshed_metrics(x: np.ndarray, fs: int) -> Dict[str, Any]:
    pk = rpeaks_adaptive(x, fs)
    if len(pk) >= 2:
        rr = np.diff(pk)/fs
        rr_mean = float(np.mean(rr)); rr_sdnn = float(np.std(rr, ddof=1)) if len(rr)>1 else 0.0; rr_n=int(len(rr))
    else:
        rr_mean = rr_sdnn = 0.0; rr_n = 0
    hf = x - gaussian_filter1d(x, sigma=max(1, int(fs*0.04)))
    art = float(np.mean(np.abs(hf) > (3*np.std(hf)+1e-6)))
    snr = snr_estimate(x, fs)
    quality = "good"
    if snr < 3 or art > 0.08: quality = "fair"
    if snr < 0 or art > 0.15 or rr_n < 2: quality = "poor"
    return dict(snr_db=float(snr), rr_mean=rr_mean, rr_sdnn=rr_sdnn, rr_n=rr_n,
                artifact_rate=art, quality=quality)

def auto_layout_panels(gray_img: np.ndarray) -> Tuple[List[Tuple[int,int,int,int]], str]:
    H, W = gray_img.shape
    cols = np.round(np.linspace(0, W, 5)).astype(int)
    rows = np.round(np.linspace(0, H, 4)).astype(int)
    panels=[]
    for i in range(3):
        for j in range(4):
            x0,x1 = cols[j], cols[j+1]; y0,y1 = rows[i], rows[i+1]
            panels.append((int(x0), int(y0), int(x1-x0), int(y1-y0)))
    return panels, "4x3"

def extract_panel_bestof(panel_gray: np.ndarray, panel_name: str,
                         mm_per_px_x: float, mm_per_px_y: float) -> Dict[str, Any]:
    H, W = panel_gray.shape
    seeds = [
        ("multi", seam_centerline_multi(panel_gray)),
        ("simple", seam_centerline_simple(panel_gray)),
        ("ridge", seam_centerline_ridge(panel_gray))
    ]
    band = _adaptive_band_px(panel_gray)
    best=None
    for nm,y0 in seeds:
        if is_flat(y0): continue
        y1 = refine_seam_around_path(panel_gray, y0, band_px=band)
        y_mV, t = to_timeseries_from_path(0.25*y0 + 0.75*y1, H, mm_per_px_x, mm_per_px_y)
        y_f, dur = resample_to_fs(y_mV, t, CFG.TARGET_FS)
        qa = bioshed_metrics(y_f, CFG.TARGET_FS)
        score = qa["snr_db"] + (8.0 if qa["rr_n"]>=2 else 0.0) + (1.0 if qa["quality"]=="good" else 0.0)
        cand = dict(name=panel_name, y=y_f, fs=CFG.TARGET_FS, duration_s=dur, qa=qa,
                    picked={"seam": f"{nm}+ref", "mm_per_px_x": float(mm_per_px_x), "mm_per_px_y": float(mm_per_px_y)})
        if (best is None) or (score > best["score"]):
            best = {"score": score, "res": cand}
    if best is None:
        y_mV, t = to_timeseries_from_path(seeds[1][1], H, mm_per_px_x, mm_per_px_y)
        y_f, dur = resample_to_fs(y_mV, t, CFG.TARGET_FS)
        qa = bioshed_metrics(y_f, CFG.TARGET_FS)
        return dict(name=panel_name, y=y_f, fs=CFG.TARGET_FS, duration_s=dur, qa=qa,
                    picked={"seam": "simple-raw", "mm_per_px_x": float(mm_per_px_x), "mm_per_px_y": float(mm_per_px_y)})
    return best["res"]

def process_image_to_leads(img_path: str) -> Optional[Dict[str, Any]]:
    raw = cv2.imread(img_path, cv2.IMREAD_COLOR)
    if raw is None:
        print("!! Cannot decode:", img_path); return None
    gray0 = cv2.cvtColor(raw, cv2.COLOR_BGR2GRAY)
    angle_deg, mmx, mmy = dfee_angle_and_pitch(gray0)
    gray = _rotate_bound(gray0, -angle_deg) if abs(angle_deg)>0.3 else gray0
    g = preprocess(gray)
    panels, layout = auto_layout_panels(g)
    lead_names = CFG.LEAD_NAMES_12 if (layout=="4x3" and len(panels)==12) else ["Lead0"]
    results = {}
    for idx, (x,y,w,h) in enumerate(panels[:len(lead_names)]):
        _PATCH["row_idx"] = idx//4 if layout=="4x3" else None
        panel = g[y:y+h, x:x+w]
        name = lead_names[idx]
        pr = extract_panel_bestof(panel, name, mmx, mmy)
        results[name] = pr
    if not results:
        return None
    min_len = min(len(results[k]["y"]) for k in results)
    for k in list(results.keys()):
        results[k]["y"] = results[k]["y"][:min_len]
        results[k]["qa"] = bioshed_metrics(results[k]["y"], CFG.TARGET_FS)
    return dict(layout=layout, angle_deg=float(angle_deg),
                mm_per_px_x=float(mmx), mm_per_px_y=float(mmy),
                fs=CFG.TARGET_FS, leads=results)

# ==========================================
# Batch digitization (train+test+others) + QC exports
# ==========================================
# discover train/test dirs (if present)
test_dir = os.path.join(CFG.ROOT, "test")
train_dir = os.path.join(CFG.ROOT, "train")

test_pngs = collect_images_from(test_dir) if os.path.isdir(test_dir) else []
train_pngs = collect_images_from(train_dir) if os.path.isdir(train_dir) else []
all_imgs = find_all_images(CFG.ROOT)

print(f"Discovered counts → train(top-level): {len(train_pngs)}, test(top-level): {len(test_pngs)}, all(recursive): {len(all_imgs)}")

# choose scope
if CFG.SCOPE.lower() == "test":
    use_imgs = test_pngs if test_pngs else all_imgs
elif CFG.SCOPE.lower() == "train":
    use_imgs = train_pngs if train_pngs else all_imgs
else:  # "all"
    # Use recursive discovery of ALL images under ROOT
    use_imgs = all_imgs

print(f"Selected {len(use_imgs)} images for processing based on scope='{CFG.SCOPE}'")
print(f"Outputs will be saved to: {CFG.OUT_DIR}")

new_cnt=0; skip_cnt=0; fail_cnt=0
for p in use_imgs:
    base = os.path.splitext(os.path.basename(p))[0]
    out_prefix = os.path.join(CFG.OUT_DIR, base)
    if (not CFG.OVERWRITE) and npz_exists(out_prefix):
        skip_cnt += 1
        continue
    try:
        _PATCH["strict"] = False
        R = process_image_to_leads(p)
        if R is None or not R.get("leads"):
            fail_cnt += 1; continue
        qc = [R["leads"][k]["qa"]["quality"] for k in R["leads"]]
        redo = (CFG.STRICT_REDO_TRIGGER_NO_GOOD and ("good" not in qc)) or \
               (qc.count("poor") >= CFG.STRICT_REDO_TRIGGER_POOR_COUNT)
        if redo:
            _PATCH["strict"] = True
            R = process_image_to_leads(p) or R
        os.makedirs(CFG.OUT_DIR, exist_ok=True)
        npz_dict = {k: R["leads"][k]["y"] for k in R["leads"]}
        npz_dict["fs"] = R["fs"]
        np.savez_compressed(out_prefix + ".npz", **npz_dict)
        with open(out_prefix + ".json", "w") as f:
            json.dump({
                "input": os.path.basename(p),
                "layout": R["layout"],
                "fs": R["fs"],
                "calibration": {
                    "angle_deg": R["angle_deg"],
                    "mm_per_px_x": R["mm_per_px_x"],
                    "mm_per_px_y": R["mm_per_px_y"],
                    "speed_mm_s": CFG.SWEEP_MM_PER_S,
                    "gain_mm_mV": CFG.GAIN_MM_PER_MV
                },
                "qa": {k: R["leads"][k]["qa"] for k in R["leads"]},
                "picked": {k: R["leads"][k]["picked"] for k in R["leads"]},
                "strict_mode": bool(_PATCH["strict"])
            }, f, indent=2)
        leads_ts = {k: R["leads"][k]["y"] for k in R["leads"]}
        save_preview_grid(leads_ts, R["fs"], out_prefix + ".png")
        new_cnt += 1
        picks = ", ".join([f"{k}:{R['leads'][k]['picked']['seam']}/{R['leads'][k]['picked']['mm_per_px_x']:.3f}"
                           for k in list(R["leads"].keys())[:4]])
        qccnt = {"good":0,"fair":0,"poor":0}
        for k in R["leads"]: qccnt[R["leads"][k]["qa"]["quality"]] += 1
        print(f"OK {os.path.basename(p)} layout={R['layout']} leads={len(R['leads'])} qc={qccnt} picks={picks}")
    except Exception as e:
        fail_cnt += 1
        print("!! Failed:", p, "→", repr(e))

print("\nBatch complete.",
      f"New processed: {new_cnt}, skipped existing: {skip_cnt}, failed: {fail_cnt}")
print("Digitized artifacts (.npz, .json, preview .png) are in:", CFG.OUT_DIR)

# --------------------------
# QC aggregation CSVs
# --------------------------
jfiles = sorted(glob.glob(os.path.join(CFG.OUT_DIR, "*.json")))
rows = []
for jf in jfiles:
    try:
        with open(jf, "r") as f:
            J = json.load(f)
        qa = J.get("qa", {})
        for lead, q in qa.items():
            rows.append({
                "image": J.get("input"),
                "lead": lead,
                "snr_db": q.get("snr_db"),
                "rr_mean": q.get("rr_mean"),
                "rr_sdnn": q.get("rr_sdnn"),
                "artifact_rate": q.get("artifact_rate"),
                "quality": q.get("quality"),
                "strict_mode": J.get("strict_mode", False)
            })
    except Exception as e:
        print("Skip QC:", jf, "→", e)

if rows:
    df_qc = pd.DataFrame(rows)
    df_qc.to_csv("/kaggle/working/qc_all_traces.csv", index=False)
    summary = (df_qc["quality"].value_counts(normalize=True)*100).round(1).to_dict()
    print("QC saved: /kaggle/working/qc_all_traces.csv")
    print("Overall quality %:", summary)
else:
    print("No QC rows produced.")

# ==========================================
# Robust submission.csv builder (no exit; always writes)
# ==========================================
LEADS_12 = CFG.LEAD_NAMES_12
DEFAULT_ROWS = {**{l:1250 for l in LEADS_12}, **{"II":5000}}

def _safe_pchip(x_src, y, x_tgt):
    try:
        y_tgt = PchipInterpolator(x_src, y, extrapolate=False)(x_tgt).astype(np.float32)
        if np.isnan(y_tgt).any():
            m = np.isnan(y_tgt)
            if (~m).any():
                y_tgt[m] = np.interp(np.flatnonzero(m), np.flatnonzero(~m), y_tgt[~m]).astype(np.float32)
            else:
                y_tgt = np.zeros_like(y_tgt, dtype=np.float32)
        return y_tgt
    except Exception:
        return np.interp(x_tgt, x_src, y).astype(np.float32)

def _fetch_vec(img_id: str, lead: str, num_rows: int) -> np.ndarray:
    base = str(img_id)
    npz_path = os.path.join(CFG.OUT_DIR, base + ".npz")
    json_path = os.path.join(CFG.OUT_DIR, base + ".json")
    if not os.path.exists(npz_path):
        alt = os.path.join(CFG.OUT_DIR, os.path.basename(base) + ".npz")
        if os.path.exists(alt):
            npz_path = alt
            json_path = os.path.join(CFG.OUT_DIR, os.path.basename(base) + ".json")
    if not os.path.exists(npz_path):
        return np.zeros(num_rows, dtype=np.float32)
    z = np.load(npz_path, allow_pickle=True)
    leads = {k: z[k] for k in z.files if k != "fs"}
    if lead in leads:
        y = leads[lead].astype(np.float32)
    elif "II" in leads:
        y = leads["II"].astype(np.float32)
    else:
        chosen = None
        try:
            with open(json_path, "r") as f:
                J = json.load(f)
            best = None
            for k, q in (J.get("qa") or {}).items():
                s = q.get("snr_db", -1e9)
                if (best is None) or (s > best[1]): best = (k, s)
            if best and best[0] in leads:
                chosen = best[0]
        except Exception:
            pass
        if chosen is None:
            chosen = sorted(leads.keys())[0]
        y = leads[chosen].astype(np.float32)
    if len(y) == num_rows:
        return y
    if len(y) < 2:
        return np.zeros(num_rows, dtype=np.float32)
    x_src = np.linspace(0, 1, len(y), dtype=np.float32)
    x_tgt = np.linspace(0, 1, num_rows, dtype=np.float32)
    return _safe_pchip(x_src, y, x_tgt)

def build_from_test_csv(path: str) -> pd.DataFrame:
    df = pd.read_csv(path)
    need = {"id","lead","number_of_rows"}
    if not need.issubset(df.columns):
        raise RuntimeError(f"{os.path.basename(path)} missing required columns {need}; found {list(df.columns)}")
    rows = []
    for _, r in df.iterrows():
        img_id = str(r["id"]); lead = str(r["lead"]); nrows = int(r["number_of_rows"])
        vec = _fetch_vec(img_id, lead, nrows)
        rows.extend({"id": f"{img_id}_{i}_{lead}", "value": float(v)} for i, v in enumerate(vec))
    return pd.DataFrame(rows)

def build_from_sample_submission(path: str) -> pd.DataFrame:
    df_sample = pd.read_csv(path)
    if "id" not in df_sample.columns:
        raise RuntimeError("sample_submission.csv must contain 'id' column")
    parts = df_sample["id"].astype(str).str.rsplit("_", n=2)
    meta = pd.DataFrame({"img_id": parts.str[0], "index": parts.str[1].astype(int), "lead": parts.str[2]})
    need_len = meta.groupby(["img_id","lead"])["index"].max().astype(int) + 1
    cache = {}
    rows = []
    for rid, (img_id, idx, lead) in zip(df_sample["id"].tolist(), zip(meta["img_id"], meta["index"], meta["lead"])):
        key = (img_id, lead)
        if key not in cache:
            cache[key] = _fetch_vec(img_id, lead, int(need_len.loc[key]))
        vec = cache[key]
        v = float(vec[idx]) if 0 <= idx < len(vec) else 0.0
        rows.append({"id": rid, "value": v})
    return pd.DataFrame(rows)

def build_from_outputs_only() -> pd.DataFrame:
    npzs = sorted(glob.glob(os.path.join(CFG.OUT_DIR, "*.npz")))
    if not npzs:
        # Ensure Kaggle UI sees a file; write a tiny valid CSV
        return pd.DataFrame([{"id": "dummy_0_II", "value": 0.0}])
    rows = []
    for path in npzs:
        base = os.path.splitext(os.path.basename(path))[0]
        # Emit all 12-leads with default expected lengths (II longer)
        for lead in CFG.LEAD_NAMES_12:
            nrows = 5000 if lead=="II" else 1250
            vec = _fetch_vec(base, lead, nrows)
            rows.extend({"id": f"{base}_{i}_{lead}", "value": float(v)} for i, v in enumerate(vec))
    return pd.DataFrame(rows)

test_csv = os.path.join(CFG.ROOT, "test.csv")
sample_csv = os.path.join(CFG.ROOT, "sample_submission.csv")

built = None
if os.path.exists(test_csv):
    try:
        print("Building submission from test.csv …")
        built = build_from_test_csv(test_csv)
    except Exception as e:
        print("test.csv route failed:", e)

if (built is None or built.empty) and os.path.exists(sample_csv):
    try:
        print("Building submission from sample_submission.csv …")
        built = build_from_sample_submission(sample_csv)
    except Exception as e:
        print("sample_submission.csv route failed:", e)

if built is None or built.empty:
    print("No schema CSV found or both routes failed → synthesizing from outputs in /out.")
    built = build_from_outputs_only()

built.to_csv(CFG.SUBMISSION_CSV, index=False)
try:
    built.to_parquet(CFG.SUBMISSION_PARQUET, index=False)
except Exception:
    pass

print(f"\n✅ submission.csv written to: {CFG.SUBMISSION_CSV}")
print(f"Rows: {len(built)}")