# Train/Evaluate: multi-label phenotyping (alternative notebook)

## Purpose
Alternative or dedicated training/evaluation path for the 25-label phenotyping task.

## Inputs
- Features + labels for phenotyping

## Outputs
- Metrics and result artifacts


In [None]:
import os, json, pickle, datetime, warnings
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional

import numpy as np
import pandas as pd
from scipy.sparse import load_npz, csr_matrix
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, average_precision_score, brier_score_loss, accuracy_score
from loguru import logger

from ehrshot.labelers.core import load_labeled_patients

In [None]:
ROOT = "/root/autodl-tmp/femr"

TRAIN = {
    "db":       f"{ROOT}/train/extract",
    "labels":   f"{ROOT}/train/femr_labels",
    "features": f"{ROOT}/train/femr_features",
}
TUNING = {
    "db":       f"{ROOT}/tuning/extract",
    "labels":   f"{ROOT}/tuning/femr_labels",
    "features": f"{ROOT}/tuning/femr_features",
}
HELDOUT = {
    "db":       f"{ROOT}/held_out/extract",
    "labels":   f"{ROOT}/held_out/femr_labels",
    "features": f"{ROOT}/held_out/femr_features",
}

In [None]:
LABELING_FUNCTION = "mimic_icu_phenotyping"
SHOT_STRAT = "all"  # "few"
SHOTS_JSON = os.path.join(TRAIN["labels"], LABELING_FUNCTION, f"{SHOT_STRAT}_shots_data.json")
K_VALUES   = [-1]  # all ：[-1]；few ：[8,16,32]
MODELS = ["count", "clmbr"]

In [None]:
COUNT_FILES = {
    "train": os.path.join(TRAIN["features"], LABELING_FUNCTION, "count_features.pkl"),
    "tuning": os.path.join(TUNING["features"],LABELING_FUNCTION, "count_features.pkl"),
    "held":   os.path.join(HELDOUT["features"], LABELING_FUNCTION,"count_features.pkl"),
}

CLMBR_FILES = {
    "train": os.path.join(TRAIN["features"],LABELING_FUNCTION, "clmbr_features.pkl"),
    "tuning":os.path.join(TUNING["features"],LABELING_FUNCTION, "clmbr_features.pkl"),
    "held":  os.path.join(HELDOUT["features"], LABELING_FUNCTION, "clmbr_features.pkl"),
}
# =================================================================

In [None]:
import os, json, pickle, datetime, warnings
from typing import List, Tuple, Dict, Any
from dataclasses import dataclass

import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix, issparse

from loguru import logger
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import MaxAbsScaler
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, accuracy_score

from ehrshot.labelers.core import load_labeled_patients

In [None]:
def _to_minute_iso(x) -> str:
    dt = pd.to_datetime(x).to_pydatetime()
    return dt.replace(second=0, microsecond=0).isoformat(timespec="minutes")

In [None]:
def load_count_split_multilabel(path_pkl: str) -> Tuple[csr_matrix, pd.DataFrame]:
    """
     count_features.pkl，：(X_csr, patient_ids, label_values, label_times)
     y； (X, rows_df[patient_id,time])
    """
    with open(path_pkl, "rb") as f:
        X, pids, _y_ignore, times = pickle.load(f)

    if not issparse(X):
        X = csr_matrix(np.asarray(X))
    else:
        X = X.tocsr()

    rows_df = pd.DataFrame({
        "patient_id": np.asarray(pids, dtype=int),
        "time": [ _to_minute_iso(t) for t in times ],
    })
    assert X.shape[0] == len(rows_df), f"X  {X.shape[0]}  rows {len(rows_df)} "
    return X, rows_df

In [None]:
def _first_present(d: dict, keys):
    for k in keys:
        if k in d and d[k] is not None:
            return d[k], k
    return None, None

