In [None]:
# =========================================
# Full pipeline (3/11) — Change-Limit BASE + 3 Creative Mix
#  - BASE: Change-Limit(국소 MAD로 SG 평활과의 차이를 클램프)
#  - + Energy Subpixel(강도×기울기)
#  - + Crest(QRS) mask 보호(스무딩/보정/II 평균)
#  - + Two-pass Narrow Corridor 재탐색
#  - Non-color/test: mean fallback
# =========================================
import os, cv2, numpy as np, pandas as pd, matplotlib.pyplot as plt
from glob import glob
from collections import defaultdict
from tqdm import tqdm
from typing import Tuple

import scipy.optimize, scipy.signal

# -------------------------
# 하이퍼파라미터 (필요시 여기만 조정)
# -------------------------
CROP_TOP = 400          # 헤더 컷
PX_PER_MV = 80          # 80 px = 1 mV (type 3/11)
LEADS = ['I','II','III','aVR','aVL','aVF','V1','V2','V3','V4','V5','V6']
MAX_TIME_SHIFT = 0.2
PERFECT_SCORE = 384

# Subpixel energy
SUBPIX_UPSCALE      = 1          # 1 or 2 (2면 더 미세하되 느려짐)
ENERGY_W_INT        = 0.6        # 강도 가중
ENERGY_W_GRAD       = 0.4        # 기울기 가중
ENERGY_SOFTARG_TAU  = 1.0        # soft-argmax 온도

# Two-pass 협대역
NARROW_BASE_HALF    = 26         # 1차 창 반높이(px)
NARROW_MIN_HALF     = 10         # 2차 최소 반높이
NARROW_MAX_HALF     = 24         # 2차 최대 반높이
NARROW_K_SLOPE      = 120.0      # 기울기→높이 맵핑 상수(클수록 더 좁힘 약함)

# Crest mask
CREST_WIN_FRAC      = 0.004      # 파생량 계산(1차 미디안/SG) 윈도 비율
CREST_PCT_DY        = 92.0       # |dy| 상위 퍼센타일
CREST_PCT_D2        = 92.0       # |d2| 상위 퍼센타일
CREST_DILATE_FRAC   = 0.008      # 마스크 팽창 폭 비율(전체 길이×)
CREST_MIN_LEN       = 7          # 너무 짧은 마스크 제거

# Change-Limit (BASE)
CH_WIN_FRAC         = 0.015      # SG·MAD 창 비율
CH_POLY             = 2
CH_K_OUT            = 3.5        # crest 바깥 MAD 계수
CH_K_IN             = 6.0        # crest 내부 MAD 계수(보호↑)

# II 평균/Einthoven 보정 제약
II_W_OUT            = 0.6        # crest 바깥 가중(II = 0.6, subset = 0.4)
II_W_IN             = 0.85       # crest 내부는 II 가중↑(첨점 보존)
EINTH_MAX_K_OUT     = 1.0        # 바깥 보정량 스케일
EINTH_MAX_K_IN      = 0.3        # 내부(마스크) 보정량 스케일(보정 약화)

# 안전 클리핑
CLIP_MV             = 1.25       # 물리적 과대치 방지; 하드 0으로 날리지 않고 clip

# -------------------------
# Metric helpers
# -------------------------
class ParticipantVisibleError(Exception): pass

def compute_power(label: np.ndarray, prediction: np.ndarray) -> Tuple[float, float]:
    if label.ndim != 1 or prediction.ndim != 1:
        raise ParticipantVisibleError('Inputs must be 1-dimensional arrays.')
    if not np.any(np.isfinite(prediction)):
        raise ParticipantVisibleError("prediction has no finite values.")
    prediction = prediction.copy()
    prediction[~np.isfinite(prediction)] = 0
    noise = label - prediction
    return float(np.sum(label**2)), float(np.sum(noise**2))

def compute_snr(signal: float, noise: float) -> float:
    if noise == 0: return PERFECT_SCORE
    if signal == 0: return 0.0
    return min((signal / noise), PERFECT_SCORE)

