# Train/Evaluate: COUNT vs CLMBR (mortality + phenotyping)

## Purpose
Train logistic regression baselines on COUNT and CLMBR features for both tasks; includes both single-label and multi-label handling.

## Inputs
- COUNT and CLMBR feature PKLs from step 04
- Split definitions / shot files (if used)

## Outputs
- Model metrics, predictions, and result artifacts written to your output directories


In [None]:
import os, json, pickle, datetime, warnings
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional

import numpy as np
import pandas as pd
from scipy.sparse import load_npz, csr_matrix
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, average_precision_score, brier_score_loss, accuracy_score
from loguru import logger

from ehrshot.labelers.core import load_labeled_patients

In [None]:
ROOT = "/root/autodl-tmp/femr"
LABELING_FUNCTION = "mimic_icu_phenotyping"

TRAIN = {
    "db":       f"{ROOT}/train/extract",
    "labels":   f"{ROOT}/train/femr_labels",
    "features": f"{ROOT}/train/femr_features",
}
TUNING = {
    "db":       f"{ROOT}/tuning/extract",
    "labels":   f"{ROOT}/tuning/femr_labels",
    "features": f"{ROOT}/tuning/femr_features",
}
HELDOUT = {
    "db":       f"{ROOT}/held_out/extract",
    "labels":   f"{ROOT}/held_out/femr_labels",
    "features": f"{ROOT}/held_out/femr_features",
}

In [None]:
SHOT_STRAT = "all"  # "few"
SHOTS_JSON = os.path.join(TRAIN["labels"], LABELING_FUNCTION, f"{SHOT_STRAT}_shots_data.json")
K_VALUES   = [-1]  # all ：[-1]；few ：[8,16,32]
MODELS = ["count", "clmbr"]

In [None]:
COUNT_FILES = {
    "train": os.path.join(TRAIN["features"], LABELING_FUNCTION, "count_features.pkl"),
    "tuning": os.path.join(TUNING["features"],LABELING_FUNCTION, "count_features.pkl"),
    "held":   os.path.join(HELDOUT["features"], LABELING_FUNCTION,"count_features.pkl"),
}

CLMBR_FILES = {
    "train": os.path.join(TRAIN["features"],LABELING_FUNCTION, "clmbr_features.pkl"),
    "tuning":os.path.join(TUNING["features"],LABELING_FUNCTION, "clmbr_features.pkl"),
    "held":  os.path.join(HELDOUT["features"], LABELING_FUNCTION, "clmbr_features.pkl"),
}
# =================================================================

In [None]:
# ------------------ Count features loader ------------------
import pickle
from scipy.sparse import csr_matrix, issparse

def _to_minute_iso(x):
    dt = pd.to_datetime(x).to_pydatetime()
    return dt.replace(second=0, microsecond=0).isoformat(timespec="minutes")

def load_count_split(path_pkl):
    """
     count_features.pkl，：(X_csr, patient_ids, label_values, label_times)
    : (X_csr, rows_df[patient_id,time], y[0/1])
    """
    with open(path_pkl, "rb") as f:
        X, pids, y, times = pickle.load(f)

    if not issparse(X):
        X = csr_matrix(np.asarray(X))
    else:
        X = X.tocsr()

    rows_df = pd.DataFrame({
        "patient_id": np.asarray(pids, dtype=int),
        "time": [ _to_minute_iso(t) for t in times ],
    })

    y = np.asarray(y)
    if y.dtype == bool:
        y = y.astype(int)
    elif np.issubdtype(y.dtype, np.integer):
        y = y.astype(int)
    else:
        try:
            y = y.astype(int)
        except Exception:
            vals = pd.Series(y).map(lambda v: 1 if str(v).strip() in {"1","True","true"} else 0)
            y = vals.to_numpy(dtype=int)

    assert X.shape[0] == len(rows_df) == len(y), \
        f": X={X.shape[0]}, rows={len(rows_df)}, y={len(y)}"

    return X, rows_df, y


In [None]:
import numpy as np
COUNT_FILES = {
    "train": f"/root/autodl-tmp/femr/train/femr_features/{LABELING_FUNCTION}/count_features.pkl",
    "tuning":f"/root/autodl-tmp/femr/tuning/femr_features/{LABELING_FUNCTION}/count_features.pkl",
    "held":  f"/root/autodl-tmp/femr/held_out/femr_features/{LABELING_FUNCTION}/count_features.pkl",
}

for split, p in COUNT_FILES.items():
    X, rows, y = load_count_split(p)
    pos = int(y.sum()); n = len(y)
    print(f"{split:7s} | X={X.shape}, rows={rows.shape}, pos={pos}/{n} ({pos/n:.3%})")
    print(rows.head(2), "\n")

In [None]:

rng = np.random.default_rng(0)
X, rows, y = load_count_split(COUNT_FILES["train"])
idxs = rng.choice(len(rows), size=3, replace=False)
print(rows.iloc[idxs].assign(y=y[idxs]).to_string(index=False))


In [None]:
from sklearn.metrics import precision_recall_curve, f1_score, precision_score, recall_score

def f1_at_threshold(y_true, prob, thr: float):
    pred = (prob >= thr).astype(int)
    return {
        "precision": float(precision_score(y_true, pred, zero_division=0)),
        "recall":    float(recall_score(y_true, pred, zero_division=0)),
        "f1":        float(f1_score(y_true, pred, zero_division=0)),
        "thr":       float(thr),
    }