def load_clmbr_split(pkl_path: str) -> Tuple[Any, pd.DataFrame]:
    """
     (features, rows_df[patient_id,time])
    -  dict: {"features"/"representations"/"X"/"data_matrix", "patient_ids"/"pids", "time"/"labeling_time"...}
    - : (features, rows_df-like  (pid,time) )
    """
    assert os.path.exists(pkl_path), f" CLMBR : {pkl_path}"
    with open(pkl_path, "rb") as f:
        obj = pickle.load(f)

    if isinstance(obj, dict):
        feats, _   = _first_present(obj, ["features", "representations", "X", "data_matrix"])
        pids, _    = _first_present(obj, ["patient_id", "patient_ids", "pids"])
        times, _   = _first_present(obj, ["time", "times", "label_times", "labeling_time"])
        assert feats is not None and pids is not None and times is not None, f"{pkl_path} : {list(obj.keys())}"
        rows_df = pd.DataFrame({
            "patient_id": np.asarray(pids).astype(int),
            "time": [ _to_minute_iso(t) for t in times ],
        })
        return feats, rows_df

    if isinstance(obj, (list, tuple)) and len(obj) == 2:
        feats, rows = obj
        if isinstance(rows, pd.DataFrame):
            rows_df = rows.copy()
        else:
            rows = np.asarray(rows)
            assert rows.shape[1] == 2, f" (patient_id,time)，: {rows.shape}"
            rows_df = pd.DataFrame(rows, columns=["patient_id", "time"])
        rows_df["patient_id"] = rows_df["patient_id"].astype(int)
        rows_df["time"] = rows_df["time"].apply(_to_minute_iso)
        return feats, rows_df

    raise ValueError(f" CLMBR pkl ：{type(obj)}")


def instances_from_labeled_csv_multilabel(path_to_labels_dir: str, task: str) -> List[Tuple[int, str, List[str]]]:
    """
     <labels_dir>/<task>/labeled_patients.csv
     [(patient_id, time_iso(), list[str])]
    """
    lp_csv = os.path.join(path_to_labels_dir, task, "labeled_patients.csv")
    lp = load_labeled_patients(lp_csv)
    assert lp.get_labeler_type() == "multilabel", f"{task}  multilabel （ {lp.get_labeler_type()}）"
    out, seen = [], set()
    for pid, labels in lp.items():
        for lab in labels:
            t_iso = lab.time.replace(second=0, microsecond=0).isoformat(timespec="minutes")
            key = (int(pid), t_iso)
            if key in seen:
                continue
            seen.add(key)
            v = lab.value
            if v is None:
                tags = []
            elif isinstance(v, str):
                tags = [v]
            elif isinstance(v, (list, tuple, set)):
                tags = list(v)
            else:
                tags = [str(v)]
            out.append((int(pid), t_iso, [str(x) for x in tags]))
    return out


import ast
def _ensure_list_of_str(x):
    """
     list[str]：
    - "['A','B']" -> ["A","B"]
    - ["['A','B']"] -> ["A","B"]
    - ["A","B"] / {"A","B"} / ("A","B") -> ["A","B"]
    - "A" / 123 -> ["A"] / ["123"]
    """
    if isinstance(x, (list, tuple, set)) and len(x) == 1:
        only = next(iter(x))
        if isinstance(only, str):
            s = only.strip()
            if s.startswith("[") and s.endswith("]"):
                try:
                    return [str(i) for i in ast.literal_eval(s)]
                except Exception:
                    return [t.strip().strip("'\"") for t in s.strip("[]").split(",") if t.strip()]
        return [str(i) for i in x]

    if isinstance(x, str):
        s = x.strip()
        if s.startswith("[") and s.endswith("]"):
            try:
                return [str(i) for i in ast.literal_eval(s)]
            except Exception:
                return [t.strip().strip("'\"") for t in s.strip("[]").split(",") if t.strip()]
        return [s]

    if isinstance(x, (list, tuple, set)):
        return [str(i) for i in x]
    return [str(x)]

