In [None]:
import os, math, json, random, warnings, gc, time, re
from pathlib import Path
import numpy as np
import pandas as pd
from contextlib import contextmanager
from typing import Dict, Any
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

from sklearn.metrics import (accuracy_score, balanced_accuracy_score, confusion_matrix,
                             f1_score, precision_score, recall_score, roc_auc_score,
                             classification_report)
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

from transformers import AutoTokenizer, AutoModelForCausalLM, get_linear_schedule_with_warmup
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict

warnings.filterwarnings("ignore")
os.environ["PYTHONUNBUFFERED"] = "1"

MODEL_PATH = r"G:/models/Qwen3-0.6B"
DATA_PATH = "data.csv"
SYSTEM_PROMPT = "你是一个传染病预测专家。请判断下一个月相对于本月的艾滋病发病数是否出现“上涨≥10%”。"
RES_DIR = Path("results"); RES_DIR.mkdir(parents=True, exist_ok=True)

TH_UP_PCT = 0.10
CLS2TEXT = {0:"未上涨10%", 1:"上涨≥10%"}
NUM_CLASSES = 2
ENSEMBLE_SEEDS = [3407]
FAST_DEBUG = False

BATCH_SIZE = 16
GRAD_ACCUM_STEPS = 1
MAX_LENGTH = 64
LR = 1e-4
WEIGHT_DECAY = 0.01
DROPOUT = 0.30
EPOCHS_TRAIN = 20
GRAD_CLIP = 1.0
WARMUP_RATIO = 0.10
EARLY_PATIENCE = 9     

FIXED_THR_PROB = 0.40 

USE_SAMPLE_WEIGHTS = False
POS_WEIGHT_MODE = "sqrt"
FOCAL_ALPHA_POS = 0.5
FOCAL_ALPHA_NEG = 0.5
FOCAL_GAMMA = 0.0            

CALIBRATOR = "platt"  

SPLIT_RATIOS = (0.7, 0.2, 0.1)

# 设备
device = "cuda:0" if torch.cuda.is_available() else "cpu"
USE_AMP = torch.cuda.is_available()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass
ATTN_IMPL = "sdpa"

def set_seed(s=3407):
    random.seed(s); np.random.seed(s); torch.manual_seed(s)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(s)

def eval_metrics_binary(y_true, prob_pos, y_pred, *, lite=False):
    y_true = np.asarray(y_true).astype(int)
    y_pred = np.asarray(y_pred).astype(int)
    acc = accuracy_score(y_true, y_pred)
    f1m = f1_score(y_true, y_pred, average="macro", zero_division=0)
    f1w = f1_score(y_true, y_pred, average="weighted", zero_division=0)
    f1b = f1_score(y_true, y_pred, average="binary", zero_division=0)
    bal = balanced_accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec = recall_score(y_true, y_pred, zero_division=0)
    cm = confusion_matrix(y_true, y_pred, labels=[0,1])
    tn, fp, fn, tp = (cm.ravel() if cm.size==4 else (0,0,0,0))
    spec = tn/(tn+fp) if (tn+fp)>0 else 0.0
    out = {
        "accuracy":float(acc),
        "macro_f1":float(f1m),
        "weighted_f1":float(f1w),
        "binary_f1":float(f1b),
        "balanced_accuracy":float(bal),
        "precision":float(prec),
        "recall":float(rec),
        "specificity":float(spec),
        "confusion_matrix": cm.tolist()
    }
    if not lite:
        try:
            out["roc_auc"] = float(roc_auc_score(y_true, prob_pos))
        except Exception:
            out["roc_auc"] = None
        out["report"] = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
    return out

def metrics_at_thr(y_true, prob_pos, thr, *, lite=True):
    y_pred = (np.asarray(prob_pos) >= float(thr)).astype(int)
    return eval_metrics_binary(y_true, prob_pos, y_pred, lite=lite), y_pred

def build_calibrator_from_state(state: Dict[str, Any]):
    kind = state.get("kind", "temperature")
    if kind == "temperature":
        T = float(state["T"])
        def calibrate(s):
            s = np.asarray(s).reshape(-1)
            z = s/float(T)
            p = 1/(1+np.exp(-z))
            return np.clip(p, 1e-6, 1-1e-6)
        return calibrate
    elif kind == "platt":
        a = float(state["a"]); b = float(state["b"])
        def calibrate(s):
            s = np.asarray(s).reshape(-1)
            z = a*s + b
            p = 1/(1+np.exp(-z))
            return np.clip(p, 1e-6, 1-1e-6)
        return calibrate
    else: 
        xs = np.asarray(state["xs"], dtype=float)
        ys = np.asarray(state["ys"], dtype=float)
        def calibrate(s):
            s = np.asarray(s).reshape(-1)
            p = np.interp(s, xs, ys, left=ys[0], right=ys[-1])
            return np.clip(p, 1e-6, 1-1e-6)
        return calibrate