def find_best_f1_threshold(y_true, prob):
    prec, rec, thr = precision_recall_curve(y_true, prob)
    f1s = 2 * prec[:-1] * rec[:-1] / (prec[:-1] + rec[:-1] + 1e-12)
    i = int(f1s.argmax())
    return float(thr[i]), float(f1s[i])

In [None]:
import os, json, pickle, datetime, warnings
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional

import numpy as np
import pandas as pd
from scipy.sparse import load_npz, csr_matrix
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, average_precision_score, brier_score_loss, accuracy_score
from loguru import logger

from ehrshot.labelers.core import load_labeled_patients

In [None]:
def load_shots(shots_json: str, task: str, ks: List[int]) -> Dict[int, Dict[int, dict]]:
    """ {k: {replicate: payload}}； ks  k，。"""
    with open(shots_json, "r") as f:
        blob = json.load(f)
    assert task in blob, f"{shots_json}  {task}"
    task_dict = blob[task]
    out: Dict[int, Dict[int, dict]] = {}
    for k in ks:
        if str(k) in task_dict:
            out[k] = {int(rep): v for rep, v in task_dict[str(k)].items()}
        elif k in task_dict:
            out[k] = {int(rep): v for rep, v in task_dict[k].items()}
        else:
            logger.warning(f"[shots] k={k}  {shots_json} ，")
    return out

def instances_from_labeled_csv(path_to_labels_dir: str, task: str) -> List[Tuple[int, str, int]]:
    """ labeled_patients.csv， (patient_id, time_iso, value0/1) """
    lp_csv = os.path.join(path_to_labels_dir, task, "labeled_patients.csv")
    lp = load_labeled_patients(lp_csv)
    out = []
    seen = set()
    for pid, labels in lp.items():
        for lab in labels:
            t_iso = lab.time.replace(second=0, microsecond=0).isoformat(timespec="minutes")
            key = (int(pid), t_iso)
            if key in seen:
                continue
            seen.add(key)
            val = 1 if bool(lab.value) else 0
            out.append((int(pid), t_iso, val))
    return out

def _coerce_time_str(x: str) -> str:
    """ ISO（YYYY-mm-ddTHH:MM）"""
    if isinstance(x, str):
        try:
            dt = datetime.datetime.fromisoformat(x)
        except Exception:
            try:
                dt = pd.to_datetime(x).to_pydatetime()
            except Exception:
                raise
    else:
        dt = x
    return dt.replace(second=0, microsecond=0).isoformat(timespec="minutes")

def _load_rows_csv(rows_path: str) -> pd.DataFrame:
    df = pd.read_csv(rows_path)
    cols = {c.lower(): c for c in df.columns}
    pid_col = cols.get("patient_id") or cols.get("pid") or list(df.columns)[0]
    time_col = cols.get("time") or cols.get("label_time") or cols.get("timestamp") or list(df.columns)[1]
    df = df[[pid_col, time_col]].copy()
    df.columns = ["patient_id", "time"]
    df["patient_id"] = df["patient_id"].astype(int)
    df["time"] = df["time"].apply(_coerce_time_str)
    return df

def _join_rows(rows_df: pd.DataFrame, samples: List[Tuple[int, str, int]]) -> Tuple[np.ndarray, np.ndarray]:
    """ (row_indices, y)； rows_df ，。"""
    key2row = {(int(pid), t): i for i, (pid, t) in enumerate(zip(rows_df["patient_id"].values,
                                                                 rows_df["time"].values))}
    idxs = []
    ys = []
    missed = 0
    for pid, t_iso, y in samples:
        r = key2row.get((pid, t_iso))
        if r is None:
            missed += 1
            continue
        idxs.append(r)
        ys.append(int(y))
    if missed > 0:
        logger.warning(f"[align]  {missed}  rows.csv ，")
    return np.array(idxs, dtype=int), np.array(ys, dtype=int)

In [None]:
# ------------------ CLMBR features loader ------------------
import os, datetime
import pickle, numpy as np, pandas as pd
from typing import Tuple

def _first_present(d: dict, keys):
    """ (value, key)；， (None, None)。"""
    for k in keys:
        if k in d and d[k] is not None:
            return d[k], k
    return None, None

def load_clmbr_split(pkl_path: str) -> Tuple[object, pd.DataFrame]:
    """
    ：
    - dict: {"features"/"representations"/"X"/"data_matrix": array-like or sparse,
             "patient_id"/"patient_ids"/"pids": array-like,
             "time"/"times"/"label_times"/"labeling_time": array-like,
             （）"label_values"/"labels"/"y": array-like}
    - (features, rows_df-like) 
    """
    assert os.path.exists(pkl_path), f" CLMBR : {pkl_path}"
    with open(pkl_path, "rb") as f:
        obj = pickle.load(f)

    if isinstance(obj, dict):
        feats, fk   = _first_present(obj, ["features", "representations", "X", "data_matrix"])
        pids, pk    = _first_present(obj, ["patient_id", "patient_ids", "pids"])
        times, tk   = _first_present(obj, ["time", "times", "label_times", "labeling_time"])
        labels, lk  = _first_present(obj, ["label_values", "labels", "y"])

        if feats is None or pids is None or times is None:
            raise ValueError(
                f"{pkl_path} ；：{list(obj.keys())}\n"
                f" features/representations/X/data_matrix + patient_id/patient_ids/pids + time/times/label_times/labeling_time"
            )

        def _to_iso_minute(x):
            if isinstance(x, str):
                try:
                    dt = datetime.datetime.fromisoformat(x)
                except Exception:
                    dt = pd.to_datetime(x).to_pydatetime()
            else:
                dt = pd.to_datetime(x).to_pydatetime()
            return dt.replace(second=0, microsecond=0).isoformat(timespec="minutes")

        rows_df = pd.DataFrame({
            "patient_id": np.asarray(pids).astype(int),
            "time": [ _to_iso_minute(t) for t in times ],
        })
        if labels is not None and len(labels) == len(rows_df):
            rows_df["label"] = np.asarray(labels)
        return feats, rows_df

    if isinstance(obj, (list, tuple)) and len(obj) == 2:
        feats, rows = obj
        if isinstance(rows, pd.DataFrame):
            rows_df = rows.copy()
        else:
            rows = np.asarray(rows)
            assert rows.shape[1] == 2, f" (patient_id, time)，: {rows.shape}"
            rows_df = pd.DataFrame(rows, columns=["patient_id", "time"])
        rows_df["patient_id"] = rows_df["patient_id"].astype(int)
        def _coerce(x):
            if isinstance(x, str):
                try:
                    dt = datetime.datetime.fromisoformat(x)
                except Exception:
                    dt = pd.to_datetime(x).to_pydatetime()
            else:
                dt = pd.to_datetime(x).to_pydatetime()
            return dt.replace(second=0, microsecond=0).isoformat(timespec="minutes")
        rows_df["time"] = rows_df["time"].apply(_coerce)
        return feats, rows_df

    raise ValueError(f" CLMBR pkl ：{type(obj)}（：{pkl_path}）")


