In [18]:
# ===========================
# PD-only 분류 + 정상상태 용량탐색 (MLP 전용)
# - 데이터: EstData.csv (PD: ng/mL), 임계 3.3 ng/mL
# - 모델: MLP (레이어/히든/드롭아웃)
# - 분할: 플라시보(0 mg, ID 1–12) 제외, 1/3/10 mg 고정 그룹별 70/15/15 ID 단위
# - 추가: 베스트 스냅샷 복원, EarlyStopping, LR 스케줄러, Weight Decay,
#         검증 F1 최대 임계값 튜닝 → 테스트/시뮬에 일관 적용
# ===========================
import os, math, warnings, numpy as np, pandas as pd
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, average_precision_score, confusion_matrix
)
from collections import defaultdict
from typing import Optional
from copy import deepcopy
import random

# -------- 설정 --------
CSV = "EstData.csv"       # 필요 시 절대경로로 교체
PD_THRESHOLD = 3.3

# 학습/최적화 설정
EPOCHS = 120
LR = 5e-2
WEIGHT_DECAY = 1e-4
PATIENCE = 12               # EarlyStopping 인내 에폭
SCHED_FACTOR = 0.5          # ReduceLROnPlateau 감쇠 비율
SCHED_PATIENCE = 5          # 스케줄러 인내 에폭(지표 기준)
CLIP_NORM = 1.0             # 그라디언트 클리핑 (안쓰려면 None)

# 시뮬레이션/검색 설정
N_SUBJ = 300
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# MLP 하이퍼파라미터
MLP_HIDDEN = 36
MLP_LAYERS = 2
MLP_DROPOUT = 0.1

# 재현성(완전 결정적 보장은 환경 의존. CUDA에서 완전 결정적 필요시 CUBLAS_WORKSPACE_CONFIG 설정 필요)
def set_global_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # 완전 결정적이 필요하면 환경변수 설정 필요(CuBLAS). 여기서는 사용하지 않음.

set_global_seed(42)
assert os.path.exists(CSV), f"CSV not found at {CSV}"

# -------- 유틸: 열 이름 추론 --------
def _find_col_like(df, name_opts):
    low = {c.lower(): c for c in df.columns}
    for n in name_opts:
        if n in low: return low[n]
    return None

# -------- 데이터 로딩/전처리 --------
df = pd.read_csv(CSV)

col_ID   = _find_col_like(df, ["id"])
col_TIME = _find_col_like(df, ["time"])
col_DVID = _find_col_like(df, ["dvid"])
col_DV   = _find_col_like(df, ["dv","pd","value"])
col_EVID = _find_col_like(df, ["evid"])
col_AMT  = _find_col_like(df, ["amt","dose","dosen","doses"])
col_BW   = _find_col_like(df, ["bw","weight","bodyweight"])
col_COMED= _find_col_like(df, ["comed","conmed","concom"])

need = [col_ID, col_TIME, col_DV, col_AMT]
miss = [n for n,v in zip(["ID","TIME","DV","AMT/DOSE"], need) if v is None]
if miss:
    warnings.warn(f"Columns missing (minimum required): {miss}")

# PD 행 추출 (DVID==2가 있으면 그걸 사용)
if col_DVID is not None and col_DV is not None:
    pdf = df[df[col_DVID]==2].copy()
else:
    pdf = df.copy()

# 투약 이벤트 (EVID==1 우선, 아니면 AMT/DOSE notna)
if col_EVID is not None:
    dose_df = df[df[col_EVID]==1].copy()
else:
    dose_df = df[df[col_AMT].notna()].copy()

# 숫자 변환
for c in [col_TIME, col_DV, col_AMT, col_BW, col_COMED]:
    if c is not None:
        pdf[c] = pd.to_numeric(pdf[c], errors="coerce")
        dose_df[c] = pd.to_numeric(dose_df[c], errors="coerce")

# 정렬/필요 열만
keep_pd = [c for c in [col_ID,col_TIME,col_DV,col_BW,col_COMED] if c is not None]
keep_dose = [c for c in [col_ID,col_TIME,col_AMT] if c is not None]
pdf = pdf[keep_pd].dropna().sort_values([col_ID, col_TIME])
dose_df = dose_df[keep_dose].dropna().sort_values([col_ID, col_TIME])

print("Detected columns:", dict(ID=col_ID, TIME=col_TIME, DV=col_DV, EVID=col_EVID, AMT=col_AMT, BW=col_BW, COMED=col_COMED))
print("PD rows:", len(pdf), "Dose rows:", len(dose_df))

# -------- 데이터셋 --------
class PDSamples(Dataset):
    def __init__(self, pd_df: pd.DataFrame, dose_df: pd.DataFrame,
                 col_ID: str, col_TIME: str, col_DV: str,
                 col_BW: Optional[str], col_COMED: Optional[str], pd_threshold: float=3.3):
        self.col_ID, self.col_TIME, self.col_DV = col_ID, col_TIME, col_DV
        self.col_BW, self.col_COMED = col_BW, col_COMED
        self.pd_threshold = pd_threshold

        # ID별 투약 히스토리
        d_groups = defaultdict(list)
        for _, row in dose_df.iterrows():
            d_groups[row[col_ID]].append((float(row[col_TIME]), float(row[col_AMT])))
        self.dose_map = {k: (np.array([t for t,a in v], dtype=np.float32),
                              np.array([a for t,a in v], dtype=np.float32)) for k,v in d_groups.items()}

        # 샘플: 각 PD 관측 시점
        feats = []
        for _, row in pd_df.iterrows():
            sid = row[col_ID]
            if sid not in self.dose_map:
                continue
            t = float(row[col_TIME])
            val = float(row[col_DV])
            y = 1.0 if (val <= pd_threshold) else 0.0
            bw = float(row[col_BW]) if col_BW is not None else 70.0
            cm = float(row[col_COMED]) if col_COMED is not None else 0.0
            feats.append((sid, t, y, bw, cm))
        self.samples = feats

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        sid, t, y, bw, cm = self.samples[idx]
        dose_t, dose_a = self.dose_map.get(sid, (np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32)))
        return {
            "t": torch.tensor(t, dtype=torch.float32),
            "y": torch.tensor(y, dtype=torch.float32),
            "bw": torch.tensor(bw, dtype=torch.float32),
            "cm": torch.tensor(cm, dtype=torch.float32),
            "dose_t": torch.tensor(dose_t, dtype=torch.float32),
            "dose_a": torch.tensor(dose_a, dtype=torch.float32),
        }

def _collate(batch): return batch

dataset_all = PDSamples(pdf, dose_df, col_ID, col_TIME, col_DV, col_BW, col_COMED, pd_threshold=PD_THRESHOLD)

# -------- ID 단위 분할: 플라시보 제외 + 1/3/10 mg 그룹별 70/15/15 --------
rng = 42
placebo_ids = set(range(1, 13))     # 0 mg (제외)
ids_1mg     = list(range(13, 25))   # 1 mg
ids_3mg     = list(range(25, 37))   # 3 mg
ids_10mg    = list(range(37, 49))   # 10 mg

ids_in_data = set(df[col_ID].unique())
ids_1mg  = sorted(ids_in_data.intersection(ids_1mg))
ids_3mg  = sorted(ids_in_data.intersection(ids_3mg))
ids_10mg = sorted(ids_in_data.intersection(ids_10mg))

def split_70_15_15(ids, seed=42):
    if len(ids) == 0:
        return [], [], []
    tr_ids, temp_ids = train_test_split(ids, test_size=0.30, random_state=seed, shuffle=True)
    va_ids, te_ids = train_test_split(temp_ids, test_size=0.50, random_state=seed, shuffle=True)
    return list(tr_ids), list(va_ids), list(te_ids)

tr_1, va_1, te_1     = split_70_15_15(ids_1mg,  seed=rng)
tr_3, va_3, te_3     = split_70_15_15(ids_3mg,  seed=rng)
tr_10, va_10, te_10  = split_70_15_15(ids_10mg, seed=rng)

ids_tr = set(tr_1 + tr_3 + tr_10)
ids_va = set(va_1 + va_3 + va_10)
ids_te = set(te_1 + te_3 + te_10)

sid_list = [s[0] for s in dataset_all.samples]
tr_idx = [i for i, sid in enumerate(sid_list) if sid in ids_tr]
va_idx = [i for i, sid in enumerate(sid_list) if sid in ids_va]
te_idx = [i for i, sid in enumerate(sid_list) if sid in ids_te]

train_ds = Subset(dataset_all, tr_idx)
valid_ds = Subset(dataset_all, va_idx)
test_ds  = Subset(dataset_all, te_idx)

print("[ID split by fixed dose groups] (70/15/15)")
print(" train IDs:", len(ids_tr), "| valid IDs:", len(ids_va), "| test IDs:", len(ids_te))
print(" 1mg -> train/valid/test:", len(tr_1), len(va_1), len(te_1))
print(" 3mg -> train/valid/test:", len(tr_3), len(va_3), len(te_3))
print("10mg -> train/valid/test:", len(tr_10), len(va_10), len(te_10))
print(f"#samples -> train: {len(train_ds)} | valid: {len(valid_ds)} | test: {len(test_ds)}")

# -------- MLP 모델 --------
class PDMLPClassifier(nn.Module):
    """
    입력 x=[E(t), (BW-mean)/10, COMED] -> [Linear+ReLU(+Dropout)]*n_layers -> Linear(→1)
    노출 링크: log(tau) = b0 + b1*((BW-mean)/10) + b2*COMED
    """
    def __init__(self, bw_mean: float=70.0, hidden: int=32, n_layers: int=3, dropout_p: float=0.2):
        super().__init__()
        self.b0 = nn.Parameter(torch.tensor(math.log(24.0)))
        self.b1 = nn.Parameter(torch.tensor(0.0))
        self.b2 = nn.Parameter(torch.tensor(0.0))
        self.bw_mean = float(bw_mean)

        layers = []
        in_dim = 3
        for i in range(n_layers):
            layers += [
                nn.Linear(in_dim if i==0 else hidden, hidden),
                nn.ReLU(),
                nn.Dropout(p=dropout_p) if dropout_p and dropout_p > 0 else nn.Identity(),
            ]
        layers += [nn.Linear(hidden if n_layers>0 else in_dim, 1)]
        self.mlp = nn.Sequential(*layers)

    def _tau(self, bw: torch.Tensor, comed: torch.Tensor):
        bwc = (bw - self.bw_mean) / 10.0
        log_tau = self.b0 + self.b1*bwc + self.b2*comed
        return torch.exp(log_tau).clamp_min(1.0)

    @torch.no_grad()
    def tau_from_cov_np(self, bw_np: np.ndarray, cm_np: np.ndarray, device=None):
        device = device or next(self.parameters()).device
        bw = torch.tensor(bw_np, dtype=torch.float32, device=device)
        cm = torch.tensor(cm_np, dtype=torch.float32, device=device)
        return self._tau(bw, cm).detach().cpu().numpy()

    def forward_single(self, t: torch.Tensor, dose_t: torch.Tensor, dose_a: torch.Tensor,
                       bw: float, comed: float):
        tau = self._tau(torch.tensor(bw, dtype=torch.float32, device=t.device),
                        torch.tensor(comed, dtype=torch.float32, device=t.device))
        dt = t - dose_t
        mask = (dt >= 0).float()
        exposure = (dose_a * torch.exp(-dt.clamp_min(0) / tau) * mask).sum()
        x = torch.stack([
            exposure,
            (torch.tensor(bw, dtype=torch.float32, device=t.device) - self.bw_mean) / 10.0,
            torch.tensor(comed, dtype=torch.float32, device=t.device)
        ])
        return self.mlp(x).squeeze()

