In [6]:
# ===========================
# PD-only 분류 + 정상상태 용량탐색 (MLP 전용, 지표 확장, CSV 저장 없음)
# - 데이터: EstData.csv (PD 단위: ng/mL), 임계 3.3 ng/mL
# - 모델: MLP (은닉층 개수/히든/드롭아웃 하이퍼파라미터화)
# - 분할: 플라시보(0 mg, ID 1–12) 제외, 1/3/10 mg 고정 그룹별 80/20 ID 단위 분할
# ===========================
from torch.utils.data import DataLoader
import os, math, warnings, numpy as np, pandas as pd
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, average_precision_score, confusion_matrix
)
from typing import Optional
from collections import defaultdict
from copy import deepcopy

# -------- 설정 --------
CSV = "EstData.csv"       # 필요 시 "/mnt/data/EstData.csv"
PD_THRESHOLD = 3.3
EPOCHS = 60
LR = 5e-2
N_SUBJ = 300
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# MLP 하이퍼파라미터
MLP_HIDDEN = 36
MLP_LAYERS = 2
MLP_DROPOUT = 0.2

assert os.path.exists(CSV), f"CSV not found at {CSV}"

# -------- 유틸: 열 이름 추론 --------
def _find_col_like(df, name_opts):
    low = {c.lower(): c for c in df.columns}
    for n in name_opts:
        if n in low: return low[n]
    return None

# -------- 데이터 로딩/전처리 --------
df = pd.read_csv(CSV)

col_ID   = _find_col_like(df, ["id"])
col_TIME = _find_col_like(df, ["time"])
col_DVID = _find_col_like(df, ["dvid"])
col_DV   = _find_col_like(df, ["dv","pd","value"])
col_EVID = _find_col_like(df, ["evid"])
col_AMT  = _find_col_like(df, ["amt","dose","dosen","doses"])
col_BW   = _find_col_like(df, ["bw","weight","bodyweight"])
col_COMED= _find_col_like(df, ["comed","conmed","concom"])

need = [col_ID, col_TIME, col_DV, col_AMT]
miss = [n for n,v in zip(["ID","TIME","DV","AMT/DOSE"], need) if v is None]
if miss:
    warnings.warn(f"Columns missing (minimum required): {miss}")

# PD 행 추출 (DVID==2가 있으면 그걸 사용)
if col_DVID is not None and col_DV is not None:
    pdf = df[df[col_DVID]==2].copy()
else:
    pdf = df.copy()

# 투약 이벤트 (EVID==1 우선, 아니면 AMT/DOSE notna)
if col_EVID is not None:
    dose_df = df[df[col_EVID]==1].copy()
else:
    dose_df = df[df[col_AMT].notna()].copy()

# 숫자 변환
for c in [col_TIME, col_DV, col_AMT, col_BW, col_COMED]:
    if c is not None:
        pdf[c] = pd.to_numeric(pdf[c], errors="coerce")
        dose_df[c] = pd.to_numeric(dose_df[c], errors="coerce")

# 정렬/필요 열만
keep_pd = [c for c in [col_ID,col_TIME,col_DV,col_BW,col_COMED] if c is not None]
keep_dose = [c for c in [col_ID,col_TIME,col_AMT] if c is not None]
pdf = pdf[keep_pd].dropna().sort_values([col_ID, col_TIME])
dose_df = dose_df[keep_dose].dropna().sort_values([col_ID, col_TIME])

print("Detected columns:", dict(ID=col_ID, TIME=col_TIME, DV=col_DV, EVID=col_EVID, AMT=col_AMT, BW=col_BW, COMED=col_COMED))
print("PD rows:", len(pdf), "Dose rows:", len(dose_df))

# -------- 데이터셋 --------
class PDSamples(Dataset):
    def __init__(self, pd_df: pd.DataFrame, dose_df: pd.DataFrame,
                 col_ID: str, col_TIME: str, col_DV: str,
                 col_BW: Optional[str], col_COMED: Optional[str], pd_threshold: float=3.3):
        self.col_ID, self.col_TIME, self.col_DV = col_ID, col_TIME, col_DV
        self.col_BW, self.col_COMED = col_BW, col_COMED
        self.pd_threshold = pd_threshold

        # ID별 투약 히스토리
        d_groups = defaultdict(list)
        for _, row in dose_df.iterrows():
            d_groups[row[col_ID]].append((float(row[col_TIME]), float(row[col_AMT])))
        self.dose_map = {k: (np.array([t for t,a in v], dtype=np.float32),
                              np.array([a for t,a in v], dtype=np.float32)) for k,v in d_groups.items()}

        # 샘플: 각 PD 관측 시점
        feats = []
        for _, row in pd_df.iterrows():
            sid = row[col_ID]
            if sid not in self.dose_map:
                continue
            t = float(row[col_TIME])
            val = float(row[col_DV])
            y = 1.0 if (val <= pd_threshold) else 0.0
            bw = float(row[col_BW]) if col_BW is not None else 70.0
            cm = float(row[col_COMED]) if col_COMED is not None else 0.0
            feats.append((sid, t, y, bw, cm))
        self.samples = feats

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        sid, t, y, bw, cm = self.samples[idx]
        dose_t, dose_a = self.dose_map.get(sid, (np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32)))
        return {
            "t": torch.tensor(t, dtype=torch.float32),
            "y": torch.tensor(y, dtype=torch.float32),
            "bw": torch.tensor(bw, dtype=torch.float32),
            "cm": torch.tensor(cm, dtype=torch.float32),
            "dose_t": torch.tensor(dose_t, dtype=torch.float32),
            "dose_a": torch.tensor(dose_a, dtype=torch.float32),
        }

def _collate(batch): return batch

dataset_all = PDSamples(pdf, dose_df, col_ID, col_TIME, col_DV, col_BW, col_COMED, pd_threshold=PD_THRESHOLD)

# -------- ID 단위 분할: 플라시보 제외 + 고정 그룹(1/3/10 mg)별 80/20 --------
rng = 42
placebo_ids = set(range(1, 13))     # 0 mg (제외)
ids_1mg     = list(range(13, 25))   # 1 mg
ids_3mg     = list(range(25, 37))   # 3 mg
ids_10mg    = list(range(37, 49))   # 10 mg

ids_in_data = set(df[col_ID].unique())
ids_1mg  = sorted(ids_in_data.intersection(ids_1mg))
ids_3mg  = sorted(ids_in_data.intersection(ids_3mg))
ids_10mg = sorted(ids_in_data.intersection(ids_10mg))