In [None]:
import os, json
from datetime import datetime
try:
    from zoneinfo import ZoneInfo  # py>=3.9
except Exception:
    ZoneInfo = None

def _timestamp_str():
    tz = ZoneInfo("Asia/Jakarta") if ZoneInfo else None
    return datetime.now(tz).strftime("%Y%m%d-%H%M%S%z")

def _snapshot_lr_params(clf):
    """
     clf（ Pipeline  LogisticRegression ）。
    """
    lr = None
    if hasattr(clf, "named_steps") and "logisticregression" in clf.named_steps:
        lr = clf.named_steps["logisticregression"]
        scaler = type(clf.named_steps.get("maxabsscaler", None)).__name__ if "maxabsscaler" in clf.named_steps else None
    else:
        lr = clf
        scaler = None
    snap = {
        "is_pipeline": hasattr(clf, "named_steps"),
        "scaler": scaler,
        "solver": getattr(lr, "solver", None),
        "penalty": getattr(lr, "penalty", None),
        "C": getattr(lr, "C", None),
        "tol": getattr(lr, "tol", None),
        "max_iter": getattr(lr, "max_iter", None),
        "class_weight": getattr(lr, "class_weight", None),
        "random_state": getattr(lr, "random_state", None),
        "n_iter_": getattr(lr, "n_iter_", None),
    }
    if hasattr(snap["n_iter_"], "tolist"):
        snap["n_iter_"] = snap["n_iter_"].tolist()
    return snap

def _save_run_txt(summary_df, all_results, extra_params=None, out_dir="runs", fname_prefix="report"):
    os.makedirs(out_dir, exist_ok=True)
    ts = _timestamp_str()
    path = os.path.join(out_dir, f"{fname_prefix}-{ts}.txt")

    param_snaps = [r.get("param_snapshot") for r in all_results if r.get("param_snapshot") is not None]
    uniq_snaps = []
    for p in param_snaps:
        if p and p not in uniq_snaps:
            uniq_snaps.append(p)

    with open(path, "w", encoding="utf-8") as f:
        f.write(f"===== Run Report =====\n")
        f.write(f"time: {ts}\n")
        if extra_params:
            f.write("\n--- run_config ---\n")
            f.write(json.dumps(extra_params, ensure_ascii=False, indent=2, sort_keys=True))
            f.write("\n")

        f.write("\n--- training_params (unique heads) ---\n")
        if uniq_snaps:
            for i, p in enumerate(uniq_snaps):
                f.write(f"[{i}] {json.dumps(p, ensure_ascii=False, indent=2, sort_keys=True)}\n")
        else:
            f.write("(no param_snapshot found in results)\n")

        f.write("\n--- SUMMARY (DataFrame) ---\n")
        f.write(summary_df.to_string(index=False))
        f.write("\n")

        f.write("\n--- ALL RESULTS (JSON per row) ---\n")
        for r in all_results:
            f.write(json.dumps(r, ensure_ascii=False, indent=2, sort_keys=True) + "\n")
    return path

In [None]:
@dataclass
class SplitData:
    X:     object  # csr_matrix  ndarray
    rows:  pd.DataFrame
    idxs:  np.ndarray
    y:     np.ndarray  # (0/1)

def slice_rows(X, idxs: np.ndarray):
    if isinstance(X, csr_matrix):
        return X[idxs]
    X = np.asarray(X)
    return X[idxs]

def evaluate_binary(y_true: np.ndarray, prob: np.ndarray, thr: float = 0.5) -> Dict[str, float]:
    out = {}
    try:
        out["auroc"] = float(roc_auc_score(y_true, prob))
    except Exception:
        out["auroc"] = float("nan")
    try:
        out["auprc"] = float(average_precision_score(y_true, prob))
    except Exception:
        out["auprc"] = float("nan")
    try:
        out["brier"] = float(brier_score_loss(y_true, prob))
    except Exception:
        out["brier"] = float("nan")
    pred = (prob >= thr).astype(int)
    out["acc@0.5"] = float(accuracy_score(y_true, pred))
    out["pos_rate"] = float(np.mean(y_true))
    return out