def _join_rows_multilabel(rows_df, samples, label_vocab):
    lab2idx = {lab: i for i, lab in enumerate(label_vocab)}
    key2row = {(int(pid), t): i for i,(pid,t) in enumerate(zip(rows_df["patient_id"].values,
                                                               rows_df["time"].values))}
    idxs, Y, missed, oov = [], [], 0, 0
    for pid, t_iso, labs in samples:
        r = key2row.get((int(pid), t_iso))
        if r is None:
            missed += 1; continue
        labs = _ensure_list_of_str(labs)  # ★
        y = np.zeros(len(label_vocab), dtype=int)
        for lab in labs:
            j = lab2idx.get(lab)
            if j is None:
                oov += 1; continue
            y[j] = 1
        idxs.append(r); Y.append(y)
    if missed:
        logger.warning(f"[align] {missed}  rows_df ，")
    if oov:
        logger.warning(f"[align] {oov}  label_vocab，")
    return np.array(idxs, dtype=int), (np.vstack(Y) if Y else np.zeros((0, len(label_vocab)), dtype=int))

In [None]:
def evaluate_multilabel(Y_true: np.ndarray, P: np.ndarray, thr: float = 0.5) -> Dict[str, float]:
    out: Dict[str, float] = {}
    # micro
    try: out["auroc_micro"] = float(roc_auc_score(Y_true.ravel(), P.ravel()))
    except Exception: out["auroc_micro"] = float("nan")
    try: out["auprc_micro"] = float(average_precision_score(Y_true.ravel(), P.ravel()))
    except Exception: out["auprc_micro"] = float("nan")
    # macro
    aucs, aprs = [], []
    for j in range(Y_true.shape[1]):
        yj, pj = Y_true[:, j], P[:, j]
        if len(np.unique(yj)) < 2:  # 0/1
            continue
        try: aucs.append(roc_auc_score(yj, pj))
        except Exception: pass
        try: aprs.append(average_precision_score(yj, pj))
        except Exception: pass
    out["auroc_macro"] = float(np.mean(aucs)) if aucs else float("nan")
    out["auprc_macro"] = float(np.mean(aprs)) if aprs else float("nan")
    # F1
    pred = (P >= thr).astype(int)
    out["f1_micro"] = float(f1_score(Y_true, pred, average="micro", zero_division=0))
    out["f1_macro"] = float(f1_score(Y_true, pred, average="macro", zero_division=0))
    out["subset_acc"] = float(accuracy_score(Y_true, pred))
    return out

def find_best_micro_threshold(Y_val: np.ndarray, P_val: np.ndarray,
                              grid: np.ndarray | None = None) -> float:
    if grid is None:
        grid = np.linspace(0.05, 0.95, 19)
    best_thr, best_f1 = 0.5, -1.0
    for t in grid:
        f1 = f1_score(Y_val, (P_val >= t).astype(int), average="micro", zero_division=0)
        if f1 > best_f1:
            best_f1, best_thr = f1, float(t)
    return best_thr


In [None]:
def _slice_rows(X, idxs: np.ndarray):
    if isinstance(X, csr_matrix):
        return X[idxs]
    return np.asarray(X)[idxs]

from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import MaxAbsScaler, StandardScaler
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.naive_bayes import ComplementNB
from sklearn.multiclass import OneVsRestClassifier
import numpy as np

def build_multilabel_head(model_type: str, head: str, rep: int):
    """
     (clf, needs_scaled_X)
    - head='saga'  : OVR(LogReg SAGA) —— （，）
    - head='sgd'   : OVR(SGD log_loss) —— ，
    - head='nb'    : OVR(ComplementNB) —— （ count ，X 、）
    """
    head = head.lower()
    if head == "saga":
        base = make_pipeline(
            MaxAbsScaler(),
            LogisticRegression(
                solver="saga", penalty="l2",
                C=0.5,
                tol=1e-3,  # ，
                max_iter=800,
                n_jobs=-1,
                class_weight="balanced",
                random_state=rep,
                verbose=0,
            )
        )
        return OneVsRestClassifier(base, n_jobs=-1), True

    if head == "sgd":
        base = make_pipeline(
            MaxAbsScaler() if model_type == "count" else StandardScaler(with_mean=False),
            SGDClassifier(
                loss="log_loss", penalty="l2",
                alpha=1e-4,
                max_iter=15,
                early_stopping=True,
                n_iter_no_change=3,
                validation_fraction=0.1,
                learning_rate="optimal",
                class_weight=None,  # ， sample_weight
                random_state=rep,
            )
        )
        return OneVsRestClassifier(base, n_jobs=-1), True

    if head == "nb":
        if model_type != "count":
            raise ValueError("head='nb'  count 。")
        base = ComplementNB(alpha=0.1)
        return OneVsRestClassifier(base, n_jobs=-1), False  # ， X

    raise ValueError(f" head: {head}")

    