def align_signals(label: np.ndarray, pred: np.ndarray, max_shift: float = float('inf')) -> np.ndarray:
    if np.any(~np.isfinite(label)): raise ParticipantVisibleError('values in label should all be finite')
    if np.sum(np.isfinite(pred)) == 0: raise ParticipantVisibleError('prediction can not all be infinite')
    la = np.asarray(label, dtype=np.float64)
    pr = np.asarray(pred,  dtype=np.float64)
    corr = scipy.signal.correlate(la - np.mean(la), pr - np.mean(pr), mode='full')
    lags = scipy.signal.correlation_lags(la.size, pr.size, mode='full')
    valid = (lags >= -max_shift) & (lags <= max_shift)
    max_corr = np.nanmax(corr[valid])
    best_i = min(np.flatnonzero(corr == max_corr), key=lambda i: abs(lags[i]))
    shift = lags[best_i]
    start_pad = max(shift, 0)
    pred_start = max(-shift, 0); pred_end = min(la.size - shift, pr.size)
    end_pad = max(la.size - pr.size - shift, 0)
    aligned = np.concatenate([np.full(start_pad, np.nan), pr[pred_start:pred_end], np.full(end_pad, np.nan)])
    def obj(v): return np.nansum((la - (aligned - v))**2)
    if np.any(np.isfinite(la) & np.isfinite(aligned)):
        res = scipy.optimize.minimize_scalar(obj, method='Brent')
        aligned -= res.x
    return aligned

# -------------------------
# Mean model (fallback)
# -------------------------
def fit_mean_model(train_df, verbose=False):
    mean_dict = defaultdict(list)
    for _, row in tqdm(train_df.iterrows(), total=len(train_df)):
        labels = pd.read_csv(f'/kaggle/input/physionet-ecg-image-digitization/train/{row.id}/{row.id}.csv')
        for lead in labels.columns:
            v = labels[lead].dropna().values
            rs = np.interp(np.linspace(0, len(v)-1, 20000), np.arange(len(v)), v)
            mean_dict[lead].append(rs)
    for k in mean_dict.keys():
        mean_dict[k] = np.stack(mean_dict[k])
        if verbose:
            m = mean_dict[k].mean(axis=0)
            plt.figure(figsize=(12,2)); plt.title(f"Mean {k}"); plt.plot(m); plt.axhline(0,color='gray'); plt.show()
    return mean_dict

def validate_mean_model(val_df, mean_dict):
    snr_list = []
    for _, row in tqdm(val_df.iterrows(), total=len(val_df)):
        labels = pd.read_csv(f'/kaggle/input/physionet-ecg-image-digitization/train/{row.id}/{row.id}.csv')
        sum_signal = 0.0; sum_noise = 0.0
        for lead in labels.columns:
            y = labels[lead].dropna().values
            p = mean_dict[lead].mean(axis=0)
            p = np.interp(np.linspace(0,1,len(y)), np.linspace(0,1,len(p)), p)
            aligned = align_signals(y, p, int(row.fs * MAX_TIME_SHIFT))
            ps, pn = compute_power(y, aligned)
            sum_signal += ps; sum_noise += pn
        snr_list.append(compute_snr(sum_signal, sum_noise))
    snr = float(np.mean(snr_list))
    val_score = max(float(10*np.log10(snr)), -PERFECT_SCORE)
    print(f"# Validation SNR for mean prediction: {snr:.2f} {val_score=:.2f}")