def fit_lr_lbfgs(X_train, y_train) -> LogisticRegression:
    model = LogisticRegression(
        solver="lbfgs",
        max_iter=2000,
        n_jobs=-1,
        verbose=0,
        class_weight=None,
    )
    model.fit(X_train, y_train)
    return model

from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import MaxAbsScaler
from sklearn.linear_model import LogisticRegression
import numpy as np

def fit_lr_saga(X, y, C=1.0, tol=1e-6, max_iter=5000, balanced=False, seed=0):
    class_w = "balanced" if balanced else None
    clf = make_pipeline(
        MaxAbsScaler(),
        LogisticRegression(
            solver="saga", penalty="l2", C=C, tol=tol, max_iter=max_iter,
            n_jobs=-1, verbose=0, class_weight=class_w, random_state=seed
        )
    )
    clf.fit(X, y)
    return clf


## MultiLabel

In [None]:
def instances_from_labeled_csv_multilabel(path_to_labels_dir: str, task: str):
    lp_csv = os.path.join(path_to_labels_dir, task, "labeled_patients.csv")
    lp = load_labeled_patients(lp_csv)  # label_type == "multilabel"
    assert lp.get_labeler_type() == "multilabel", " multilabel"
    out = []
    seen = set()
    for pid, labels in lp.items():
        for lab in labels:
            t_iso = lab.time.replace(second=0, microsecond=0).isoformat(timespec="minutes")
            key = (int(pid), t_iso)
            if key in seen:
                continue
            seen.add(key)
            vals = lab.value
            if not isinstance(vals, (list, tuple)):
                vals = [vals]
            out.append((int(pid), t_iso, [str(v) for v in vals]))
    return out

def load_label_vocab_from_json(phenotype_json_path: str):
    with open(phenotype_json_path, "r") as f:
        mp = json.load(f)
    return list(mp.keys())

def _join_rows_multilabel(rows_df: pd.DataFrame,
                          samples: list[tuple[int, str, list[str]]],
                          label_vocab: list[str]) -> tuple[np.ndarray, np.ndarray]:
    label_to_idx = {l:i for i,l in enumerate(label_vocab)}
    key2row = {(int(pid), t): i for i, (pid, t) in enumerate(zip(rows_df["patient_id"].values,
                                                                 rows_df["time"].values))}
    idxs, Y, missed, oov = [], [], 0, 0
    for pid, t_iso, labels in samples:
        r = key2row.get((pid, t_iso))
        if r is None:
            missed += 1
            continue
        y = np.zeros(len(label_vocab), dtype=int)
        for lab in labels:
            j = label_to_idx.get(lab)
            if j is None:
                oov += 1
                continue
            y[j] = 1
        idxs.append(r)
        Y.append(y)
    if missed > 0:
        logger.warning(f"[align]  {missed}  rows.csv ，")
    if oov > 0:
        logger.warning(f"[align]  {oov}  label_vocab，")
    return np.array(idxs, dtype=int), (np.vstack(Y) if Y else np.zeros((0, len(label_vocab)), dtype=int))

from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, accuracy_score

def evaluate_multilabel(Y_true: np.ndarray, P: np.ndarray, thr: float = 0.5) -> dict:
    out = {}

    # --- AUROC ---
    try:
        out["auroc_micro"] = float(roc_auc_score(Y_true.ravel(), P.ravel()))
    except Exception:
        out["auroc_micro"] = float("nan")

    per = []
    for j in range(Y_true.shape[1]):
        yj, pj = Y_true[:, j], P[:, j]
        if len(np.unique(yj)) < 2:
            continue
        try:
            per.append(roc_auc_score(yj, pj))
        except Exception:
            pass
    out["auroc_macro"] = float(np.mean(per)) if per else float("nan")

    # --- AUPRC ---
    try:
        out["auprc_micro"] = float(average_precision_score(Y_true.ravel(), P.ravel()))
    except Exception:
        out["auprc_micro"] = float("nan")

    per_ap = []
    for j in range(Y_true.shape[1]):
        yj, pj = Y_true[:, j], P[:, j]
        if len(np.unique(yj)) < 2:
            continue
        try:
            per_ap.append(average_precision_score(yj, pj))
        except Exception:
            pass
    out["auprc_macro"] = float(np.mean(per_ap)) if per_ap else float("nan")

    pred = (P >= thr).astype(int)
    out["f1_micro"] = float(f1_score(Y_true, pred, average="micro", zero_division=0))
    out["f1_macro"] = float(f1_score(Y_true, pred, average="macro", zero_division=0))
    out["subset_acc"] = float(accuracy_score(Y_true, pred))

    out["pos_rate_macro"] = float(np.mean(Y_true.mean(axis=0))) if Y_true.size else float("nan")
    return out

def find_best_micro_threshold(Y_val: np.ndarray, P_val: np.ndarray, grid=None) -> float:
    if grid is None:
        grid = np.linspace(0.05, 0.95, 19)
    best_thr, best_f1 = 0.5, -1
    for t in grid:
        pred = (P_val >= t).astype(int)
        f1 = f1_score(Y_val, pred, average="micro", zero_division=0)
        if f1 > best_f1:
            best_f1, best_thr = f1, float(t)
    return best_thr

def find_per_class_thresholds(Y_val: np.ndarray, P_val: np.ndarray, grid=None) -> np.ndarray:
    if grid is None:
        grid = np.linspace(0.05, 0.95, 19)
    T = np.full(Y_val.shape[1], 0.5, dtype=float)
    for j in range(Y_val.shape[1]):
        yj, pj = Y_val[:, j], P_val[:, j]
        if len(np.unique(yj)) < 2:
            continue
        best, bestf = 0.5, -1
        for t in grid:
            f1 = f1_score(yj, (pj>=t).astype(int), average="binary", zero_division=0)
            if f1 > bestf:
                bestf, best = f1, float(t)
        T[j] = best
    return T

