In [None]:
# -*- coding: utf-8 -*-
"""
GAT Edge Risk (robust-B)
- ROAD: width 미사용(= w_per_lane 없음). 폭 관련은 lanes*3.5로만 파생 사용
- Y: risk_score (0~1), log-space 회귀 (Y_SCALE)
- Split: pair(u,v) 기준 GroupShuffleSplit (누수 방지)
- Graph: 양방향(2E)로 메시지패싱, 학습/평가/저장은 앞 E(원본)만 사용
- Loss: Huber(SmoothL1) + Quantile(원스케일) 혼합 + 동적잔차 가중
- Calibration: Train-CV 기반 Piecewise (lin_low@log + iso_high@orig) + 스무스 블렌드
- Safety: 캘리브가 검증 성능을 악화시키면 자동 비활성화(fallback to raw)
"""

import os, math, random
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GATv2Conv

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from sklearn.model_selection import GroupShuffleSplit, KFold
from sklearn.linear_model import LinearRegression
from sklearn.isotonic import IsotonicRegression

# =========================
# Paths / Config
# =========================
ROAD_CSV = "data/마포_서대문_은평_도로망_최종_v2.csv"      
Y_CSV    = "data/은마서_도로_위험도_calibrated.csv"
OUT_DIR  = "outputs30"
os.makedirs(OUT_DIR, exist_ok=True)

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Train
EPOCHS        = 800
PRINT_EVERY   = 80
PATIENCE      = 140
LR            = 3e-4
WEIGHT_DECAY  = 1e-4

GAT_HIDDEN = 128
GAT_HEADS = 8
GAT_LAYERS = 3
MLP_HIDDEN = 512
GAT_DROPOUT = 0.3
MLP_DROPOUT = 0.3

# Split
TEST_SIZE     = 0.10
VAL_SIZE      = 0.10

# Target transform
Y_SCALE       = 4000.0

# Quantile
QL_TAU    = 0.90
QL_LAMBDA = 0.35

# Weights
LEN_W, SIG_W, HI_W = 0.4, 0.2, 0.4
W_MIN, W_MAX       = 0.1, 10.0
RES_ALPHA, RES_BETA = 2.0, 1.0
RES_CLIP_MINMAX     = (1.0, 8.0)

# High-risk
HI_PCTILE = 0.85
HI_MULT   = 6.0
   # (hotfix) 과가중 완화

# Calibration blend
BLEND_K       = 100.0 # (hotfix) 경계 과보정 완화

# =========================
# Utils
# =========================
def to_pair_tuple(a, b):
    a = abs(int(a)); b = abs(int(b))
    return (a, b) if a <= b else (b, a)

def progress_bar(curr, total, width=40):
    done = int(width * curr / total)
    return "[" + "#" * done + "-" * (width - done) + f"] {curr:>3d}/{total}"

def metrics(y_true, y_pred):
    y_true = np.asarray(y_true, dtype=float); y_pred = np.asarray(y_pred, dtype=float)
    rmse = math.sqrt(mean_squared_error(y_true, y_pred))
    mae  = mean_absolute_error(y_true, y_pred)
    r2   = r2_score(y_true, y_pred)
    denom = np.clip(np.abs(y_true), 1e-12, None)
    mape = float(np.mean(np.abs((y_true - y_pred) / denom)) * 100.0)
    return {"RMSE": rmse, "MAE": mae, "R2": r2, "MAPE%": mape}

def y_to_t(y):
    y = np.asarray(y, dtype=float)
    return np.log1p(np.clip(y, 0.0, 1.0) * Y_SCALE).astype(np.float32)

def t_to_y(t):
    t = np.asarray(t, dtype=float)
    return np.clip(np.expm1(t) / Y_SCALE, 0.0, 1.0).astype(np.float32)

def t_to_y_torch(t_tensor: torch.Tensor) -> torch.Tensor:
    return torch.clamp(torch.expm1(t_tensor) / Y_SCALE, min=0.0, max=1.0)

def quantile_loss(y_true_torch, y_pred_torch, tau=0.9):
    e = y_true_torch - y_pred_torch
    return torch.mean(torch.maximum(tau*e, (tau-1)*e))