def run_one_model_multilabel(model_type: str,
                             shots_for_k: Dict[int, dict],
                             label_vocab: List[str],
                             COUNT_FILES: Dict[str, str],
                             CLMBR_FILES: Dict[str, str],
                             HELDOUT_LABELS_DIR: str,
                             TASK_NAME: str,
                             use_pipeline_scaler: bool = True) -> List[Dict[str, Any]]:
    """
    shots_for_k: {rep: payload}；payload  list[str]  label_values_*_k
    """
    logger.info(f"=== [{model_type}] (multilabel)  ===")
    if model_type == "count":
        X_tr, rows_tr = load_count_split_multilabel(COUNT_FILES["train"])
        X_va, rows_va = load_count_split_multilabel(COUNT_FILES["tuning"])
        X_te, rows_te = load_count_split_multilabel(COUNT_FILES["held"])
    else:
        X_tr, rows_tr = load_clmbr_split(CLMBR_FILES["train"])
        X_va, rows_va = load_clmbr_split(CLMBR_FILES["tuning"])
        X_te, rows_te = load_clmbr_split(CLMBR_FILES["held"])

    held_instances = instances_from_labeled_csv_multilabel(HELDOUT_LABELS_DIR, TASK_NAME)
    idx_te, Y_te = _join_rows_multilabel(rows_te, held_instances, label_vocab)
    logger.info(f"[{model_type}] held_out : {len(idx_te)}/{len(held_instances)}")

    results = []

    for rep, payload in shots_for_k.items():
        tr_samples = list(zip(payload["patient_ids_train_k"],
                              payload["label_times_train_k"],
                              payload["label_values_train_k"]))   # list[str]
        va_samples = list(zip(payload["patient_ids_val_k"],
                              payload["label_times_val_k"],
                              payload["label_values_val_k"]))     # list[str]

        idx_tr, Y_tr = _join_rows_multilabel(rows_tr, tr_samples, label_vocab)
        idx_va, Y_va = _join_rows_multilabel(rows_va, va_samples, label_vocab)
        logger.info(f"[{model_type}|rep={rep}]  train={len(idx_tr)}/{len(tr_samples)}, val={len(idx_va)}/{len(va_samples)}")

        Xtr = _slice_rows(X_tr, idx_tr)
        Xva = _slice_rows(X_va, idx_va)
        Xte = _slice_rows(X_te, idx_te)

        base_lr = make_pipeline(
            MaxAbsScaler(),
            LogisticRegression(
                solver="saga", penalty="l2", C=1.0, tol=1e-6, max_iter=5000,
                n_jobs=-1, class_weight="balanced", random_state=rep, verbose=0
            )
        )
        clf = OneVsRestClassifier(base_lr, n_jobs=-1)
        clf.fit(Xtr, Y_tr)

        P_tr = clf.predict_proba(Xtr)    # (n_tr, L)
        P_va = clf.predict_proba(Xva)    # (n_va, L)
        P_te = clf.predict_proba(Xte)    # (n_te, L)

        best_thr = find_best_micro_threshold(Y_va, P_va)

        m_tr = evaluate_multilabel(Y_tr, P_tr, thr=best_thr)
        m_va = evaluate_multilabel(Y_va, P_va, thr=best_thr)
        m_te = evaluate_multilabel(Y_te, P_te, thr=best_thr)

        logger.success(
            f"[{model_type}|rep={rep}] "
            f"AUROC micro/macro tr/va/te = "
            f"{m_tr['auroc_micro']:.4f}/{m_tr['auroc_macro']:.4f} | "
            f"{m_va['auroc_micro']:.4f}/{m_va['auroc_macro']:.4f} | "
            f"{m_te['auroc_micro']:.4f}/{m_te['auroc_macro']:.4f} || "
            f"F1 micro/macro tr/va/te @best = "
            f"{m_tr['f1_micro']:.3f}/{m_tr['f1_macro']:.3f} | "
            f"{m_va['f1_micro']:.3f}/{m_va['f1_macro']:.3f} | "
            f"{m_te['f1_micro']:.3f}/{m_te['f1_macro']:.3f}"
        )

        results.append({
            "model": model_type,
            "rep": rep,
            "metrics_train": m_tr,
            "metrics_val": m_va,
            "metrics_test": m_te,
            "thresholds": {"best_on_val_micro": best_thr},
            "data_stats": {
                "n_train": int(Y_tr.shape[0]),
                "n_val":   int(Y_va.shape[0]),
                "n_test":  int(Y_te.shape[0]),
                "pos_rate_macro_train": float(np.mean(Y_tr.mean(axis=0))) if Y_tr.size else float("nan"),
                "pos_rate_macro_val":   float(np.mean(Y_va.mean(axis=0))) if Y_va.size else float("nan"),
                "pos_rate_macro_test":  float(np.mean(Y_te.mean(axis=0))) if Y_te.size else float("nan"),
            },
        })
    return results