def _collate(batch): return batch

# -------- 지표 계산/도움 함수 --------
def _compute_metrics(y_true: np.ndarray, prob: np.ndarray, thr: float = 0.5):
    pred = (prob >= thr).astype(int)
    acc = accuracy_score(y_true, pred) if len(y_true) else float("nan")
    prec = precision_score(y_true, pred, zero_division=0)
    rec  = recall_score(y_true, pred, zero_division=0)
    f1   = f1_score(y_true, pred, zero_division=0)
    try:
        roc = roc_auc_score(y_true, prob) if (len(np.unique(y_true))>1) else float("nan")
    except Exception:
        roc = float("nan")
    try:
        ap  = average_precision_score(y_true, prob) if (len(np.unique(y_true))>1) else float("nan")
    except Exception:
        ap  = float("nan")
    try:
        tn, fp, fn, tp = confusion_matrix(y_true, pred, labels=[0,1]).ravel()
    except Exception:
        tn=fp=fn=tp = 0
    return {"acc":acc, "prec":prec, "rec":rec, "f1":f1, "roc_auc":roc, "pr_auc":ap,
            "tn":tn, "fp":fp, "fn":fn, "tp":tp}

def _predict_dataset(model, dataset, batch_size=512):
    device = next(model.parameters()).device
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=_collate)
    ys, ps = [], []
    model.eval()
    with torch.no_grad():
        for batch in loader:
            for b in batch:
                logit = model.forward_single(
                    b["t"].to(device), b["dose_t"].to(device), b["dose_a"].to(device),
                    float(b["bw"]), float(b["cm"])
                )
                ys.append(float(b["y"]))
                ps.append(torch.sigmoid(logit).item())
    return np.array(ys), np.array(ps)

def find_best_threshold(y, p, grid=None, metric="f1"):
    if grid is None:
        grid = np.linspace(0.05, 0.95, 91)  # 0.05~0.95
    best_thr, best_val = 0.5, -1.0
    for thr in grid:
        m = _compute_metrics(y, p, thr=thr)
        val = m["f1"] if metric=="f1" else (0.5*m["prec"] + 0.5*m["rec"])
        if val > best_val:
            best_val, best_thr = val, thr
    return best_thr, best_val

# -------- 학습 루틴(베스트 스냅샷/ES/스케줄러/WD) --------
def train_classifier(train_ds, valid_ds,
                     epochs=60, lr=5e-2, seed=42,
                     mlp_hidden=32, mlp_layers=3, mlp_dropout=0.2,
                     weight_decay=0.0, patience=10,
                     sched_factor=0.5, sched_patience=5,
                     clip_norm=None):
    device = DEVICE
    g = torch.Generator().manual_seed(seed)
    tr_loader = DataLoader(train_ds, batch_size=128, shuffle=True, collate_fn=_collate, generator=g)
    va_loader = DataLoader(valid_ds, batch_size=256, shuffle=False, collate_fn=_collate)

    bw_mean = float(np.mean([b["bw"].item() for b in train_ds]))
    model = PDMLPClassifier(bw_mean=bw_mean, hidden=mlp_hidden, n_layers=mlp_layers, dropout_p=mlp_dropout).to(device)

    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = nn.BCEWithLogitsLoss()
    # ReduceLROnPlateau: 검증 F1 기준, 상승 없으면 LR 감소
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        opt, mode='max', factor=sched_factor, patience=sched_patience, verbose=True
    )

    best = (-1.0, None)
    best_state = None
    es_counter = 0  # EarlyStopping 카운터

    for ep in range(1, epochs+1):
        model.train()
        for batch in tr_loader:
            opt.zero_grad()
            loss = 0.0
            for b in batch:
                logit = model.forward_single(
                    b["t"].to(device), b["dose_t"].to(device), b["dose_a"].to(device),
                    float(b["bw"]), float(b["cm"])
                )
                loss = loss + loss_fn(logit.view(()), b["y"].to(device).view(()))
            loss = loss/len(batch)
            loss.backward()
            if clip_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
            opt.step()

        # ---- validation ----
        y, p = _predict_dataset(model, valid_ds, batch_size=256)
        # 모니터링 지표: F1(thr=0.5). (임계값 튜닝은 학습 후 별도)
        metrics = _compute_metrics(y, p, thr=0.5)

        # 스케줄러 스텝(F1 기준)
        scheduler.step(metrics["f1"])

        # 베스트 갱신/스냅샷
        if metrics["f1"] > best[0]:
            best = (metrics["f1"], {"epoch":ep, **metrics})
            best_state = deepcopy(model.state_dict())
            es_counter = 0
        else:
            es_counter += 1

        if ep%10==0 or ep<=3:
            lr_cur = opt.param_groups[0]["lr"]
            print(f"[ep {ep:03d}] lr={lr_cur:.5f} acc={metrics['acc']:.3f} f1={metrics['f1']:.3f} "
                  f"prec={metrics['prec']:.3f} rec={metrics['rec']:.3f} "
                  f"roc_auc={metrics['roc_auc']:.3f} pr_auc={metrics['pr_auc']:.3f}")

        # EarlyStopping
        if es_counter >= patience:
            print(f"EarlyStopping at epoch {ep} (no F1 improvement for {patience} epochs).")
            break

    # 베스트로 복원
    if best_state is not None:
        model.load_state_dict(best_state)

    print("best(valid @thr=0.5):", best[1])

    # ---- 검증셋에서 임계값 최적화 (베스트 모델로) ----
    y_val, p_val = _predict_dataset(model, valid_ds, batch_size=256)
    best_thr, best_f1 = find_best_threshold(y_val, p_val, grid=None, metric="f1")
    tuned_metrics = _compute_metrics(y_val, p_val, thr=best_thr)
    print(f"tuned threshold on valid: thr={best_thr:.3f} | F1={tuned_metrics['f1']:.4f} "
          f"(prec={tuned_metrics['prec']:.4f}, rec={tuned_metrics['rec']:.4f})")

    return model, best[1], best_thr

# -------- 인구 공변량 샘플러 --------
def subject_cov_sampler(pd_df: pd.DataFrame, col_BW: Optional[str], col_COMED: Optional[str],
                        n:int, scenario:str="base", rng_seed:int=123):
    obs_bw = pd_df[col_BW].dropna().to_numpy(dtype=float) if col_BW is not None else np.full(len(pd_df), 80.0)
    obs_cm = pd_df[col_COMED].dropna().to_numpy(dtype=float) if col_COMED is not None else np.zeros(len(pd_df))
    rng = np.random.default_rng(rng_seed)
    if scenario=="base":
        idx = rng.integers(0, len(obs_bw), size=n)
        return obs_bw[idx], obs_cm[idx]
    elif scenario=="bw_wide":
        bw = rng.uniform(70.0, 140.0, size=n)   # 필요시 관찰 범위로 조정 가능
        cm = obs_cm[rng.integers(0, len(obs_cm), size=n)] if len(obs_cm)>0 else np.zeros(n)
        return bw, cm
    elif scenario=="no_comed":
        idx = rng.integers(0, len(obs_bw), size=n)
        return obs_bw[idx], np.zeros(n)
    else:
        raise ValueError("unknown scenario")

# -------- 정상상태(SS) 평가: MLP --------
@torch.no_grad()
def success_fraction_for_dose_ss_mlp(model: PDMLPClassifier, dose_mg: float, freq_h: int,
                                     last_window_h: int, pd_df, col_BW: Optional[str], col_COMED: Optional[str],
                                     Nsubj=300, scenario="base", decision_threshold=0.5,
                                     grid_step_h=1.0) -> float:
    bw_arr, cm_arr = subject_cov_sampler(pd_df, col_BW, col_COMED, Nsubj, scenario=scenario)
    tau = model.tau_from_cov_np(bw_arr, cm_arr, device=next(model.parameters()).device)
    tgrid = np.arange(0.0, last_window_h + 1e-6, grid_step_h, dtype=float)
    ok = 0
    for i in range(Nsubj):
        denom = (1.0 - np.exp(-float(freq_h) / max(tau[i],1e-6)))
        denom = max(denom, 1e-6)
        e_t = dose_mg * np.exp(-tgrid / max(tau[i],1e-6)) / denom
        bwc = (bw_arr[i] - model.bw_mean) / 10.0
        cm  = cm_arr[i]
        X = torch.tensor(np.stack([e_t, np.full_like(e_t, bwc), np.full_like(e_t, cm)], axis=1),
                         dtype=torch.float32, device=next(model.parameters()).device)
        logits = model.mlp(X).squeeze(1)
        probs = torch.sigmoid(logits).detach().cpu().numpy()
        # 윈도우 전체 시간에서 성공 확률이 임계 이상이면 성공(아주 보수적 기준)
        if probs.min() >= decision_threshold:
            ok += 1
    return ok / Nsubj

def search_min_dose_ss(model: PDMLPClassifier, grid, freq_h, last_window_h, pd_df,
                       col_BW: Optional[str], col_COMED: Optional[str],
                       Nsubj=300, scenario="base", target=0.9, decision_threshold=0.5):
    rows = []
    for d in grid:
        frac = success_fraction_for_dose_ss_mlp(model, d, freq_h, last_window_h, pd_df, col_BW, col_COMED,
                                                Nsubj=Nsubj, scenario=scenario, decision_threshold=decision_threshold)
        rows.append({"dose": d, "fraction": frac})
    df_res = pd.DataFrame(rows).sort_values("dose")
    feas = df_res[df_res["fraction"]>=target]
    best = feas.iloc[0]["dose"] if len(feas)>0 else None
    return df_res, best

# -------- 학습 실행 (개선 루틴) --------
model, valid_best, tuned_thr = train_classifier(
    train_ds, valid_ds,
    epochs=EPOCHS, lr=LR,
    mlp_hidden=MLP_HIDDEN, mlp_layers=MLP_LAYERS, mlp_dropout=MLP_DROPOUT,
    weight_decay=WEIGHT_DECAY, patience=PATIENCE,
    sched_factor=SCHED_FACTOR, sched_patience=SCHED_PATIENCE,
    clip_norm=CLIP_NORM
)
print("\nValidation best @thr=0.5:", valid_best)
print(f"Validated tuned decision threshold: {tuned_thr:.3f}")