# -------------------------
# Marker finder (스캔 전용)
# -------------------------
class MarkerFinder:
    def __init__(self, show_templates=False):
        ima = np.max([
            cv2.imread('/kaggle/input/physionet-ecg-image-digitization/train/4292118763/4292118763-0001.png'),
            cv2.imread('/kaggle/input/physionet-ecg-image-digitization/train/4289880010/4289880010-0001.png'),
            cv2.imread('/kaggle/input/physionet-ecg-image-digitization/train/4284351157/4284351157-0001.png'),
        ], axis=0)
        absolute_points = np.zeros((17, 2), dtype=int)
        for i in range(3):
            absolute_points[5*i] = np.array([707 + 284*i, 118])
            for j in range(1,5):
                absolute_points[5*i + j] = np.array([707 + 284*i, 118 + 492*j])
        absolute_points[15] = np.array([1535, 118])
        absolute_points[16] = np.array([1535, 118 + 492*4])
        template_positions = [None]*17
        for i in range(len(absolute_points)):
            if absolute_points[i][1] < 118 + 492*4:
                template_positions[i] = (absolute_points[i][0] - (87 if i%5==0 else 37),
                                         absolute_points[i][1] - (50 if i%5==0 else 13))
        template_sizes = np.array([(105, 60)]*17)
        template_points = [
            (np.array([absolute_points[i][0]-template_positions[i][0],
                       absolute_points[i][1]-template_positions[i][1]])
             if template_positions[i] is not None else None)
            for i in range(17)
        ]
        templates = [None]*17
        for i in range(17):
            if template_points[i] is not None:
                t,l = template_positions[i]; h,w = template_sizes[i]
                templates[i] = ima[t:t+h, l:l+w]
        self._template_positions = template_positions
        self._template_sizes = template_sizes
        self._template_points = template_points
        self._templates = templates

    def find_markers(self, ima, warn=False, plot=False, title=''):
        if ima.shape[0] != 1652: raise ValueError("scanned only (3,4,11,12)")
        markers = [None]*17
        for j in range(len(self._templates)):
            if self._template_points[j] is None: continue
            t0 = self._template_positions[j][0]-100
            l0 = max(self._template_positions[j][1]-100, 0)
            h  = self._template_sizes[j][0]; w  = self._template_sizes[j][1]
            sr = ima[t0:self._template_positions[j][0]+100+h,
                     l0:self._template_positions[j][1]+250+w]
            if sr.size == 0: continue
            res = cv2.matchTemplate(sr, self._templates[j], cv2.TM_CCOEFF)
            _, _, _, max_loc = cv2.minMaxLoc(res)
            top_left = max_loc
            m = np.array((t0 + top_left[1] + self._template_points[j][0],
                          l0 + top_left[0] + self._template_points[j][1]))
            markers[j] = m
        for i in range(3):
            if (markers[5*i+3] is not None) and (markers[5*i+2] is not None):
                markers[5*i+4] = markers[5*i+3]*2 - markers[5*i+2]
        if (markers[14] is not None) and (markers[9] is not None):
            markers[16] = ((markers[14]*(284+260) - markers[9]*260) / 284).astype(int)
        if plot:
            vis = ima.copy()
            for m in markers:
                if m is not None:
                    cv2.rectangle(vis, (m[1]-40, m[0]-40), (m[1]+40, m[0]+40), (255,0,0), 2)
            plt.imshow(vis); plt.title(title); plt.show()
        return markers

    @staticmethod
    def lead_info(lead):
        begin, end = {
            'I': (0, 1), 'II-subset': (5, 6), 'III': (10, 11),
            'aVR': (1, 2), 'aVL': (6, 7), 'aVF': (11, 12),
            'V1': (2, 3), 'V2': (7, 8), 'V3': (12, 13),
            'V4': (3, 4), 'V5': (8, 9), 'V6': (13, 14), 'II': (15, 16),
        }[lead]
        return begin // 5, begin, end

mf = MarkerFinder(show_templates=False)

# -------------------------
# 라인(밴드) 탐색 — top-down sweep
# -------------------------
def find_line_by_topdown_sweep(ima_bool):
    top = np.argmin(ima_bool, axis=0)
    H,W = ima_bool.shape
    mask = np.arange(H).reshape(-1,1) >= top.reshape(1,-1)
    ima_bool &= mask
    bottom = np.argmax(ima_bool, axis=0)
    bottomx = np.maximum(bottom, np.median(top) + 100).astype(int)
    mask2 = (np.arange(H).reshape(-1,1) < bottomx.reshape(1,-1))
    ima_bool |= mask2
    ima_bool[:, :-1] |= mask2[:, 1:]
    ima_bool[:, 1:]  |= mask2[:, :-1]
    return top, bottom