def split_70_15_15(ids, seed=42):
    if len(ids) == 0:
        return [], [], []
    # 70% train vs 30% temp
    tr_ids, temp_ids = train_test_split(ids, test_size=0.30, random_state=seed, shuffle=True)
    # temp을 15/15로 반분
    va_ids, te_ids = train_test_split(temp_ids, test_size=0.50, random_state=seed, shuffle=True)
    return list(tr_ids), list(va_ids), list(te_ids)

tr_1, va_1, te_1   = split_70_15_15(ids_1mg,  seed=rng)
tr_3, va_3, te_3   = split_70_15_15(ids_3mg,  seed=rng)
tr_10, va_10, te_10 = split_70_15_15(ids_10mg, seed=rng)

ids_tr = set(tr_1 + tr_3 + tr_10)
ids_va = set(va_1 + va_3 + va_10)
ids_te = set(te_1 + te_3 + te_10)

sid_list = [s[0] for s in dataset_all.samples]  # (sid, t, y, bw, cm)
tr_idx = [i for i, sid in enumerate(sid_list) if sid in ids_tr]
va_idx = [i for i, sid in enumerate(sid_list) if sid in ids_va]
te_idx = [i for i, sid in enumerate(sid_list) if sid in ids_te]

train_ds = Subset(dataset_all, tr_idx)
valid_ds = Subset(dataset_all, va_idx)
test_ds  = Subset(dataset_all, te_idx)

print("[ID split by fixed dose groups] (70/15/15)")
print(" train IDs:", len(ids_tr), "| valid IDs:", len(ids_va), "| test IDs:", len(ids_te))
print(" 1mg -> train/valid/test:", len(tr_1), len(va_1), len(te_1))
print(" 3mg -> train/valid/test:", len(tr_3), len(va_3), len(te_3))
print("10mg -> train/valid/test:", len(tr_10), len(va_10), len(te_10))
print(f"#samples -> train: {len(train_ds)} | valid: {len(valid_ds)} | test: {len(test_ds)}")

# -------- MLP 모델 --------
class PDMLPClassifier(nn.Module):
    """
    입력 x=[E(t), (BW-mean)/10, COMED] -> [Linear+ReLU(+Dropout)]*n_layers -> Linear(→1)
    노출 계산용 링크: log(tau) = b0 + b1*((BW-mean)/10) + b2*COMED
    """
    def __init__(self, bw_mean: float=70.0, hidden: int=32, n_layers: int=3, dropout_p: float=0.2):
        super().__init__()
        # tau 링크 파라미터
        self.b0 = nn.Parameter(torch.tensor(math.log(24.0)))
        self.b1 = nn.Parameter(torch.tensor(0.0))
        self.b2 = nn.Parameter(torch.tensor(0.0))
        self.bw_mean = float(bw_mean)

        layers = []
        in_dim = 3
        for i in range(n_layers):
            layers += [
                nn.Linear(in_dim if i==0 else hidden, hidden),
                nn.ReLU(),
                nn.Dropout(p=dropout_p) if dropout_p and dropout_p > 0 else nn.Identity(),
            ]
        layers += [nn.Linear(hidden if n_layers>0 else in_dim, 1)]
        self.mlp = nn.Sequential(*layers)

    def _tau(self, bw: torch.Tensor, comed: torch.Tensor):
        bwc = (bw - self.bw_mean) / 10.0
        log_tau = self.b0 + self.b1*bwc + self.b2*comed
        return torch.exp(log_tau).clamp_min(1.0)

    @torch.no_grad()
    def tau_from_cov_np(self, bw_np: np.ndarray, cm_np: np.ndarray, device=None):
        device = device or next(self.parameters()).device
        bw = torch.tensor(bw_np, dtype=torch.float32, device=device)
        cm = torch.tensor(cm_np, dtype=torch.float32, device=device)
        return self._tau(bw, cm).detach().cpu().numpy()

    def forward_single(self, t: torch.Tensor, dose_t: torch.Tensor, dose_a: torch.Tensor,
                       bw: float, comed: float):
        tau = self._tau(torch.tensor(bw, dtype=torch.float32, device=t.device),
                        torch.tensor(comed, dtype=torch.float32, device=t.device))
        dt = t - dose_t
        mask = (dt >= 0).float()
        exposure = (dose_a * torch.exp(-dt.clamp_min(0) / tau) * mask).sum()
        x = torch.stack([
            exposure,
            (torch.tensor(bw, dtype=torch.float32, device=t.device) - self.bw_mean) / 10.0,
            torch.tensor(comed, dtype=torch.float32, device=t.device)
        ])
        return self.mlp(x).squeeze()

def _collate(batch): return batch

# -------- 지표 계산 --------
def _compute_metrics(y_true: np.ndarray, prob: np.ndarray, thr: float = 0.5):
    pred = (prob >= thr).astype(int)
    acc = accuracy_score(y_true, pred) if len(y_true) else float("nan")
    prec = precision_score(y_true, pred, zero_division=0)
    rec  = recall_score(y_true, pred, zero_division=0)
    f1   = f1_score(y_true, pred, zero_division=0)
    try:
        roc = roc_auc_score(y_true, prob) if (len(np.unique(y_true))>1) else float("nan")
    except Exception:
        roc = float("nan")
    try:
        ap  = average_precision_score(y_true, prob) if (len(np.unique(y_true))>1) else float("nan")
    except Exception:
        ap  = float("nan")
    try:
        tn, fp, fn, tp = confusion_matrix(y_true, pred, labels=[0,1]).ravel()
    except Exception:
        tn=fp=fn=tp = 0
    return {"acc":acc, "prec":prec, "rec":rec, "f1":f1, "roc_auc":roc, "pr_auc":ap,
            "tn":tn, "fp":fp, "fn":fn, "tp":tp}