# -------- 용량 탐색 (튜닝 임계값 적용) --------
daily_grid  = [0.5*i for i in range(0, 121)]   # 0..60 mg, 0.5 mg
weekly_grid = [5*i   for i in range(0, 41)]    # 0..200 mg, 5 mg

daily_base,  best_daily_base  = search_min_dose_ss(model, daily_grid,  24, 24,  pdf, col_BW, col_COMED,
                                                   Nsubj=N_SUBJ, scenario="base",
                                                   target=0.90, decision_threshold=tuned_thr)
weekly_base, best_weekly_base = search_min_dose_ss(model, weekly_grid, 168, 168, pdf, col_BW, col_COMED,
                                                   Nsubj=N_SUBJ, scenario="base",
                                                   target=0.90, decision_threshold=tuned_thr)

daily_bw,   best_daily_bw   = search_min_dose_ss(model, daily_grid,  24, 24,  pdf, col_BW, col_COMED,
                                                 Nsubj=N_SUBJ, scenario="bw_wide",
                                                 target=0.90, decision_threshold=tuned_thr)
weekly_bw,  best_weekly_bw  = search_min_dose_ss(model, weekly_grid, 168, 168, pdf, col_BW, col_COMED,
                                                 Nsubj=N_SUBJ, scenario="bw_wide",
                                                 target=0.90, decision_threshold=tuned_thr)

daily_nocm, best_daily_nocm = search_min_dose_ss(model, daily_grid,  24, 24,  pdf, col_BW, col_COMED,
                                                 Nsubj=N_SUBJ, scenario="no_comed",
                                                 target=0.90, decision_threshold=tuned_thr)
weekly_nocm,best_weekly_nocm= search_min_dose_ss(model, weekly_grid, 168, 168, pdf, col_BW, col_COMED,
                                                 Nsubj=N_SUBJ, scenario="no_comed",
                                                 target=0.90, decision_threshold=tuned_thr)

daily_75,   best_daily_75   = search_min_dose_ss(model, daily_grid,  24, 24,  pdf, col_BW, col_COMED,
                                                 Nsubj=N_SUBJ, scenario="base",
                                                 target=0.75, decision_threshold=tuned_thr)
weekly_75,  best_weekly_75  = search_min_dose_ss(model, weekly_grid, 168, 168, pdf, col_BW, col_COMED,
                                                 Nsubj=N_SUBJ, scenario="base",
                                                 target=0.75, decision_threshold=tuned_thr)

summary = pd.DataFrame([
    {"scenario":"Base (Phase 1-like)", "target":"90%", "once-daily (mg)": best_daily_base,  "once-weekly (mg)": best_weekly_base},
    {"scenario":"BW 70–140 kg",        "target":"90%", "once-daily (mg)": best_daily_bw,    "once-weekly (mg)": best_weekly_bw},
    {"scenario":"No COMED allowed",    "target":"90%", "once-daily (mg)": best_daily_nocm,  "once-weekly (mg)": best_weekly_nocm},
    {"scenario":"Base (Phase 1-like)", "target":"75%", "once-daily (mg)": best_daily_75,    "once-weekly (mg)": best_weekly_75},
])
print("\n=== Dose recommendations summary (thr tuned) ===")
print(summary.to_string(index=False))

# -------- 테스트 평가 (튜닝 임계값 적용) --------
def evaluate_dataset(model, dataset, threshold=0.5, batch_size=512):
    y, p = _predict_dataset(model, dataset, batch_size=batch_size)
    return _compute_metrics(y, p, thr=threshold)

test_metrics = evaluate_dataset(model, test_ds, threshold=tuned_thr)
print("\n=== Final TEST metrics (unseen IDs, tuned thr) ===")
for k, v in test_metrics.items():
    if isinstance(v, float):
        print(f"{k:>8s}: {v:.4f}")
    else:
        print(f"{k:>8s}: {v}")


Detected columns: {'ID': 'ID', 'TIME': 'TIME', 'DV': 'DV', 'EVID': 'EVID', 'AMT': 'AMT', 'BW': 'BW', 'COMED': 'COMED'}
PD rows: 1200 Dose rows: 756
[ID split by fixed dose groups] (70/15/15)
 train IDs: 24 | valid IDs: 6 | test IDs: 6
 1mg -> train/valid/test: 8 2 2
 3mg -> train/valid/test: 8 2 2
10mg -> train/valid/test: 8 2 2
#samples -> train: 600 | valid: 150 | test: 150




[ep 001] lr=0.05000 acc=0.787 f1=0.000 prec=0.000 rec=0.000 roc_auc=0.701 pr_auc=0.462
[ep 002] lr=0.05000 acc=0.787 f1=0.000 prec=0.000 rec=0.000 roc_auc=0.583 pr_auc=0.457
[ep 003] lr=0.05000 acc=0.787 f1=0.000 prec=0.000 rec=0.000 roc_auc=0.547 pr_auc=0.420
[ep 010] lr=0.05000 acc=0.787 f1=0.000 prec=0.000 rec=0.000 roc_auc=0.730 pr_auc=0.554
[ep 020] lr=0.05000 acc=0.913 f1=0.772 prec=0.880 rec=0.688 roc_auc=0.908 pr_auc=0.841
[ep 030] lr=0.05000 acc=0.940 f1=0.836 prec=1.000 rec=0.719 roc_auc=0.966 pr_auc=0.924
[ep 040] lr=0.05000 acc=0.953 f1=0.881 prec=0.963 rec=0.812 roc_auc=0.976 pr_auc=0.941
[ep 050] lr=0.02500 acc=0.947 f1=0.857 prec=1.000 rec=0.750 roc_auc=0.988 pr_auc=0.963
EarlyStopping at epoch 56 (no F1 improvement for 12 epochs).
best(valid @thr=0.5): {'epoch': 44, 'acc': 0.96, 'prec': 1.0, 'rec': 0.8125, 'f1': 0.896551724137931, 'roc_auc': np.float64(0.9883474576271187), 'pr_auc': np.float64(0.9648562601644852), 'tn': np.int64(118), 'fp': np.int64(0), 'fn': np.int64(6

In [21]:
# ===========================
# PD-only 분류 + 정상상태 용량탐색 (MLP) 
# └ 외부 루프: "테스트셋"을 여러 번 임의로 바꿔가며 반복 평가
# └ 내부 루프: train+valid에서 K-Fold(Group=ID) + 시드 반복
# ===========================
import os, math, warnings, numpy as np, pandas as pd
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, average_precision_score, confusion_matrix
)
from collections import defaultdict
from typing import Optional, List, Tuple
from copy import deepcopy
import random

# -------- 경로/기본 설정 --------
CSV = "EstData.csv"
assert os.path.exists(CSV), f"CSV not found at {CSV}"

PD_THRESHOLD = 3.3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# 학습/최적화
EPOCHS = 120
LR = 5e-2
WEIGHT_DECAY = 1e-4
PATIENCE = 12
SCHED_FACTOR = 0.5
SCHED_PATIENCE = 5
CLIP_NORM = 1.0

# MLP 하이퍼파라미터
MLP_HIDDEN = 36
MLP_LAYERS = 2
MLP_DROPOUT = 0.1

# 내부 CV
N_SPLITS = 5
SEEDS_INNER = [42, 3407, 777, 2021, 123]   # 내부(K-Fold×Seed) 반복

# 외부 테스트 분할 반복 (테스트셋을 바꿔가며)
TEST_SEEDS = [10, 11, 12, 13, 14]          # 원하는 만큼 추가
SAVE_CSV = True
OUT_CSV = "results_random_test_repeats.csv"

def set_global_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def _find_col_like(df, name_opts):
    low = {c.lower(): c for c in df.columns}
    for n in name_opts:
        if n in low: return low[n]
    return None

# -------- 데이터 로딩/전처리 --------
df = pd.read_csv(CSV)

col_ID   = _find_col_like(df, ["id"])
col_TIME = _find_col_like(df, ["time"])
col_DVID = _find_col_like(df, ["dvid"])
col_DV   = _find_col_like(df, ["dv","pd","value"])
col_EVID = _find_col_like(df, ["evid"])
col_AMT  = _find_col_like(df, ["amt","dose","dosen","doses"])
col_BW   = _find_col_like(df, ["bw","weight","bodyweight"])
col_COMED= _find_col_like(df, ["comed","conmed","concom"])

need = [col_ID, col_TIME, col_DV, col_AMT]
miss = [n for n,v in zip(["ID","TIME","DV","AMT/DOSE"], need) if v is None]
if miss:
    warnings.warn(f"Columns missing (minimum required): {miss}")

# PD 프레임
if col_DVID is not None and col_DV is not None:
    pdf = df[df[col_DVID]==2].copy()
else:
    pdf = df.copy()

# 투약 이벤트
if col_EVID is not None:
    dose_df = df[df[col_EVID]==1].copy()
else:
    dose_df = df[df[col_AMT].notna()].copy()

# 숫자 변환
for c in [col_TIME, col_DV, col_AMT, col_BW, col_COMED]:
    if c is not None:
        pdf[c] = pd.to_numeric(pdf[c], errors="coerce")
        dose_df[c] = pd.to_numeric(dose_df[c], errors="coerce")

# 정렬/필요 열
keep_pd = [c for c in [col_ID,col_TIME,col_DV,col_BW,col_COMED] if c is not None]
keep_dose = [c for c in [col_ID,col_TIME,col_AMT] if c is not None]
pdf = pdf[keep_pd].dropna().sort_values([col_ID, col_TIME])
dose_df = dose_df[keep_dose].dropna().sort_values([col_ID, col_TIME])

print("Detected columns:", dict(ID=col_ID, TIME=col_TIME, DV=col_DV, EVID=col_EVID, AMT=col_AMT, BW=col_BW, COMED=col_COMED))
print("PD rows:", len(pdf), "Dose rows:", len(dose_df))

# -------- Dataset --------
class PDSamples(Dataset):
    def __init__(self, pd_df: pd.DataFrame, dose_df: pd.DataFrame,
                 col_ID: str, col_TIME: str, col_DV: str,
                 col_BW: Optional[str], col_COMED: Optional[str], pd_threshold: float=3.3):
        self.col_ID, self.col_TIME, self.col_DV = col_ID, col_TIME, col_DV
        self.col_BW, self.col_COMED = col_BW, col_COMED
        self.pd_threshold = pd_threshold

        d_groups = defaultdict(list)
        for _, row in dose_df.iterrows():
            d_groups[row[col_ID]].append((float(row[col_TIME]), float(row[col_AMT])))
        self.dose_map = {k: (np.array([t for t,a in v], dtype=np.float32),
                              np.array([a for t,a in v], dtype=np.float32)) for k,v in d_groups.items()}

        feats = []
        for _, row in pd_df.iterrows():
            sid = row[col_ID]
            if sid not in self.dose_map:
                continue
            t = float(row[col_TIME])
            val = float(row[col_DV])
            y = 1.0 if (val <= pd_threshold) else 0.0
            bw = float(row[col_BW]) if col_BW is not None else 70.0
            cm = float(row[col_COMED]) if col_COMED is not None else 0.0
            feats.append((sid, t, y, bw, cm))
        self.samples = feats

    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        sid, t, y, bw, cm = self.samples[idx]
        dose_t, dose_a = self.dose_map.get(sid, (np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32)))
        return {
            "sid": sid,
            "t": torch.tensor(t, dtype=torch.float32),
            "y": torch.tensor(y, dtype=torch.float32),
            "bw": torch.tensor(bw, dtype=torch.float32),
            "cm": torch.tensor(cm, dtype=torch.float32),
            "dose_t": torch.tensor(dose_t, dtype=torch.float32),
            "dose_a": torch.tensor(dose_a, dtype=torch.float32),
        }