# -------------------------
# (Creative #1) Energy-based subpixel centerline
# -------------------------
def _softargmax_idx(vals, tau=1.0):
    v = (vals - np.max(vals)) / max(tau, 1e-6)
    p = np.exp(v); p /= (p.sum() + 1e-8)
    ys = np.arange(len(vals), dtype=np.float32)
    return float((p * ys).sum())

def _quad_peak_idx(vals, idx):
    if idx <= 0 or idx >= len(vals)-1: return float(idx)
    fm, f0, fp = float(vals[idx-1]), float(vals[idx]), float(vals[idx+1])
    denom = (fm - 2*f0 + fp)
    if abs(denom) < 1e-6: return float(idx)
    delta = 0.5 * (fm - fp) / denom
    return float(idx + np.clip(delta, -0.5, 0.5))

def _norm01(x):
    x = x.astype(np.float32)
    if x.size == 0: return x
    mn, mx = float(x.min()), float(x.max())
    if mx - mn < 1e-6: return np.zeros_like(x, dtype=np.float32)
    return (x - mn) / (mx - mn)

def _centerline_energy(grayR, gradY, yt, yb, x):
    colI = grayR[yt:yb+1, x].astype(np.float32)     # 잉크 = R반전
    colG = gradY[yt:yb+1, x].astype(np.float32)     # 수직기울기
    if colI.size == 0: return None
    e = ENERGY_W_INT*_norm01(colI) + ENERGY_W_GRAD*_norm01(np.abs(colG))
    return e

def _energy_peak_y(e, method="soft"):
    if e is None or len(e)==0: return 0.0
    j = int(np.argmax(e))
    if method == "soft":  return _softargmax_idx(e, tau=ENERGY_SOFTARG_TAU)
    else:                 return _quad_peak_idx(e, j)

def _choose_centerline_energy(ima_crop_bgr, top, bottom, x0, x1,
                              upscale=1, method="soft", narrow_map=None):
    """
    열 x는 [x0, x1) 범위. narrow_map은 길이 (x1-x0).
    narrow_map에 접근할 때 절대 x가 아니라 '상대 인덱스'로 접근해야 한다.
    """
    grayR = 255 - ima_crop_bgr[:,:,2]
    gradY = cv2.Sobel(grayR, cv2.CV_32F, 0, 1, ksize=3)

    if upscale > 1:
        H,W = grayR.shape
        grayR = cv2.resize(grayR, (W*upscale, H*upscale), interpolation=cv2.INTER_CUBIC)
        gradY = cv2.resize(gradY, (W*upscale, H*upscale), interpolation=cv2.INTER_CUBIC) * upscale

        top_u = (top * upscale).astype(np.int32)
        bot_u = (bottom * upscale).astype(np.int32)
        x0_u, x1_u = x0*upscale, x1*upscale

        ys_u = np.zeros(max(1, x1_u - x0_u), np.float32)
        H2, W2 = grayR.shape

        for i, xu in enumerate(range(x0_u, min(x1_u, W2))):
            # 절대→상대 인덱스 (열 오프셋)
            if narrow_map is None:
                half_u = NARROW_BASE_HALF * upscale
            else:
                rel = (xu - x0_u) // upscale                            # 0..(x1-x0-1)
                rel = int(np.clip(rel, 0, len(narrow_map)-1))
                half   = int(np.clip(narrow_map[rel], NARROW_MIN_HALF, NARROW_MAX_HALF))
                half_u = half * upscale

            xg = xu // upscale
            # 밴드 중심
            yc_u = int(0.5 * (top_u[xg] + bot_u[xg]))
            yt = int(np.clip(yc_u - half_u, 0, H2-1))
            yb = int(np.clip(yc_u + half_u, 0, H2-1))

            e = _centerline_energy(grayR, gradY, yt, yb, xu)
            ys_u[i] = yt + _energy_peak_y(e, method=method)

        ys = cv2.resize(ys_u.reshape(1,-1), (max(1, x1-x0), 1), interpolation=cv2.INTER_AREA).ravel()
        return ys

    else:
        H,W = grayR.shape
        ys = np.zeros(max(1, x1 - x0), np.float32)

        for i, x in enumerate(range(x0, min(x1, W))):
            # 절대→상대 인덱스 (열 오프셋 i 사용)
            if narrow_map is None:
                half = NARROW_BASE_HALF
            else:
                rel = int(np.clip(i, 0, len(narrow_map)-1))             # 0..(x1-x0-1)
                half = int(np.clip(narrow_map[rel], NARROW_MIN_HALF, NARROW_MAX_HALF))

            yc = int(0.5 * (top[x] + bottom[x]))
            yt = int(np.clip(yc - half, 0, H-1))
            yb = int(np.clip(yc + half, 0, H-1))

            e = _centerline_energy(grayR, gradY, yt, yb, x)
            ys[i] = yt + _energy_peak_y(e, method=method)

        if len(ys) >= 3:
            ys = cv2.blur(ys.reshape(1,-1), (1,3)).ravel()
        return ys