# -------- 학습 루틴 (MLP 전용) --------
def train_classifier(train_ds, valid_ds, pd_threshold=3.3, epochs=60, lr=5e-2, seed=42,
                     mlp_hidden=32, mlp_layers=3, mlp_dropout=0.2):
    device = DEVICE
    g = torch.Generator().manual_seed(seed)
    tr_loader = DataLoader(train_ds, batch_size=128, shuffle=True, collate_fn=_collate, generator=g)
    va_loader = DataLoader(valid_ds, batch_size=256, shuffle=False, collate_fn=_collate)

    bw_mean = float(np.mean([b["bw"].item() for b in train_ds]))
    model = PDMLPClassifier(bw_mean=bw_mean, hidden=mlp_hidden, n_layers=mlp_layers, dropout_p=mlp_dropout).to(device)

    opt = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCEWithLogitsLoss()
    best = (-1.0, None)

    for ep in range(1, epochs+1):
        model.train()
        for batch in tr_loader:
            opt.zero_grad()
            loss = 0.0
            for b in batch:
                logit = model.forward_single(b["t"].to(device), b["dose_t"].to(device), b["dose_a"].to(device),
                                             float(b["bw"]), float(b["cm"]))
                loss = loss + loss_fn(logit.view(()), b["y"].to(device).view(()))
            (loss/len(batch)).backward()
            opt.step()

        # validation
        model.eval(); ys=[]; ps=[]
        with torch.no_grad():
            for batch in va_loader:
                for b in batch:
                    logit = model.forward_single(b["t"].to(device), b["dose_t"].to(device), b["dose_a"].to(device),
                                                 float(b["bw"]), float(b["cm"]))
                    p = torch.sigmoid(logit).item()
                    ys.append(float(b["y"])); ps.append(p)
        y = np.array(ys); p = np.array(ps)
        metrics = _compute_metrics(y, p, thr=0.5)

        if metrics["f1"] > best[0]:
            best = (metrics["f1"], {"epoch":ep, **metrics})

        if ep%10==0 or ep<=3:
            print(f"[ep {ep:03d}] acc={metrics['acc']:.3f} f1={metrics['f1']:.3f} "
                  f"prec={metrics['prec']:.3f} rec={metrics['rec']:.3f} "
                  f"roc_auc={metrics['roc_auc']:.3f} pr_auc={metrics['pr_auc']:.3f}")

    print("best(valid):", best[1])
    return model, best[1]


# -------- 인구 공변량 샘플러 --------
def subject_cov_sampler(pd_df: pd.DataFrame, col_BW: Optional[str], col_COMED: Optional[str],
                        n:int, scenario:str="base", rng_seed:int=123):
    obs_bw = pd.read_csv(CSV)  # ensure same source? but better use pd_df
    obs_bw = pd_df[col_BW].dropna().to_numpy(dtype=float) if col_BW is not None else np.full(len(pd_df), 80.0)
    obs_cm = pd_df[col_COMED].dropna().to_numpy(dtype=float) if col_COMED is not None else np.zeros(len(pd_df))
    rng = np.random.default_rng(rng_seed)
    if scenario=="base":
        idx = rng.integers(0, len(obs_bw), size=n)
        return obs_bw[idx], obs_cm[idx]
    elif scenario=="bw_wide":
        bw = rng.uniform(70.0, 140.0, size=n)
        cm = obs_cm[rng.integers(0, len(obs_cm), size=n)] if len(obs_cm)>0 else np.zeros(n)
        return bw, cm
    elif scenario=="no_comed":
        idx = rng.integers(0, len(obs_bw), size=n)
        return obs_bw[idx], np.zeros(n)
    else:
        raise ValueError("unknown scenario")

# -------- 정상상태(SS) 평가: MLP --------
@torch.no_grad()
def success_fraction_for_dose_ss_mlp(model: PDMLPClassifier, dose_mg: float, freq_h: int,
                                     last_window_h: int, pd_df, col_BW: Optional[str], col_COMED: Optional[str],
                                     Nsubj=300, scenario="base", decision_threshold=0.5,
                                     grid_step_h=1.0) -> float:
    bw_arr, cm_arr = subject_cov_sampler(pd_df, col_BW, col_COMED, Nsubj, scenario=scenario)
    tau = model.tau_from_cov_np(bw_arr, cm_arr, device=next(model.parameters()).device)
    tgrid = np.arange(0.0, last_window_h + 1e-6, grid_step_h, dtype=float)
    ok = 0
    for i in range(Nsubj):
        denom = (1.0 - np.exp(-float(freq_h) / max(tau[i],1e-6)))
        denom = max(denom, 1e-6)
        e_t = dose_mg * np.exp(-tgrid / max(tau[i],1e-6)) / denom
        bwc = (bw_arr[i] - model.bw_mean) / 10.0
        cm  = cm_arr[i]
        X = torch.tensor(np.stack([e_t, np.full_like(e_t, bwc), np.full_like(e_t, cm)], axis=1),
                         dtype=torch.float32, device=next(model.parameters()).device)
        logits = model.mlp(X).squeeze(1)
        probs = torch.sigmoid(logits).detach().cpu().numpy()
        if probs.min() >= decision_threshold:
            ok += 1
    return ok / Nsubj

def search_min_dose_ss(model: PDMLPClassifier, grid, freq_h, last_window_h, pd_df,
                       col_BW: Optional[str], col_COMED: Optional[str],
                       Nsubj=300, scenario="base", target=0.9, decision_threshold=0.5):
    rows = []
    for d in grid:
        frac = success_fraction_for_dose_ss_mlp(model, d, freq_h, last_window_h, pd_df, col_BW, col_COMED,
                                                Nsubj=Nsubj, scenario=scenario, decision_threshold=decision_threshold)
        rows.append({"dose": d, "fraction": frac})
    df_res = pd.DataFrame(rows).sort_values("dose")
    feas = df_res[df_res["fraction"]>=target]
    best = feas.iloc[0]["dose"] if len(feas)>0 else None
    return df_res, best

# -------- 학습 실행 (MLP 하이퍼파라미터 적용) --------
model, valid_best = train_classifier(train_ds, valid_ds,
                                     pd_threshold=PD_THRESHOLD, epochs=EPOCHS, lr=LR,
                                     mlp_hidden=MLP_HIDDEN, mlp_layers=MLP_LAYERS, mlp_dropout=MLP_DROPOUT)
print("\nValidation best:", valid_best)

# -------- 용량 탐색 (MLP 전용) --------
daily_grid  = [0.5*i for i in range(0, 121)]  # 0..60 mg, 0.5 mg 단위
weekly_grid = [5*i   for i in range(0, 41)]   # 0..200 mg, 5 mg 단위

daily_base,  best_daily_base  = search_min_dose_ss(model, daily_grid,  24, 24,  pdf, col_BW, col_COMED,
                                                   Nsubj=N_SUBJ, scenario="base", target=0.90)