def calibrator_fit(scores, y_true, kind="temperature"):
    scores = np.asarray(scores).reshape(-1)
    y_true = np.asarray(y_true).astype(int)
    if kind == "isotonic":
        from sklearn.isotonic import IsotonicRegression
        iso = IsotonicRegression(out_of_bounds="clip").fit(scores, y_true)
        xs = np.unique(np.sort(scores))
        ys = iso.predict(xs)
        state = {"kind":"isotonic","xs": xs.tolist(), "ys": ys.tolist()}
        return build_calibrator_from_state(state), state
    elif kind == "platt":
        lr_cls = LogisticRegression(C=1.0, solver="lbfgs").fit(scores.reshape(-1,1), y_true)
        a = float(lr_cls.coef_[0,0]); b = float(lr_cls.intercept_[0])
        state = {"kind":"platt","a": a, "b": b}
        return build_calibrator_from_state(state), state
    else: 
        Ts = np.linspace(0.5, 3.0, 51)
        def nll(T):
            z = scores/float(T)
            p = 1/(1+np.exp(-z)); eps=1e-7
            return -np.mean(y_true*np.log(p+eps)+(1-y_true)*np.log(1-p+eps))
        Tbest = float(Ts[np.argmin([nll(T) for T in Ts])])
        state = {"kind":"temperature","T": Tbest}
        return build_calibrator_from_state(state), state

def prob_thr_to_score_thr_prob_cal(state: Dict[str, Any], prob_thr: float) -> float:
    kind = state.get("kind", "temperature")
    if kind == "temperature":
        T = float(state["T"])
        logit = math.log(prob_thr/(1-prob_thr))
        return float(T*logit)
    elif kind == "platt":
        a = float(state["a"]); b = float(state["b"])
        logit = math.log(prob_thr/(1-prob_thr))
        return float((logit - b)/a)
    else:
        return float("nan")

def prob_thr_to_score_thr(scores, calibrate_func, prob_thr: float, state: Dict[str, Any]) -> float:
    s = prob_thr_to_score_thr_prob_cal(state, prob_thr)
    if not (np.isfinite(s)):
        s_sorted = np.sort(np.unique(np.asarray(scores).reshape(-1)))
        p_sorted = calibrate_func(s_sorted)
        idx = np.searchsorted(p_sorted, prob_thr, side="left")
        if idx <= 0:  return float(s_sorted[0]) - 1e-6
        if idx >= len(s_sorted): return float(s_sorted[-1]) + 1e-6
        s0, s1 = s_sorted[idx-1], s_sorted[idx]
        p0, p1 = p_sorted[idx-1], p_sorted[idx] if idx < len(p_sorted) else p_sorted[-1]
        if p1 == p0: return float(s1)
        t = (prob_thr - p0) / (p1 - p0)
        s = float(s0 + t*(s1 - s0))
    return s

def read_csv_any(path: str) -> pd.DataFrame:
    for enc in ["utf-8-sig","gb18030","utf-8"]:
        try:
            return pd.read_csv(path, encoding=enc)
        except:
            pass
    raise RuntimeError(f"无法读取 {path}")

def ym_cn(p):
    y,m = str(p).split("-"); return f"{int(y)}年{int(m)}月"

def logratio(a, b):
    return math.log1p(a) - math.log1p(b)

raw = read_csv_any(DATA_PATH)

cols = {c: re.sub(r"\s+", "", str(c)) for c in raw.columns}
rev = {v:k for k,v in cols.items()}
month_col = rev.get("月份", "月份")
region_col = rev.get("地区", "地区")
cases_col = rev.get("发病数", "发病数")
index_col = rev.get("全部", "全部")

raw = raw.rename(columns={month_col:"月份", region_col:"地区", cases_col:"发病数", index_col:"全部"})

raw["月份"] = raw["月份"].astype(str).str.slice(0,7)
raw["period"] = pd.PeriodIndex(raw["月份"], freq="M")
raw["地区_std"] = raw["地区"].astype(str).str.strip()
raw["发病数"] = pd.to_numeric(raw["发病数"], errors="coerce").fillna(0).astype(int)
raw["全部"] = pd.to_numeric(raw["全部"], errors="coerce").fillna(0.0)

raw = raw[(raw["period"]>=pd.Period("2011-01","M")) & (raw["period"]<=pd.Period("2020-12","M"))].copy()
raw = raw.sort_values(["地区_std","period"]).reset_index(drop=True)

raw["CASE_lag12"] = raw.groupby("地区_std")["发病数"].shift(12)
raw["IDX_lag12"] = raw.groupby("地区_std")["全部"].shift(12)

