In [None]:
# ============================================================
# PhysioNet ECG Digitization — Full Pipeline (time-calibrated + phase-align)
# + Paper speed(25/50) detection & override
# + TTA with shared time-scale (spp_ref) and clamp
# + Light 1D-CNN denoiser (optional training & inference)
# ============================================================

import os, glob, cv2, math, warnings
warnings.filterwarnings("ignore")
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from scipy.signal import butter, filtfilt, find_peaks

# -----------------------
# Paths / Config
# -----------------------
TRAIN_DIR    = '/kaggle/input/physionet-ecg-image-digitization/train/'
TRAIN_CSV    = '/kaggle/input/physionet-ecg-image-digitization/train.csv'
TEST_DIR     = '/kaggle/input/physionet-ecg-image-digitization/test/'
TEST_CSV     = '/kaggle/input/physionet-ecg-image-digitization/test.csv'
WORK_DIR     = '/kaggle/working'

TEMPLATE_NPZ = os.path.join(WORK_DIR, 'lead_templates_beats.npz')  # 저장만, 로딩은 안 함
VIS_DIR      = os.path.join(WORK_DIR, 'train_vis')
os.makedirs(VIS_DIR, exist_ok=True)

SUBMISSION_CSV = os.path.join(WORK_DIR, 'submission.csv')

# 리드 순서 (3x4 패널 가정)
LEAD_GRID = [
    ["I","II","III","aVR"],
    ["aVL","aVF","V1","V2"],
    ["V3","V4","V5","V6"],
]
LEADS = sum(LEAD_GRID, [])

# 제출 스케일
MIN_VAL, MAX_VAL = 0.0, 0.07

# ROI/DP/블렌딩 파라미터 (약간 튜닝)
INK_GRAY_THR     = 48
LOCAL_INK_THR    = 0.06
MARGIN_COLS      = 8

DP_LAMBDA        = 1.25
DP_WIN_FRAC      = 0.10
DP_EDGE_GAIN     = 0.45

CONF_BAND        = 3
USE_IMG_MIN_CONF = 0.20

# 템플릿 (beat-wise)
TEMPLATE_BEAT_LEN = 360
R_PRE_S           = 0.22
R_POST_S          = 0.42

# ---- Paper speed override (환경변수로 강제 가능: '25' 또는 '50') ----
PAPER_SPEED_OVERRIDE = os.getenv("ECG_PAPER_SPEED", "").strip()  # "25" or "50" or ""

# -----------------------
# TTA (증강) 설정
# -----------------------
TTA_ENABLED = bool(int(os.getenv("ECG_TTA", "1")))   # 1=켜기, 0=끄기
MAX_AUG     = int(os.getenv("ECG_TTA_N", "6"))
TTA_AGG     = os.getenv("ECG_TTA_AGG", "weighted_mean")  # 'weighted_mean' or 'median'

# 미세 지오메트릭 + 포토메트릭 (시간스케일 안정 위해 각/시어는 매우 작게)
AUG_PRESETS = [
    {"angle": -0.2}, {"angle": +0.2},
    {"shear": -0.3}, {"shear": +0.3},
    {"tx": -6.0}, {"tx": +6.0},
    {"alpha": 1.10}, {"alpha": 0.90},
    {"gamma": 0.90}, {"gamma": 1.10},
]

# -----------------------
# Denoiser(학습/추론) 설정
# -----------------------
DENOISER_ENABLE      = bool(int(os.getenv("ECG_DENOISER", "1")))   # 1=사용, 0=비사용
DENOISER_TRAIN       = bool(int(os.getenv("ECG_DENOISER_TRAIN", "0"))) # 1=학습 실행
DENOISER_EPOCHS      = int(os.getenv("ECG_DENOISER_EPOCHS", "1"))
DENOISER_LR          = float(os.getenv("ECG_DENOISER_LR", "1e-3"))
DENOISER_PATH        = os.getenv("ECG_DENOISER_PATH", os.path.join(WORK_DIR, "denoiser1d.pt"))
DENOISER_FREQ_LOSS_W = float(os.getenv("ECG_DENOISER_FREQW", "0.2"))  # 주파수 손실 가중치
DENOISER_USE_TPL_CH  = bool(int(os.getenv("ECG_DENOISER_TPLCH", "1"))) # 템플릿 채널 사용

# -----------------------
# Utils
# -----------------------
def lowpass(x, fs, cutoff_hz=15.0, order=2):
    x = np.asarray(x, dtype=np.float32)
    if x.size <= 10: return x
    nyq = 0.5*float(fs); wn = min(cutoff_hz/max(nyq,1e-6), 0.99)
    b,a = butter(order, wn, btype='low')
    return filtfilt(b,a,x).astype(np.float32)

def zscore(x):
    x = np.asarray(x, dtype=np.float32)
    return (x - np.mean(x)) / (np.std(x)+1e-8)

def rescale_range(x, lo=MIN_VAL, hi=MAX_VAL):
    x = np.asarray(x, dtype=np.float32)
    mn, mx = float(np.min(x)), float(np.max(x))
    if not np.isfinite(mn) or not np.isfinite(mx) or mx <= mn:
        return np.full_like(x, (lo+hi)/2, dtype=np.float32)
    y = (x - mn) / (mx - mn)
    return (lo + y * (hi-lo)).astype(np.float32)