def _collate(batch): return batch
dataset_all = PDSamples(pdf, dose_df, col_ID, col_TIME, col_DV, col_BW, col_COMED, pd_threshold=PD_THRESHOLD)

# -------- 투약 군 ID (플라시보 제외) --------
placebo_ids = set(range(1, 13))     # 제외
ids_1mg     = list(range(13, 25))   # 후보
ids_3mg     = list(range(25, 37))
ids_10mg    = list(range(37, 49))

ids_in_data = set(df[col_ID].unique())
ids_1mg  = sorted(ids_in_data.intersection(ids_1mg))
ids_3mg  = sorted(ids_in_data.intersection(ids_3mg))
ids_10mg = sorted(ids_in_data.intersection(ids_10mg))

def split_70_15_15(ids, seed=42):
    if len(ids) == 0:
        return [], [], []
    tr_ids, temp_ids = train_test_split(ids, test_size=0.30, random_state=seed, shuffle=True)
    va_ids, te_ids = train_test_split(temp_ids, test_size=0.50, random_state=seed, shuffle=True)
    return list(tr_ids), list(va_ids), list(te_ids)

def indices_by_ids(id_set: set) -> List[int]:
    sids = [s[0] for s in dataset_all.samples]
    return [i for i, sid in enumerate(sids) if sid in id_set]

# -------- 모델/유틸 --------
class PDMLPClassifier(nn.Module):
    def __init__(self, bw_mean: float=70.0, hidden: int=32, n_layers: int=3, dropout_p: float=0.2):
        super().__init__()
        self.b0 = nn.Parameter(torch.tensor(math.log(24.0)))
        self.b1 = nn.Parameter(torch.tensor(0.0))
        self.b2 = nn.Parameter(torch.tensor(0.0))
        self.bw_mean = float(bw_mean)

        layers = []
        in_dim = 3
        for i in range(n_layers):
            layers += [
                nn.Linear(in_dim if i==0 else hidden, hidden),
                nn.ReLU(),
                nn.Dropout(p=dropout_p) if dropout_p and dropout_p > 0 else nn.Identity(),
            ]
        layers += [nn.Linear(hidden if n_layers>0 else in_dim, 1)]
        self.mlp = nn.Sequential(*layers)

    def _tau(self, bw: torch.Tensor, comed: torch.Tensor):
        bwc = (bw - self.bw_mean) / 10.0
        log_tau = self.b0 + self.b1*bwc + self.b2*comed
        return torch.exp(log_tau).clamp_min(1.0)

    @torch.no_grad()
    def tau_from_cov_np(self, bw_np: np.ndarray, cm_np: np.ndarray, device=None):
        device = device or next(self.parameters()).device
        bw = torch.tensor(bw_np, dtype=torch.float32, device=device)
        cm = torch.tensor(cm_np, dtype=torch.float32, device=device)
        return self._tau(bw, cm).detach().cpu().numpy()

    def forward_single(self, t: torch.Tensor, dose_t: torch.Tensor, dose_a: torch.Tensor,
                       bw: float, comed: float):
        tau = self._tau(torch.tensor(bw, dtype=torch.float32, device=t.device),
                        torch.tensor(comed, dtype=torch.float32, device=t.device))
        dt = t - dose_t
        mask = (dt >= 0).float()
        exposure = (dose_a * torch.exp(-dt.clamp_min(0) / tau) * mask).sum()
        x = torch.stack([
            exposure,
            (torch.tensor(bw, dtype=torch.float32, device=t.device) - self.bw_mean) / 10.0,
            torch.tensor(comed, dtype=torch.float32, device=t.device)
        ])
        return self.mlp(x).squeeze()

def _compute_metrics(y_true: np.ndarray, prob: np.ndarray, thr: float = 0.5):
    pred = (prob >= thr).astype(int)
    acc = accuracy_score(y_true, pred) if len(y_true) else float("nan")
    prec = precision_score(y_true, pred, zero_division=0)
    rec  = recall_score(y_true, pred, zero_division=0)
    f1   = f1_score(y_true, pred, zero_division=0)
    try:
        roc = roc_auc_score(y_true, prob) if (len(np.unique(y_true))>1) else float("nan")
    except Exception:
        roc = float("nan")
    try:
        ap  = average_precision_score(y_true, prob) if (len(np.unique(y_true))>1) else float("nan")
    except Exception:
        ap  = float("nan")
    try:
        tn, fp, fn, tp = confusion_matrix(y_true, pred, labels=[0,1]).ravel()
    except Exception:
        tn=fp=fn=tp = 0
    return {"acc":acc, "prec":prec, "rec":rec, "f1":f1, "roc_auc":roc, "pr_auc":ap,
            "tn":tn, "fp":fp, "fn":fn, "tp":tp}

def _predict_dataset(model, dataset, batch_size=512):
    device = next(model.parameters()).device
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=_collate)
    ys, ps = [], []
    model.eval()
    with torch.no_grad():
        for batch in loader:
            for b in batch:
                logit = model.forward_single(
                    b["t"].to(device), b["dose_t"].to(device), b["dose_a"].to(device),
                    float(b["bw"]), float(b["cm"])
                )
                ys.append(float(b["y"]))
                ps.append(torch.sigmoid(logit).item())
    return np.array(ys), np.array(ps)

def find_best_threshold(y, p, grid=None, metric="f1"):
    if grid is None:
        grid = np.linspace(0.05, 0.95, 91)
    best_thr, best_val = 0.5, -1.0
    for thr in grid:
        m = _compute_metrics(y, p, thr=thr)
        val = m["f1"] if metric=="f1" else (0.5*m["prec"] + 0.5*m["rec"])
        if val > best_val:
            best_val, best_thr = val, thr
    return best_thr, best_val

def train_classifier(train_ds, valid_ds,
                     epochs=60, lr=5e-2, seed=42,
                     mlp_hidden=32, mlp_layers=3, mlp_dropout=0.2,
                     weight_decay=0.0, patience=10,
                     sched_factor=0.5, sched_patience=5,
                     clip_norm=None):
    device = DEVICE
    set_global_seed(seed)
    g = torch.Generator().manual_seed(seed)
    tr_loader = DataLoader(train_ds, batch_size=128, shuffle=True, collate_fn=_collate, generator=g)
    va_loader = DataLoader(valid_ds, batch_size=256, shuffle=False, collate_fn=_collate)

    bw_mean = float(np.mean([b["bw"].item() for b in train_ds]))
    model = PDMLPClassifier(bw_mean=bw_mean, hidden=mlp_hidden, n_layers=mlp_layers, dropout_p=mlp_dropout).to(device)

    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = nn.BCEWithLogitsLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        opt, mode='max', factor=sched_factor, patience=sched_patience, verbose=False
    )

    best = (-1.0, None)
    best_state = None
    es_counter = 0

    for ep in range(1, epochs+1):
        model.train()
        for batch in tr_loader:
            opt.zero_grad()
            loss = 0.0
            for b in batch:
                logit = model.forward_single(
                    b["t"].to(device), b["dose_t"].to(device), b["dose_a"].to(device),
                    float(b["bw"]), float(b["cm"])
                )
                loss = loss + loss_fn(logit.view(()), b["y"].to(device).view(()))
            loss = loss/len(batch)
            loss.backward()
            if clip_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
            opt.step()

        # validation (모니터링만; 임계값 튜닝은 아래서 별도)
        y, p = _predict_dataset(model, valid_ds, batch_size=256)
        metrics = _compute_metrics(y, p, thr=0.5)
        scheduler.step(metrics["f1"])

        if metrics["f1"] > best[0]:
            best = (metrics["f1"], {"epoch":ep, **metrics})
            best_state = deepcopy(model.state_dict())
            es_counter = 0
        else:
            es_counter += 1

        if es_counter >= patience:
            break

    if best_state is not None:
        model.load_state_dict(best_state)

    # 임계값 튜닝 (검증 확률 기반)
    y_val, p_val = _predict_dataset(model, valid_ds, batch_size=256)
    best_thr, _ = find_best_threshold(y_val, p_val, grid=None, metric="f1")
    tuned_metrics = _compute_metrics(y_val, p_val, thr=best_thr)

    return model, best[1], best_thr, tuned_metrics

# -------- CV 폴드 도우미 (용량군 균형 유지) --------
def make_balanced_group_folds(ids_group: List[int], n_splits:int, seed:int) -> List[List[int]]:
    rng = np.random.default_rng(seed)
    ids = np.array(ids_group, dtype=int)
    rng.shuffle(ids)
    folds = np.array_split(ids, n_splits)
    return [list(f) for f in folds]

def make_cv_folds(ids_1: List[int], ids_3: List[int], ids_10: List[int],
                  candidate_ids: List[int], n_splits:int, seed:int) -> List[Tuple[List[int], List[int]]]:
    folds_1  = make_balanced_group_folds([i for i in ids_1  if i in candidate_ids], n_splits, seed)
    folds_3  = make_balanced_group_folds([i for i in ids_3  if i in candidate_ids], n_splits, seed+1)
    folds_10 = make_balanced_group_folds([i for i in ids_10 if i in candidate_ids], n_splits, seed+2)
    folds = []
    for k in range(n_splits):
        val_ids = set(folds_1[k]) | set(folds_3[k]) | set(folds_10[k])
        all_ids = set(candidate_ids)
        tr_ids  = sorted(list(all_ids - val_ids))
        folds.append((tr_ids, sorted(list(val_ids))))
    return folds

def subset_by_ids(ids: List[int]) -> Subset:
    sids = [s[0] for s in dataset_all.samples]
    idx = [i for i, sid in enumerate(sids) if sid in set(ids)]
    return Subset(dataset_all, idx)

def evaluate_dataset(model, dataset, threshold=0.5, batch_size=512):
    y, p = _predict_dataset(model, dataset, batch_size=batch_size)
    return _compute_metrics(y, p, thr=threshold), y, p