reg_months = raw.groupby("地区_std")["period"].nunique().sort_values(ascending=False)
REGIONS = reg_months.index.tolist()
REG2ID = {rg:i for i,rg in enumerate(REGIONS)}

rows=[]
for rg, g in raw.groupby("地区_std", sort=False):
    g = g.sort_values("period").reset_index(drop=True)
    for i in range(len(g)-3):
        t2,t1,t0,tP1 = g.iloc[i], g.iloc[i+1], g.iloc[i+2], g.iloc[i+3]
        if (pd.Period(t2["period"],"M")+1) != pd.Period(t1["period"],"M"): continue
        if (pd.Period(t1["period"],"M")+1) != pd.Period(t0["period"],"M"): continue
        if (pd.Period(t0["period"],"M")+1) != pd.Period(tP1["period"],"M"): continue

        y_t, y_tp1 = int(t0["发病数"]), int(tP1["发病数"])
        rel = (y_tp1 - y_t) / max(y_t, 1e-6)
        y = 1 if rel >= TH_UP_PCT else 0

        month = int(str(t0["period"])[-2:])
        sinm, cosm = np.sin(2*np.pi*month/12.0), np.cos(2*np.pi*month/12.0)

        lag12_y = float(t0["CASE_lag12"]) if pd.notna(t0["CASE_lag12"]) else 0.0
        lag12_idx = float(t0["IDX_lag12"]) if pd.notna(t0["IDX_lag12"]) else 0.0

        yoy_case = (math.log1p(float(t0["发病数"])) - math.log1p(lag12_y)) if lag12_y>0 else 0.0
        yoy_idx = (math.log1p(float(t0["全部"])) - math.log1p(lag12_idx)) if lag12_idx>0 else 0.0

        row = {
            "地区_std":rg,"date_feature":str(t0["period"]),"date_target":str(tP1["period"]),
            "t_month":month,"y_cls":y,
            "x_case_t":float(t0["发病数"]),
            "x_case_t_prev":float(t1["发病数"]),
            "x_case_t_prev2":float(t2["发病数"]),
            "x_idx_t":float(t0["全部"]),
            "x_idx_t_prev":float(t1["全部"]),
            "x_idx_t_prev2":float(t2["全部"]),
            "logratio_case_10":logratio(float(t0["发病数"]),float(t1["发病数"])),
            "logratio_case_21":logratio(float(t1["发病数"]),float(t2["发病数"])),
            "logratio_idx_10":logratio(float(t0["全部"]),float(t1["全部"])),
            "logratio_idx_21":logratio(float(t1["全部"]),float(t2["全部"])),
            "y_t_lag12":float(lag12_y),
            "x_idx_t_lag12":float(lag12_idx),
            "yoy_logratio_case":float(yoy_case),
            "yoy_logratio_idx":float(yoy_idx),
            "month_sin":float(sinm),
            "month_cos":float(cosm),
        }
        row["input_text"] = (
            f"System: {SYSTEM_PROMPT}\n"
            f"User: Question: （地区={rg}）已知{ym_cn(t2['period'])}的艾滋病发病数为{int(t2['发病数'])}、搜索指数为{int(round(t2['全部']))}；"
            f"{ym_cn(t1['period'])}的艾滋病发病数为{int(t1['发病数'])}、搜索指数为{int(round(t1['全部']))}；"
            f"{ym_cn(t0['period'])}的艾滋病发病数为{int(t0['发病数'])}、搜索指数为{int(round(t0['全部']))}。"
            f"请判断下一月相对于{ym_cn(t0['period'])}是否出现“上涨≥10%”。Answer:"
        )
        rows.append(row)

data = pd.DataFrame(rows).sort_values(["date_feature","地区_std"]).reset_index(drop=True)
assert not data.empty, "没有可用样本"

FEAT_COLS = [
 "x_case_t","x_case_t_prev","x_case_t_prev2",
 "x_idx_t","x_idx_t_prev","x_idx_t_prev2",
 "logratio_case_10","logratio_case_21",
 "logratio_idx_10","logratio_idx_21",
 "y_t_lag12","x_idx_t_lag12",
 "yoy_logratio_case","yoy_logratio_idx",
 "month_sin","month_cos"
]

months = sorted(data["date_target"].unique().tolist())
M = len(months); assert M>=3
m_train = int(math.floor(M*SPLIT_RATIOS[0]))
m_val = int(math.floor(M*SPLIT_RATIOS[1]))
if m_train+m_val>=M:
    m_val = max(1, M-m_train-1)

train_months = months[:m_train]
val_months   = months[m_train:m_train+m_val]
test_months  = months[m_train+m_val:]