In [None]:
def load_shots(shots_json: str, task: str, ks: List[int]) -> Dict[int, Dict[int, dict]]:
    with open(shots_json, "r") as f:
        blob = json.load(f)
    assert task in blob, f"{shots_json}  {task}"
    task_dict = blob[task]
    out: Dict[int, Dict[int, dict]] = {}
    for k in ks:
        if str(k) in task_dict:
            out[k] = {int(rep): v for rep, v in task_dict[str(k)].items()}
        elif k in task_dict:
            out[k] = {int(rep): v for rep, v in task_dict[k].items()}
        else:
            logger.warning(f"[shots] k={k}  {shots_json} ，")
    return out

def _is_multilabel_task(path_to_labels_dir: str, task: str) -> bool:
    lp_csv = os.path.join(path_to_labels_dir, task, "labeled_patients.csv")
    lp = load_labeled_patients(lp_csv)
    return lp.get_labeler_type() == "multilabel"

def eval_one_model_multitask(model_name: str,
                             SHOTS_JSON: str,
                             LABELING_FUNCTION: str,
                             K_VALUES: List[int],
                             COUNT_FILES: Dict[str, str],
                             CLMBR_FILES: Dict[str, str],
                             TRAIN_LABELS_DIR: str,
                             TUNING_LABELS_DIR: str,
                             HELDOUT_LABELS_DIR: str,
                             LABEL_VOCAB: List[str],
                             out_dir: str = "runs"):
    """
    ： multilabel，； binary （）。
    """
    shots = load_shots(SHOTS_JSON, LABELING_FUNCTION, K_VALUES)
    IS_MULTI = _is_multilabel_task(HELDOUT_LABELS_DIR, LABELING_FUNCTION)

    all_results = []
    for k, reps in shots.items():
        logger.info(f"******** K={k} | model={model_name} | task={'multilabel' if IS_MULTI else 'binary'} ********")
        if IS_MULTI:
            res = run_one_model_multilabel(
                model_type=model_name,
                shots_for_k=reps,
                label_vocab=LABEL_VOCAB,
                COUNT_FILES=COUNT_FILES,
                CLMBR_FILES=CLMBR_FILES,
                HELDOUT_LABELS_DIR=HELDOUT_LABELS_DIR,
                TASK_NAME=LABELING_FUNCTION,
            )
        else:
            raise NotImplementedError(" multilabel； run_one_model。")
        for r in res:
            r["k"] = k
            r["model"] = model_name
        all_results.extend(res)

    rows = []
    for r in all_results:
        te = r["metrics_test"]; va = r["metrics_val"]
        rows.append({
            "k": r["k"], "model": r["model"], "rep": r["rep"],
            "val_auroc_micro": va["auroc_micro"], "val_auroc_macro": va["auroc_macro"],
            "val_auprc_micro": va["auprc_micro"], "val_auprc_macro": va["auprc_macro"],
            "test_auroc_micro": te["auroc_micro"], "test_auroc_macro": te["auroc_macro"],
            "test_auprc_micro": te["auprc_micro"], "test_auprc_macro": te["auprc_macro"],
            "val_f1_micro": va["f1_micro"], "val_f1_macro": va["f1_macro"],
            "test_f1_micro": te["f1_micro"], "test_f1_macro": te["f1_macro"],
            "test_subset_acc": te["subset_acc"],
            "best_thr": r["thresholds"]["best_on_val_micro"],
        })
    summary = pd.DataFrame(rows).sort_values(by=["k","model","rep"]).reset_index(drop=True)
    print(f"\n===== Summary for model: {model_name} =====")
    print(summary)

    os.makedirs(out_dir, exist_ok=True)
    report_path = os.path.join(out_dir, f"{model_name}-multilabel-summary.json")
    with open(report_path, "w") as f:
        json.dump({"summary": rows, "results": all_results}, f, ensure_ascii=False, indent=2)
    print(f"[Saved] JSON report -> {report_path}")

    return summary, all_results