def tukey_window(n, alpha=0.25):
    if n <= 1: return np.ones(n, np.float32)
    w = np.ones(n, np.float32)
    e = int(alpha*(n-1)/2.0)
    if e > 0:
        ramp = (1 - np.cos(np.linspace(0, np.pi, e*2, dtype=np.float32)))/2.0
        w[:e] = ramp[:e]; w[-e:] = ramp[-e:]
    return w

def sigmoid_blend(alpha, k=8.0, bias=-0.10, lo=0.12, hi=0.92):
    s = 1.0/(1.0 + np.exp(-k*(alpha + bias)))
    return float(np.clip(lo + (hi-lo)*s, lo, hi))

# -----------------------
# Augmentation helpers
# -----------------------
def _photometric_bgr(img_bgr, alpha=1.0, beta=0.0, gamma=1.0):
    x = img_bgr.astype(np.float32)
    if alpha is None: alpha = 1.0
    if beta  is None: beta  = 0.0
    if gamma is None: gamma = 1.0
    x = x * float(alpha) + float(beta)
    x = np.clip(x, 0, 255)
    if abs(float(gamma) - 1.0) > 1e-6:
        x = (x / 255.0) ** (1.0 / float(gamma))
        x = np.clip(x * 255.0, 0, 255)
    return x.astype(np.uint8)

def _affine_shear_rotate_translate(img_bgr, angle=0.0, shear=0.0, tx=0.0, ty=0.0, scale=1.0):
    H, W = img_bgr.shape[:2]
    cx, cy = (W-1)*0.5, (H-1)*0.5
    def _to33(M23):
        M33 = np.eye(3, dtype=np.float32)
        M33[:2,:3] = M23
        return M33
    C    = np.array([[1,0,-cx],[0,1,-cy],[0,0,1]], np.float32)
    Cinv = np.array([[1,0, cx],[0,1, cy],[0,0,1]], np.float32)
    R23  = cv2.getRotationMatrix2D((0,0), float(angle), float(scale))
    R    = _to33(R23)
    sh = math.tan(math.radians(float(shear)))
    S  = np.array([[1, sh, 0],[0,1,0],[0,0,1]], np.float32)
    T  = np.array([[1,0,float(tx)],[0,1,float(ty)],[0,0,1]], np.float32)
    M  = T @ Cinv @ R @ S @ C
    M23 = M[:2,:]
    return cv2.warpAffine(img_bgr, M23, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)

def augment_panel(panel_bgr, angle=0.0, shear=0.0, tx=0.0, ty=0.0, scale=1.0,
                  alpha=1.0, beta=0.0, gamma=1.0):
    out = _affine_shear_rotate_translate(panel_bgr, angle=angle, shear=shear, tx=tx, ty=ty, scale=scale)
    out = _photometric_bgr(out, alpha=alpha, beta=beta, gamma=gamma)
    return out

# -----------------------
# 0) 패널 분할 (3x4 균등)
# -----------------------
def split_3x4_panels(img_bgr, trim=6):
    H, W = img_bgr.shape[:2]
    w = W // 4; h = H // 3
    panels = {}
    for r in range(3):
        for c in range(4):
            y0, y1 = r*h, (r+1)*h
            x0, x1 = c*w, (c+1)*w
            roi = img_bgr[y0:y1, x0:x1]
            if trim>0: roi = roi[trim:-trim, trim:-trim]
            lead = LEAD_GRID[r][c]
            panels[lead] = roi
    return panels

# -----------------------
# 1) 격자 억제 + 활성구간
# -----------------------
def degrid_gray(bgr):
    hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
    h,s,v = cv2.split(hsv)
    red = ((h < 12) | (h > 168)) & (s > 60)
    gray = v.copy()
    gray[red] = np.median(v)
    gray = cv2.GaussianBlur(gray, (3,3), 0)
    gray = cv2.normalize(gray, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    return gray

def find_active_columns_longest(gray, ink_thr=INK_GRAY_THR, min_ratio=LOCAL_INK_THR, margin=MARGIN_COLS):
    H, W = gray.shape
    inv = 255 - gray
    ink = (inv > ink_thr).astype(np.uint8)
    col_sum = ink.sum(axis=0)
    thr = max(int(min_ratio*H), 5)
    mask = col_sum >= thr

    best_len, best = 0, (0, W-1); s = None
    for i, v in enumerate(mask):
        if v and s is None: s = i
        if ((not v) or i==W-1) and s is not None:
            e = i if not v else i
            if e - s > best_len:
                best_len = e - s
                best = (max(0, s-margin), min(W-1, e+margin))
            s = None
    lo, hi = best
    return lo, hi

# -----------------------
# 1.5) 격자 간격 기반 보정 (NEW)
# -----------------------
def _red_mask(bgr):
    hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
    h,s,v = cv2.split(hsv)
    m1 = (h <= 10) & (s >= 50) & (v >= 40)
    m2 = (h >= 170) & (s >= 50) & (v >= 40)
    mask = (m1 | m2).astype(np.uint8)*255
    mask = cv2.medianBlur(mask, 3)
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3,3), np.uint8))
    return mask