df_tr = data[data["date_target"].isin(train_months)].reset_index(drop=True)
df_vl = data[data["date_target"].isin(val_months )].reset_index(drop=True)
df_te = data[data["date_target"].isin(test_months )].reset_index(drop=True)

def fit_regionwise_stats(train_df: pd.DataFrame, cols, group_col="地区_std"):
    stats = {}
    for rg, g in train_df.groupby(group_col):
        stats[rg] = {}
        for c in cols:
            mu = float(g[c].mean()); sd = float(g[c].std()) or 1.0
            stats[rg][c] = (mu, sd)
    gmu = {c: float(train_df[c].mean()) for c in cols}
    gsd = {c: float(train_df[c].std()) or 1.0 for c in cols}
    return stats, gmu, gsd

def apply_regionwise_standardize(df: pd.DataFrame, cols, stats, gmu, gsd, group_col="地区_std"):
    g = df.copy()
    rgs = g[group_col].astype(str).values
    for c in cols:
        vals = g[c].astype(float).values
        out = np.empty_like(vals, dtype=float)
        for i, rg in enumerate(rgs):
            mu, sd = stats.get(rg, {}).get(c, (gmu[c], gsd[c]))
            out[i] = (vals[i] - mu) / (sd if sd!=0 else 1.0)
        g[c] = out
    return g

stats_rg, gmu, gsd = fit_regionwise_stats(df_tr, FEAT_COLS)
trN = apply_regionwise_standardize(df_tr, FEAT_COLS, stats_rg, gmu, gsd)
vlN = apply_regionwise_standardize(df_vl, FEAT_COLS, stats_rg, gmu, gsd)
teN = apply_regionwise_standardize(df_te, FEAT_COLS, stats_rg, gmu, gsd)

x_scaler = StandardScaler().fit(trN[FEAT_COLS].astype("float32").values)
def scale_X(df):
    return x_scaler.transform(df[FEAT_COLS].astype("float32").values).astype("float32")

# =============== Dataset/Loader ===============
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True, local_files_only=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

class TxtBinDataset(Dataset):
    def __init__(self, df, feats_scaled, tokenizer):
        self.texts = df["input_text"].tolist()
        self.feats = feats_scaled.astype("float32")
        self.y = df["y_cls"].astype("int64").values
        self.region_id = np.array([REG2ID[s] for s in df["地区_std"].astype(str).tolist()], dtype=np.int64)
        self.meta = df.reset_index(drop=True)
        enc = tokenizer(self.texts, truncation=True, padding=True, max_length=MAX_LENGTH, pad_to_multiple_of=8)
        self.enc = {k: torch.tensor(v, dtype=torch.long) for k,v in enc.items()}
    def __len__(self): return len(self.y)
    def __getitem__(self, i):
        it = {k:v[i] for k,v in self.enc.items()}
        it["feats"] = torch.tensor(self.feats[i], dtype=torch.float32)
        it["y"] = torch.tensor(self.y[i], dtype=torch.long)
        it["region_id"] = torch.tensor(self.region_id[i], dtype=torch.long)
        return it

def collate_fn(b):
    return {
        "input_ids": torch.stack([x["input_ids"] for x in b]),
        "attention_mask": torch.stack([x["attention_mask"]for x in b]),
        "feats": torch.stack([x["feats"] for x in b]),
        "y": torch.stack([x["y"] for x in b]),
        "region_id": torch.stack([x["region_id"] for x in b]),
    }

def make_loader(df, feats_scaled, tokenizer, batch_size, shuffle):
    ds = TxtBinDataset(df, feats_scaled, tokenizer)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn,
                    num_workers=0, pin_memory=False)
    return ds, dl

# =============== LoRA ===============
def get_lora_base(train_lora=True):
    base = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH, trust_remote_code=True,
        dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        low_cpu_mem_usage=True, local_files_only=True,
        device_map={"":0} if torch.cuda.is_available() else None,
        attn_implementation=ATTN_IMPL
    )
    base.config.use_cache = False
    base.config.output_hidden_states = True
    peft_cfg = LoraConfig(r=4, lora_alpha=16, lora_dropout=0.05, target_modules=["q_proj","v_proj","o_proj"],
                          bias="none", task_type="CAUSAL_LM")
    lora = get_peft_model(base, peft_cfg)
    for n,p in lora.named_parameters():
        p.requires_grad = ("lora_" in n) and bool(train_lora)
    return lora