from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, accuracy_score

def evaluate_multilabel(Y_true: np.ndarray, P: np.ndarray, thr: float = 0.5) -> dict:
    out = {}

    # --- AUROC ---
    try:
        out["auroc_micro"] = float(roc_auc_score(Y_true.ravel(), P.ravel()))
    except Exception:
        out["auroc_micro"] = float("nan")

    per = []
    for j in range(Y_true.shape[1]):
        yj, pj = Y_true[:, j], P[:, j]
        if len(np.unique(yj)) < 2:
            continue
        try:
            per.append(roc_auc_score(yj, pj))
        except Exception:
            pass
    out["auroc_macro"] = float(np.mean(per)) if per else float("nan")

    # --- AUPRC ---
    try:
        out["auprc_micro"] = float(average_precision_score(Y_true.ravel(), P.ravel()))
    except Exception:
        out["auprc_micro"] = float("nan")

    per_ap = []
    for j in range(Y_true.shape[1]):
        yj, pj = Y_true[:, j], P[:, j]
        if len(np.unique(yj)) < 2:
            continue
        try:
            per_ap.append(average_precision_score(yj, pj))
        except Exception:
            pass
    out["auprc_macro"] = float(np.mean(per_ap)) if per_ap else float("nan")

    pred = (P >= thr).astype(int)
    out["f1_micro"] = float(f1_score(Y_true, pred, average="micro", zero_division=0))
    out["f1_macro"] = float(f1_score(Y_true, pred, average="macro", zero_division=0))
    out["subset_acc"] = float(accuracy_score(Y_true, pred))

    out["pos_rate_macro"] = float(np.mean(Y_true.mean(axis=0))) if Y_true.size else float("nan")
    return out

def find_best_micro_threshold(Y_val: np.ndarray, P_val: np.ndarray, grid=None) -> float:
    if grid is None:
        grid = np.linspace(0.05, 0.95, 19)
    best_thr, best_f1 = 0.5, -1
    for t in grid:
        pred = (P_val >= t).astype(int)
        f1 = f1_score(Y_val, pred, average="micro", zero_division=0)
        if f1 > best_f1:
            best_f1, best_thr = f1, float(t)
    return best_thr

def find_per_class_thresholds(Y_val: np.ndarray, P_val: np.ndarray, grid=None) -> np.ndarray:
    if grid is None:
        grid = np.linspace(0.05, 0.95, 19)
    T = np.full(Y_val.shape[1], 0.5, dtype=float)
    for j in range(Y_val.shape[1]):
        yj, pj = Y_val[:, j], P_val[:, j]
        if len(np.unique(yj)) < 2:
            continue
        best, bestf = 0.5, -1
        for t in grid:
            f1 = f1_score(yj, (pj>=t).astype(int), average="binary", zero_division=0)
            if f1 > bestf:
                bestf, best = f1, float(t)
        T[j] = best
    return T