def group_splits_by_pair(df, labeled_idx, test_size=0.10, val_size=0.10, seed=42):
    """pair(u,v) 기준으로 Train/Val/Test 인덱스 반환 (df의 row-index 기준)."""
    lab_df = df.iloc[labeled_idx]
    groups = lab_df["pair"].astype(str).values

    gss1 = GroupShuffleSplit(n_splits=1, test_size=(test_size+val_size), random_state=seed)
    tr_rel, hold_rel = next(gss1.split(lab_df, groups=groups))

    hold_df = lab_df.iloc[hold_rel]
    hold_groups = hold_df["pair"].astype(str).values
    test_frac = test_size / (test_size + val_size)
    gss2 = GroupShuffleSplit(n_splits=1, test_size=test_frac, random_state=seed+1)
    va_rel, te_rel = next(gss2.split(hold_df, groups=hold_groups))

    lab_train = labeled_idx[tr_rel]
    lab_val   = labeled_idx[hold_rel[va_rel]]
    lab_test  = labeled_idx[hold_rel[te_rel]]
    return lab_train, lab_val, lab_test

# =========================
# 1) Load & Merge
# =========================
roads = pd.read_csv(ROAD_CSV)
if not {"u","v"}.issubset(roads.columns):
    raise ValueError("ROAD_CSV에는 'u','v' 컬럼이 필요합니다.")
roads["pair"] = [to_pair_tuple(u, v) for u, v in zip(roads["u"], roads["v"])]

ys = pd.read_csv(Y_CSV)
if "risk_score" not in ys.columns:
    raise ValueError("Y_CSV에 'risk_score' 컬럼이 필요합니다.")

# 노드 컬럼 자동 감지
if {"u","v"}.issubset(ys.columns): au, av = "u","v"
elif {"u_id","v_id"}.issubset(ys.columns): au, av = "u_id","v_id"
elif {"node1","node2"}.issubset(ys.columns): au, av = "node1","node2"
else: raise ValueError("Y_CSV에서 노드쌍 컬럼(u,v / u_id,v_id / node1,node2) 미발견")

ys["pair"] = [to_pair_tuple(a,b) for a,b in zip(ys[au], ys[av])]
ys["risk_score"] = ys["risk_score"].astype(float).clip(0.0, 1.0)
ys = ys.groupby("pair", as_index=False)["risk_score"].max()

df = roads.merge(ys, on="pair", how="left").copy()
print(f"[INFO] roads={len(roads):,}, labeled_edges={df['risk_score'].notna().sum():,} "
      f"({df['risk_score'].notna().mean()*100:.1f}% labeled)")

# =========================
# 2) Edge features  (폭은 lanes*3.5만 사용)
# =========================
# 숫자 캐스팅
for c in ["length","lanes","avg_height","avg_slope","up_lanes","down_lanes"]:
    if c in df.columns:
        df[c] = pd.to_numeric(df[c], errors="coerce")

# 안전 처리
if "length" in df.columns:
    df["length"] = df["length"].clip(lower=0)
    df["log_length"] = np.log1p(df["length"])
else:
    df["length"] = 0.0
    df["log_length"] = 0.0

if "avg_slope" in df.columns:
    df["abs_slope"] = np.abs(df["avg_slope"])
else:
    df["abs_slope"] = 0.0

if {"length","avg_slope"}.issubset(df.columns):
    df["len_x_slope"] = df["length"] * np.abs(df["avg_slope"])
else:
    df["len_x_slope"] = 0.0

# lanes 보정
if "lanes" in df.columns:
    df["lanes"] = df["lanes"].fillna(1).clip(1, 8)
else:
    df["lanes"] = 1.0

# width_len_ratio (폭은 lanes*3.5 가정), + 로그 파생
est_width = df["lanes"] * 3.5
df["width_len_ratio"] = est_width / (df["log_length"] + 1e-6)
df["log_width_len_ratio"] = np.log1p(df["width_len_ratio"])

# 이진 컬럼 정리
for c in ["oneway","bridge","tunnel"]:
    if c in df.columns:
        df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0).round().astype(int)

# 후보
num_candidates = [
    "length","lanes","avg_height","avg_slope",
    "up_lanes","down_lanes",
    "log_length","abs_slope","len_x_slope",
    "width_len_ratio","log_width_len_ratio",
    "oneway","bridge","tunnel",
]
cat_candidates = ["highway","surface"]

# 결측/더미화
edge_num_cols = [c for c in num_candidates if c in df.columns]
for c in edge_num_cols:
    df[c] = pd.to_numeric(df[c], errors="coerce").astype(float)