class QwenLoRABinary(nn.Module):
    def __init__(self, feat_dim: int, dropout=0.3, train_lora=True, train_head=True, num_regions=None, reg_emb_dim=16):
        super().__init__()
        self.backbone = get_lora_base(train_lora=train_lora)
        hid = self.backbone.config.hidden_size
        self.reg_emb = nn.Embedding(num_regions if num_regions else len(REGIONS), reg_emb_dim)
        self.feat_mlp = nn.Sequential(
            nn.Linear(feat_dim+reg_emb_dim,128), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(128,128), nn.ReLU()
        )
        self.drop = nn.Dropout(dropout)
        self.fuse = nn.Linear(hid+128, hid)
        self.act = nn.ReLU()
        self.out = nn.Linear(hid,1)
        for p in self.feat_mlp.parameters(): p.requires_grad = bool(train_head)
        for p in self.fuse.parameters(): p.requires_grad = bool(train_head)
        for p in self.out.parameters(): p.requires_grad = bool(train_head)
        nn.init.normal_(self.out.weight, std=1e-3); nn.init.zeros_(self.out.bias)
        self.register_buffer("pos_w", torch.ones(()))
    def forward(self, input_ids, attention_mask, feats, region_id):
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask,
                            output_hidden_states=True, use_cache=False, return_dict=True)
        last = out.hidden_states[-1]
        mask = attention_mask.unsqueeze(-1).float()
        pooled = (last*mask).sum(1) / mask.sum(1).clamp(min=1e-6)
        pooled = self.drop(pooled)
        reg = self.reg_emb(region_id)
        featv = self.feat_mlp(torch.cat([feats, reg], -1))
        h = self.act(self.fuse(torch.cat([pooled, featv], -1)))
        logit = self.out(self.drop(h)).squeeze(-1)
        return {"logit": logit}

def get_trainable_state_dict(m: nn.Module):
    sd={}
    try:
        lora_sd = get_peft_model_state_dict(m.backbone)
        for k,v in lora_sd.items():
            sd[f"backbone.{k}"]=v.detach().cpu()
    except Exception:
        pass
    for name,mod in [("reg_emb",m.reg_emb),("feat_mlp",m.feat_mlp),("fuse",m.fuse),("out",m.out)]:
        for k,v in mod.state_dict().items():
            sd[f"{name}.{k}"]=v.detach().cpu()
    if hasattr(m,"pos_w"):
        sd["pos_w"]=m.pos_w.detach().cpu()
    return sd

def load_trainable_state_dict(m: nn.Module, sd: dict):
    lora_sd = {k.replace("backbone.",""):v for k,v in sd.items() if k.startswith("backbone.")}
    if len(lora_sd):
        m.backbone.load_state_dict(lora_sd, strict=False)
    for name,module in [("reg_emb",m.reg_emb),("feat_mlp",m.feat_mlp),("fuse",m.fuse),("out",m.out)]:
        sub = {k.replace(f"{name}.",""):v for k,v in sd.items() if k.startswith(f"{name}.")}
        if len(sub): module.load_state_dict(sub, strict=False)
    if "pos_w" in sd and hasattr(m,"pos_w"):
        m.pos_w.copy_(sd["pos_w"].to(m.pos_w.dtype))

def binary_focal_bce_with_logits(logits,y,pos_weight=None,gamma=FOCAL_GAMMA,
                                 alpha_pos=FOCAL_ALPHA_POS,alpha_neg=FOCAL_ALPHA_NEG,
                                 sample_weight=None):
    bce = torch.nn.functional.binary_cross_entropy_with_logits(
        logits,y.float(),pos_weight=pos_weight,reduction="none")
    p = torch.sigmoid(logits); pt = y*p + (1-y)*(1-p)
    alpha_t = alpha_pos*y + alpha_neg*(1-y)
    if gamma > 0:
        focal = alpha_t * ((1-pt).clamp(min=1e-6)**gamma) * bce
    else:
        focal = alpha_t * bce
    if sample_weight is not None:
        focal = focal * sample_weight
    return focal.mean()

@contextmanager
def model_in_eval(m: nn.Module):
    was_training = m.training
    m.eval()
    try:
        yield
    finally:
        if was_training:
            m.train()

@torch.no_grad()
def run_validation(model, loader, desc="eval"):
    losses=[]; probs=[]; scores=[]; ys=[]
    with model_in_eval(model):
        for b in tqdm(loader, desc=desc, leave=False):
            ids=b["input_ids"].to(device); mask=b["attention_mask"].to(device)
            feats=b["feats"].to(device); y=b["y"].to(device); rid=b["region_id"].to(device)
            with torch.cuda.amp.autocast(enabled=USE_AMP):
                logit = model(ids, mask, feats=feats, region_id=rid)["logit"]
                loss = binary_focal_bce_with_logits(logit,y,pos_weight=model.pos_w)
            losses.append(float(loss.item()))
            probs.append(torch.sigmoid(logit).detach().cpu().numpy())
            scores.append(logit.detach().cpu().numpy())
            ys.append(y.detach().cpu().numpy())
    return (np.mean(losses) if len(losses) else 0.0,
            np.concatenate(probs,0), np.concatenate(scores,0), np.concatenate(ys,0))