def run_one_model_multilabel(model_type: str, shots_for_k: dict, label_vocab: list[str], sage=False):
    logger.info(f"=== [{model_type}] (multilabel)  ===")
    if model_type == "count":
        X_tr, rows_tr, _ = load_count_split(COUNT_FILES["train"])
        X_va, rows_va, _ = load_count_split(COUNT_FILES["tuning"])
        X_te, rows_te, _ = load_count_split(COUNT_FILES["held"])
        X_tr = _ensure_csr32(X_tr); X_va = _ensure_csr32(X_va); X_te = _ensure_csr32(X_te)
        scaler = MaxAbsScaler(copy=False); scaler.fit(X_tr)
        scale = scaler.scale_.astype(np.float32); inv = np.ones_like(scale, dtype=np.float32)
        nz = scale != 0; inv[nz] = 1.0 / scale[nz]
        X_tr_s = X_tr.copy(); inplace_column_scale(X_tr_s, inv)
        X_va_s = X_va.copy(); inplace_column_scale(X_va_s, inv)
        X_te_s = X_te.copy(); inplace_column_scale(X_te_s, inv)
    else:
        X_tr, rows_tr = load_clmbr_split(CLMBR_FILES["train"])
        X_va, rows_va = load_clmbr_split(CLMBR_FILES["tuning"])
        X_te, rows_te = load_clmbr_split(CLMBR_FILES["held"])

    held_instances = instances_from_labeled_csv_multilabel(HELDOUT["labels"], LABELING_FUNCTION)
    idx_te, Y_te = _join_rows_multilabel(rows_te, held_instances, label_vocab)
    logger.info(f"[{model_type}] held_out : {len(idx_te)}/{len(held_instances)}")

    results = []
    for rep, payload in shots_for_k.items():
        tr_samples = list(zip(payload["patient_ids_train_k"],
                              payload["label_times_train_k"],
                              payload["label_values_train_k"]))  # <- list[str]
        va_samples = list(zip(payload["patient_ids_val_k"],
                              payload["label_times_val_k"],
                              payload["label_values_val_k"]))    # <- list[str]

        idx_tr, Y_tr = _join_rows_multilabel(rows_tr, tr_samples, label_vocab)
        idx_va, Y_va = _join_rows_multilabel(rows_va, va_samples, label_vocab)
        logger.info(f"[{model_type}|rep={rep}] train : {len(idx_tr)}/{len(tr_samples)}, "
                    f"val : {len(idx_va)}/{len(va_samples)}")

        if model_type == "count":
            Xtr = slice_rows(X_tr_s, idx_tr)
            Xva = slice_rows(X_va_s, idx_va)
            Xte = slice_rows(X_te_s, idx_te)
        else:
            Xtr = slice_rows(X_tr, idx_tr)
            Xva = slice_rows(X_va, idx_va)
            Xte = slice_rows(X_te, idx_te)

        base_lr = LogisticRegression(
            solver="saga", penalty="l2", C=1.0, tol=1e-6, max_iter=5000,
            n_jobs=-1, class_weight="balanced", random_state=rep, verbose=0
        )
        clf = OneVsRestClassifier(base_lr, n_jobs=-1)
        clf.fit(Xtr, Y_tr)

        P_tr = clf.predict_proba(Xtr)   # (n_tr, L)
        P_va = clf.predict_proba(Xva)   # (n_va, L)
        P_te = clf.predict_proba(Xte)   # (n_te, L)

        best_thr = find_best_micro_threshold(Y_va, P_va)

        m_tr = evaluate_multilabel(Y_tr, P_tr, thr=best_thr)
        m_va = evaluate_multilabel(Y_va, P_va, thr=best_thr)
        m_te = evaluate_multilabel(Y_te, P_te, thr=best_thr)

        logger.success(
            f"[{model_type}|rep={rep}] AUROC micro/macro (tr/va/te) = "
            f"{m_tr['auroc_micro']:.4f}/{m_tr['auroc_macro']:.4f} | "
            f"{m_va['auroc_micro']:.4f}/{m_va['auroc_macro']:.4f} | "
            f"{m_te['auroc_micro']:.4f}/{m_te['auroc_macro']:.4f}  ||  "
            f"F1@best micro/macro (tr/va/te) = "
            f"{m_tr['f1_micro']:.3f}/{m_tr['f1_macro']:.3f} | "
            f"{m_va['f1_micro']:.3f}/{m_va['f1_macro']:.3f} | "
            f"{m_te['f1_micro']:.3f}/{m_te['f1_macro']:.3f}"
        )

        results.append({
            "model": model_type,
            "rep": rep,
            "metrics_train": m_tr,
            "metrics_val": m_va,
            "metrics_test": m_te,
            "thresholds": {"best_on_val_micro": best_thr},
            "data_stats": {
                "n_train": int(Y_tr.shape[0]),
                "pos_train_macro": float(np.mean(Y_tr.mean(axis=0))),
                "n_val":   int(Y_va.shape[0]),
                "pos_val_macro":  float(np.mean(Y_va.mean(axis=0))),
                "n_test":  int(Y_te.shape[0]),
                "pos_test_macro": float(np.mean(Y_te.mean(axis=0))),
            },
        })
    return results


def _is_multilabel_task(path_to_labels_dir: str, task: str) -> bool:
    lp_csv = os.path.join(path_to_labels_dir, task, "labeled_patients.csv")
    lp = load_labeled_patients(lp_csv)
    return lp.get_labeler_type() == "multilabel"

## Single Label

In [None]:
import numpy as np
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import MaxAbsScaler
from sklearn.linear_model import LogisticRegression
from scipy.sparse import csr_matrix, issparse
from sklearn.utils.sparsefuncs import inplace_column_scale

def _ensure_csr32(X):
    if not issparse(X):
        X = csr_matrix(X)
    X = X.tocsr()
    if X.dtype != np.float32:
        X = X.astype(np.float32)
    return X

def fit_lr_saga_staged(X, y, C=1.0, seed=0, max_total=2000, step=200, tol=1e-4, balanced=True, verbose=0):
    class_w = "balanced" if balanced else None
    clf = LogisticRegression(
        solver="saga", penalty="l2", C=C,
        tol=tol, max_iter=step, warm_start=True,
        n_jobs=1, verbose=verbose, class_weight=class_w,
        random_state=seed,
    )
    total = 0
    while total < max_total:
        clf.fit(X, y)
        total += step
        if hasattr(clf, "n_iter_") and np.all(clf.n_iter_ < clf.max_iter):
            break
        clf.max_iter += step
    return clf

