# Generate few-shot / subsample splits: ICU mortality

## Purpose
Create few-shot subsets/splits for ICU mortality experiments.

## Inputs
- Labels and/or features produced earlier

## Outputs
- Shot/split files saved to your chosen output directory


In [None]:
from loguru import logger
import os, csv, json, random
from typing import Dict, List, Tuple
from femr.labelers import load_labeled_patients

In [None]:
ROOT = "/root/autodl-tmp/femr"
LABELING_FUNCTION = "mimic_icu_mortality"

TRAIN_LABELS_DIR   = f"{ROOT}/train/femr_labels"
TUNING_LABELS_DIR  = f"{ROOT}/tuning/femr_labels"
HELDOUT_LABELS_DIR = f"{ROOT}/held_out/femr_labels"  # #6JSON， split.csv

TRAIN_SPLIT_CSV   = f"{ROOT}/train/split.csv"
TUNING_SPLIT_CSV  = f"{ROOT}/tuning/split.csv"
HELDOUT_SPLIT_CSV = f"{ROOT}/held_out/split.csv"

In [None]:
SHOT_STRAT = "few"  # "all"  "few"
KS = [1, 8, 16, 32, 64, 128, -1]  # SHOT_STRAT="few"
N_REPLICATES = 5
SEED = 97
BALANCE = "balanced"  # "balanced" =  1:1 ；"prevalence" =


def _write_split_csv_from_labeled(lp_path: str, out_csv: str, split_name: str):
    """ labeled_patients.csv  patient_id  split（train/val/test）"""
    lp = load_labeled_patients(lp_path)

    pid_set = set()
    for pid, _labels in lp.items():
        pid_set.add(int(pid))
    pids = sorted(pid_set)

    os.makedirs(os.path.dirname(out_csv), exist_ok=True)
    with open(out_csv, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["patient_id", "split"])
        for pid in pids:
            w.writerow([pid, split_name])

    logger.info(f"[split.csv] {out_csv}: {split_name} = {len(pids)} patients")


def _load_instances(lp_path: str) -> List[Tuple[int, str, int]]:
    """
     labeled_patients.csv， (patient_id, time_iso(), value0/1) 
    （ ICU =； (pid,time) ）
    """
    lp = load_labeled_patients(lp_path)
    assert lp.labeler_type == "boolean", f" boolean ，: {lp.labeler_type}"
    inst = []
    seen = set()
    for pid, labels in lp.items():
        for lab in labels:
            t_iso = lab.time.replace(second=0, microsecond=0).isoformat(timespec="minutes")
            key = (int(pid), t_iso)
            if key in seen:
                continue  # ，
            seen.add(key)
            v = 1 if bool(lab.value) else 0
            inst.append((int(pid), t_iso, v))
    return inst

def _stratified_orders(train_list, seed):
    """（/）。。"""
    pos = [i for i,(_,_,v) in enumerate(train_list) if v == 1]
    neg = [i for i,(_,_,v) in enumerate(train_list) if v == 0]
    rng = random.Random(seed); rng.shuffle(pos); rng.shuffle(neg)
    return pos, neg

def _extend_to_k(order: List[int], k: int) -> List[int]:
    """>=k，；order """
    if k <= 0: return order[:]
    if not order: return []  # 0，
    if len(order) >= k: return order[:k]
    reps, rem = divmod(k, len(order))
    return order * reps + order[:rem]

def _build_train_indices(pos_order, neg_order, k, balance="balanced"):
    """ k （balanced / prevalence）（）"""
    if k <= 0:
        return pos_order + neg_order

    npos, nneg = len(pos_order), len(neg_order)
    if balance == "balanced":
        k_pos = k // 2
        k_neg = k - k_pos
        if npos == 0: k_pos, k_neg = 0, k
        if nneg == 0: k_pos, k_neg = k, 0
    else:
        tot = max(1, npos + nneg)
        prop = npos / tot
        k_pos = int(round(k * prop))
        k_neg = k - k_pos

    pos_take = _extend_to_k(pos_order, k_pos)
    neg_take = _extend_to_k(neg_order, k_neg)
    return pos_take + neg_take

In [None]:

train_lp   = os.path.join(TRAIN_LABELS_DIR,   LABELING_FUNCTION, "labeled_patients.csv")
tuning_lp  = os.path.join(TUNING_LABELS_DIR,  LABELING_FUNCTION, "labeled_patients.csv")
heldout_lp = os.path.join(HELDOUT_LABELS_DIR, LABELING_FUNCTION, "labeled_patients.csv")

assert os.path.exists(train_lp),   f" {train_lp}"
assert os.path.exists(tuning_lp),  f" {tuning_lp}"
assert os.path.exists(heldout_lp), f" {heldout_lp}"

_write_split_csv_from_labeled(train_lp,   TRAIN_SPLIT_CSV,   "train")
_write_split_csv_from_labeled(tuning_lp,  TUNING_SPLIT_CSV,  "val")
_write_split_csv_from_labeled(heldout_lp, HELDOUT_SPLIT_CSV, "test")

In [None]:

train_pool = _load_instances(train_lp)   # [(pid, time_iso, y), ...]
val_pool   = _load_instances(tuning_lp)

logger.info(f"train ={len(train_pool)}, val ={len(val_pool)}")

train_pids  = [p for p,_,_ in train_pool]
train_times = [t for _,t,_ in train_pool]
train_vals  = [v for _,_,v in train_pool]
val_pids    = [p for p,_,_ in val_pool]
val_times   = [t for _,t,_ in val_pool]
val_vals    = [v for _,_,v in val_pool]

k_list = ([-1] if SHOT_STRAT == "all" else sorted(KS))

few_shots_dict = { LABELING_FUNCTION: {} }

for k in k_list:
    few_shots_dict[LABELING_FUNCTION][k] = {}
    for rep in range(N_REPLICATES):
        pos_order, neg_order = _stratified_orders(train_pool, SEED + rep)
        train_idxs = _build_train_indices(pos_order, neg_order, k, BALANCE)
        val_idxs   = list(range(len(val_pool)))

        entry = {
            "patient_ids_train_k":  [train_pids[i]            for i in train_idxs],
            "patient_ids_val_k":    [val_pids[i]              for i in val_idxs],
            "label_times_train_k":  [train_times[i]           for i in train_idxs],
            "label_times_val_k":    [val_times[i]             for i in val_idxs],
            "label_values_train_k": [int(train_vals[i])       for i in train_idxs],
            "label_values_val_k":   [int(val_vals[i])         for i in val_idxs],
            "train_idxs":           train_idxs,
            "val_idxs":             val_idxs,
        }
        few_shots_dict[LABELING_FUNCTION][k][rep] = entry

out_json = os.path.join(TRAIN_LABELS_DIR, LABELING_FUNCTION, f"{SHOT_STRAT}_shots_data.json")
os.makedirs(os.path.dirname(out_json), exist_ok=True)
with open(out_json, "w") as f:
    json.dump(few_shots_dict, f, ensure_ascii=False)

logger.success(f"[ok] wrote {out_json}")

def _rate(vals): return (sum(vals)/len(vals)) if vals else 0.0
for k in k_list:
    for rep, blob in few_shots_dict[LABELING_FUNCTION][k].items():
        trn_n = len(blob["train_idxs"]); val_n = len(blob["val_idxs"])
        trn_r = _rate(blob["label_values_train_k"]); val_r = _rate(blob["label_values_val_k"])
        logger.info(f"k={k}, rep={rep}: train={trn_n} (pos={trn_r:.3f}), val={val_n} (pos={val_r:.3f})")