# -------------------------
# (Creative #2) Crest mask (QRS 보호)
# -------------------------
def _sg(y, win):
    win = max(5, win + (win%2==0))
    if len(y) < win: return y.copy()
    return scipy.signal.savgol_filter(y, win, 2, mode='interp').astype(np.float32)

def build_crest_mask(y):
    n = len(y)
    if n < 9: return np.zeros(n, dtype=bool)
    w = max(5, int(n*CREST_WIN_FRAC)); w += (w%2==0)
    y_s = _sg(y, w)
    dy  = np.diff(y_s, prepend=y_s[:1])
    d2  = np.diff(dy, prepend=dy[:1])

    thr1 = np.percentile(np.abs(dy), CREST_PCT_DY)
    thr2 = np.percentile(np.abs(d2), CREST_PCT_D2)
    mask = (np.abs(dy) >= thr1) | (np.abs(d2) >= thr2)

    # dilation
    dil = max(3, int(n*CREST_DILATE_FRAC))
    out = np.zeros_like(mask)
    for i, v in enumerate(mask):
        if not v: continue
        s = max(0, i-dil); e = min(n, i+dil+1)
        out[s:e] = True

    # 짧은 조각 제거
    clean = out.copy()
    cnt = 0; s = None
    for i, v in enumerate(out.tolist()+[False]):
        if v and s is None: s = i
        if not v and s is not None:
            if (i - s) < CREST_MIN_LEN:
                clean[s:i] = False
            s = None
    return clean

# -------------------------
# (BASE + Creative #3) Change-Limit + crest-aware
# -------------------------
def change_limit_filter(y, crest_mask):
    n = len(y)
    if n < 9: return y.astype(np.float32)
    w = max(7, int(n*CH_WIN_FRAC)); w += (w%2==0)
    ys = _sg(y, w)
    # rolling MAD
    half = w//2
    mad = np.zeros(n, np.float32)
    for i in range(n):
        s = max(0, i-half); e = min(n, i+half+1)
        seg = y[s:e] - ys[s:e]
        med = np.median(seg)
        mad[i] = 1.4826*np.median(np.abs(seg - med)) + 1e-6
    k = np.where(crest_mask, CH_K_IN, CH_K_OUT)
    allowed = k * mad
    y_lim = ys + np.clip(y - ys, -allowed, allowed)
    return np.clip(y_lim, -CLIP_MV, CLIP_MV).astype(np.float32)