def run_one_model(model_type: str, shots_for_k: dict, sage=False):
    logger.info(f"=== [{model_type}]  ===")
    if model_type == "count":
        X_tr, rows_tr, y_tr_all = load_count_split(COUNT_FILES["train"])
        X_va, rows_va, y_va_all = load_count_split(COUNT_FILES["tuning"])
        X_te, rows_te, y_te_all = load_count_split(COUNT_FILES["held"])

        X_tr = _ensure_csr32(X_tr); X_va = _ensure_csr32(X_va); X_te = _ensure_csr32(X_te)
        scaler = MaxAbsScaler(copy=False)
        scaler.fit(X_tr)
        scale = scaler.scale_.astype(np.float32)
        inv = np.ones_like(scale, dtype=np.float32)
        nz = scale != 0
        inv[nz] = 1.0 / scale[nz]

        X_tr_s = X_tr.copy(); inplace_column_scale(X_tr_s, inv)
        X_va_s = X_va.copy(); inplace_column_scale(X_va_s, inv)
        X_te_s = X_te.copy(); inplace_column_scale(X_te_s, inv)
    else:
        X_tr, rows_tr = load_clmbr_split(CLMBR_FILES["train"])
        X_va, rows_va = load_clmbr_split(CLMBR_FILES["tuning"])
        X_te, rows_te = load_clmbr_split(CLMBR_FILES["held"])

    held_instances = instances_from_labeled_csv(HELDOUT["labels"], LABELING_FUNCTION)
    idx_te, y_te = _join_rows(rows_te, held_instances)
    logger.info(f"[{model_type}] held_out : {len(idx_te)}/{len(held_instances)}")

    results = []
    for rep, payload in shots_for_k.items():
        tr_samples = list(zip(payload["patient_ids_train_k"],
                              payload["label_times_train_k"],
                              payload["label_values_train_k"]))
        va_samples = list(zip(payload["patient_ids_val_k"],
                              payload["label_times_val_k"],
                              payload["label_values_val_k"]))
        idx_tr, y_tr = _join_rows(rows_tr, tr_samples)
        idx_va, y_va = _join_rows(rows_va, va_samples)

        logger.info(f"[{model_type}|rep={rep}] train : {len(idx_tr)}/{len(tr_samples)}, "
                    f"val : {len(idx_va)}/{len(va_samples)}")

        if model_type == "count":
            Xtr = slice_rows(X_tr_s, idx_tr)
            Xva = slice_rows(X_va_s, idx_va)
            Xte = slice_rows(X_te_s, idx_te)
        else:
            Xtr = slice_rows(X_tr, idx_tr)
            Xva = slice_rows(X_va, idx_va)
            Xte = slice_rows(X_te, idx_te)

        if sage:
            print("sage")
            C_grid = [0.3, 1.0, 3.0]
            best = None
            for C in C_grid:
                clf_try = fit_lr_saga_staged(Xtr, y_tr, C=C, seed=rep, max_total=2000, step=200, tol=1e-4, balanced=True, verbose=0)
                p_va_try = clf_try.predict_proba(Xva)[:, 1]
                auroc = roc_auc_score(y_va, p_va_try)
                if best is None or auroc > best[0]:
                    best = (auroc, C, clf_try)
            logger.info(f"[count|rep={rep}] SAGA best C={best[1]} (AUROC={best[0]:.4f})")
            clf = best[2]
        else:
            clf = fit_lr_lbfgs(Xtr, y_tr)

        p_tr = clf.predict_proba(Xtr)[:, 1]
        p_va = clf.predict_proba(Xva)[:, 1]
        p_te = clf.predict_proba(Xte)[:, 1]
        m_tr = evaluate_binary(y_tr, p_tr)
        m_va = evaluate_binary(y_va, p_va)
        m_te = evaluate_binary(y_te, p_te)

        best_thr, _ = find_best_f1_threshold(y_va, p_va)
        f1_05_tr = f1_at_threshold(y_tr, p_tr, 0.5)
        f1_05_va = f1_at_threshold(y_va, p_va, 0.5)
        f1_05_te = f1_at_threshold(y_te, p_te, 0.5)
        f1b_tr = f1_at_threshold(y_tr, p_tr, best_thr)
        f1b_va = f1_at_threshold(y_va, p_va, best_thr)
        f1b_te = f1_at_threshold(y_te, p_te, best_thr)

        logger.success(
            f"[{model_type}|rep={rep}] "
            f"AUROC (tr/va/te) = {m_tr['auroc']:.4f} / {m_va['auroc']:.4f} / {m_te['auroc']:.4f}  | "
            f"AUPRC (tr/va/te) = {m_tr['auprc']:.4f} / {m_va['auprc']:.4f} / {m_te['auprc']:.4f}  | "
            f"F1@0.5 (tr/va/te) = {f1_05_tr['f1']:.3f} / {f1_05_va['f1']:.3f} / {f1_05_te['f1']:.3f}"
        )
        logger.info(
            f"[{model_type}|rep={rep}] F1@val-best({best_thr:.3f}) "
            f"(tr/va/te) = {f1b_tr['f1']:.3f} / {f1b_va['f1']:.3f} / {f1b_te['f1']:.3f}  | "
            f"P/R@test = {f1b_te['precision']:.3f}/{f1b_te['recall']:.3f}"
        )

        param_snapshot = _snapshot_lr_params(clf)
        results.append({
            "model": model_type,
            "rep": rep,
            "metrics_train": m_tr,
            "metrics_val": m_va,
            "metrics_test": m_te,
            "thresholds": {"fixed_0.5": 0.5, "best_on_val": best_thr},
            "f1_at_0.5": {
                "train": {"f1": f1_05_tr["f1"], "precision": f1_05_tr["precision"], "recall": f1_05_tr["recall"]},
                "val":   {"f1": f1_05_va["f1"], "precision": f1_05_va["precision"], "recall": f1_05_va["recall"]},
                "test":  {"f1": f1_05_te["f1"], "precision": f1_05_te["precision"], "recall": f1_05_te["recall"]},
            },
            "f1_at_best": {
                "train": {"f1": f1b_tr["f1"], "precision": f1b_tr["precision"], "recall": f1b_tr["recall"]},
                "val":   {"f1": f1b_va["f1"], "precision": f1b_va["precision"], "recall": f1b_va["recall"]},
                "test":  {"f1": f1b_te["f1"], "precision": f1b_te["precision"], "recall": f1b_te["recall"]},
            },
            "param_snapshot": param_snapshot,
            "data_stats": {
                "n_train": int(len(y_tr)), "pos_train": int(y_tr.sum()),
                "n_val":   int(len(y_va)), "pos_val":  int(y_va.sum()),
                "n_test":  int(len(y_te)), "pos_test": int(y_te.sum()),
            },
        })
    return results


In [None]:
warnings.filterwarnings("ignore", category=UserWarning)

assert os.path.exists(SHOTS_JSON), f" shots JSON: {SHOTS_JSON}"
shots = load_shots(SHOTS_JSON, LABELING_FUNCTION, K_VALUES)