def compute_pos_weight(cnt, mode=POS_WEIGHT_MODE):
    ratio = float(cnt[0] / max(cnt[1],1.0))
    if mode == "sqrt": return math.sqrt(ratio)
    if mode == "log":  return math.log1p(ratio)
    return ratio

def train_one_seed(seed, trN, vlN, teN):
    set_seed(seed)
    model = QwenLoRABinary(len(FEAT_COLS), dropout=DROPOUT,
                           train_lora=True, train_head=True,
                           num_regions=len(REGIONS), reg_emb_dim=16).to(device)

    if FAST_DEBUG:
        trN = trN.iloc[: max(256, min(2048, len(trN)))].reset_index(drop=True)
        vlN = vlN.iloc[: max(128, min(1024, len(vlN)))].reset_index(drop=True)
        teN = teN.iloc[: max(128, min(1024, len(teN)))].reset_index(drop=True)

    ds_tr, dl_tr = make_loader(trN, scale_X(trN), tokenizer, BATCH_SIZE, shuffle=True)
    ds_vl, dl_vl = make_loader(vlN, scale_X(vlN), tokenizer, BATCH_SIZE, shuffle=False)
    ds_te, dl_te = make_loader(teN, scale_X(teN), tokenizer, BATCH_SIZE, shuffle=False)

    ytr = trN["y_cls"].values.astype(int)
    cnt = np.bincount(ytr, minlength=2).astype(float)
    pos_weight = compute_pos_weight(cnt, POS_WEIGHT_MODE)
    model.pos_w = torch.tensor(float(pos_weight), dtype=torch.float32, device=device)

    print(f"pos_weight(+1)={float(pos_weight):.4f} | class_counts={cnt.tolist()} | USE_SAMPLE_WEIGHTS={USE_SAMPLE_WEIGHTS}", flush=True)
    print(f"Loader sizes | train={len(dl_tr)} val={len(dl_vl)} test={len(dl_te)} | epochs={EPOCHS_TRAIN if not FAST_DEBUG else 1}", flush=True)

    if USE_SAMPLE_WEIGHTS:
        inv = cnt.sum()/np.clip(cnt,1.0,None); inv = inv/inv.mean()
        cls_weight_vec = torch.tensor(inv, dtype=torch.float32, device=device)
    else:
        cls_weight_vec = None

    opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad],
                            lr=LR, weight_decay=WEIGHT_DECAY, fused=False)
    steps_total = max(1, (len(dl_tr) if len(dl_tr)>0 else 1) * (1 if FAST_DEBUG else EPOCHS_TRAIN))
    sch = get_linear_schedule_with_warmup(opt, int(steps_total*WARMUP_RATIO), steps_total)
    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

    best_payload = None
    best_val_score = -1.0  
    pat = 0
    per_epoch = []
    epochs_run = 1 if FAST_DEBUG else EPOCHS_TRAIN

    for ep in range(1, epochs_run+1):
        model.train()
        total=0.0; step=0
        pbar = tqdm(dl_tr, desc=f"Epoch {ep:02d} [train]", leave=False)
        for b in pbar:
            ids=b["input_ids"].to(device); mask=b["attention_mask"].to(device)
            feats=b["feats"].to(device); y=b["y"].to(device); rid=b["region_id"].to(device)
            sw = cls_weight_vec[y] if (cls_weight_vec is not None and USE_SAMPLE_WEIGHTS) else None
            with torch.cuda.amp.autocast(enabled=USE_AMP):
                logit = model(ids, mask, feats=feats, region_id=rid)["logit"]
                loss = binary_focal_bce_with_logits(
                    logit, y, pos_weight=model.pos_w,
                    gamma=FOCAL_GAMMA,
                    alpha_pos=FOCAL_ALPHA_POS, alpha_neg=FOCAL_ALPHA_NEG,
                    sample_weight=sw
                ) / GRAD_ACCUM_STEPS
            scaler.scale(loss).backward()
            step += 1
            if step % GRAD_ACCUM_STEPS == 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], GRAD_CLIP)
                scaler.step(opt); scaler.update(); opt.zero_grad(set_to_none=True)
                if sch: sch.step()
            total += float(loss.item())*GRAD_ACCUM_STEPS
            if step % 10 == 0:
                pbar.set_postfix(loss=f"{(total/max(1,step)):.4f}")

        tr_loss = total/max(1,len(dl_tr))
        if torch.cuda.is_available():
            torch.cuda.synchronize()

        vl_loss, pr_vl_raw, sc_vl, y_vl = run_validation(model, dl_vl, desc=f"Epoch {ep:02d} [val]")
        calibrate, cal_state = calibrator_fit(sc_vl, y_vl, kind=CALIBRATOR)
        pr_vl = calibrate(sc_vl)

        thr_prob = FIXED_THR_PROB
        thr_score = prob_thr_to_score_thr(sc_vl, calibrate, thr_prob, cal_state)

        tr_loss_eval, pr_tr_raw, sc_tr, y_tr = run_validation(model, dl_tr, desc=f"Epoch {ep:02d} [train-eval]")
        te_loss_eval, pr_te_raw, sc_te, y_te = run_validation(model, dl_te, desc=f"Epoch {ep:02d} [test]")
        pr_tr = calibrate(sc_tr); pr_te = calibrate(sc_te)

        m_tr,_ = metrics_at_thr(y_tr, pr_tr, thr_prob, lite=False)
        m_vl,_ = metrics_at_thr(y_vl, pr_vl, thr_prob, lite=False)
        m_te,_ = metrics_at_thr(y_te, pr_te, thr_prob, lite=False)

        val_auc = m_vl.get("roc_auc", -1.0) or -1.0
        val_score = np.mean([
            m_vl["accuracy"],
            m_vl["binary_f1"],
            m_vl["macro_f1"],
            m_vl["precision"],
            m_vl["recall"],
            val_auc
        ])

        print(
            f"\nEpoch {ep:02d} | tr_loss={tr_loss:.4f} | val_loss={vl_loss:.4f} | "
            f"AUC_val={val_auc:.3f} | THR(fixed)={thr_prob:.3f} | VAL_score_no_spec={val_score:.3f}",
            flush=True
        )
        print(
            f" TRAIN Acc={m_tr['accuracy']:.3f} F1={m_tr['binary_f1']:.3f} "
            f"P={m_tr['precision']:.3f} R={m_tr['recall']:.3f} "
            f"Spec={m_tr['specificity']:.3f} BalAcc={m_tr['balanced_accuracy']:.3f}", flush=True
        )
        print(
            f" VAL   Acc={m_vl['accuracy']:.3f} F1={m_vl['binary_f1']:.3f} "
            f"P={m_vl['precision']:.3f} R={m_vl['recall']:.3f} "
            f"Spec={m_vl['specificity']:.3f} BalAcc={m_vl['balanced_accuracy']:.3f}", flush=True
        )
        print(
            f" TEST  Acc={m_te['accuracy']:.3f} F1={m_te['binary_f1']:.3f} "
            f"P={m_te['precision']:.3f} R={m_te['recall']:.3f} "
            f"Spec={m_te['specificity']:.3f} BalAcc={m_te['balanced_accuracy']:.3f}", flush=True
        )

        per_epoch.append({
            "epoch": ep,
            "thr_prob": float(thr_prob),
            "thr_score": float(thr_score),
            "calibrator": CALIBRATOR,
            "val_score_no_spec": float(val_score),
            "train": m_tr,
            "val": m_vl,
            "test": m_te
        })

        if val_score > best_val_score + 1e-4:
            best_val_score = val_score

            snap_tr = {
                "scores": sc_tr.astype(np.float32),
                "labels": y_tr.astype(np.int64),
                "meta": ds_tr.meta.to_dict(orient="list")
            }
            snap_vl = {
                "scores": sc_vl.astype(np.float32),
                "labels": y_vl.astype(np.int64),
                "meta": ds_vl.meta.to_dict(orient="list")
            }
            snap_te = {
                "scores": sc_te.astype(np.float32),
                "labels": y_te.astype(np.int64),
                "meta": ds_te.meta.to_dict(orient="list")
            }

            payload = {
                "trainable": get_trainable_state_dict(model),
                "thr_prob": float(thr_prob),
                "thr_score": float(thr_score),
                "calibrator_state": cal_state,
                "epoch": ep,
                "snapshots": {"train": snap_tr, "val": snap_vl, "test": snap_te},
                "note": "BEST_BY_VAL_MULTI_NO_SPEC_FIXED_THR"
            }
            best_payload = payload
            torch.save(payload, RES_DIR/"best_by_val_score_no_spec.pt")
            pat = 0
        else:
            pat += 1
            if pat >= EARLY_PATIENCE and not FAST_DEBUG:
                print(f"早停：patience={EARLY_PATIENCE}", flush=True)
                break

        torch.cuda.empty_cache(); gc.collect()

    if best_payload is None:
        payload = {
            "trainable": get_trainable_state_dict(model),
            "thr_prob": float(thr_prob),
            "thr_score": float(thr_score),
            "calibrator_state": cal_state,
            "epoch": ep,
            "snapshots": {"train": snap_tr, "val": snap_vl, "test": snap_te},
            "note": "FALLBACK_LAST_EPOCH"
        }
        torch.save(payload, RES_DIR/"best_by_val_score_no_spec.pt")
        best_payload = payload

    print(f"[SAVE] 使用 best_by_val_score_no_spec @ epoch {best_payload['epoch']} VAL_score_no_spec={best_val_score:.4f}", flush=True)
    return RES_DIR/"best_by_val_score_no_spec.pt", per_epoch