# -------- 정상상태 시뮬레이션(옵션) --------
def subject_cov_sampler(pd_df: pd.DataFrame, col_BW: Optional[str], col_COMED: Optional[str],
                        n:int, scenario:str="base", rng_seed:int=123):
    obs_bw = pd_df[col_BW].dropna().to_numpy(dtype=float) if col_BW is not None else np.full(len(pd_df), 80.0)
    obs_cm = pd_df[col_COMED].dropna().to_numpy(dtype=float) if col_COMED is not None else np.zeros(len(pd_df))
    rng = np.random.default_rng(rng_seed)
    if scenario=="base":
        idx = rng.integers(0, len(obs_bw), size=n)
        return obs_bw[idx], obs_cm[idx]
    elif scenario=="bw_wide":
        bw = rng.uniform(70.0, 140.0, size=n)
        cm = obs_cm[rng.integers(0, len(obs_cm), size=n)] if len(obs_cm)>0 else np.zeros(n)
        return bw, cm
    elif scenario=="no_comed":
        idx = rng.integers(0, len(obs_bw), size=n)
        return obs_bw[idx], np.zeros(n)
    else:
        raise ValueError("unknown scenario")

@torch.no_grad()
def success_fraction_for_dose_ss_mlp(model: PDMLPClassifier, dose_mg: float, freq_h: int,
                                     last_window_h: int, pd_df, col_BW: Optional[str], col_COMED: Optional[str],
                                     Nsubj=300, scenario="base", decision_threshold=0.5,
                                     grid_step_h=1.0) -> float:
    bw_arr, cm_arr = subject_cov_sampler(pd_df, col_BW, col_COMED, Nsubj, scenario=scenario)
    tau = model.tau_from_cov_np(bw_arr, cm_arr, device=next(model.parameters()).device)
    tgrid = np.arange(0.0, last_window_h + 1e-6, grid_step_h, dtype=float)
    ok = 0
    for i in range(Nsubj):
        denom = (1.0 - np.exp(-float(freq_h) / max(tau[i],1e-6)))
        denom = max(denom, 1e-6)
        e_t = dose_mg * np.exp(-tgrid / max(tau[i],1e-6)) / denom
        bwc = (bw_arr[i] - model.bw_mean) / 10.0
        cm  = cm_arr[i]
        X = torch.tensor(np.stack([e_t, np.full_like(e_t, bwc), np.full_like(e_t, cm)], axis=1),
                         dtype=torch.float32, device=next(model.parameters()).device)
        logits = model.mlp(X).squeeze(1)
        probs = torch.sigmoid(logits).detach().cpu().numpy()
        if probs.min() >= decision_threshold:
            ok += 1
    return ok / Nsubj

# ==========================================================
# 메인: 외부 테스트 분할 반복 (TEST_SEEDS) × 내부 CV(SEEDS_INNER)
# ==========================================================
all_rows = []

for test_seed in TEST_SEEDS:
    # 1) 용량군별 70/15/15로 "새로운" train/valid/test 분할 생성
    tr_1, va_1, te_1     = split_70_15_15(ids_1mg,  seed=test_seed)
    tr_3, va_3, te_3     = split_70_15_15(ids_3mg,  seed=test_seed)
    tr_10, va_10, te_10  = split_70_15_15(ids_10mg, seed=test_seed)

    ids_tr = set(tr_1 + tr_3 + tr_10)
    ids_va = set(va_1 + va_3 + va_10)
    ids_te = set(te_1 + te_3 + te_10)

    tr_idx = indices_by_ids(ids_tr)
    va_idx = indices_by_ids(ids_va)
    te_idx = indices_by_ids(ids_te)

    train_ds_full = Subset(dataset_all, tr_idx)
    valid_ds_full = Subset(dataset_all, va_idx)
    test_ds       = Subset(dataset_all, te_idx)

    print(f"\n=== [TEST SEED={test_seed}] New random split (70/15/15 per dose group) ===")
    print(" train IDs:", len(ids_tr), "| valid IDs:", len(ids_va), "| test IDs:", len(ids_te))
    print(f"#samples -> train: {len(train_ds_full)} | valid: {len(valid_ds_full)} | test: {len(test_ds)}")

    # 2) 내부 CV: train+valid 묶음으로 K-Fold(Group) 구성 (용량군 균형)
    ids_trva = sorted(list(ids_tr | ids_va))
    cv_folds = make_cv_folds(ids_1mg, ids_3mg, ids_10mg, ids_trva, N_SPLITS, seed=42)

    # 3) 내부: 시드 반복 × KFold → 임계값 중앙값 산출
    tuned_thresholds = []
    cv_rows = []
    for seed in SEEDS_INNER:
        for k, (tr_ids_k, va_ids_k) in enumerate(cv_folds, start=1):
            tr_ds_k = subset_by_ids(tr_ids_k)
            va_ds_k = subset_by_ids(va_ids_k)
            model_k, best_raw, tuned_thr_k, tuned_metrics_k = train_classifier(
                tr_ds_k, va_ds_k,
                epochs=EPOCHS, lr=LR, seed=seed,
                mlp_hidden=MLP_HIDDEN, mlp_layers=MLP_LAYERS, mlp_dropout=MLP_DROPOUT,
                weight_decay=WEIGHT_DECAY, patience=PATIENCE,
                sched_factor=SCHED_FACTOR, sched_patience=SCHED_PATIENCE,
                clip_norm=CLIP_NORM
            )
            tuned_thresholds.append(tuned_thr_k)
            cv_rows.append({
                "test_seed": test_seed,
                "inner_seed": seed,
                "fold": k,
                "val_acc": tuned_metrics_k["acc"],
                "val_prec": tuned_metrics_k["prec"],
                "val_rec": tuned_metrics_k["rec"],
                "val_f1": tuned_metrics_k["f1"],
                "val_roc_auc": tuned_metrics_k["roc_auc"],
                "val_pr_auc": tuned_metrics_k["pr_auc"],
                "thr": tuned_thr_k
            })

    thr_cv_final = float(np.median(tuned_thresholds)) if len(tuned_thresholds)>0 else 0.5
    print(f"[TEST SEED={test_seed}] CV tuned threshold (median over inner seed×fold): {thr_cv_final:.3f}")

    # 4) train+valid 전체로 재학습 → 해당 test로 평가 (thr = thr_cv_final)
    trainval_ds = Subset(dataset_all, tr_idx + va_idx)
    final_model, _, _, _ = train_classifier(
        trainval_ds, valid_ds_full,  # valid_ds_full은 모니터링용
        epochs=EPOCHS, lr=LR, seed=SEEDS_INNER[0],
        mlp_hidden=MLP_HIDDEN, mlp_layers=MLP_LAYERS, mlp_dropout=MLP_DROPOUT,
        weight_decay=WEIGHT_DECAY, patience=PATIENCE,
        sched_factor=SCHED_FACTOR, sched_patience=SCHED_PATIENCE,
        clip_norm=CLIP_NORM
)
    test_metrics, y_te, p_te = evaluate_dataset(final_model, test_ds, threshold=thr_cv_final)

    row = {"test_seed": test_seed, "thr_cv_final": thr_cv_final}
    row.update({f"test_{k}": v for k, v in test_metrics.items()})
    all_rows.append(row)

    # (선택) 내부 CV 통계 출력
    cv_df = pd.DataFrame(cv_rows)
    if not cv_df.empty:
        print("\n[Inner CV summary] (mean over inner seed×fold)")
        print(cv_df[["val_f1","val_pr_auc","val_roc_auc"]].mean().rename({
            "val_f1":"F1(mean)","val_pr_auc":"PR-AUC(mean)","val_roc_auc":"ROC-AUC(mean)"
        }))

# -------- 반복 결과 요약 --------
res_df = pd.DataFrame(all_rows)
print("\n=== Repeated RANDOM TEST splits summary ===")
if not res_df.empty:
    cols_print = [c for c in res_df.columns if c.startswith("test_")]  # test_지표만
    print(res_df[["test_seed","thr_cv_final"] + cols_print].to_string(index=False))
    print("\n[Aggregate over test splits] mean ± std")
    agg = res_df[cols_print].agg(['mean','std']).T
    print(agg.to_string())

if SAVE_CSV and not res_df.empty:
    res_df.to_csv(OUT_CSV, index=False)
    print(f"\nSaved: {OUT_CSV}")


Detected columns: {'ID': 'ID', 'TIME': 'TIME', 'DV': 'DV', 'EVID': 'EVID', 'AMT': 'AMT', 'BW': 'BW', 'COMED': 'COMED'}
PD rows: 1200 Dose rows: 756

=== [TEST SEED=10] New random split (70/15/15 per dose group) ===
 train IDs: 24 | valid IDs: 6 | test IDs: 6
#samples -> train: 600 | valid: 150 | test: 150




[TEST SEED=10] CV tuned threshold (median over inner seed×fold): 0.500





[Inner CV summary] (mean over inner seed×fold)
F1(mean)         0.819886
PR-AUC(mean)     0.841951
ROC-AUC(mean)    0.921435
dtype: float64

=== [TEST SEED=11] New random split (70/15/15 per dose group) ===
 train IDs: 24 | valid IDs: 6 | test IDs: 6
#samples -> train: 600 | valid: 150 | test: 150




[TEST SEED=11] CV tuned threshold (median over inner seed×fold): 0.490





[Inner CV summary] (mean over inner seed×fold)
F1(mean)         0.860310
PR-AUC(mean)     0.886408
ROC-AUC(mean)    0.957458
dtype: float64

=== [TEST SEED=12] New random split (70/15/15 per dose group) ===
 train IDs: 24 | valid IDs: 6 | test IDs: 6
#samples -> train: 600 | valid: 150 | test: 150




[TEST SEED=12] CV tuned threshold (median over inner seed×fold): 0.500





[Inner CV summary] (mean over inner seed×fold)
F1(mean)         0.779051
PR-AUC(mean)     0.809282
ROC-AUC(mean)    0.894194
dtype: float64

=== [TEST SEED=13] New random split (70/15/15 per dose group) ===
 train IDs: 24 | valid IDs: 6 | test IDs: 6
#samples -> train: 600 | valid: 150 | test: 150




[TEST SEED=13] CV tuned threshold (median over inner seed×fold): 0.490





[Inner CV summary] (mean over inner seed×fold)
F1(mean)         0.836191
PR-AUC(mean)     0.859612
ROC-AUC(mean)    0.926580
dtype: float64

=== [TEST SEED=14] New random split (70/15/15 per dose group) ===
 train IDs: 24 | valid IDs: 6 | test IDs: 6
#samples -> train: 600 | valid: 150 | test: 150




[TEST SEED=14] CV tuned threshold (median over inner seed×fold): 0.480





[Inner CV summary] (mean over inner seed×fold)
F1(mean)         0.795207
PR-AUC(mean)     0.814495
ROC-AUC(mean)    0.903829
dtype: float64