def _autocorr_period(sig, min_lag=6, max_lag=None):
    sig = np.asarray(sig, np.float32)
    sig = (sig - sig.mean())/(sig.std()+1e-6)
    ac = np.correlate(sig, sig, mode='full')[len(sig)-1:]
    if max_lag is None: max_lag = len(ac)//2
    min_lag = max(min_lag, 1); max_lag = max(max_lag, min_lag+1)
    if max_lag <= min_lag: return None
    k = np.argmax(ac[min_lag:max_lag]) + min_lag
    return int(k)

def measure_grid_spacing_bgr(panel_bgr):
    H, W = panel_bgr.shape[:2]
    mask = _red_mask(panel_bgr)
    col = mask.sum(axis=0)
    row = mask.sum(axis=1)

    dx_small = _autocorr_period(col, min_lag=max(6, W//200), max_lag=W//2)
    dy_small = _autocorr_period(row, min_lag=max(6, H//200), max_lag=H//2)

    dx_big = int(dx_small*5) if dx_small else None
    dy_big = int(dy_small*5) if dy_small else None
    if dx_big is not None and (dx_big <= 0 or dx_big > W): dx_big = None
    if dy_big is not None and (dy_big <= 0 or dy_big > H): dy_big = None
    return dx_big, dy_big

# ---- 속도표기(25/50) 감지 ----
def detect_paper_speed_label(img_bgr):
    """
    returns: 25, 50, or None
    - 하단 30%에서 '25' / '50' 템플릿매칭
    - 빨간 글자 우선, 약하면 흑백 adaptive threshold
    - 환경변수 ECG_PAPER_SPEED로 강제 가능
    """
    if PAPER_SPEED_OVERRIDE in ("25","50"):
        return int(PAPER_SPEED_OVERRIDE)

    H, W = img_bgr.shape[:2]
    roi = img_bgr[int(H*0.70):, :]
    hsv = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)
    h,s,v = cv2.split(hsv)
    red = (((h < 15) | (h > 165)) & (s > 60) & (v > 40)).astype(np.uint8)*255

    gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
    mask = red.copy()
    if mask.sum() < 0.002 * mask.size:
        mask = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C,
                                     cv2.THRESH_BINARY_INV, 31, 5)
    vis = cv2.bitwise_and(gray, mask)

    def best_score_for(label):
        base = np.zeros((40, 90), np.uint8)
        cv2.putText(base, label, (2, 34), cv2.FONT_HERSHEY_SIMPLEX, 1.2, 255, 2, cv2.LINE_AA)
        best = -1.0
        for sc in (0.8, 0.9, 1.0, 1.1, 1.2):
            t = cv2.resize(base, (int(base.shape[1]*sc), int(base.shape[0]*sc)), interpolation=cv2.INTER_AREA)
            if vis.shape[0] < t.shape[0] or vis.shape[1] < t.shape[1]:
                continue
            res = cv2.matchTemplate(vis, t, cv2.TM_CCOEFF_NORMED)
            if res.size:
                best = max(best, float(res.max()))
        return best

    s25 = best_score_for("25")
    s50 = best_score_for("50")
    thr = 0.35
    if s25 < thr and s50 < thr:
        return None
    return 25 if s25 >= s50 else 50

# -----------------------
# (A) Hough fallback (기존)
# -----------------------
def sec_per_pixel_from_grid_calibrated(gray, x_lo, x_hi, T_known):
    g = cv2.equalizeHist(gray)
    edges = cv2.Canny(g, 40, 100)
    lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=120, minLineLength=40, maxLineGap=6)

    d = None
    if lines is not None:
        xs = [int(round((x1+x2)/2)) for x1,y1,x2,y2 in lines[:,0]
              if abs(x1-x2) < 3 and abs(y1-y2) > 25]
        xs = np.array(sorted(set(xs)))
        if len(xs) >= 3:
            diffs = np.diff(xs)
            diffs = diffs[(diffs>5) & (diffs<2000)]
            if len(diffs): d = float(np.median(diffs))

    cands = []
    if d and d > 0:
        cands += [0.2/d, 0.1/d]

    Wroi = max(1, (x_hi - x_lo + 1))
    if not cands:
        return T_known / Wroi

    best = None; best_err = 1e9
    for spp in cands:
        err = abs(Wroi*spp - T_known)
        if err < best_err:
            best_err, best = err, spp
    return best

# ---- s/px 산출 (속도표기 반영) ----
def sec_per_pixel_from_grid(panel_bgr, x_lo, x_hi, T_known, gray_fallback=None, paper_speed=None):
    dx_big, _ = measure_grid_spacing_bgr(panel_bgr)
    Wroi = max(1, (x_hi - x_lo + 1))

    cands = []
    if dx_big:
        if paper_speed == 25:
            cands = [0.2/float(dx_big)]
        elif paper_speed == 50:
            cands = [0.1/float(dx_big)]
        else:
            cands = [0.2/float(dx_big), 0.1/float(dx_big)]

    if not cands and gray_fallback is not None:
        spp_hough = sec_per_pixel_from_grid_calibrated(gray_fallback, x_lo, x_hi, T_known)
        if spp_hough is not None:
            cands.append(spp_hough)

    if not cands:
        return T_known / float(Wroi)

    best_spp, best_err = None, 1e18
    for spp in cands:
        err = abs(Wroi*spp - T_known)
        if err < best_err:
            best_err, best_spp = err, spp

    if len(cands) == 2:
        e0 = abs(Wroi * cands[0] - T_known)
        e1 = abs(Wroi * cands[1] - T_known)
        if abs(e0 - e1) < 0.06 * T_known:
            best_spp = cands[0]  # 25 mm/s 선호

    return best_spp

# -----------------------
# 2) DP centerline
# -----------------------
def dp_trace_center(gray, x_lo, x_hi):
    H, W = gray.shape
    x_lo = max(0, int(x_lo)); x_hi = min(W-1, int(x_hi))
    if x_hi <= x_lo: x_lo, x_hi = 0, W-1

    g = gray.astype(np.float32)/255.0
    inv = 1.0 - g
    gx = cv2.Sobel(g, cv2.CV_32F, 1, 0, ksize=3)
    gy = cv2.Sobel(g, cv2.CV_32F, 0, 1, ksize=3)
    edge = cv2.magnitude(gx, gy); edge = edge / (edge.max()+1e-6)

    cost_img = (1.0 - (inv[:, x_lo:x_hi+1] + DP_EDGE_GAIN*edge[:, x_lo:x_hi+1]))
    cost_img = np.clip(cost_img, 0.0, 2.0).astype(np.float32)

    Hc, Wc = cost_img.shape
    win = max(1, int(DP_WIN_FRAC*Hc)); lam = DP_LAMBDA
    C = np.full((Hc, Wc), np.inf, np.float32)
    P = np.full((Hc, Wc), -1, np.int32)
    C[:,0] = cost_img[:,0]

    for x in range(1, Wc):
        prev = C[:, x-1]
        for dy in range(-win, win+1):
            trans = 0.0 if dy==0 else lam*abs(dy)
            if dy < 0:
                src = prev[-dy:Hc]; dst_rows = slice(0, Hc+dy)
            elif dy > 0:
                src = prev[0:Hc-dy]; dst_rows = slice(dy, Hc)
            else:
                src = prev; dst_rows = slice(0, Hc)
            cand = src + cost_img[dst_rows, x] + trans
            better = cand < C[dst_rows, x]
            C[dst_rows, x][better] = cand[better]
            P[dst_rows, x][better] = np.arange(len(src), dtype=np.int32)[better]

    y = int(np.argmin(C[:, -1]))
    ys = [y]
    for x in range(Wc-1, 0, -1):
        py = P[ys[-1], x]
        if py < 0: py = ys[-1]
        ys.append(py)
    ys = ys[::-1]
    xs = np.arange(x_lo, x_hi+1, dtype=np.int32)
    if len(xs) != len(ys): xs = np.linspace(x_lo, x_hi, len(ys), dtype=np.int32)
    return xs, np.array(ys, np.int32)

def path_confidence(gray, xs, ys, band=CONF_BAND):
    H, W = gray.shape
    xs = np.clip(np.asarray(xs, np.int32), 0, W-1)
    ys = np.clip(np.asarray(ys, np.int32), 0, H-1)

    inv = 255 - gray
    gx = cv2.Sobel(gray, cv2.CV_32F, 1,0,ksize=3)
    gy = cv2.Sobel(gray, cv2.CV_32F, 0,1,ksize=3)
    mag = cv2.magnitude(gx, gy)

    ink_cnt, ink_hit, edge_acc = 0, 0, 0.0
    for x,y in zip(xs,ys):
        y0, y1 = max(0,y-band), min(H-1,y+band)
        line = inv[y0:y1+1, x]; edge = mag[y0:y1+1, x]
        ink_cnt += line.size; ink_hit += int((line > INK_GRAY_THR).sum())
        edge_acc += float(np.mean(edge))
    ink_ratio = 0.0 if ink_cnt==0 else ink_hit/float(ink_cnt)
    edge_norm = (edge_acc / max(len(xs),1)) / 50.0
    edge_norm = float(np.clip(edge_norm, 0.0, 1.0))
    conf = 0.35 * np.clip((ink_ratio - 0.10)/0.35, 0.0, 1.0) + 0.65 * edge_norm
    return float(np.clip(conf, 0.0, 1.0))

# ROI → 값 시퀀스 (시간보정 포함)
def roi_series_from_path_t(gray, xs, ys, spp, n_out, fs):
    inv = 255 - gray
    band = 2
    vals = []
    for x,y in zip(xs,ys):
        y0, y1 = max(0,y-band), min(gray.shape[0]-1,y+band)
        vals.append(np.mean(inv[y0:y1+1, x]))
    vals = np.array(vals, np.float32)
    vals = zscore(vals); vals = lowpass(vals, fs, 20.0, 2)

    T = (n_out-1)/float(fs)
    t_pix = (np.arange(len(vals), dtype=np.float32)) * float(spp)
    t_pix -= t_pix[0]
    t_pix = np.clip(t_pix, 0.0, T)

    t_out = np.linspace(0.0, T, n_out, dtype=np.float32)
    out = np.interp(t_out, t_pix, vals).astype(np.float32)
    return out

# -----------------------
# 서로상관 기반 시프트 정렬
# -----------------------
def phase_align_to_template(y_img, y_tpl):
    a = zscore(y_img); b = zscore(y_tpl)
    L = min(len(a), len(b))
    if L < 8: return y_img
    c = np.correlate(a[:L], b[:L], mode="full")
    lag = int(np.argmax(c) - (L-1))
    if lag > 0:
        y_img = np.r_[np.zeros(lag, dtype=y_img.dtype), y_img[:-lag]]
    elif lag < 0:
        lag = -lag
        y_img = np.r_[y_img[lag:], np.zeros(lag, dtype=y_img.dtype)]
    return y_img

# -----------------------
# 3) Beat-wise Templates (train)
# -----------------------
def bandpass_ecg(x, fs, lo=5.0, hi=25.0, order=2):
    nyq = 0.5*fs
    lo = max(lo/nyq, 1e-3); hi = min(hi/nyq, 0.99)
    b,a = butter(order, [lo,hi], btype='band')
    return filtfilt(b,a,x).astype(np.float32)

def build_lead_templates_beatwise(train_csv_path, train_dir,
                                  leads=LEADS,
                                  beat_len=TEMPLATE_BEAT_LEN,
                                  pre_R=R_PRE_S, post_R=R_POST_S):
    meta = pd.read_csv(train_csv_path)
    beats = {ld: [] for ld in leads}
    used  = {ld: 0 for ld in leads}
    for row in tqdm(meta.itertuples(index=False), total=len(meta), desc="Build beatwise templates"):
        rid = str(row.id); fs = int(row.fs)
        csvp = os.path.join(train_dir, rid, f"{rid}.csv")
        if not os.path.exists(csvp): continue
        try: df = pd.read_csv(csvp)
        except: continue
        for ld in leads:
            if ld not in df.columns: continue
            y = df[ld].dropna().to_numpy(np.float32)
            if y.size < 50: continue
            y0 = zscore(y)
            yf = bandpass_ecg(y0, fs, 5, 25, 2)
            peaks, _ = find_peaks(yf, distance=int(0.35*fs), prominence=max(0.5*np.std(yf), 0.2))
            if len(peaks) < 2: continue
            n_pre  = int(round(pre_R*fs)); n_post = int(round(post_R*fs))
            for pk in peaks:
                a = pk - n_pre; b = pk + n_post
                if a < 0 or b >= len(y0): continue
                seg = y0[a:b+1]
                seg_rs = np.interp(np.linspace(0,1,beat_len, dtype=np.float32),
                                   np.linspace(0,1,len(seg), dtype=np.float32),
                                   seg).astype(np.float32)
                beats[ld].append(seg_rs); used[ld]+=1
    templates = {}
    for ld in leads:
        if beats[ld]:
            arr = np.vstack(beats[ld])
            tpl = np.median(arr, axis=0).astype(np.float32)
        else:
            t = np.linspace(0,1,beat_len, dtype=np.float32)
            tpl = np.sin(2*np.pi*t).astype(np.float32)
        templates[ld] = zscore(tpl)
    return templates, used

def stretch_template_for_bpm(template, fs, bpm):
    beat_samp = max(6, int(round(fs * 60.0 / max(bpm, 1e-3))))
    x = np.linspace(0, 1, beat_samp, dtype=np.float32)
    tx = np.linspace(0, 1, len(template), dtype=np.float32)
    return np.interp(x, tx, template).astype(np.float32)

def template_series(template, fs, n_out, bpm, amp=1.0):
    one = stretch_template_for_bpm(template, fs, bpm)
    reps = int(np.ceil(n_out / len(one)))
    y = np.tile(one, reps)[:n_out]
    y = zscore(y) * float(amp)
    y = lowpass(y, fs, 20.0, 2)
    return y

def estimate_bpm_from_series(y, fs):
    y = zscore(y); y = lowpass(y, fs, 15.0, 2)
    ac = np.correlate(y, y, mode='full')[len(y)-1:]
    ac[:int(0.30*fs)] = 0
    pk = int(np.argmax(ac[:max(int(2.0*fs),1)]))
    if pk <= 0: return 75.0
    rr = pk/float(fs)
    return float(np.clip(60.0/max(rr, 1e-3), 40.0, 160.0))

# -----------------------
# 4) 이미지 패널 → 시리즈 (시간보정 + 시프트정렬 + 블렌딩)
#     - 원본에서 s/px 한 번만 추정(spp_ref) → 모든 TTA에 공유
#     - s/px를 T_known/Wroi의 ±5%로 클램프
#     - paper_speed(25/50) 지정 시 해당 후보만 사용
# -----------------------
def panel_to_series(panel_bgr, fs, n_out, template, debug_path=None,
                    tta=TTA_ENABLED, aug_presets=AUG_PRESETS, aggregate=TTA_AGG,
                    paper_speed=None):

    def _decode_one(img_bgr, spp_override=None, dbg=False):
        gray = degrid_gray(img_bgr)
        x_lo, x_hi = find_active_columns_longest(gray)
        xs, ys = dp_trace_center(gray, x_lo, x_hi)
        conf = path_confidence(gray, xs, ys, band=CONF_BAND)

        T = (n_out-1)/float(fs)
        if spp_override is None:
            spp = sec_per_pixel_from_grid(img_bgr, x_lo, x_hi, T_known=T,
                                          gray_fallback=gray, paper_speed=paper_speed)
        else:
            spp = float(spp_override)

        # 가드레일: 전체 시간 길이 유지
        Wroi  = max(1, (x_hi - x_lo + 1))
        ideal = T / float(Wroi)
        spp   = float(np.clip(float(spp), 0.95*ideal, 1.05*ideal))

        # ROI → 시리즈 (시간보정)
        y_img = roi_series_from_path_t(gray, xs, ys, spp, n_out, fs)

        # BPM 추정 + 템플릿
        bpm_est = estimate_bpm_from_series(y_img, fs) if conf > 0.15 else 75.0
        y_tpl   = template_series(template, fs, n_out, bpm=bpm_est, amp=1.0)

        # 서로상관 기반 시프트 정렬
        y_img = phase_align_to_template(y_img, y_tpl)

        # 신뢰도 기반 블렌딩
        a = sigmoid_blend(conf)
        if conf < USE_IMG_MIN_CONF: a *= 0.5
        y_mix = rescale_range(a*y_img + (1.0-a)*y_tpl, MIN_VAL, MAX_VAL)

        if dbg and (debug_path is not None):
            H,W = gray.shape
            color_dbg = panel_bgr.copy()
            cv2.rectangle(color_dbg, (x_lo,0), (x_hi,H-1), (0,255,255), 2)
            for x,y in zip(xs,ys):
                cv2.circle(color_dbg, (int(x),int(y)), 1, (0,0,255), -1)
            cv2.imwrite(debug_path, color_dbg)

        return y_mix, float(conf), float(bpm_est), float(spp)

    # 1) 원본으로 먼저 복원 → 참조 s/px 확보
    y0, c0, b0, spp_ref = _decode_one(panel_bgr, spp_override=None, dbg=True)

    if not tta:
        return y0, c0, b0

    # 2) 증강들은 모두 spp_ref를 강제 사용(시간 스케일 고정)
    ys, confs, bpms = [y0], [c0], [b0]
    use_presets = (aug_presets or [])[:max(0, int(MAX_AUG))]

    for cfg in use_presets:
        try:
            aug = augment_panel(
                panel_bgr,
                angle=cfg.get("angle", 0.0),
                shear=cfg.get("shear", 0.0),
                tx=cfg.get("tx", 0.0),
                ty=cfg.get("ty", 0.0),
                scale=cfg.get("scale", 1.0),
                alpha=cfg.get("alpha", 1.0),
                beta=cfg.get("beta", 0.0),
                gamma=cfg.get("gamma", 1.0),
            )
            y, c, b, _ = _decode_one(aug, spp_override=spp_ref, dbg=False)
            ys.append(y); confs.append(c); bpms.append(b)
        except Exception:
            continue

    Y = np.vstack(ys)
    C = np.asarray(confs, np.float32)
    W = np.maximum(C**2, 1e-3)

    if aggregate == "median":
        y_agg = np.median(Y, axis=0)
    else:
        y_agg = np.average(Y, axis=0, weights=W)

    y_agg   = rescale_range(y_agg, MIN_VAL, MAX_VAL)
    conf_ag = float(np.average(C, weights=W))
    bpm_ag  = float(np.average(np.asarray(bpms, np.float32), weights=W))
    return y_agg, conf_ag, bpm_ag

# ============================================================
# 5) Light 1D-CNN Denoiser (옵션): 학습/로드/적용
# ============================================================
try:
    import torch, torch.nn as nn
    TORCH_OK = True
except Exception:
    TORCH_OK = False

class Denoiser1D(nn.Module):
    def __init__(self, in_ch=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(in_ch, 32, 9, padding=4), nn.ReLU(),
            nn.Conv1d(32, 32, 9, padding=4), nn.ReLU(),
            nn.Conv1d(32, 16, 5, padding=2), nn.ReLU(),
            nn.Conv1d(16, 1,  1)
        )
    def forward(self, x):
        # residual: 입력의 첫 채널(y_mix_norm)에 잔차를 더해 보정값 산출
        res = self.net(x)
        return x[:, :1, :] + res

def spectral_mse(a, b):
    # a,b: (B,1,N)
    A = torch.fft.rfft(a, dim=-1)
    B = torch.fft.rfft(b, dim=-1)
    return torch.mean((A.real - B.real)**2 + (A.imag - B.imag)**2)

def train_denoiser(templates, train_meta_path, train_dir, leads=LEADS,
                   epochs=1, lr=1e-3, use_tpl=True, device="cpu",
                   freq_w=0.2, paper_speed_fn=detect_paper_speed_label):
    model = Denoiser1D(in_ch=2 if use_tpl else 1).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    huber = nn.SmoothL1Loss()

    meta = pd.read_csv(train_meta_path)
    ids = meta["id"].astype(int).tolist()

    model.train()
    for ep in range(epochs):
        np.random.shuffle(ids)
        running = 0.0; steps = 0
        for rid in tqdm(ids, desc=f"Denoiser epoch {ep+1}/{epochs}"):
            csvp = os.path.join(train_dir, str(rid), f"{rid}.csv")
            imgs = sorted(glob.glob(os.path.join(train_dir, str(rid), f"{rid}-*.png")))
            if not (os.path.exists(csvp) and imgs): continue
            df  = pd.read_csv(csvp)
            fs  = int(meta.loc[meta['id']==rid,'fs'].iloc[0])
            img = cv2.imread(imgs[0])
            paper_speed = paper_speed_fn(img)
            panels = split_3x4_panels(img)

            for ld in leads:
                if ld not in df.columns or ld not in panels: continue
                y_gt = df[ld].dropna().to_numpy(np.float32)
                n    = len(y_gt)
                # 파이프라인 복원
                y_mix, conf, bpm = panel_to_series(panels[ld], fs, n,
                                                   templates.get(ld, templates['II']),
                                                   debug_path=None, paper_speed=paper_speed)
                # 입력/타깃 표준화
                yy = zscore(y_gt).astype(np.float32)
                xx = zscore(y_mix).astype(np.float32)
                if use_tpl:
                    ytpl = zscore(template_series(templates.get(ld, templates['II']), fs, n, bpm, 1.0)).astype(np.float32)
                    x_np = np.stack([xx, ytpl])[None,...]   # (1,2,n)
                else:
                    x_np = xx[None,None,:]                  # (1,1,n)

                t_np = yy[None,None,:]                      # (1,1,n)

                x = torch.tensor(x_np, dtype=torch.float32, device=device)
                t = torch.tensor(t_np, dtype=torch.float32, device=device)

                opt.zero_grad()
                y_hat = model(x)               # (1,1,n)
                loss  = huber(y_hat, t) + freq_w * spectral_mse(y_hat, t)
                loss.backward()
                opt.step()

                running += float(loss.detach().cpu().item()); steps += 1

        print(f"[Denoiser] epoch {ep+1}: avg loss={running/max(1,steps):.5f}")

    # 저장
    os.makedirs(os.path.dirname(DENOISER_PATH), exist_ok=True)
    torch.save({"state_dict": model.state_dict(),
                "in_ch": 2 if use_tpl else 1}, DENOISER_PATH)
    print(f"[Denoiser] saved -> {DENOISER_PATH}")
    return model

def load_denoiser(path=DENOISER_PATH, device="cpu"):
    if not TORCH_OK or not os.path.exists(path):
        return None
    ckpt = torch.load(path, map_location=device)
    in_ch = ckpt.get("in_ch", 2)
    model = Denoiser1D(in_ch=in_ch).to(device)
    model.load_state_dict(ckpt["state_dict"], strict=True)
    model.eval()
    return model

def apply_denoiser(y_mix, y_tpl, model, device="cpu"):
    if (model is None) or (not DENOISER_ENABLE) or (not TORCH_OK):
        return y_mix
    # 표준화 입력 → 보정 → 제출 스케일로 재스케일
    xx = zscore(y_mix).astype(np.float32)
    if y_tpl is not None:
        xt = zscore(y_tpl).astype(np.float32)
        x_np = np.stack([xx, xt])[None,...]   # (1,2,n)
    else:
        x_np = xx[None,None,:]
    x = torch.tensor(x_np, dtype=torch.float32, device=device)
    with torch.no_grad():
        y_hat = model(x).cpu().numpy()[0,0]
    return rescale_range(y_hat, MIN_VAL, MAX_VAL)

# -----------------------
# 6) Train 시각화 (GT vs Pred) with optional denoiser
# -----------------------
def plot_train_gt_vs_pred(rec_id, templates, leads=('II','V2','V5'), denoiser=None, device="cpu"):
    rid = str(int(rec_id))
    csvp = os.path.join(TRAIN_DIR, rid, f"{rid}.csv")
    imgs = sorted(glob.glob(os.path.join(TRAIN_DIR, rid, f"{rid}-*.png")))
    if not os.path.exists(csvp) or not imgs:
        print(f"[Skip] {rid} (csv/image missing)")
        return
    df = pd.read_csv(csvp)
    fs = int(pd.read_csv(TRAIN_CSV).loc[lambda d: d['id']==int(rid),'fs'].iloc[0])

    img = cv2.imread(imgs[0])
    paper_speed = detect_paper_speed_label(img)
    panels = split_3x4_panels(img)

    rows = len(leads)
    fig, axes = plt.subplots(rows, 1, figsize=(12, 2.8*rows), squeeze=False)
    axes = axes.ravel()

    for i, ld in enumerate(leads):
        if ld not in df.columns or ld not in panels:
            axes[i].set_axis_off(); continue
        y_gt = df[ld].dropna().to_numpy(np.float32)
        n = len(y_gt)
        y_pred, conf, bpm = panel_to_series(
            panels[ld], fs, n, templates.get(ld, templates['II']),
            debug_path=None, paper_speed=paper_speed
        )
        # (옵션) 디노이저 적용
        y_tpl = template_series(templates.get(ld, templates['II']), fs, n, bpm, 1.0)
        y_pred = apply_denoiser(y_pred, y_tpl if DENOISER_USE_TPL_CH else None, denoiser, device=device)

        yy = zscore(y_gt); pp = zscore(y_pred)
        pear = np.corrcoef(yy, pp)[0,1] if (np.std(yy)>0 and np.std(pp)>0) else 0.0
        rmse = float(np.sqrt(np.mean((yy-pp)**2)))
        ss_tot = float(np.sum((yy - np.mean(yy))**2)); r2 = float(1.0 - np.sum((yy-pp)**2)/ss_tot) if ss_tot>0 else 0.0

        t = np.arange(n)/float(fs)
        ax = axes[i]
        ax.plot(t, yy, color='k', lw=1.0, label=f'GT {ld}')
        ax.plot(t, pp, color='crimson', lw=1.0, label=f'Pred {ld}')
        ax.set_title(f"Train ID {rid} — Lead {ld}   r={pear:.3f}  RMSE={rmse:.3f}  R²={r2:.3f}")
        ax.set_xlabel("Time (s)"); ax.set_ylabel("Norm amp")
        ax.grid(True); ax.legend(loc='upper right', fontsize=9)

    plt.tight_layout()
    outp = os.path.join(VIS_DIR, f"gt_pred_{rid}.png")
    plt.savefig(outp, dpi=200, bbox_inches='tight')
    print(f"[Saved] {outp}")
    plt.show()

# -----------------------
# 7) TEST Inference & Submission (with optional denoiser)
# -----------------------
def run_test_submission(templates, denoiser=None, device="cpu"):
    test = pd.read_csv(TEST_CSV)
    id2img = {}
    for tid in test['id'].unique():
        p = sorted(glob.glob(os.path.join(TEST_DIR, str(int(tid)), f"{int(tid)}-*.png")))
        id2img[int(tid)] = p[0] if p else None

    sub_rows = []
    for r in tqdm(test.itertuples(index=False), total=len(test), desc="Predict(test)"):
        base_id, lead, fs, n = int(r.id), str(r.lead), int(r.fs), int(r.number_of_rows)
        imgp = id2img.get(base_id, None)
        if (imgp is None) or (not os.path.exists(imgp)) or (lead not in LEADS):
            y = template_series(templates.get(lead, templates['II']), fs, n, bpm=75.0, amp=1.0)
            y = rescale_range(y, MIN_VAL, MAX_VAL)
        else:
            img = cv2.imread(imgp)
            paper_speed = detect_paper_speed_label(img)
            panels = split_3x4_panels(img)
            if lead not in panels:
                y = template_series(templates.get(lead, templates['II']), fs, n, bpm=75.0, amp=1.0)
                y = rescale_range(y, MIN_VAL, MAX_VAL)
            else:
                y, conf, bpm = panel_to_series(
                    panels[lead], fs, n, templates.get(lead, templates['II']),
                    debug_path=None, paper_speed=paper_speed
                )
                # (옵션) 디노이저
                if DENOISER_ENABLE and TORCH_OK:
                    y_tpl = template_series(templates.get(lead, templates['II']), fs, n, bpm, 1.0)
                    y = apply_denoiser(y, y_tpl if DENOISER_USE_TPL_CH else None, denoiser, device=device)

        ids = [f"{base_id}_{i}_{lead}" for i in range(n)]
        sub_rows.extend(zip(ids, y.tolist()))
    sub = pd.DataFrame(sub_rows, columns=['id','value'])
    sub.to_csv(SUBMISSION_CSV, index=False)
    print(sub.head(10))
    print(f"[OK] Wrote submission: {SUBMISSION_CSV} (rows={len(sub)})")
    return sub

# =======================
# RUN
# =======================
# 1) 템플릿을 항상 새로 학습
templates, used = build_lead_templates_beatwise(TRAIN_CSV, TRAIN_DIR, leads=LEADS)
# (옵션) 저장만 수행 (로딩 안 함)
np.savez_compressed(TEMPLATE_NPZ, **templates)
print("[OK] (re)built templates and saved ->", TEMPLATE_NPZ)
for ld in LEADS:
    print(f"  {ld:>3}: beats={used.get(ld,0)} tpl_len={len(templates[ld])}")

# 2) (옵션) 디노이저 학습 또는 로드
device = "cuda" if TORCH_OK and (os.getenv("CUDA_VISIBLE_DEVICES","") != "") and (getattr(__import__('torch'), 'cuda', None) and __import__('torch').cuda.is_available()) else "cpu"
denoiser = None
if DENOISER_ENABLE and TORCH_OK:
    if DENOISER_TRAIN:
        denoiser = train_denoiser(templates, TRAIN_CSV, TRAIN_DIR, leads=LEADS,
                                  epochs=DENOISER_EPOCHS, lr=DENOISER_LR,
                                  use_tpl=DENOISER_USE_TPL_CH, device=device,
                                  freq_w=DENOISER_FREQ_LOSS_W)
    else:
        denoiser = load_denoiser(DENOISER_PATH, device=device)
        if denoiser is None:
            print("[Denoiser] no weights found; run with ECG_DENOISER_TRAIN=1 to train.")

# 3) Train 시각화 (예시 2개)
train_meta = pd.read_csv(TRAIN_CSV)
example_ids = [int(train_meta.iloc[i]['id']) for i in range(min(2, len(train_meta)))]
for rid in example_ids:
    plot_train_gt_vs_pred(rid, templates, leads=('II','V2','V5'), denoiser=denoiser, device=device)

# 4) Test → submission.csv
_ = run_test_submission(templates, denoiser=denoiser, device=device)