# -------------------------
# 3/11 중심선→전위 변환 (Two-pass + Energy subpixel)
# -------------------------
def get_lead_from_energy_two_pass(ima_crop_bgr, tops, bottoms, lead, number_of_rows, markers):
    line, b_idx, e_idx = mf.lead_info(lead)
    top, bottom = tops[line], bottoms[line]
    begin, end = markers[b_idx], markers[e_idx]
    W = len(top)
    if (begin is None) or (end is None):
        x0, x1 = 0, W
        base = np.linspace(top[x0], bottom[x1-1], max(1, x1-x0))  # 대략적(실제는 밴드 중간)
    else:
        x0 = int(np.clip(begin[1], 0, W-1))
        x1 = int(np.clip(end[1],   0, W))
        if x1 <= x0: x1 = min(W, x0+1)
        base = np.linspace(begin[0]-CROP_TOP, end[0]-CROP_TOP, max(1, x1-x0))

    # 1-pass: 기본 NARROW_BASE_HALF로 에너지 중심선
    ys1 = _choose_centerline_energy(ima_crop_bgr, top, bottom, x0, x1,
                                    upscale=SUBPIX_UPSCALE, method="soft", narrow_map=None)

    # 열별 기울기 → 협대역 지도
    if len(ys1) >= 3:
        slope = np.abs(np.diff(ys1, prepend=ys1[:1]))
        slope_n = _norm01(slope)
        hmap = NARROW_MAX_HALF - (slope_n * (NARROW_MAX_HALF - NARROW_MIN_HALF) * (ys1.size / NARROW_K_SLOPE))
        hmap = np.clip(hmap, NARROW_MIN_HALF, NARROW_MAX_HALF)
    else:
        hmap = np.full_like(ys1, NARROW_BASE_HALF, dtype=np.float32)
   
    # 2-pass: 협대역으로 재탐색
    ys2 = _choose_centerline_energy(ima_crop_bgr, top, bottom, x0, x1,
                                    upscale=SUBPIX_UPSCALE, method="soft", narrow_map=hmap)

    base = base[:len(ys2)]
    pred = (base - ys2) / PX_PER_MV

    # 마커 가림 완화(경계 튐 방지)
    if lead in ['aVR','aVL','aVF','V1','V2','V3','V4','V5','V6'] and len(pred) >= 5:
        pred[:4][pred[:4] > 0.2] = pred[4]
    if lead in ['I','II-subset','III','aVR','aVL','aVF','V1','V2','V3'] and len(pred) >= 6:
        pred[-5:][pred[-5:] > 0.2] = pred[-6]
    if lead in ['I','II-subset','III','II']:
        if len(pred) >= 2 and 0.9 < pred[0] < 1.1 and pred[1] < 0.5:
            pred[0] = pred[1]

    pred = np.interp(np.linspace(0,1,number_of_rows), np.linspace(0,1,len(pred)), pred)
    return np.clip(pred, -CLIP_MV, CLIP_MV).astype(np.float32)

# -------------------------
# Einthoven 보정(crest-aware) + II 평균(crest-aware)
# -------------------------
def crest_aware_einthoven_and_II(preds):
    # per-lead crest mask
    masks = {ld: build_crest_mask(preds[ld]) for ld in LEADS}

    # II vs II-subset은 호출부에서 가중 평균/보호 적용하므로 여기선 스킵

    # I + III ≈ II (crest 내부 보정량 축소)
    L = min(len(preds['I']), len(preds['III']), len(preds['II']))
    if L > 0:
        res = preds['I'][:L] + preds['III'][:L] - preds['II'][:L]
        # 보정량 스케일
        m = (masks['I'][:L] | masks['III'][:L] | masks['II'][:L])
        scale = np.where(m, EINTH_MAX_K_IN, EINTH_MAX_K_OUT).astype(np.float32)
        corr = (res/3) * scale
        preds['I'][:L]   -= corr
        preds['III'][:L] -= corr
        preds['II'][:L]  += corr

    # aVR + aVL + aVF ≈ 0
    L2 = min(len(preds['aVR']), len(preds['aVL']), len(preds['aVF']))
    if L2 > 0:
        res2 = preds['aVR'][:L2] + preds['aVL'][:L2] + preds['aVF'][:L2]
        m2 = (masks['aVR'][:L2] | masks['aVL'][:L2] | masks['aVF'][:L2])
        scale2 = np.where(m2, EINTH_MAX_K_IN, EINTH_MAX_K_OUT).astype(np.float32)
        corr2 = (res2/3) * scale2
        preds['aVR'][:L2] -= corr2
        preds['aVL'][:L2] -= corr2
        preds['aVF'][:L2] -= corr2