weekly_base, best_weekly_base = search_min_dose_ss(model, weekly_grid, 168, 168, pdf, col_BW, col_COMED,
                                                   Nsubj=N_SUBJ, scenario="base", target=0.90)

daily_bw,   best_daily_bw   = search_min_dose_ss(model, daily_grid,  24, 24,  pdf, col_BW, col_COMED,
                                                 Nsubj=N_SUBJ, scenario="bw_wide", target=0.90)
weekly_bw,  best_weekly_bw  = search_min_dose_ss(model, weekly_grid, 168, 168, pdf, col_BW, col_COMED,
                                                 Nsubj=N_SUBJ, scenario="bw_wide", target=0.90)

daily_nocm, best_daily_nocm = search_min_dose_ss(model, daily_grid,  24, 24,  pdf, col_BW, col_COMED,
                                                 Nsubj=N_SUBJ, scenario="no_comed", target=0.90)
weekly_nocm,best_weekly_nocm= search_min_dose_ss(model, weekly_grid, 168, 168, pdf, col_BW, col_COMED,
                                                 Nsubj=N_SUBJ, scenario="no_comed", target=0.90)

daily_75,   best_daily_75   = search_min_dose_ss(model, daily_grid,  24, 24,  pdf, col_BW, col_COMED,
                                                 Nsubj=N_SUBJ, scenario="base", target=0.75)
weekly_75,  best_weekly_75  = search_min_dose_ss(model, weekly_grid, 168, 168, pdf, col_BW, col_COMED,
                                                 Nsubj=N_SUBJ, scenario="base", target=0.75)

summary = pd.DataFrame([
    {"scenario":"Base (Phase 1-like)", "target":"90%", "once-daily (mg)": best_daily_base,  "once-weekly (mg)": best_weekly_base},
    {"scenario":"BW 70–140 kg",        "target":"90%", "once-daily (mg)": best_daily_bw,    "once-weekly (mg)": best_weekly_bw},
    {"scenario":"No COMED allowed",    "target":"90%", "once-daily (mg)": best_daily_nocm,  "once-weekly (mg)": best_weekly_nocm},
    {"scenario":"Base (Phase 1-like)", "target":"75%", "once-daily (mg)": best_daily_75,    "once-weekly (mg)": best_weekly_75},
])
print("\n=== Dose recommendations summary ===")
print(summary.to_string(index=False))
# -------- test 평가 --------


def evaluate_dataset(model, dataset, batch_size=512):
    device = next(model.parameters()).device
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=_collate)
    ys, ps = [], []
    model.eval()
    with torch.no_grad():
        for batch in loader:
            for b in batch:
                logit = model.forward_single(
                    b["t"].to(device), b["dose_t"].to(device), b["dose_a"].to(device),
                    float(b["bw"]), float(b["cm"])
                )
                ys.append(float(b["y"]))
                ps.append(torch.sigmoid(logit).item())
    y = np.array(ys); p = np.array(ps)
    return _compute_metrics(y, p, thr=0.5)

test_metrics = evaluate_dataset(model, test_ds)
print("\n=== Final TEST metrics (unseen IDs) ===")
for k, v in test_metrics.items():
    if isinstance(v, float):
        print(f"{k:>8s}: {v:.4f}")
    else:
        print(f"{k:>8s}: {v}")

Detected columns: {'ID': 'ID', 'TIME': 'TIME', 'DV': 'DV', 'EVID': 'EVID', 'AMT': 'AMT', 'BW': 'BW', 'COMED': 'COMED'}
PD rows: 1200 Dose rows: 756
[ID split by fixed dose groups] (70/15/15)
 train IDs: 24 | valid IDs: 6 | test IDs: 6
 1mg -> train/valid/test: 8 2 2
 3mg -> train/valid/test: 8 2 2
10mg -> train/valid/test: 8 2 2
#samples -> train: 600 | valid: 150 | test: 150
[ep 001] acc=0.780 f1=0.492 prec=0.485 rec=0.500 roc_auc=0.705 pr_auc=0.492
[ep 002] acc=0.787 f1=0.000 prec=0.000 rec=0.000 roc_auc=0.729 pr_auc=0.472
[ep 003] acc=0.787 f1=0.000 prec=0.000 rec=0.000 roc_auc=0.764 pr_auc=0.436
[ep 010] acc=0.833 f1=0.545 prec=0.652 rec=0.469 roc_auc=0.757 pr_auc=0.517
[ep 020] acc=0.873 f1=0.698 prec=0.710 rec=0.688 roc_auc=0.865 pr_auc=0.728
[ep 030] acc=0.913 f1=0.812 prec=0.757 rec=0.875 roc_auc=0.975 pr_auc=0.940
[ep 040] acc=0.947 f1=0.871 prec=0.900 rec=0.844 roc_auc=0.989 pr_auc=0.968
[ep 050] acc=0.953 f1=0.881 prec=0.963 rec=0.812 roc_auc=0.991 pr_auc=0.970
[ep 060] acc=

In [5]:
print(ids_tr,ids_va, ids_te)

{14, 15, 16, 17, 18, 19, 20, 24, 26, 27, 28, 29, 30, 31, 32, 36, 38, 39, 40, 41, 42, 43, 44, 48} {35, 37, 13, 47, 23, 25} {33, 34, 45, 46, 21, 22}


In [5]:
print(ids_tr,ids_va, ids_te)

{14, 15, 16, 17, 18, 19, 20, 24, 26, 27, 28, 29, 30, 31, 32, 36, 38, 39, 40, 41, 42, 43, 44, 48} {35, 37, 13, 47, 23, 25} {33, 34, 45, 46, 21, 22}


In [2]:
# ===========================
# PD-only 분류 + 정상상태 용량탐색 (MLP 전용, 지표 확장, 베스트 가중치 복원 + 검증 임계값 적용)
# - 데이터: EstData.csv (PD 단위: ng/mL), 임계 3.3 ng/mL
# - 모델: MLP (은닉층 개수/히든/드롭아웃 하이퍼파라미터화)
# - 분할: 플라시보(0 mg, ID 1–12) 제외, 1/3/10 mg 고정 그룹별 70/15/15 ID 단위 분할
# ===========================
import os, math, warnings, numpy as np, pandas as pd, random
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, average_precision_score, confusion_matrix
)
from typing import Optional
from collections import defaultdict
from copy import deepcopy