## Evaluation

In [None]:
PHENOTYPE_JSON_PATH = "/root/autodl-tmp/phenotypes_ccs_from_parent.json"
def eval_one_model(model_name: str, save_txt: bool = True, out_dir: str = "runs", sage=False):
    all_results = []
    IS_MULTI = _is_multilabel_task(HELDOUT["labels"], LABELING_FUNCTION)
    if IS_MULTI:
        LABEL_VOCAB = load_label_vocab_from_json(PHENOTYPE_JSON_PATH)
    for k, reps in shots.items():
        logger.info(f"******** K={k} | model={model_name} | task={'multilabel' if IS_MULTI else 'binary'} ********")
        if IS_MULTI:
            res = run_one_model_multilabel(model_name, reps, label_vocab=LABEL_VOCAB, sage=sage)
        else:
            res = run_one_model(model_name, reps, sage=sage)
        for r in res:
            r["k"] = k
            r["model"] = model_name
        all_results.extend(res)

    rows = []
    for r in all_results:
        te = r["metrics_test"]; va = r["metrics_val"]
        if IS_MULTI:
            rows.append({
                "k": r["k"], "model": r["model"], "rep": r["rep"],
                "val_auroc_micro": va["auroc_micro"], "val_auroc_macro": va["auroc_macro"],
                "val_auprc_micro": va["auprc_micro"], "val_auprc_macro": va["auprc_macro"],
                "test_auroc_micro": te["auroc_micro"], "test_auroc_macro": te["auroc_macro"],
                "test_auprc_micro": te["auprc_micro"], "test_auprc_macro": te["auprc_macro"],
                "test_subset_acc": te["subset_acc"],
                "val_f1_micro": va["f1_micro"], "test_f1_micro": te["f1_micro"],
                "val_f1_macro": va["f1_macro"], "test_f1_macro": te["f1_macro"],
                "best_thr": r["thresholds"]["best_on_val_micro"],
            })
        else:
            rows.append({
                "k": r["k"], "model": r["model"], "rep": r["rep"],
                "val_auroc": va["auroc"], "val_auprc": va["auprc"], "val_brier": va["brier"],
                "test_auroc": te["auroc"], "test_auprc": te["auprc"], "test_brier": te["brier"],
                "test_acc@0.5": te["acc@0.5"], "test_pos_rate": te["pos_rate"],
                "val_f1@best": r["f1_at_best"]["val"]["f1"], "test_f1@best": r["f1_at_best"]["test"]["f1"],
                "best_thr":     r["thresholds"]["best_on_val"],
            })
    summary = pd.DataFrame(rows).sort_values(by=["k","model","rep"]).reset_index(drop=True)
    print(f"\n===== Summary for model: {model_name} =====")
    print(summary)

    if save_txt:
        extra = {"model": model_name, "k_values": list(shots.keys()), "multilabel": IS_MULTI}
        out_fp = _save_run_txt(summary, all_results, extra_params=extra, out_dir=out_dir, fname_prefix=f"{model_name}")
        print(f"[Saved] TXT report -> {out_fp}")

    return summary, all_results

## Count Model

In [None]:
summary_m0 = eval_one_model(MODELS[0],sage=False)

## CLMBR Model

In [None]:
summary_m1 = eval_one_model(MODELS[1],sage=True)

In [None]:


all_results = []
for k, reps in shots.items():
    logger.info(f"******** K={k}（{'ALL' if k==-1 else 'FEW'}）********")
    for m in MODELS:
        res = run_one_model(m, reps)
        for r in res:
            r["k"] = k
        all_results.extend(res)

rows = []
for r in all_results:
    rows.append({
        "k": r["k"],
        "model": r["model"],
        "rep": r["rep"],
        "val_auroc": r["metrics_val"]["auroc"],
        "val_auprc": r["metrics_val"]["auprc"],
        "val_brier": r["metrics_val"]["brier"],
        "test_auroc": r["metrics_test"]["auroc"],
        "test_auprc": r["metrics_test"]["auprc"],
        "test_brier": r["metrics_test"]["brier"],
        "test_acc@0.5": r["metrics_test"]["acc@0.5"],
        "test_pos_rate": r["metrics_test"]["pos_rate"],
    })
summary = pd.DataFrame(rows).sort_values(by=["k","model","rep"]).reset_index(drop=True)
summary

In [None]:
import pickle, numpy as np, pandas as pd

PKL = "/root/autodl-tmp/femr/train/femr_features/count_features.pkl"

with open(PKL, "rb") as f:
    obj = pickle.load(f)

print("TOP:", type(obj))
if isinstance(obj, dict):
    print("  keys:", list(obj.keys())[:50])
    for k,v in obj.items():
        if isinstance(v, (np.ndarray, pd.DataFrame, list, tuple)):
            try:
                shape = v.shape if hasattr(v, "shape") else len(v)
            except Exception:
                shape = "?"
            print(f"   - {k}: {type(v)}, shape/len={shape}")
        else:
            print(f"   - {k}: {type(v)}")

elif isinstance(obj, (list, tuple)):
    print("  len =", len(obj))
    for i,v in enumerate(obj[:10]):
        if isinstance(v, (np.ndarray, pd.DataFrame, list, tuple)):
            try:
                shape = v.shape if hasattr(v, "shape") else len(v)
            except Exception:
                shape = "?"
            print(f"   [{i}] {type(v)}, shape/len={shape}")
        else:
            print(f"   [{i}] {type(v)}")

else:
    print("  (unknown top-level)")