for c in edge_num_cols:
    df[c] = df[c].fillna(df[c].median())

edge_cat_cols = [c for c in cat_candidates if c in df.columns and df[c].dtype == object]
for c in edge_cat_cols:
    df[c] = df[c].astype("object").fillna("Unknown")
    vc = df[c].value_counts(dropna=False)
    rare = vc[vc < 20].index
    df[c] = df[c].where(~df[c].isin(rare), other="Other")

if edge_cat_cols:
    dummies = pd.get_dummies(df[edge_cat_cols], prefix=edge_cat_cols, dtype=np.float32)
    df = pd.concat([df.drop(columns=edge_cat_cols), dummies], axis=1)

edge_feat_cols = []
edge_feat_cols += edge_num_cols
edge_feat_cols += [c for c in df.columns if any(c.startswith(p+"_") for p in cat_candidates)]
if len(edge_feat_cols) == 0:
    raise ValueError("엣지 피처가 비었습니다.")
print(f"[INFO] Using {len(edge_feat_cols)} edge features. First 12: {edge_feat_cols[:12]}")

# =========================
# 3) Split / Scale (Group by pair)
# =========================
y_raw = df["risk_score"].astype(float).values
labeled_mask_np = ~np.isnan(y_raw)
labeled_idx = np.where(labeled_mask_np)[0]

lab_train, lab_val, lab_test = group_splits_by_pair(
    df, labeled_idx, test_size=TEST_SIZE, val_size=VAL_SIZE, seed=SEED
)

scaler = StandardScaler()
if len(edge_num_cols) > 0:
    scaler.fit(df.loc[lab_train, edge_num_cols].astype(float))
    df[edge_num_cols] = scaler.transform(df[edge_num_cols].astype(float))

# =========================
# 4) Graph build (bidirectional)
# =========================
all_nodes = pd.Index(pd.unique(pd.concat([df["u"], df["v"]])))
node_to_idx = {nid:i for i, nid in enumerate(all_nodes)}
num_nodes = len(all_nodes)

in_deg  = np.zeros(num_nodes, dtype=np.float32)
out_deg = np.zeros(num_nodes, dtype=np.float32)
for u, v in zip(df["u"].values, df["v"].values):
    ui = node_to_idx[u]; vi = node_to_idx[v]
    out_deg[ui] += 1; in_deg[vi]  += 1

node_x = np.stack([
    in_deg,
    out_deg,
    in_deg + out_deg,
    np.log1p(in_deg),
    np.log1p(out_deg),
], axis=1).astype(np.float32)

u_idx = df["u"].map(node_to_idx).values
v_idx = df["v"].map(node_to_idx).values

ei     = np.vstack([u_idx, v_idx])
ei_rev = np.vstack([v_idx, u_idx])
edge_index = torch.tensor(np.hstack([ei, ei_rev]), dtype=torch.long)

ea = df[edge_feat_cols].values.astype(np.float32)
edge_attr  = torch.tensor(np.vstack([ea, ea]), dtype=torch.float32)
x          = torch.tensor(node_x, dtype=torch.float32)

E = len(df)  # 원본 엣지 수

# labels (log-space), mask (E 기준)
y_t_np = np.where(np.isnan(y_raw), np.nan, y_to_t(y_raw))
y_t_filled = np.where(np.isnan(y_t_np), 0.0, y_t_np)
y_t = torch.tensor(y_t_filled, dtype=torch.float32).view(-1,1)
labeled_mask = torch.tensor(labeled_mask_np, dtype=torch.bool)

edge_train_mask = torch.zeros(E, dtype=torch.bool); edge_train_mask[lab_train] = True
edge_val_mask   = torch.zeros(E, dtype=torch.bool); edge_val_mask[lab_val]   = True
edge_test_mask  = torch.zeros(E, dtype=torch.bool); edge_test_mask[lab_test] = True

data = Data(
    x=x, edge_index=edge_index, edge_attr=edge_attr, y=y_t,
    edge_train_mask=edge_train_mask,
    edge_val_mask=edge_val_mask,
    edge_test_mask=edge_test_mask,
    labeled_mask=labeled_mask
).to(DEVICE)

print(f"[INFO] nodes={num_nodes:,}, edges_dir={edge_index.size(1):,} (2E), E(original)={E:,}, labeled={labeled_mask_np.sum():,} "
      f"(train/val/test = {len(lab_train)}/{len(lab_val)}/{len(lab_test)})")