# -------- (선택) 난수 고정: 분할/셔플 재현성 --------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# -------- 설정 --------
CSV = "EstData.csv"       # 필요 시 "/mnt/data/EstData.csv"
PD_THRESHOLD = 3.3
EPOCHS = 60
LR = 5e-2
N_SUBJ = 300
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# MLP 하이퍼파라미터
MLP_HIDDEN = 36
MLP_LAYERS = 2
MLP_DROPOUT = 0.2

assert os.path.exists(CSV), f"CSV not found at {CSV}"

# -------- 유틸: 열 이름 추론 --------
def _find_col_like(df, name_opts):
    low = {c.lower(): c for c in df.columns}
    for n in name_opts:
        if n in low: return low[n]
    return None

# -------- 데이터 로딩/전처리 --------
df = pd.read_csv(CSV)

col_ID   = _find_col_like(df, ["id"])
col_TIME = _find_col_like(df, ["time"])
col_DVID = _find_col_like(df, ["dvid"])
col_DV   = _find_col_like(df, ["dv","pd","value"])
col_EVID = _find_col_like(df, ["evid"])
col_AMT  = _find_col_like(df, ["amt","dose","dosen","doses"])
col_BW   = _find_col_like(df, ["bw","weight","bodyweight"])
col_COMED= _find_col_like(df, ["comed","conmed","concom"])

need = [col_ID, col_TIME, col_DV, col_AMT]
miss = [n for n,v in zip(["ID","TIME","DV","AMT/DOSE"], need) if v is None]
if miss:
    warnings.warn(f"Columns missing (minimum required): {miss}")

# PD 행 추출 (DVID==2가 있으면 그걸 사용)
if col_DVID is not None and col_DV is not None:
    pdf = df[df[col_DVID]==2].copy()
else:
    pdf = df.copy()

# 투약 이벤트 (EVID==1 우선, 아니면 AMT/DOSE notna)
if col_EVID is not None:
    dose_df = df[df[col_EVID]==1].copy()
else:
    dose_df = df[df[col_AMT].notna()].copy()

# 숫자 변환
for c in [col_TIME, col_DV, col_AMT, col_BW, col_COMED]:
    if c is not None:
        pdf[c] = pd.to_numeric(pdf[c], errors="coerce")
        dose_df[c] = pd.to_numeric(dose_df[c], errors="coerce")

# 정렬/필요 열만
keep_pd = [c for c in [col_ID,col_TIME,col_DV,col_BW,col_COMED] if c is not None]
keep_dose = [c for c in [col_ID,col_TIME,col_AMT] if c is not None]
pdf = pdf[keep_pd].dropna().sort_values([col_ID, col_TIME])
dose_df = dose_df[keep_dose].dropna().sort_values([col_ID, col_TIME])

print("Detected columns:", dict(ID=col_ID, TIME=col_TIME, DV=col_DV, EVID=col_EVID, AMT=col_AMT, BW=col_BW, COMED=col_COMED))
print("PD rows:", len(pdf), "Dose rows:", len(dose_df))

# -------- 데이터셋 --------
class PDSamples(Dataset):
    def __init__(self, pd_df: pd.DataFrame, dose_df: pd.DataFrame,
                 col_ID: str, col_TIME: str, col_DV: str,
                 col_BW: Optional[str], col_COMED: Optional[str], pd_threshold: float=3.3):
        self.col_ID, self.col_TIME, self.col_DV = col_ID, col_TIME, col_DV
        self.col_BW, self.col_COMED = col_BW, col_COMED
        self.pd_threshold = pd_threshold

        # ID별 투약 히스토리
        d_groups = defaultdict(list)
        for _, row in dose_df.iterrows():
            d_groups[row[col_ID]].append((float(row[col_TIME]), float(row[col_AMT])))
        self.dose_map = {k: (np.array([t for t,a in v], dtype=np.float32),
                              np.array([a for t,a in v], dtype=np.float32)) for k,v in d_groups.items()}

        # 샘플: 각 PD 관측 시점
        feats = []
        for _, row in pd_df.iterrows():
            sid = row[col_ID]
            if sid not in self.dose_map:
                continue
            t = float(row[col_TIME])
            val = float(row[col_DV])
            y = 1.0 if (val <= pd_threshold) else 0.0
            bw = float(row[col_BW]) if col_BW is not None else 70.0
            cm = float(row[col_COMED]) if col_COMED is not None else 0.0
            feats.append((sid, t, y, bw, cm))
        self.samples = feats

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        sid, t, y, bw, cm = self.samples[idx]
        dose_t, dose_a = self.dose_map.get(sid, (np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32)))
        return {
            "t": torch.tensor(t, dtype=torch.float32),
            "y": torch.tensor(y, dtype=torch.float32),
            "bw": torch.tensor(bw, dtype=torch.float32),
            "cm": torch.tensor(cm, dtype=torch.float32),
            "dose_t": torch.tensor(dose_t, dtype=torch.float32),
            "dose_a": torch.tensor(dose_a, dtype=torch.float32),
        }

def _collate(batch): return batch

dataset_all = PDSamples(pdf, dose_df, col_ID, col_TIME, col_DV, col_BW, col_COMED, pd_threshold=PD_THRESHOLD)

# -------- ID 단위 분할: 플라시보 제외 + 고정 그룹(1/3/10 mg)별 70/15/15 --------
rng = 42
placebo_ids = set(range(1, 13))     # 0 mg (제외)
ids_1mg     = list(range(13, 25))   # 1 mg
ids_3mg     = list(range(25, 37))   # 3 mg
ids_10mg    = list(range(37, 49))   # 10 mg

ids_in_data = set(df[col_ID].unique())
ids_1mg  = sorted(ids_in_data.intersection(ids_1mg))
ids_3mg  = sorted(ids_in_data.intersection(ids_3mg))
ids_10mg = sorted(ids_in_data.intersection(ids_10mg))

def split_70_15_15(ids, seed=42):
    if len(ids) == 0:
        return [], [], []
    tr_ids, temp_ids = train_test_split(ids, test_size=0.30, random_state=seed, shuffle=True)
    va_ids, te_ids = train_test_split(temp_ids, test_size=0.50, random_state=seed, shuffle=True)
    return list(tr_ids), list(va_ids), list(te_ids)

tr_1, va_1, te_1    = split_70_15_15(ids_1mg,  seed=rng)
tr_3, va_3, te_3    = split_70_15_15(ids_3mg,  seed=rng)
tr_10, va_10, te_10 = split_70_15_15(ids_10mg, seed=rng)