if not torch.cuda.is_available():
    raise SystemExit("未检测到 GPU 环境。")
torch.cuda.set_device(0)
print("===== 开始训练=====", flush=True)

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True, local_files_only=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

ckpt_path, per_epoch_log = train_one_seed(ENSEMBLE_SEEDS[0], trN, vlN, teN)

def load_payload(ckpt_path: Path):
    payload = torch.load(ckpt_path, map_location="cpu")
    cal_state = payload.get("calibrator_state", {"kind": "temperature", "T": 1.0})
    thr_prob = float(payload.get("thr_prob", FIXED_THR_PROB))
    thr_score = float(payload.get("thr_score", 0.0))
    epoch = int(payload.get("epoch", -1))
    snaps = payload.get("snapshots", {})
    return cal_state, thr_prob, thr_score, epoch, snaps

CAL_STATE, DECISION_THR_PROB, DECISION_THR_SCORE, SELECTED_EPOCH, SNAPSHOTS = load_payload(ckpt_path)
calibrate = build_calibrator_from_state(CAL_STATE)

def finalize_from_snapshot(snap: Dict[str, Any], thr_prob: float, thr_score: float):
    meta = pd.DataFrame(snap["meta"])
    scores = np.asarray(snap["scores"]).reshape(-1)
    labels = np.asarray(snap["labels"]).astype(int).reshape(-1)
    prob = calibrate(scores)
    yhat = (scores >= float(thr_score)).astype(int)
    m_full = eval_metrics_binary(labels, prob, yhat, lite=False)
    out_meta = meta.copy()
    out_meta["prob_pos"] = prob
    out_meta["score_logit"] = scores
    out_meta["pred_cls"] = yhat
    out_meta["pred_text"] = out_meta["pred_cls"].map(CLS2TEXT)
    out_meta["true_text"] = out_meta["y_cls"].map(CLS2TEXT)
    out_meta["thr_used_prob"] = float(thr_prob)
    out_meta["thr_used_score"] = float(thr_score)
    return m_full, out_meta