# =========================
# 4.5) HI_THR from labeled
# =========================
y_labeled = df["risk_score"].dropna().values
HI_THR  = float(np.quantile(y_labeled, HI_PCTILE))   # e.g., p85
print(f"[INFO] HI_THR (p{int(HI_PCTILE*100)}) = {HI_THR:.6f}, HI_MULT = {HI_MULT}")

# =========================
# 5) Model
# =========================
class EdgeRegressor(nn.Module):
    def __init__(self, in_node_dim, in_edge_dim):
        super().__init__()
        self.gats = nn.ModuleList()
        last = in_node_dim
        for _ in range(GAT_LAYERS):
            self.gats.append(
                GATv2Conv(
                    in_channels=last,
                    out_channels=GAT_HIDDEN,
                    heads=GAT_HEADS,
                    dropout=GAT_DROPOUT,
                    share_weights=True,
                    edge_dim=in_edge_dim
                )
            )
            last = GAT_HIDDEN * GAT_HEADS
        in_mlp = last*2 + in_edge_dim
        self.mlp = nn.Sequential(
            nn.Linear(in_mlp, MLP_HIDDEN),
            nn.ReLU(),
            nn.Dropout(MLP_DROPOUT),
            nn.Linear(MLP_HIDDEN, MLP_HIDDEN//2),
            nn.ReLU(),
            nn.Dropout(MLP_DROPOUT),
            nn.Linear(MLP_HIDDEN//2, 1)
        )
    def forward(self, x, edge_index, edge_attr):
        z = x
        for conv in self.gats:
            z = conv(z, edge_index, edge_attr=edge_attr)
            z = F.elu(z)
            z = F.dropout(z, p=GAT_DROPOUT, training=self.training)
        src, dst = edge_index
        zu = z[src]; zv = z[dst]
        h = torch.cat([zu, zv, edge_attr], dim=1)
        out = self.mlp(h).squeeze(1)  # log-target
        return out

model = EdgeRegressor(
    in_node_dim=data.x.size(1),
    in_edge_dim=data.edge_attr.size(1)
).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
criterion = nn.SmoothL1Loss(reduction="none")  # Huber

# ----------- 가중치 -----------
len_w = torch.tensor(df["length"].fillna(1).values, dtype=torch.float32, device=DEVICE) if "length" in df.columns else torch.ones(E, device=DEVICE)
len_w = len_w / (len_w.mean() + 1e-8)

yt_log = data.y.view(-1)
sig_w = (1.0 + 0.5 * (yt_log[:E] > math.log1p(0.001*Y_SCALE))).float()

with torch.no_grad():
    y_orig_np = t_to_y(yt_log[:E].detach().cpu().numpy())
hi_mask_np = (y_orig_np >= HI_THR)
hi_w = np.where(hi_mask_np, HI_MULT, 1.0).astype(np.float32)
hi_w = torch.tensor(hi_w, device=DEVICE)

base_w = (LEN_W*len_w + SIG_W*sig_w + HI_W*hi_w).clamp_(W_MIN, W_MAX)

# =========================
# Eval helpers (raw, no calib during training)
# =========================
def eval_split(mask: torch.Tensor, hi_only=False):
    mask = (mask & data.labeled_mask)
    if mask.sum() == 0:
        return {"R2": float("nan"), "RMSE": float("nan"), "MAE": float("nan"), "MAPE%": float("nan")}
    model.eval()
    with torch.no_grad():
        pred_t_all = model(data.x, data.edge_index, data.edge_attr)
        pred_t = pred_t_all[:E]
        yp = t_to_y(pred_t.detach().cpu().numpy())
        yt_true = t_to_y(data.y.detach().cpu().numpy().ravel()[:E])
        idx = mask.detach().cpu().numpy()
        yp = yp[idx]; yt_true = yt_true[idx]
        if hi_only:
            hmask = yt_true >= HI_THR
            if hmask.sum() == 0:
                return {"R2": float("nan"), "RMSE": float("nan"), "MAE": float("nan"), "MAPE%": float("nan")}
            yp, yt_true = yp[hmask], yt_true[hmask]
        return metrics(yt_true, yp)

def val_objective():
    """RMSE_all + coef*RMSE_hi (HI 표본이 적으면 coef 완화)"""
    model.eval()
    with torch.no_grad():
        pt = model(data.x, data.edge_index, data.edge_attr)[:E].detach().cpu().numpy()
        yp_all = t_to_y(pt)
        yt_all = t_to_y(data.y.detach().cpu().numpy().ravel()[:E])
        m_val  = (data.edge_val_mask & data.labeled_mask).cpu().numpy()
        yp, yt = yp_all[m_val], yt_all[m_val]
        rmse_all = math.sqrt(mean_squared_error(yt, yp))
        m_hi = yt >= HI_THR
        hi_cnt = int(m_hi.sum())
        coef = 0.15 if hi_cnt < 25 else 0.30
        rmse_hi = math.sqrt(mean_squared_error(yt[m_hi], yp[m_hi])) if hi_cnt > 0 else rmse_all
        return rmse_all + coef*rmse_hi

# =========================
# 6) Train
# =========================
best_obj = float("inf")
best_state = None
pat = 0

for epoch in range(1, EPOCHS+1):
    model.train()
    optimizer.zero_grad()

    pred_t_all = model(data.x, data.edge_index, data.edge_attr)
    pred_t = pred_t_all[:E]
    mask = (data.edge_train_mask & data.labeled_mask)

    # 동적 잔차 가중
    pred_y_all_t = t_to_y_torch(pred_t)
    true_y_all_t = t_to_y_torch(data.y.view(-1)[:E])

    with torch.no_grad():
        pred_y_np  = pred_y_all_t.detach().cpu().numpy()
        true_y_np  = true_y_all_t.detach().cpu().numpy()
        resid_np   = np.zeros_like(pred_y_np, dtype=np.float32)
        m_np       = mask.cpu().numpy()
        resid_np[m_np] = np.abs(pred_y_np[m_np] - true_y_np[m_np])
    resid = torch.tensor(resid_np, device=DEVICE)
    res_w = (1.0 + RES_ALPHA * (resid ** RES_BETA)).clamp_(*RES_CLIP_MINMAX)
    w = (base_w * res_w).clamp_(W_MIN, W_MAX)

    # 손실 (log-Huber + quantile 혼합)
    loss_elem = criterion(pred_t[mask], data.y.view(-1)[:E][mask])
    q_loss = quantile_loss(true_y_all_t[mask], pred_y_all_t[mask], tau=QL_TAU)
    loss = (loss_elem * w[mask]).mean() * (1.0 - QL_LAMBDA) + q_loss * QL_LAMBDA

    loss.backward(); optimizer.step()

    if epoch % PRINT_EVERY == 1 or epoch == EPOCHS:
        tr_all = eval_split(data.edge_train_mask)
        va_all = eval_split(data.edge_val_mask)
        va_hi  = eval_split(data.edge_val_mask, hi_only=True)
        bar = progress_bar(epoch, EPOCHS, 40)
        print(f"{bar} | loss={loss.item():.6f} | "
              f"TRN R2={tr_all['R2']:.4f} MAE={tr_all['MAE']:.6f} RMSE={tr_all['RMSE']:.6f} | "
              f"VAL R2={va_all['R2']:.4f} MAE={va_all['MAE']:.6f} RMSE={va_all['RMSE']:.6f} | "
              f"VAL(hi) R2={va_hi['R2']:.4f} RMSE={va_hi['RMSE']:.6f}")

    obj = val_objective()
    if obj + 1e-9 < best_obj:
        best_obj = obj
        best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
        pat = 0
    else:
        pat += 1
        if pat >= PATIENCE:
            print(f"[Early Stop] epoch={epoch}, best Val Objective={best_obj:.6f}")
            break

# =========================
# 7) CV Calibration (+Smooth Blend) with safety switch
# =========================
if best_state is not None:
    model.load_state_dict({k: v.to(DEVICE) for k, v in best_state.items()})

model.eval()
with torch.no_grad():
    pt_full = model(data.x, data.edge_index, data.edge_attr).detach().cpu().numpy().ravel()
pt_all = pt_full[:E]
tt_all = data.y.detach().cpu().numpy().ravel()[:E]

def _t2y(x): return np.clip(np.expm1(x)/Y_SCALE, 0.0, 1.0)

# K-Fold on TRAIN edges to learn calibration
kf = KFold(n_splits=3, shuffle=True, random_state=SEED)
tr_mask_np = (data.edge_train_mask & data.labeled_mask).cpu().numpy()
idx_tr = np.where(tr_mask_np)[0]
pt_tr = pt_all[idx_tr]
tt_tr = tt_all[idx_tr]

A_list, B_list, iso_list, w_list = [], [], [], []
for tr_idx, va_idx in kf.split(pt_tr):
    pt_tr_tr, pt_tr_va = pt_tr[tr_idx], pt_tr[va_idx]
    tt_tr_tr, tt_tr_va = tt_tr[tr_idx], tt_tr[va_idx]

    low_mask = _t2y(tt_tr_tr) < HI_THR
    if low_mask.sum() >= 2:
        lin = LinearRegression().fit(pt_tr_tr[low_mask].reshape(-1,1), tt_tr_tr[low_mask])
        A_cv, B_cv = float(lin.coef_[0]), float(lin.intercept_)
    else:
        A_cv, B_cv = 1.0, 0.0

    iso_cv = IsotonicRegression(out_of_bounds='clip')
    high_mask = ~low_mask
    if high_mask.sum() >= 2:
        iso_cv.fit(_t2y(pt_tr_tr[high_mask]), _t2y(tt_tr_tr[high_mask]))
    else:
        iso_cv.fit(np.array([0.0,1.0]), np.array([0.0,1.0]))

    # local piecewise + smooth blend
    def _apply_piecewise_blend_local(pred_t, k=BLEND_K):
        y_low  = _t2y(A_cv*pred_t + B_cv)
        y_high = iso_cv.transform(_t2y(pred_t))
        s = 1.0 / (1.0 + np.exp(-k*(y_low - HI_THR)))  # sigmoid blend
        return (1.0 - s)*y_low + s*y_high

    yp_va = _apply_piecewise_blend_local(pt_tr_va)
    rmse_va = math.sqrt(mean_squared_error(_t2y(tt_tr_va), yp_va))
    weight = 1.0 / (rmse_va + 1e-9)

    A_list.append(A_cv); B_list.append(B_cv); iso_list.append(iso_cv); w_list.append(weight)

A_low = float(np.average(A_list, weights=w_list)) if len(w_list) else 1.0
B_low = float(np.average(B_list, weights=w_list)) if len(w_list) else 0.0
iso   = iso_list[int(np.argmax(w_list))] if len(w_list) else IsotonicRegression(out_of_bounds='clip').fit([0,1],[0,1])

def apply_piecewise_blend(pred_t_vec: np.ndarray, k: float = BLEND_K) -> np.ndarray:
    y_low  = _t2y(A_low*pred_t_vec + B_low)
    y_high = iso.transform(_t2y(pred_t_vec))
    s = 1.0 / (1.0 + np.exp(-k*(y_low - HI_THR)))  # 0~1
    return (1.0 - s)*y_low + s*y_high

# ---- Compare calibrated vs raw on validation, fallback if needed ----
def _eval_rmse_on_val(predict_fn):
    with torch.no_grad():
        pt = model(data.x, data.edge_index, data.edge_attr).detach().cpu().numpy().ravel()[:E]
    yp = predict_fn(pt)
    yt = t_to_y(data.y.detach().cpu().numpy().ravel()[:E])
    m  = (data.edge_val_mask & data.labeled_mask).cpu().numpy()
    return math.sqrt(mean_squared_error(yt[m], yp[m]))

rmse_val_raw  = _eval_rmse_on_val(lambda pt: t_to_y(pt))
rmse_val_cal  = _eval_rmse_on_val(lambda pt: apply_piecewise_blend(pt, k=BLEND_K))
USE_CAL = rmse_val_cal <= rmse_val_raw * 1.001  # 0.1%라도 개선 못하면 raw 사용

if not USE_CAL:
    print(f"[CALIB] Disabled (raw better): VAL_RMSE_RAW={rmse_val_raw:.6f} < VAL_RMSE_CAL={rmse_val_cal:.6f}")
else:
    print(f"[CALIB] Enabled: VAL_RMSE_CAL={rmse_val_cal:.6f} <= RAW={rmse_val_raw:.6f}")

def _predict_final(pt):
    if USE_CAL:
        return apply_piecewise_blend(pt, k=BLEND_K)
    else:
        return t_to_y(pt)

# =========================
# 8) Final eval (calib switch applied) & save
# =========================
def evaluate_final(split_mask, hi_only=False):
    mask = (split_mask & data.labeled_mask).cpu().numpy()
    if mask.sum() == 0:
        return {"R2": float("nan"), "RMSE": float("nan"), "MAE": float("nan"), "MAPE%": float("nan")}
    with torch.no_grad():
        pt_full_eval = model(data.x, data.edge_index, data.edge_attr).detach().cpu().numpy().ravel()
        pt = pt_full_eval[:E]
        yp = _predict_final(pt)
        yt = t_to_y(data.y.detach().cpu().numpy().ravel()[:E])
    yp, yt = yp[mask], yt[mask]
    if hi_only:
        m = yt >= HI_THR
        if m.sum() == 0: return {"R2": float("nan"), "RMSE": float("nan"), "MAE": float("nan"), "MAPE%": float("nan")}
        yp, yt = yp[m], yt[m]
    return metrics(yt, yp)

tr   = evaluate_final(data.edge_train_mask)
va   = evaluate_final(data.edge_val_mask)
te   = evaluate_final(data.edge_test_mask)
va_h = evaluate_final(data.edge_val_mask, hi_only=True)
te_h = evaluate_final(data.edge_test_mask, hi_only=True)

print("\n== Final Metrics (original scale, calibrated if helpful) ==")
print(f"Train: R2={tr['R2']:.4f} | RMSE={tr['RMSE']:.6f} | MAE={tr['MAE']:.6f}")
print(f"Valid: R2={va['R2']:.4f} | RMSE={va['RMSE']:.6f} | MAE={va['MAE']:.6f}")
print(f" Test: R2={te['R2']:.4f} | RMSE={te['RMSE']:.6f} | MAE={te['MAE']:.6f}")
print(f"[High-risk ≥ {HI_THR:.6f}]  Valid: R2={va_h['R2']:.4f} RMSE={va_h['RMSE']:.6f} MAE={va_h['MAE']:.6f} | "
      f"Test: R2={te_h['R2']:.4f} RMSE={te_h['RMSE']:.6f} MAE={te_h['MAE']:.6f}")

# 전체 예측 저장
with torch.no_grad():
    pred_all_t_full = model(data.x, data.edge_index, data.edge_attr).detach().cpu().numpy().ravel()
    pred_all_t = pred_all_t_full[:E]
    pred_all   = _predict_final(pred_all_t)

out_df = df[["u","v"]].copy()
out_df["risk_score_true"] = df["risk_score"].values
out_df["risk_score_pred"] = pred_all
save_pred = os.path.join(OUT_DIR, "gat_edge_predictions_fullgraph_v4.csv")
out_df.to_csv(save_pred, index=False, encoding="utf-8-sig")
print(f"[SAVE] {save_pred}")

# 모델 저장
ckpt_path = os.path.join(OUT_DIR, "gat_edge_regressor_fullgraph_v4.pt")
torch.save({
    "state_dict": model.state_dict(),
    "config": {
        "EPOCHS": EPOCHS, "H": GAT_HIDDEN, "HEADS": GAT_HEADS,
        "LAYERS": GAT_LAYERS, "DROPOUT": GAT_DROPOUT,
        "MLP_HIDDEN": MLP_HIDDEN, "MLP_DROPOUT": MLP_DROPOUT,
        "Y_SCALE": Y_SCALE,
        "HI_THR": HI_THR, "HI_PCTILE": HI_PCTILE, "HI_MULT": HI_MULT,
        "RES_ALPHA": RES_ALPHA, "RES_BETA": RES_BETA,
        "QL_TAU": QL_TAU, "QL_LAMBDA": QL_LAMBDA,
        "CALIB": "CV piecewise + smooth blend",
        "BLEND_K": BLEND_K,
        "USE_CAL": bool(USE_CAL),
    },
    "edge_feat_cols": edge_feat_cols,
    "node_feat_dim": int(data.x.size(1)),
    "edge_feat_dim": int(data.edge_attr.size(1)),
    "scaler_mean_": scaler.mean_.tolist() if hasattr(scaler, "mean_") else None,
    "scaler_scale_": scaler.scale_.tolist() if hasattr(scaler, "scale_") else None,
    "lin_low": {"A": float(A_low), "B": float(B_low)} if 'A_low' in locals() else {"A": 1.0, "B": 0.0}
}, ckpt_path)
print(f"[SAVE] {ckpt_path}")