ids_tr = set(tr_1 + tr_3 + tr_10)
ids_va = set(va_1 + va_3 + va_10)
ids_te = set(te_1 + te_3 + te_10)

sid_list = [s[0] for s in dataset_all.samples]
tr_idx = [i for i, sid in enumerate(sid_list) if sid in ids_tr]
va_idx = [i for i, sid in enumerate(sid_list) if sid in ids_va]
te_idx = [i for i, sid in enumerate(sid_list) if sid in ids_te]

train_ds = Subset(dataset_all, tr_idx)
valid_ds = Subset(dataset_all, va_idx)
test_ds  = Subset(dataset_all, te_idx)

print("[ID split by fixed dose groups] (70/15/15)")
print(" train IDs:", len(ids_tr), "| valid IDs:", len(ids_va), "| test IDs:", len(ids_te))
print(" 1mg -> train/valid/test:", len(tr_1), len(va_1), len(te_1))
print(" 3mg -> train/valid/test:", len(tr_3), len(va_3), len(te_3))
print("10mg -> train/valid/test:", len(tr_10), len(va_10), len(te_10))
print(f"#samples -> train: {len(train_ds)} | valid: {len(valid_ds)} | test: {len(test_ds)}")

# -------- MLP 모델 --------
class PDMLPClassifier(nn.Module):
    """
    입력 x=[E(t), (BW-mean)/10, COMED] -> [Linear+ReLU(+Dropout)]*n_layers -> Linear(→1)
    노출 계산용 링크: log(tau) = b0 + b1*((BW-mean)/10) + b2*COMED
    """
    def __init__(self, bw_mean: float=70.0, hidden: int=32, n_layers: int=3, dropout_p: float=0.2):
        super().__init__()
        # tau 링크 파라미터
        self.b0 = nn.Parameter(torch.tensor(math.log(24.0)))
        self.b1 = nn.Parameter(torch.tensor(0.0))
        self.b2 = nn.Parameter(torch.tensor(0.0))
        self.bw_mean = float(bw_mean)

        layers = []
        in_dim = 3
        for i in range(n_layers):
            layers += [
                nn.Linear(in_dim if i==0 else hidden, hidden),
                nn.ReLU(),
                nn.Dropout(p=dropout_p) if dropout_p and dropout_p > 0 else nn.Identity(),
            ]
        layers += [nn.Linear(hidden if n_layers>0 else in_dim, 1)]
        self.mlp = nn.Sequential(*layers)

    def _tau(self, bw: torch.Tensor, comed: torch.Tensor):
        bwc = (bw - self.bw_mean) / 10.0
        log_tau = self.b0 + self.b1*bwc + self.b2*comed
        return torch.exp(log_tau).clamp_min(1.0)

    @torch.no_grad()
    def tau_from_cov_np(self, bw_np: np.ndarray, cm_np: np.ndarray, device=None):
        device = device or next(self.parameters()).device
        bw = torch.tensor(bw_np, dtype=torch.float32, device=device)
        cm = torch.tensor(cm_np, dtype=torch.float32, device=device)
        return self._tau(bw, cm).detach().cpu().numpy()

    def forward_single(self, t: torch.Tensor, dose_t: torch.Tensor, dose_a: torch.Tensor,
                       bw: float, comed: float):
        tau = self._tau(torch.tensor(bw, dtype=torch.float32, device=t.device),
                        torch.tensor(comed, dtype=torch.float32, device=t.device))
        dt = t - dose_t
        mask = (dt >= 0).float()
        exposure = (dose_a * torch.exp(-dt.clamp_min(0) / tau) * mask).sum()
        x = torch.stack([
            exposure,
            (torch.tensor(bw, dtype=torch.float32, device=t.device) - self.bw_mean) / 10.0,
            torch.tensor(comed, dtype=torch.float32, device=t.device)
        ])
        return self.mlp(x).squeeze()

def _collate(batch): return batch

# -------- 지표 계산 --------
def _compute_metrics(y_true: np.ndarray, prob: np.ndarray, thr: float = 0.5):
    pred = (prob >= thr).astype(int)
    acc = accuracy_score(y_true, pred) if len(y_true) else float("nan")
    prec = precision_score(y_true, pred, zero_division=0)
    rec  = recall_score(y_true, pred, zero_division=0)
    f1   = f1_score(y_true, pred, zero_division=0)
    try:
        roc = roc_auc_score(y_true, prob) if (len(np.unique(y_true))>1) else float("nan")
    except Exception:
        roc = float("nan")
    try:
        ap  = average_precision_score(y_true, prob) if (len(np.unique(y_true))>1) else float("nan")
    except Exception:
        ap  = float("nan")
    try:
        tn, fp, fn, tp = confusion_matrix(y_true, pred, labels=[0,1]).ravel()
    except Exception:
        tn=fp=fn=tp = 0
    return {"acc":acc, "prec":prec, "rec":rec, "f1":f1, "roc_auc":roc, "pr_auc":ap,
            "tn":tn, "fp":fp, "fn":fn, "tp":tp}

# -------- 학습 루틴 (베스트 가중치 저장/복원 포함) --------
def train_classifier(train_ds, valid_ds, epochs=60, lr=5e-2,
                     seed=42, mlp_hidden=32, mlp_layers=3, mlp_dropout=0.2):
    device = DEVICE
    g = torch.Generator().manual_seed(seed)
    tr_loader = DataLoader(train_ds, batch_size=128, shuffle=True, collate_fn=_collate, generator=g)
    va_loader = DataLoader(valid_ds, batch_size=256, shuffle=False, collate_fn=_collate)

    bw_mean = float(np.mean([b["bw"].item() for b in train_ds]))
    model = PDMLPClassifier(bw_mean=bw_mean, hidden=mlp_hidden, n_layers=mlp_layers, dropout_p=mlp_dropout).to(device)

    opt = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCEWithLogitsLoss()

    best = (-1.0, None)
    best_state = None

    for ep in range(1, epochs+1):
        model.train()
        for batch in tr_loader:
            opt.zero_grad()
            loss = 0.0
            for b in batch:
                logit = model.forward_single(
                    b["t"].to(device), b["dose_t"].to(device), b["dose_a"].to(device),
                    float(b["bw"]), float(b["cm"])
                )
                loss = loss + loss_fn(logit.view(()), b["y"].to(device).view(())
                )
            (loss/len(batch)).backward()
            opt.step()

        # validation
        model.eval(); ys=[]; ps=[]
        with torch.no_grad():
            for batch in va_loader:
                for b in batch:
                    logit = model.forward_single(
                        b["t"].to(device), b["dose_t"].to(device), b["dose_a"].to(device),
                        float(b["bw"]), float(b["cm"])
                    )
                    ys.append(float(b["y"]))
                    ps.append(torch.sigmoid(logit).item())
        y = np.array(ys); p = np.array(ps)
        metrics = _compute_metrics(y, p, thr=0.5)

        # 베스트 스냅샷
        if metrics["f1"] > best[0]:
            best = (metrics["f1"], {"epoch":ep, **metrics})
            best_state = deepcopy(model.state_dict())

        if ep%10==0 or ep<=3:
            print(f"[ep {ep:03d}] acc={metrics['acc']:.3f} f1={metrics['f1']:.3f} "
                  f"prec={metrics['prec']:.3f} rec={metrics['rec']:.3f} "
                  f"roc_auc={metrics['roc_auc']:.3f} pr_auc={metrics['pr_auc']:.3f}")

    # 베스트 가중치 복원
    if best_state is not None:
        model.load_state_dict(best_state)

    print("best(valid):", best[1])
    return model, best[1]