In [None]:
def load_label_vocab_from_json(path: str) -> list[str]:
    import json
    with open(path, "r") as f:
        d = json.load(f)

    if isinstance(d, dict) and "root" in d and isinstance(d["root"], dict):
        labels = list(d["root"].keys())
    elif isinstance(d, dict):
        labels = list(d.keys())
    elif isinstance(d, list):
        labels = [str(x) for x in d]
    else:
        raise ValueError(f" {path} ，={type(d)}")

    def _clean(s: str) -> str:
        return " ".join(str(s).split())

    labels = [_clean(x) for x in labels]

    assert len(labels) == 25, f" 25 ， {len(labels)}； JSON "

    return labels


In [None]:
LABELING_FUNCTION = "mimic_icu_phenotyping"
PHENOTYPE_JSON_PATH = "/root/autodl-tmp/phenotypes_ccs_from_parent.json"  # 25 ，

LABEL_VOCAB = load_label_vocab_from_json(PHENOTYPE_JSON_PATH)
COUNT_FILES = {
    "train": f"/root/autodl-tmp/femr/train/femr_features/{LABELING_FUNCTION}/count_features.pkl",
    "tuning":f"/root/autodl-tmp/femr/tuning/femr_features/{LABELING_FUNCTION}/count_features.pkl",
    "held":  f"/root/autodl-tmp/femr/held_out/femr_features/{LABELING_FUNCTION}/count_features.pkl",
}
CLMBR_FILES = {
    "train": f"/root/autodl-tmp/femr/train/femr_features/{LABELING_FUNCTION}/clmbr_features.pkl",
    "tuning":f"/root/autodl-tmp/femr/tuning/femr_features/{LABELING_FUNCTION}/clmbr_features.pkl",
    "held":  f"/root/autodl-tmp/femr/held_out/femr_features/{LABELING_FUNCTION}/clmbr_features.pkl",
}
TRAIN_LABELS_DIR   = "/root/autodl-tmp/femr/train/femr_labels"
TUNING_LABELS_DIR  = "/root/autodl-tmp/femr/tuning/femr_labels"
HELDOUT_LABELS_DIR = "/root/autodl-tmp/femr/held_out/femr_labels"

SHOTS_JSON = f"/root/autodl-tmp/femr/train/femr_labels/{LABELING_FUNCTION}/all_shots_data.json"  # few/all-shot JSON
K_VALUES = [-1, 8, 16, 32, 64]  # all-shot

summary, all_results = eval_one_model_multitask(
    model_name="count",  # "clmbr"
    SHOTS_JSON=SHOTS_JSON,
    LABELING_FUNCTION=LABELING_FUNCTION,
    K_VALUES=K_VALUES,
    COUNT_FILES=COUNT_FILES,
    CLMBR_FILES=CLMBR_FILES,
    TRAIN_LABELS_DIR=TRAIN_LABELS_DIR,
    TUNING_LABELS_DIR=TUNING_LABELS_DIR,
    HELDOUT_LABELS_DIR=HELDOUT_LABELS_DIR,
    LABEL_VOCAB=LABEL_VOCAB,
)