def crest_aware_merge_II(ii, ii_subset):
    n = min(len(ii), len(ii_subset))
    if n == 0: return ii
    mask = build_crest_mask(ii[:n]) | build_crest_mask(ii_subset[:n])
    out = ii.copy()
    # crest 내부: II 가중↑ + 피크 보호(더 큰 절댓값 유지)
    mix_in  = II_W_IN*ii[:n] + (1.0-II_W_IN)*ii_subset[:n]
    choose  = np.where(np.abs(ii[:n])>=np.abs(ii_subset[:n]), ii[:n], ii_subset[:n])
    out[:n] = np.where(mask, choose, II_W_OUT*ii[:n] + (1.0-II_W_OUT)*ii_subset[:n])
    # 경계 부드럽게
    k = max(3, int(0.004*n)); k += (k%2==0)
    out[:n] = _sg(out[:n], k)
    return np.clip(out, -CLIP_MV, CLIP_MV).astype(np.float32)

# -------------------------
# 변환(3/11 컬러)
# -------------------------
def convert_scanned_color(ima, markers, n_timesteps, verbose=False):
    crop_top = CROP_TOP
    # R 채널 기준 이진화(밴드 찾기용)
    bw0 = (ima[crop_top:, :, 2] > 160).astype(np.uint8)
    bw  = (bw0[:-2, :-2] + bw0[:-2, 1:-1] + bw0[:-2, 2:]
         + bw0[1:-1, :-2] + bw0[1:-1, 1:-1] + bw0[1:-1, 2:]
         + bw0[2:, :-2] + bw0[2:, 1:-1] + bw0[2:, 2:]) >= 7
    if verbose:
        plt.figure(figsize=(6,3)); plt.imshow(bw, cmap='gray'); plt.title('BW majority (R>160)'); plt.show()
    # 밴드(4줄) 찾기
    tops, bottoms = [], []; work = bw.copy()
    for _ in range(4):
        top, bottom = find_line_by_topdown_sweep(work)
        tops.append(top); bottoms.append(bottom)

    # 리드별 예측
    n_ts = dict(n_timesteps); n_ts['II-subset'] = n_ts['I']
    preds = {}
    ima_crop = ima[crop_top:, :, :]
    for lead in (LEADS + ['II-subset']):
        y0 = get_lead_from_energy_two_pass(ima_crop, tops, bottoms, lead, n_ts[lead], markers)
        # BASE: Change-Limit + crest-aware
        crest = build_crest_mask(y0)
        y1 = change_limit_filter(y0, crest)
        preds[lead] = y1

    # II + II-subset crest-aware 병합
    L = min(len(preds['II']), len(preds['II-subset']))
    if L > 0:
        ii_merged = crest_aware_merge_II(preds['II'][:L], preds['II-subset'][:L])
        preds['II'][:L] = ii_merged
    if 'II-subset' in preds: del preds['II-subset']

    # Einthoven crest-aware
    crest_aware_einthoven_and_II(preds)

    # 최종 clip
    for ld in LEADS:
        preds[ld] = np.clip(preds[ld], -CLIP_MV, CLIP_MV).astype(np.float32)
    return preds

# -------------------------
# 간단 컬러 감지
# -------------------------
def is_color_image(ima):
    return ima.std(axis=2).mean() != 0