# -------- 검증셋 기반 임계값 탐색 --------
def predict_dataset(model, dataset, batch_size=512):
    device = next(model.parameters()).device
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=_collate)
    ys, ps = [], []
    model.eval()
    with torch.no_grad():
        for batch in loader:
            for b in batch:
                logit = model.forward_single(
                    b["t"].to(device), b["dose_t"].to(device), b["dose_a"].to(device),
                    float(b["bw"]), float(b["cm"])
                )
                ys.append(float(b["y"]))
                ps.append(torch.sigmoid(logit).item())
    return np.array(ys), np.array(ps)

def find_optimal_threshold(y, p, mode="f1"):
    thr_grid = np.linspace(0.05, 0.95, 19)
    best_thr, best_score = 0.5, -1.0
    for thr in thr_grid:
        m = _compute_metrics(y, p, thr)
        score = m["f1"] if mode == "f1" else (m["rec"] + m["prec"] - 1)
        if score > best_score:
            best_score, best_thr = score, thr
    return best_thr

# -------- 인구 공변량 샘플러 --------
def subject_cov_sampler(pd_df: pd.DataFrame, col_BW: Optional[str], col_COMED: Optional[str],
                        n:int, scenario:str="base", rng_seed:int=123):
    obs_bw = pd_df[col_BW].dropna().to_numpy(dtype=float) if col_BW is not None else np.full(len(pd_df), 80.0)
    obs_cm = pd_df[col_COMED].dropna().to_numpy(dtype=float) if col_COMED is not None else np.zeros(len(pd_df))
    rng = np.random.default_rng(rng_seed)
    if scenario=="base":
        idx = rng.integers(0, len(obs_bw), size=n)
        return obs_bw[idx], obs_cm[idx]
    elif scenario=="bw_wide":
        # 외삽: 70~140 kg 구간에서 균등 샘플
        bw = rng.uniform(70.0, 140.0, size=n)
        cm = obs_cm[rng.integers(0, len(obs_cm), size=n)] if len(obs_cm)>0 else np.zeros(n)
        return bw, cm
    elif scenario=="no_comed":
        idx = rng.integers(0, len(obs_bw), size=n)
        return obs_bw[idx], np.zeros(n)
    else:
        raise ValueError("unknown scenario")

# -------- 정상상태(SS) 평가: MLP --------
@torch.no_grad()
def success_fraction_for_dose_ss_mlp(
    model: PDMLPClassifier, dose_mg: float, freq_h: int, last_window_h: int,
    pd_df, col_BW: Optional[str], col_COMED: Optional[str],
    Nsubj=300, scenario="base", decision_threshold=0.5,
    grid_step_h=1.0, agg="min", alpha=0.9
) -> float:
    """
    agg:
      - "min"      : 창 내 모든 시점에서 확률 >= threshold
      - "trough"   : 마지막 시점(트러프)에서만 확률 >= threshold
      - "mean"     : 창 내 평균 확률 >= threshold
      - "coverage" : 창 내 확률 >= threshold인 비율이 alpha 이상
    """
    bw_arr, cm_arr = subject_cov_sampler(pd_df, col_BW, col_COMED, Nsubj, scenario=scenario)
    tau = model.tau_from_cov_np(bw_arr, cm_arr, device=next(model.parameters()).device)
    tgrid = np.arange(0.0, last_window_h + 1e-6, grid_step_h, dtype=float)
    ok = 0
    for i in range(Nsubj):
        denom = (1.0 - np.exp(-float(freq_h) / max(tau[i],1e-6)))
        denom = max(denom, 1e-6)
        e_t = dose_mg * np.exp(-tgrid / max(tau[i],1e-6)) / denom
        bwc = (bw_arr[i] - model.bw_mean) / 10.0
        cm  = cm_arr[i]
        X = torch.tensor(
            np.stack([e_t, np.full_like(e_t, bwc), np.full_like(e_t, cm)], axis=1),
            dtype=torch.float32, device=next(model.parameters()).device
        )
        probs = torch.sigmoid(model.mlp(X).squeeze(1)).detach().cpu().numpy()

        if agg == "min":
            success = probs.min() >= decision_threshold
        elif agg == "trough":
            success = probs[-1] >= decision_threshold
        elif agg == "mean":
            success = probs.mean() >= decision_threshold
        elif agg == "coverage":
            success = (probs >= decision_threshold).mean() >= alpha
        else:
            raise ValueError("Unknown agg")

        ok += int(success)
    return ok / Nsubj

def search_min_dose_ss(
    model: PDMLPClassifier, grid, freq_h, last_window_h, pd_df,
    col_BW: Optional[str], col_COMED: Optional[str],
    Nsubj=300, scenario="base", target=0.9, decision_threshold=0.5,
    agg="min", alpha=0.9
):
    rows = []
    for d in grid:
        frac = success_fraction_for_dose_ss_mlp(
            model, d, freq_h, last_window_h, pd_df, col_BW, col_COMED,
            Nsubj=Nsubj, scenario=scenario, decision_threshold=decision_threshold,
            agg=agg, alpha=alpha
        )
        rows.append({"dose": d, "fraction": frac})
    df_res = pd.DataFrame(rows).sort_values("dose")
    feas = df_res[df_res["fraction"]>=target]
    best = feas.iloc[0]["dose"] if len(feas)>0 else None
    return df_res, best