m_tr,  meta_tr  = finalize_from_snapshot(SNAPSHOTS["train"], DECISION_THR_PROB, DECISION_THR_SCORE)
m_vl,  meta_vl  = finalize_from_snapshot(SNAPSHOTS["val"]  , DECISION_THR_PROB, DECISION_THR_SCORE)
m_te,  meta_te  = finalize_from_snapshot(SNAPSHOTS["test"] , DECISION_THR_PROB, DECISION_THR_SCORE)

with open(RES_DIR/"per_epoch_log.json","w",encoding="utf-8") as f:
    json.dump(per_epoch_log, f, ensure_ascii=False, indent=2)

with open(RES_DIR/"metrics.json","w",encoding="utf-8") as f:
    json.dump({
        "train": m_tr, "val": m_vl, "test": m_te,
        "selected_regions": REGIONS,
        "feature_cols": FEAT_COLS,
        "class_mapping": CLS2TEXT,
        "threshold_pct_for_positive": TH_UP_PCT,
        "decision_threshold_on_val_prob": float(DECISION_THR_PROB),
        "decision_threshold_on_val_score": float(DECISION_THR_SCORE),
        "calibrator_state": CAL_STATE,
        "use_sample_weights": USE_SAMPLE_WEIGHTS,
        "pos_weight_mode": POS_WEIGHT_MODE,
        "seeds": ENSEMBLE_SEEDS,
        "attn_impl": ATTN_IMPL,
        "calibrator": CALIBRATOR,
        "dataloader": {"num_workers": 0, "pin_memory": False},
        "split_by": "date_target_month",
        "split_months": {"train": train_months, "val": val_months, "test": test_months},
        "selected_epoch": SELECTED_EPOCH,
        "select_strategy": "best_val_multi_no_spec_fixed_thr",
        "fixed_thr_prob": FIXED_THR_PROB,
        "data_path": DATA_PATH,
        "data_columns_mapping": {"月份":"period","地区":"地区_std","发病数":"cases","全部":"index_all"}
    }, f, ensure_ascii=False, indent=2)

with pd.ExcelWriter(RES_DIR/"cls_predictions.xlsx", engine="xlsxwriter") as writer:
    meta_tr.to_excel(writer, sheet_name="train", index=False)
    meta_vl.to_excel(writer, sheet_name="val"  , index=False)
    meta_te.to_excel(writer, sheet_name="test" , index=False)