# -------------------------
# 검증 루틴(3/11만)
# -------------------------
def validate_algorithm(train_df, image_types, convert_fn):
    snr_list = []; index_list = []; is_first = True
    for idx, row in train_df.iterrows():
        labels = pd.read_csv(f'/kaggle/input/physionet-ecg-image-digitization/train/{row.id}/{row.id}.csv')
        for path in sorted(glob(f'/kaggle/input/physionet-ecg-image-digitization/train/{row.id}/{row.id}-*.png')):
            t = int(path[-8:-4])
            if t not in image_types: continue
            ima = cv2.imread(path)
            markers = mf.find_markers(ima, plot=is_first, title='markers' if is_first else '')
            n_ts = {ld: (~labels[ld].isna()).sum() for ld in LEADS}
            preds = convert_fn(ima, markers, n_ts, verbose=is_first)

            if is_first:
                _, axs = plt.subplots(6,2, figsize=(12,18))
            sum_signal = 0; sum_noise = 0
            for i, ld in enumerate(LEADS):
                y = labels[ld].dropna().values; p = preds[ld]
                aligned = align_signals(y, p, int(row.fs*MAX_TIME_SHIFT))
                ps, pn = compute_power(y, aligned); sum_signal += ps; sum_noise += pn
                if is_first:
                    ax = axs.T.ravel()[i]; ax.set_title(ld)
                    ax.plot(y, label='y_true'); ax.plot(p, label='y_pred'); ax.legend()
            if is_first: plt.tight_layout(); plt.suptitle('y_true vs y_pred', y=1.01); plt.show()

            snr = compute_snr(sum_signal, sum_noise)
            print(f"idx={idx:4d} id={row.id:10d} type={t:2d} SNR: {snr:5.2f}")
            snr_list.append(snr); index_list.append([idx, t])
            is_first = False

    snr = float(np.mean(snr_list)) if snr_list else 0.0
    val_score = max(float(10*np.log10(snr)) if snr>0 else -PERFECT_SCORE, -PERFECT_SCORE)
    print(f"# Average SNR: {snr:.2f} {val_score=:.2f}")
    pd.DataFrame(index_list, columns=['idx','type']).assign(snr=snr_list).to_csv('~snr.csv', index=False)

# -------------------------
# Main
# -------------------------
if __name__ == "__main__":
    train = pd.read_csv('/kaggle/input/physionet-ecg-image-digitization/train.csv')
    test  = pd.read_csv('/kaggle/input/physionet-ecg-image-digitization/test.csv')

    # mean fallback 준비
    split = 780
    mean_dict = fit_mean_model(train.iloc[:split], verbose=False)
    validate_mean_model(train.iloc[split:], mean_dict)
    mean_dict = fit_mean_model(train, verbose=False)

    # 검증: 컬러(3,11)만
    print("\n[Validate] Color scans (3,11) — BASE: Change-Limit + Energy + Crest + Narrow")
    validate_algorithm(train.iloc[100:110], image_types=[3,11], convert_fn=convert_scanned_color)

    # 테스트 변환/제출(컬러만 변환, 아니면 mean)
    submission = []; old_id=None; preds_cache=None
    for _, row in test.iterrows():
        if row.id != old_id:
            path = f"/kaggle/input/physionet-ecg-image-digitization/test/{row.id}.png"
            ima  = cv2.imread(path); preds_cache=None
            if ima is not None and ima.shape[0]==1652 and is_color_image(ima):
                try:
                    ms  = mf.find_markers(ima, plot=False)
                    n_ts = {ld: (row.fs*10 if ld=='II' else row.fs*10//4) for ld in LEADS}
                    preds_cache = convert_scanned_color(ima, ms, n_ts, verbose=False)
                    reason = "converted_color_energyCL"
                except Exception as e:
                    print(f"[WARN] convert failed {row.id}: {e} -> mean fallback"); preds_cache=None
            else:
                reason = "mean_fallback"
            print(f"[TEST] id={row.id} route={reason}")
            old_id = row.id

        need = row.fs*10 if row.lead=='II' else row.fs*10//4
        if preds_cache is not None:
            pred = preds_cache[row.lead]
        else:
            pred = mean_dict[row.lead].mean(axis=0)
            pred = np.interp(np.linspace(0,1,need), np.linspace(0,1,len(pred)), pred)
        assert len(pred)==need
        for t in range(need):
            submission.append({"id": f"{row.id}_{t}_{row.lead}", "value": float(pred[t])})

    sub = pd.DataFrame(submission)
    sub.to_csv('submission.csv', index=False)
    print("Length:", len(sub))
    print(sub.head())