=== Repeated RANDOM TEST splits summary ===
 test_seed  thr_cv_final  test_seed  test_acc  test_prec  test_rec  test_f1  test_roc_auc  test_pr_auc  test_tn  test_fp  test_fn  test_tp
        10          0.50         10  0.866667   0.558140  0.960000 0.705882      0.967680     0.903632      106       19        1       24
        11          0.49         11  0.946667   0.818182  1.000000 0.900000      0.998538     0.996032      106        8        0       36
        12          0.50         12  0.980000   1.000000  0.911765 0.953846      0.998225     0.994332      116        0        3       31
        13          0.49         13  0.906667   0.733333  0.785714 0.758621      0.958138     0.850582      114        8        6       22
        14          0.48         14  0.953333   0.931034  0.843750 0.885246      0.991261     0.971836      116        

In [9]:
# residual 추가버전
import os, math, warnings, numpy as np, pandas as pd
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, average_precision_score, confusion_matrix
)
from collections import defaultdict
from typing import Optional
from copy import deepcopy
import random

# -------- 설정 --------
CSV = "EstData.csv"       # 필요 시 절대경로로 교체
PD_THRESHOLD = 3.3

# 학습/최적화 설정
EPOCHS = 120
LR = 5e-2
WEIGHT_DECAY = 1e-4
PATIENCE = 12               # EarlyStopping 인내 에폭
SCHED_FACTOR = 0.5          # ReduceLROnPlateau 감쇠 비율
SCHED_PATIENCE = 5          # 스케줄러 인내 에폭(지표 기준)
CLIP_NORM = 1.0             # 그라디언트 클리핑 (안쓰려면 None)

# 시뮬레이션/검색 설정
N_SUBJ = 300
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ===== 아래 두 줄을 설정 블록(MLP 하이퍼파라미터 근처)에 추가 =====
USE_RESIDUAL = True      # True: 잔차 버전 / False: 기존(순차형) 버전
RES_BLOCKS   = 2         # 잔차 버전일 때의 블록 수(= “층” 느낌). 2면 네 원래 2층 감각
INPUT_LN     = True      # 잔차 버전일 때 입력 LayerNorm(3) 사용 여부


# MLP 하이퍼파라미터
MLP_HIDDEN = 36
MLP_LAYERS = 2
MLP_DROPOUT = 0.2

# 재현성(완전 결정적 보장은 환경 의존. CUDA에서 완전 결정적 필요시 CUBLAS_WORKSPACE_CONFIG 설정 필요)
def set_global_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # 완전 결정적이 필요하면 환경변수 설정 필요(CuBLAS). 여기서는 사용하지 않음.

set_global_seed(42)
assert os.path.exists(CSV), f"CSV not found at {CSV}"

# -------- 유틸: 열 이름 추론 --------
def _find_col_like(df, name_opts):
    low = {c.lower(): c for c in df.columns}
    for n in name_opts:
        if n in low: return low[n]
    return None

# -------- 데이터 로딩/전처리 --------
df = pd.read_csv(CSV)

col_ID   = _find_col_like(df, ["id"])
col_TIME = _find_col_like(df, ["time"])
col_DVID = _find_col_like(df, ["dvid"])
col_DV   = _find_col_like(df, ["dv","pd","value"])
col_EVID = _find_col_like(df, ["evid"])
col_AMT  = _find_col_like(df, ["amt","dose","dosen","doses"])
col_BW   = _find_col_like(df, ["bw","weight","bodyweight"])
col_COMED= _find_col_like(df, ["comed","conmed","concom"])

need = [col_ID, col_TIME, col_DV, col_AMT]
miss = [n for n,v in zip(["ID","TIME","DV","AMT/DOSE"], need) if v is None]
if miss:
    warnings.warn(f"Columns missing (minimum required): {miss}")

# PD 행 추출 (DVID==2가 있으면 그걸 사용)
if col_DVID is not None and col_DV is not None:
    pdf = df[df[col_DVID]==2].copy()
else:
    pdf = df.copy()

# 투약 이벤트 (EVID==1 우선, 아니면 AMT/DOSE notna)
if col_EVID is not None:
    dose_df = df[df[col_EVID]==1].copy()
else:
    dose_df = df[df[col_AMT].notna()].copy()

# 숫자 변환
for c in [col_TIME, col_DV, col_AMT, col_BW, col_COMED]:
    if c is not None:
        pdf[c] = pd.to_numeric(pdf[c], errors="coerce")
        dose_df[c] = pd.to_numeric(dose_df[c], errors="coerce")

# 정렬/필요 열만
keep_pd = [c for c in [col_ID,col_TIME,col_DV,col_BW,col_COMED] if c is not None]
keep_dose = [c for c in [col_ID,col_TIME,col_AMT] if c is not None]
pdf = pdf[keep_pd].dropna().sort_values([col_ID, col_TIME])
dose_df = dose_df[keep_dose].dropna().sort_values([col_ID, col_TIME])

print("Detected columns:", dict(ID=col_ID, TIME=col_TIME, DV=col_DV, EVID=col_EVID, AMT=col_AMT, BW=col_BW, COMED=col_COMED))
print("PD rows:", len(pdf), "Dose rows:", len(dose_df))

# -------- 데이터셋 --------
class PDSamples(Dataset):
    def __init__(self, pd_df: pd.DataFrame, dose_df: pd.DataFrame,
                 col_ID: str, col_TIME: str, col_DV: str,
                 col_BW: Optional[str], col_COMED: Optional[str], pd_threshold: float=3.3):
        self.col_ID, self.col_TIME, self.col_DV = col_ID, col_TIME, col_DV
        self.col_BW, self.col_COMED = col_BW, col_COMED
        self.pd_threshold = pd_threshold

        # ID별 투약 히스토리
        d_groups = defaultdict(list)
        for _, row in dose_df.iterrows():
            d_groups[row[col_ID]].append((float(row[col_TIME]), float(row[col_AMT])))
        self.dose_map = {k: (np.array([t for t,a in v], dtype=np.float32),
                              np.array([a for t,a in v], dtype=np.float32)) for k,v in d_groups.items()}

        # 샘플: 각 PD 관측 시점
        feats = []
        for _, row in pd_df.iterrows():
            sid = row[col_ID]
            if sid not in self.dose_map:
                continue
            t = float(row[col_TIME])
            val = float(row[col_DV])
            y = 1.0 if (val <= pd_threshold) else 0.0
            bw = float(row[col_BW]) if col_BW is not None else 70.0
            cm = float(row[col_COMED]) if col_COMED is not None else 0.0
            feats.append((sid, t, y, bw, cm))
        self.samples = feats

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        sid, t, y, bw, cm = self.samples[idx]
        dose_t, dose_a = self.dose_map.get(sid, (np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32)))
        return {
            "t": torch.tensor(t, dtype=torch.float32),
            "y": torch.tensor(y, dtype=torch.float32),
            "bw": torch.tensor(bw, dtype=torch.float32),
            "cm": torch.tensor(cm, dtype=torch.float32),
            "dose_t": torch.tensor(dose_t, dtype=torch.float32),
            "dose_a": torch.tensor(dose_a, dtype=torch.float32),
        }

def _collate(batch): return batch

dataset_all = PDSamples(pdf, dose_df, col_ID, col_TIME, col_DV, col_BW, col_COMED, pd_threshold=PD_THRESHOLD)

# -------- ID 단위 분할: 플라시보 제외 + 1/3/10 mg 그룹별 70/15/15 --------
rng = 42
placebo_ids = set(range(1, 13))     # 0 mg (제외)
ids_1mg     = list(range(13, 25))   # 1 mg
ids_3mg     = list(range(25, 37))   # 3 mg
ids_10mg    = list(range(37, 49))   # 10 mg

ids_in_data = set(df[col_ID].unique())
ids_1mg  = sorted(ids_in_data.intersection(ids_1mg))
ids_3mg  = sorted(ids_in_data.intersection(ids_3mg))
ids_10mg = sorted(ids_in_data.intersection(ids_10mg))

def split_70_15_15(ids, seed=42):
    if len(ids) == 0:
        return [], [], []
    tr_ids, temp_ids = train_test_split(ids, test_size=0.30, random_state=seed, shuffle=True)
    va_ids, te_ids = train_test_split(temp_ids, test_size=0.50, random_state=seed, shuffle=True)
    return list(tr_ids), list(va_ids), list(te_ids)

tr_1, va_1, te_1     = split_70_15_15(ids_1mg,  seed=rng)
tr_3, va_3, te_3     = split_70_15_15(ids_3mg,  seed=rng)
tr_10, va_10, te_10  = split_70_15_15(ids_10mg, seed=rng)

ids_tr = set(tr_1 + tr_3 + tr_10)
ids_va = set(va_1 + va_3 + va_10)
ids_te = set(te_1 + te_3 + te_10)

sid_list = [s[0] for s in dataset_all.samples]
tr_idx = [i for i, sid in enumerate(sid_list) if sid in ids_tr]
va_idx = [i for i, sid in enumerate(sid_list) if sid in ids_va]
te_idx = [i for i, sid in enumerate(sid_list) if sid in ids_te]

train_ds = Subset(dataset_all, tr_idx)
valid_ds = Subset(dataset_all, va_idx)
test_ds  = Subset(dataset_all, te_idx)

print("[ID split by fixed dose groups] (70/15/15)")
print(" train IDs:", len(ids_tr), "| valid IDs:", len(ids_va), "| test IDs:", len(ids_te))
print(" 1mg -> train/valid/test:", len(tr_1), len(va_1), len(te_1))
print(" 3mg -> train/valid/test:", len(tr_3), len(va_3), len(te_3))
print("10mg -> train/valid/test:", len(tr_10), len(va_10), len(te_10))
print(f"#samples -> train: {len(train_ds)} | valid: {len(valid_ds)} | test: {len(test_ds)}")

class ResBlock(nn.Module):
    def __init__(self, h, p=0.2):
        super().__init__()
        self.ln  = nn.LayerNorm(h)
        self.fc  = nn.Linear(h, h)
        self.act = nn.ReLU()
        self.do  = nn.Dropout(p) if p and p>0 else nn.Identity()
    def forward(self, x):
        r = x
        x = self.ln(x)
        x = self.fc(x)
        x = self.act(x)
        x = self.do(x)
        return x + r