# -------- 학습 실행 (베스트 복원) --------
model, valid_best = train_classifier(
    train_ds, valid_ds,
    epochs=EPOCHS, lr=LR,
    mlp_hidden=MLP_HIDDEN, mlp_layers=MLP_LAYERS, mlp_dropout=MLP_DROPOUT,
    seed=SEED
)
print("\nValidation best:", valid_best)

# -------- 검증셋에서 임계값 최적화 후 활용 --------
y_va, p_va = predict_dataset(model, valid_ds)
thr_opt = find_optimal_threshold(y_va, p_va, mode="f1")
print(f"\n[validation] optimal threshold (by F1): {thr_opt:.3f}")

# -------- 용량 탐색 (thr_opt 적용, agg 선택 가능) --------
daily_grid  = [0.5*i for i in range(0, 121)]  # 0..60 mg, 0.5 mg 단위
weekly_grid = [5*i   for i in range(0, 41)]   # 0..200 mg, 5 mg 단위

# 예시: 하루1회(24h 창, min), 주1회(168h 창, trough)로 가정
daily_base,  best_daily_base  = search_min_dose_ss(
    model, daily_grid,  24, 24,  pdf, col_BW, col_COMED,
    Nsubj=N_SUBJ, scenario="base", target=0.90, decision_threshold=thr_opt,
    agg="min"
)
weekly_base, best_weekly_base = search_min_dose_ss(
    model, weekly_grid, 168, 168, pdf, col_BW, col_COMED,
    Nsubj=N_SUBJ, scenario="base", target=0.90, decision_threshold=thr_opt,
    agg="trough"
)

daily_bw,   best_daily_bw   = search_min_dose_ss(
    model, daily_grid,  24, 24,  pdf, col_BW, col_COMED,
    Nsubj=N_SUBJ, scenario="bw_wide", target=0.90, decision_threshold=thr_opt,
    agg="min"
)
weekly_bw,  best_weekly_bw  = search_min_dose_ss(
    model, weekly_grid, 168, 168, pdf, col_BW, col_COMED,
    Nsubj=N_SUBJ, scenario="bw_wide", target=0.90, decision_threshold=thr_opt,
    agg="trough"
)

daily_nocm, best_daily_nocm = search_min_dose_ss(
    model, daily_grid,  24, 24,  pdf, col_BW, col_COMED,
    Nsubj=N_SUBJ, scenario="no_comed", target=0.90, decision_threshold=thr_opt,
    agg="min"
)
weekly_nocm,best_weekly_nocm= search_min_dose_ss(
    model, weekly_grid, 168, 168, pdf, col_BW, col_COMED,
    Nsubj=N_SUBJ, scenario="no_comed", target=0.90, decision_threshold=thr_opt,
    agg="trough"
)

daily_75,   best_daily_75   = search_min_dose_ss(
    model, daily_grid,  24, 24,  pdf, col_BW, col_COMED,
    Nsubj=N_SUBJ, scenario="base", target=0.75, decision_threshold=thr_opt,
    agg="min"
)
weekly_75,  best_weekly_75  = search_min_dose_ss(
    model, weekly_grid, 168, 168, pdf, col_BW, col_COMED,
    Nsubj=N_SUBJ, scenario="base", target=0.75, decision_threshold=thr_opt,
    agg="trough"
)

summary = pd.DataFrame([
    {"scenario":"Base (Phase 1-like)", "target":"90%", "once-daily (mg)": best_daily_base,  "once-weekly (mg)": best_weekly_base},
    {"scenario":"BW 70–140 kg",        "target":"90%", "once-daily (mg)": best_daily_bw,    "once-weekly (mg)": best_weekly_bw},
    {"scenario":"No COMED allowed",    "target":"90%", "once-daily (mg)": best_daily_nocm,  "once-weekly (mg)": best_weekly_nocm},
    {"scenario":"Base (Phase 1-like)", "target":"75%", "once-daily (mg)": best_daily_75,    "once-weekly (mg)": best_weekly_75},
])
print("\n=== Dose recommendations summary (thr_opt applied) ===")
print(summary.to_string(index=False))

# -------- TEST 평가 (thr_opt로) --------
def evaluate_dataset(model, dataset, thr=0.5, batch_size=512):
    y, p = predict_dataset(model, dataset, batch_size=batch_size)
    return _compute_metrics(y, p, thr=thr)

test_metrics = evaluate_dataset(model, test_ds, thr=thr_opt)
print("\n=== Final TEST metrics (unseen IDs, thr_opt applied) ===")
for k, v in test_metrics.items():
    if isinstance(v, float):
        print(f"{k:>8s}: {v:.4f}")
    else:
        print(f"{k:>8s}: {v}")


Detected columns: {'ID': 'ID', 'TIME': 'TIME', 'DV': 'DV', 'EVID': 'EVID', 'AMT': 'AMT', 'BW': 'BW', 'COMED': 'COMED'}
PD rows: 1200 Dose rows: 756
[ID split by fixed dose groups] (70/15/15)
 train IDs: 24 | valid IDs: 6 | test IDs: 6
 1mg -> train/valid/test: 8 2 2
 3mg -> train/valid/test: 8 2 2
10mg -> train/valid/test: 8 2 2
#samples -> train: 600 | valid: 150 | test: 150
[ep 001] acc=0.787 f1=0.000 prec=0.000 rec=0.000 roc_auc=0.705 pr_auc=0.387
[ep 002] acc=0.787 f1=0.000 prec=0.000 rec=0.000 roc_auc=0.585 pr_auc=0.456
[ep 003] acc=0.787 f1=0.000 prec=0.000 rec=0.000 roc_auc=0.599 pr_auc=0.465
[ep 010] acc=0.827 f1=0.629 prec=0.579 rec=0.688 roc_auc=0.796 pr_auc=0.562
[ep 020] acc=0.880 f1=0.735 prec=0.694 rec=0.781 roc_auc=0.914 pr_auc=0.852
[ep 030] acc=0.947 f1=0.867 prec=0.929 rec=0.812 roc_auc=0.976 pr_auc=0.939
[ep 040] acc=0.953 f1=0.881 prec=0.963 rec=0.812 roc_auc=0.991 pr_auc=0.972
[ep 050] acc=0.953 f1=0.881 prec=0.963 rec=0.812 roc_auc=0.984 pr_auc=0.956
[ep 060] acc=