class PDMLPClassifierResidual(nn.Module):
    """
    입력 x=[E(t), (BW-mean)/10, COMED]
    구조: in_ln(선택) -> Linear(3->hidden) -> (ResBlock x n_blocks) -> Linear(hidden->1)
    τ 링크는 기존 그대로: log τ = b0 + b1*bwc + b2*COMED
    """
    def __init__(self, bw_mean: float=70.0, hidden: int=32, n_blocks: int=2,
                 dropout_p: float=0.2, use_input_ln: bool=True):
        super().__init__()
        self.b0 = nn.Parameter(torch.tensor(math.log(24.0)))
        self.b1 = nn.Parameter(torch.tensor(0.0))
        self.b2 = nn.Parameter(torch.tensor(0.0))
        self.bw_mean = float(bw_mean)

        self.in_ln  = nn.LayerNorm(3) if use_input_ln else nn.Identity()
        self.fc_in  = nn.Linear(3, hidden)
        self.blocks = nn.ModuleList([ResBlock(hidden, p=dropout_p) for _ in range(n_blocks)])
        self.head   = nn.Linear(hidden, 1)

    def _tau(self, bw: torch.Tensor, comed: torch.Tensor):
        bwc = (bw - self.bw_mean) / 10.0
        log_tau = self.b0 + self.b1*bwc + self.b2*comed
        return torch.exp(log_tau).clamp_min(1.0)

    @torch.no_grad()
    def tau_from_cov_np(self, bw_np: np.ndarray, cm_np: np.ndarray, device=None):
        device = device or next(self.parameters()).device
        bw = torch.tensor(bw_np, dtype=torch.float32, device=device)
        cm = torch.tensor(cm_np, dtype=torch.float32, device=device)
        return self._tau(bw, cm).detach().cpu().numpy()

    def forward_single(self, t: torch.Tensor, dose_t: torch.Tensor, dose_a: torch.Tensor,
                       bw: float, comed: float):
        # 노출 계산(기존과 동일)
        tau = self._tau(torch.tensor(bw, dtype=torch.float32, device=t.device),
                        torch.tensor(comed, dtype=torch.float32, device=t.device))
        dt = t - dose_t
        mask = (dt >= 0).float()
        exposure = (dose_a * torch.exp(-dt.clamp_min(0) / tau) * mask).sum()

        # 피처 순서도 기존과 동일: [E(t), bwc, COMED]
        x = torch.stack([
            exposure,
            (torch.tensor(bw, dtype=torch.float32, device=t.device) - self.bw_mean) / 10.0,
            torch.tensor(comed, dtype=torch.float32, device=t.device)
        ])
        # 잔차 경로
        h = self.in_ln(x)
        h = self.fc_in(h)
        for blk in self.blocks:
            h = blk(h)
        logit = self.head(h).squeeze()
        return logit


# -------- MLP 모델 --------
class PDMLPClassifier(nn.Module):
    """
    입력 x=[E(t), (BW-mean)/10, COMED] -> [Linear+ReLU(+Dropout)]*n_layers -> Linear(→1)
    노출 링크: log(tau) = b0 + b1*((BW-mean)/10) + b2*COMED
    """
    def __init__(self, bw_mean: float=70.0, hidden: int=32, n_layers: int=3, dropout_p: float=0.2):
        super().__init__()
        self.b0 = nn.Parameter(torch.tensor(math.log(24.0)))
        self.b1 = nn.Parameter(torch.tensor(0.0))
        self.b2 = nn.Parameter(torch.tensor(0.0))
        self.bw_mean = float(bw_mean)

        layers = []
        in_dim = 3
        for i in range(n_layers):
            layers += [
                nn.Linear(in_dim if i==0 else hidden, hidden),
                nn.ReLU(),
                nn.Dropout(p=dropout_p) if dropout_p and dropout_p > 0 else nn.Identity(),
            ]
        layers += [nn.Linear(hidden if n_layers>0 else in_dim, 1)]
        self.mlp = nn.Sequential(*layers)

    def _tau(self, bw: torch.Tensor, comed: torch.Tensor):
        bwc = (bw - self.bw_mean) / 10.0
        log_tau = self.b0 + self.b1*bwc + self.b2*comed
        return torch.exp(log_tau).clamp_min(1.0)

    @torch.no_grad()
    def tau_from_cov_np(self, bw_np: np.ndarray, cm_np: np.ndarray, device=None):
        device = device or next(self.parameters()).device
        bw = torch.tensor(bw_np, dtype=torch.float32, device=device)
        cm = torch.tensor(cm_np, dtype=torch.float32, device=device)
        return self._tau(bw, cm).detach().cpu().numpy()

    def forward_single(self, t: torch.Tensor, dose_t: torch.Tensor, dose_a: torch.Tensor,
                       bw: float, comed: float):
        tau = self._tau(torch.tensor(bw, dtype=torch.float32, device=t.device),
                        torch.tensor(comed, dtype=torch.float32, device=t.device))
        dt = t - dose_t
        mask = (dt >= 0).float()
        exposure = (dose_a * torch.exp(-dt.clamp_min(0) / tau) * mask).sum()
        x = torch.stack([
            exposure,
            (torch.tensor(bw, dtype=torch.float32, device=t.device) - self.bw_mean) / 10.0,
            torch.tensor(comed, dtype=torch.float32, device=t.device)
        ])
        return self.mlp(x).squeeze()

def _collate(batch): return batch

# -------- 지표 계산/도움 함수 --------
def _compute_metrics(y_true: np.ndarray, prob: np.ndarray, thr: float = 0.5):
    pred = (prob >= thr).astype(int)
    acc = accuracy_score(y_true, pred) if len(y_true) else float("nan")
    prec = precision_score(y_true, pred, zero_division=0)
    rec  = recall_score(y_true, pred, zero_division=0)
    f1   = f1_score(y_true, pred, zero_division=0)
    try:
        roc = roc_auc_score(y_true, prob) if (len(np.unique(y_true))>1) else float("nan")
    except Exception:
        roc = float("nan")
    try:
        ap  = average_precision_score(y_true, prob) if (len(np.unique(y_true))>1) else float("nan")
    except Exception:
        ap  = float("nan")
    try:
        tn, fp, fn, tp = confusion_matrix(y_true, pred, labels=[0,1]).ravel()
    except Exception:
        tn=fp=fn=tp = 0
    return {"acc":acc, "prec":prec, "rec":rec, "f1":f1, "roc_auc":roc, "pr_auc":ap,
            "tn":tn, "fp":fp, "fn":fn, "tp":tp}

def _predict_dataset(model, dataset, batch_size=512):
    device = next(model.parameters()).device
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=_collate)
    ys, ps = [], []
    model.eval()
    with torch.no_grad():
        for batch in loader:
            for b in batch:
                logit = model.forward_single(
                    b["t"].to(device), b["dose_t"].to(device), b["dose_a"].to(device),
                    float(b["bw"]), float(b["cm"])
                )
                ys.append(float(b["y"]))
                ps.append(torch.sigmoid(logit).item())
    return np.array(ys), np.array(ps)

def find_best_threshold(y, p, grid=None, metric="f1"):
    if grid is None:
        grid = np.linspace(0.05, 0.95, 91)  # 0.05~0.95
    best_thr, best_val = 0.5, -1.0
    for thr in grid:
        m = _compute_metrics(y, p, thr=thr)
        val = m["f1"] if metric=="f1" else (0.5*m["prec"] + 0.5*m["rec"])
        if val > best_val:
            best_val, best_thr = val, thr
    return best_thr, best_val

# -------- 학습 루틴(베스트 스냅샷/ES/스케줄러/WD) --------
def train_classifier(train_ds, valid_ds,
                     epochs=60, lr=5e-2, seed=42,
                     mlp_hidden=32, mlp_layers=3, mlp_dropout=0.2,
                     weight_decay=0.0, patience=10,
                     sched_factor=0.5, sched_patience=5,
                     clip_norm=None):
    device = DEVICE
    g = torch.Generator().manual_seed(seed)
    tr_loader = DataLoader(train_ds, batch_size=128, shuffle=True, collate_fn=_collate, generator=g)
    va_loader = DataLoader(valid_ds, batch_size=256, shuffle=False, collate_fn=_collate)

    bw_mean = float(np.mean([b["bw"].item() for b in train_ds]))

    
    if USE_RESIDUAL:
            model = PDMLPClassifierResidual(
                bw_mean=bw_mean,
                hidden=mlp_hidden,
                n_blocks=RES_BLOCKS,        # 잔차 “층 수”
                dropout_p=mlp_dropout,
                use_input_ln=INPUT_LN
            ).to(device)
    else:
            model = PDMLPClassifier(
                bw_mean=bw_mean,
                hidden=mlp_hidden,
                n_layers=mlp_layers,        # 기존 순차형 “층 수”
                dropout_p=mlp_dropout
            ).to(device)

    
    model = PDMLPClassifier(bw_mean=bw_mean, hidden=mlp_hidden, n_layers=mlp_layers, dropout_p=mlp_dropout).to(device)

    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = nn.BCEWithLogitsLoss()
    # ReduceLROnPlateau: 검증 F1 기준, 상승 없으면 LR 감소
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        opt, mode='max', factor=sched_factor, patience=sched_patience, verbose=True
    )

    best = (-1.0, None)
    best_state = None
    es_counter = 0  # EarlyStopping 카운터

    for ep in range(1, epochs+1):
        model.train()
        for batch in tr_loader:
            opt.zero_grad()
            loss = 0.0
            for b in batch:
                logit = model.forward_single(
                    b["t"].to(device), b["dose_t"].to(device), b["dose_a"].to(device),
                    float(b["bw"]), float(b["cm"])
                )
                loss = loss + loss_fn(logit.view(()), b["y"].to(device).view(()))
            loss = loss/len(batch)
            loss.backward()
            if clip_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
            opt.step()

        # ---- validation ----
        y, p = _predict_dataset(model, valid_ds, batch_size=256)
        # 모니터링 지표: F1(thr=0.5). (임계값 튜닝은 학습 후 별도)
        metrics = _compute_metrics(y, p, thr=0.5)

        # 스케줄러 스텝(F1 기준)
        scheduler.step(metrics["f1"])

        # 베스트 갱신/스냅샷
        if metrics["f1"] > best[0]:
            best = (metrics["f1"], {"epoch":ep, **metrics})
            best_state = deepcopy(model.state_dict())
            es_counter = 0
        else:
            es_counter += 1

        if ep%10==0 or ep<=3:
            lr_cur = opt.param_groups[0]["lr"]
            print(f"[ep {ep:03d}] lr={lr_cur:.5f} acc={metrics['acc']:.3f} f1={metrics['f1']:.3f} "
                  f"prec={metrics['prec']:.3f} rec={metrics['rec']:.3f} "
                  f"roc_auc={metrics['roc_auc']:.3f} pr_auc={metrics['pr_auc']:.3f}")

        # EarlyStopping
        if es_counter >= patience:
            print(f"EarlyStopping at epoch {ep} (no F1 improvement for {patience} epochs).")
            break

    # 베스트로 복원
    if best_state is not None:
        model.load_state_dict(best_state)

    print("best(valid @thr=0.5):", best[1])

    # ---- 검증셋에서 임계값 최적화 (베스트 모델로) ----
    y_val, p_val = _predict_dataset(model, valid_ds, batch_size=256)
    best_thr, best_f1 = find_best_threshold(y_val, p_val, grid=None, metric="f1")
    tuned_metrics = _compute_metrics(y_val, p_val, thr=best_thr)
    print(f"tuned threshold on valid: thr={best_thr:.3f} | F1={tuned_metrics['f1']:.4f} "
          f"(prec={tuned_metrics['prec']:.4f}, rec={tuned_metrics['rec']:.4f})")

    return model, best[1], best_thr

# -------- 인구 공변량 샘플러 --------
def subject_cov_sampler(pd_df: pd.DataFrame, col_BW: Optional[str], col_COMED: Optional[str],
                        n:int, scenario:str="base", rng_seed:int=123):
    obs_bw = pd_df[col_BW].dropna().to_numpy(dtype=float) if col_BW is not None else np.full(len(pd_df), 80.0)
    obs_cm = pd_df[col_COMED].dropna().to_numpy(dtype=float) if col_COMED is not None else np.zeros(len(pd_df))
    rng = np.random.default_rng(rng_seed)
    if scenario=="base":
        idx = rng.integers(0, len(obs_bw), size=n)
        return obs_bw[idx], obs_cm[idx]
    elif scenario=="bw_wide":
        bw = rng.uniform(70.0, 140.0, size=n)   # 필요시 관찰 범위로 조정 가능
        cm = obs_cm[rng.integers(0, len(obs_cm), size=n)] if len(obs_cm)>0 else np.zeros(n)
        return bw, cm
    elif scenario=="no_comed":
        idx = rng.integers(0, len(obs_bw), size=n)
        return obs_bw[idx], np.zeros(n)
    else:
        raise ValueError("unknown scenario")

# -------- 정상상태(SS) 평가: MLP --------
@torch.no_grad()
def success_fraction_for_dose_ss_mlp(model: PDMLPClassifier, dose_mg: float, freq_h: int,
                                     last_window_h: int, pd_df, col_BW: Optional[str], col_COMED: Optional[str],
                                     Nsubj=300, scenario="base", decision_threshold=0.5,
                                     grid_step_h=1.0) -> float:
    bw_arr, cm_arr = subject_cov_sampler(pd_df, col_BW, col_COMED, Nsubj, scenario=scenario)
    tau = model.tau_from_cov_np(bw_arr, cm_arr, device=next(model.parameters()).device)
    tgrid = np.arange(0.0, last_window_h + 1e-6, grid_step_h, dtype=float)
    ok = 0
    for i in range(Nsubj):
        denom = (1.0 - np.exp(-float(freq_h) / max(tau[i],1e-6)))
        denom = max(denom, 1e-6)
        e_t = dose_mg * np.exp(-tgrid / max(tau[i],1e-6)) / denom
        bwc = (bw_arr[i] - model.bw_mean) / 10.0
        cm  = cm_arr[i]
        X = torch.tensor(np.stack([e_t, np.full_like(e_t, bwc), np.full_like(e_t, cm)], axis=1),
                         dtype=torch.float32, device=next(model.parameters()).device)
        logits = model.mlp(X).squeeze(1)
        probs = torch.sigmoid(logits).detach().cpu().numpy()
        # 윈도우 전체 시간에서 성공 확률이 임계 이상이면 성공(아주 보수적 기준)
        if probs.min() >= decision_threshold:
            ok += 1
    return ok / Nsubj

def search_min_dose_ss(model: PDMLPClassifier, grid, freq_h, last_window_h, pd_df,
                       col_BW: Optional[str], col_COMED: Optional[str],
                       Nsubj=300, scenario="base", target=0.9, decision_threshold=0.5):
    rows = []
    for d in grid:
        frac = success_fraction_for_dose_ss_mlp(model, d, freq_h, last_window_h, pd_df, col_BW, col_COMED,
                                                Nsubj=Nsubj, scenario=scenario, decision_threshold=decision_threshold)
        rows.append({"dose": d, "fraction": frac})
    df_res = pd.DataFrame(rows).sort_values("dose")
    feas = df_res[df_res["fraction"]>=target]
    best = feas.iloc[0]["dose"] if len(feas)>0 else None
    return df_res, best

# -------- 학습 실행 (개선 루틴) --------
model, valid_best, tuned_thr = train_classifier(
    train_ds, valid_ds,
    epochs=EPOCHS, lr=LR,
    mlp_hidden=MLP_HIDDEN, mlp_layers=MLP_LAYERS, mlp_dropout=MLP_DROPOUT,
    weight_decay=WEIGHT_DECAY, patience=PATIENCE,
    sched_factor=SCHED_FACTOR, sched_patience=SCHED_PATIENCE,
    clip_norm=CLIP_NORM
)
print("\nValidation best @thr=0.5:", valid_best)
print(f"Validated tuned decision threshold: {tuned_thr:.3f}")

# -------- 용량 탐색 (튜닝 임계값 적용) --------
daily_grid  = [0.5*i for i in range(0, 121)]   # 0..60 mg, 0.5 mg
weekly_grid = [5*i   for i in range(0, 41)]    # 0..200 mg, 5 mg

daily_base,  best_daily_base  = search_min_dose_ss(model, daily_grid,  24, 24,  pdf, col_BW, col_COMED,
                                                   Nsubj=N_SUBJ, scenario="base",
                                                   target=0.90, decision_threshold=tuned_thr)
weekly_base, best_weekly_base = search_min_dose_ss(model, weekly_grid, 168, 168, pdf, col_BW, col_COMED,
                                                   Nsubj=N_SUBJ, scenario="base",
                                                   target=0.90, decision_threshold=tuned_thr)

daily_bw,   best_daily_bw   = search_min_dose_ss(model, daily_grid,  24, 24,  pdf, col_BW, col_COMED,
                                                 Nsubj=N_SUBJ, scenario="bw_wide",
                                                 target=0.90, decision_threshold=tuned_thr)
weekly_bw,  best_weekly_bw  = search_min_dose_ss(model, weekly_grid, 168, 168, pdf, col_BW, col_COMED,
                                                 Nsubj=N_SUBJ, scenario="bw_wide",
                                                 target=0.90, decision_threshold=tuned_thr)

daily_nocm, best_daily_nocm = search_min_dose_ss(model, daily_grid,  24, 24,  pdf, col_BW, col_COMED,
                                                 Nsubj=N_SUBJ, scenario="no_comed",
                                                 target=0.90, decision_threshold=tuned_thr)
weekly_nocm,best_weekly_nocm= search_min_dose_ss(model, weekly_grid, 168, 168, pdf, col_BW, col_COMED,
                                                 Nsubj=N_SUBJ, scenario="no_comed",
                                                 target=0.90, decision_threshold=tuned_thr)

daily_75,   best_daily_75   = search_min_dose_ss(model, daily_grid,  24, 24,  pdf, col_BW, col_COMED,
                                                 Nsubj=N_SUBJ, scenario="base",
                                                 target=0.75, decision_threshold=tuned_thr)
weekly_75,  best_weekly_75  = search_min_dose_ss(model, weekly_grid, 168, 168, pdf, col_BW, col_COMED,
                                                 Nsubj=N_SUBJ, scenario="base",
                                                 target=0.75, decision_threshold=tuned_thr)

summary = pd.DataFrame([
    {"scenario":"Base (Phase 1-like)", "target":"90%", "once-daily (mg)": best_daily_base,  "once-weekly (mg)": best_weekly_base},
    {"scenario":"BW 70–140 kg",        "target":"90%", "once-daily (mg)": best_daily_bw,    "once-weekly (mg)": best_weekly_bw},
    {"scenario":"No COMED allowed",    "target":"90%", "once-daily (mg)": best_daily_nocm,  "once-weekly (mg)": best_weekly_nocm},
    {"scenario":"Base (Phase 1-like)", "target":"75%", "once-daily (mg)": best_daily_75,    "once-weekly (mg)": best_weekly_75},
])
print("\n=== Dose recommendations summary (thr tuned) ===")
print(summary.to_string(index=False))

# -------- 테스트 평가 (튜닝 임계값 적용) --------
def evaluate_dataset(model, dataset, threshold=0.5, batch_size=512):
    y, p = _predict_dataset(model, dataset, batch_size=batch_size)
    return _compute_metrics(y, p, thr=threshold)

test_metrics = evaluate_dataset(model, test_ds, threshold=tuned_thr)
print("\n=== Final TEST metrics (unseen IDs, tuned thr) ===")
for k, v in test_metrics.items():
    if isinstance(v, float):
        print(f"{k:>8s}: {v:.4f}")
    else:
        print(f"{k:>8s}: {v}")


Detected columns: {'ID': 'ID', 'TIME': 'TIME', 'DV': 'DV', 'EVID': 'EVID', 'AMT': 'AMT', 'BW': 'BW', 'COMED': 'COMED'}
PD rows: 1200 Dose rows: 756
[ID split by fixed dose groups] (70/15/15)
 train IDs: 24 | valid IDs: 6 | test IDs: 6
 1mg -> train/valid/test: 8 2 2
 3mg -> train/valid/test: 8 2 2
10mg -> train/valid/test: 8 2 2
#samples -> train: 600 | valid: 150 | test: 150




[ep 001] lr=0.05000 acc=0.760 f1=0.471 prec=0.444 rec=0.500 roc_auc=0.750 pr_auc=0.527
[ep 002] lr=0.05000 acc=0.787 f1=0.000 prec=0.000 rec=0.000 roc_auc=0.731 pr_auc=0.472
[ep 003] lr=0.05000 acc=0.787 f1=0.000 prec=0.000 rec=0.000 roc_auc=0.697 pr_auc=0.484
[ep 010] lr=0.02500 acc=0.853 f1=0.607 prec=0.708 rec=0.531 roc_auc=0.718 pr_auc=0.620
[ep 020] lr=0.02500 acc=0.880 f1=0.690 prec=0.769 rec=0.625 roc_auc=0.852 pr_auc=0.737
[ep 030] lr=0.02500 acc=0.913 f1=0.764 prec=0.913 rec=0.656 roc_auc=0.907 pr_auc=0.833
[ep 040] lr=0.01250 acc=0.920 f1=0.793 prec=0.885 rec=0.719 roc_auc=0.962 pr_auc=0.914
[ep 050] lr=0.00625 acc=0.933 f1=0.828 prec=0.923 rec=0.750 roc_auc=0.979 pr_auc=0.941
EarlyStopping at epoch 59 (no F1 improvement for 12 epochs).
best(valid @thr=0.5): {'epoch': 47, 'acc': 0.9333333333333333, 'prec': 0.8928571428571429, 'rec': 0.78125, 'f1': 0.8333333333333334, 'roc_auc': np.float64(0.982521186440678), 'pr_auc': np.float64(0.947851131556032), 'tn': np.int64(115), 'fp': 

In [None]:
전체 데이터셋 분포같은거 체크해서 test set 만들때 train set이랑 다른 분포 많이 섞어서 test하면 좋을듯~