# Generate few-shot / subsample splits: ICU phenotyping

## Purpose
Create few-shot subsets/splits for ICU phenotyping experiments.

## Inputs
- Labels and/or features produced earlier

## Outputs
- Shot/split files saved to your chosen output directory


In [None]:
from loguru import logger
import os, csv, json, random
from typing import Dict, List, Tuple
from ehrshot.labelers.core import load_labeled_patients

In [None]:
ROOT = "/root/autodl-tmp/femr"
LABELING_FUNCTION = "mimic_icu_phenotyping"

TRAIN_LABELS_DIR   = f"{ROOT}/train/femr_labels"
TUNING_LABELS_DIR  = f"{ROOT}/tuning/femr_labels"
HELDOUT_LABELS_DIR = f"{ROOT}/held_out/femr_labels"  # #6JSON， split.csv

TRAIN_SPLIT_CSV   = f"{ROOT}/train/split.csv"
TUNING_SPLIT_CSV  = f"{ROOT}/tuning/split.csv"
HELDOUT_SPLIT_CSV = f"{ROOT}/held_out/split.csv"

SHOT_STRAT = "few"  # "all"  "few"
KS = [1, 8, 16, 32, 64, 128, -1]  # SHOT_STRAT="few"
N_REPLICATES = 5
SEED = 97
BALANCE = "balanced"  # "balanced" =  1:1 ；"prevalence" =


In [None]:
def _write_split_csv_from_labeled(lp_path: str, out_csv: str, split_name: str):
    """ labeled_patients.csv  patient_id  split（train/val/test）"""
    lp = load_labeled_patients(lp_path)

    pid_set = set()
    for pid, _labels in lp.items():
        pid_set.add(int(pid))
    pids = sorted(pid_set)

    os.makedirs(os.path.dirname(out_csv), exist_ok=True)
    with open(out_csv, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["patient_id", "split"])
        for pid in pids:
            w.writerow([pid, split_name])

    logger.info(f"[split.csv] {out_csv}: {split_name} = {len(pids)} patients")
    
def _pick_label_csv(base_dir: str, task: str) -> str:
    p_all = os.path.join(base_dir, task, "all_labels.csv")
    p_lp  = os.path.join(base_dir, task, "labeled_patients.csv")
    if os.path.exists(p_all):
        return p_all
    if os.path.exists(p_lp):
        return p_lp
    raise FileNotFoundError(f" {p_all}  {p_lp}")

from collections import Counter, defaultdict
from typing import Any, Iterable, Dict, List, Tuple
import random, json, csv, os
from ehrshot.labelers.core import load_labeled_patients

def _load_instances_any(lp_path: str) -> Tuple[List[Tuple[int, str, Any]], str]:
    lp = load_labeled_patients(lp_path)
    ltype = getattr(lp, "labeler_type", "unknown")
    inst, seen = [], set()
    for pid, labels in lp.items():
        for lab in labels:
            t_iso = lab.time.replace(second=0, microsecond=0).isoformat(timespec="minutes")
            key = (int(pid), t_iso)
            if key in seen:
                continue
            seen.add(key)
            v = lab.value  # ：bool/int/str/list[str]
            inst.append((int(pid), t_iso, v))
    return inst, ltype

def _orders_boolean(train_list, seed):
    pos = [i for i,(_,_,v) in enumerate(train_list) if bool(v)]
    neg = [i for i,(_,_,v) in enumerate(train_list) if not bool(v)]
    rng = random.Random(seed); rng.shuffle(pos); rng.shuffle(neg)
    return pos, neg

def _orders_categorical(train_list, seed):
    cls2idxs: Dict[str, List[int]] = defaultdict(list)
    for i,(_,_,v) in enumerate(train_list):
        cls2idxs[str(v)].append(i)
    rng = random.Random(seed)
    for idxs in cls2idxs.values():
        rng.shuffle(idxs)
    return cls2idxs  # dict[str, List[int]]

def _order_multilabel(train_list, seed):
    labels_per_inst: List[List[str]] = []
    for _,_,v in train_list:
        if v is None:
            labels_per_inst.append([])
        elif isinstance(v, (list, set, tuple)):
            labels_per_inst.append([str(x) for x in v])
        else:
            labels_per_inst.append([str(v)])  # ：

    freq = Counter()
    for tags in labels_per_inst:
        for t in set(tags):  # set ，
            freq[t] += 1

    rng = random.Random(seed)
    n = len(train_list)
    idxs = list(range(n))
    rng.shuffle(idxs)

    def rarity_score(i: int) -> float:
        tags = labels_per_inst[i]
        if not tags:
            return 0.0
        return sum(1.0 / max(1, freq[t]) for t in set(tags)) + 1e-6 * rng.random()

    idxs.sort(key=rarity_score, reverse=True)
    return idxs

def _extend_to_k(order: List[int], k: int) -> List[int]:
    if k <= 0: return order[:]
    if not order: return []
    if len(order) >= k: return order[:k]
    reps, rem = divmod(k, len(order))
    return order * reps + order[:rem]

def _build_train_indices_boolean(pos_order, neg_order, k, balance="balanced"):
    if k <= 0:
        return pos_order + neg_order
    npos, nneg = len(pos_order), len(neg_order)
    if balance == "balanced":
        k_pos = k // 2; k_neg = k - k_pos
        if npos == 0: k_pos, k_neg = 0, k
        if nneg == 0: k_pos, k_neg = k, 0
    else:
        tot = max(1, npos + nneg); prop = npos / tot
        k_pos = int(round(k * prop)); k_neg = k - k_pos
    return _extend_to_k(pos_order, k_pos) + _extend_to_k(neg_order, k_neg)

def _build_train_indices_categorical(cls2idxs: Dict[str,List[int]], k: int, balance="balanced"):
    orders = {c: idxs[:] for c, idxs in cls2idxs.items()}
    classes = list(orders.keys())
    if k <= 0:
        merged = []
        for c in classes:
            merged += orders[c]
        return merged
    C = max(1, len(classes))
    sizes = {c: len(orders[c]) for c in classes}
    if balance == "balanced":
        k_each = {c: k // C for c in classes}
        rem = k - sum(k_each.values())
        for c in sorted(classes, key=lambda x: -sizes[x])[:rem]:
            k_each[c] += 1
    else:
        tot = sum(sizes.values()) or 1
        k_each = {c: int(round(k * sizes[c] / tot)) for c in classes}
        diff = k - sum(k_each.values())
        if diff != 0:
            for c in sorted(classes, key=lambda x: -sizes[x]):
                if diff == 0: break
                k_each[c] += 1 if diff > 0 else -1
                diff += -1 if diff > 0 else 1
    merged = []
    for c in classes:
        merged += _extend_to_k(orders[c], k_each[c])
    return merged

def _build_train_indices_multilabel(order_all: List[int], k: int):
    if k <= 0: return order_all[:]
    return order_all[:k]

def _attach_row_indices(entry: dict, row_index_map: Dict[str,int], split: str):
    key_pid = f"patient_ids_{split}_k"
    key_tim = f"label_times_{split}_k"
    rows = []
    for pid, t_iso in zip(entry[key_pid], entry[key_tim]):
        rid = row_index_map.get(f"{int(pid)}|{t_iso}")
        rows.append(None if rid is None else int(rid))
    entry[f"row_indices_{split}_k"] = rows


In [None]:
train_lp   = _pick_label_csv(TRAIN_LABELS_DIR,   LABELING_FUNCTION)
tuning_lp  = _pick_label_csv(TUNING_LABELS_DIR,  LABELING_FUNCTION)
heldout_lp = _pick_label_csv(HELDOUT_LABELS_DIR, LABELING_FUNCTION)

_write_split_csv_from_labeled(train_lp,   TRAIN_SPLIT_CSV,   "train")
_write_split_csv_from_labeled(tuning_lp,  TUNING_SPLIT_CSV,  "val")
_write_split_csv_from_labeled(heldout_lp, HELDOUT_SPLIT_CSV, "test")

train_pool, train_type = _load_instances_any(train_lp)   # [(pid, time_iso, value), ...]
val_pool,   val_type   = _load_instances_any(tuning_lp)
assert train_type == val_type, f"train={train_type}, val={val_type} "
labeler_type = train_type
logger.info(f" {labeler_type} ")

k_list = ([-1] if SHOT_STRAT == "all" else sorted(KS))
few_shots_dict = { LABELING_FUNCTION: {} }

maybe_row_index_json = os.path.join(TRAIN_LABELS_DIR, LABELING_FUNCTION, "..", "..", "features", LABELING_FUNCTION, "row_index.json")
row_index_map = {}
if os.path.exists(maybe_row_index_json):
    try:
        with open(maybe_row_index_json, "r") as f:
            row_index_map = json.load(f)
        logger.info(f"[row_index] loaded: {maybe_row_index_json}")
    except Exception as e:
        logger.warning(f"[row_index] load failed: {e}; will skip row indices")

train_pids  = [p for p,_,_ in train_pool]
train_times = [t for _,t,_ in train_pool]
train_vals  = [v for _,_,v in train_pool]
val_pids    = [p for p,_,_ in val_pool]
val_times   = [t for _,t,_ in val_pool]
val_vals    = [v for _,_,v in val_pool]

for k in k_list:
    few_shots_dict[LABELING_FUNCTION][k] = {}
    for rep in range(N_REPLICATES):
        if labeler_type in ("boolean",):
            pos_order, neg_order = _orders_boolean(train_pool, SEED + rep)
            train_idxs = _build_train_indices_boolean(pos_order, neg_order, k, BALANCE)
        elif labeler_type in ("categorical", "string"):
            cls2idxs = _orders_categorical(train_pool, SEED + rep)
            train_idxs = _build_train_indices_categorical(cls2idxs, k, BALANCE)
        else:
            order_all = _order_multilabel(train_pool, SEED + rep)
            train_idxs = _build_train_indices_multilabel(order_all, k)

        val_idxs = list(range(len(val_pool)))

        entry = {
            "labeler_type":          labeler_type,
            "patient_ids_train_k":   [train_pids[i]  for i in train_idxs],
            "patient_ids_val_k":     [val_pids[i]    for i in val_idxs],
            "label_times_train_k":   [train_times[i] for i in train_idxs],
            "label_times_val_k":     [val_times[i]   for i in val_idxs],
            "label_values_train_k":  [train_vals[i]  for i in train_idxs],
            "label_values_val_k":    [val_vals[i]    for i in val_idxs],
            "train_idxs":            train_idxs,
            "val_idxs":              val_idxs,
        }

        if row_index_map:
            _attach_row_indices(entry, row_index_map, split="train")
            _attach_row_indices(entry, row_index_map, split="val")

        few_shots_dict[LABELING_FUNCTION][k][rep] = entry

out_json = os.path.join(TRAIN_LABELS_DIR, LABELING_FUNCTION, f"{SHOT_STRAT}_shots_data.json")
os.makedirs(os.path.dirname(out_json), exist_ok=True)
with open(out_json, "w") as f:
    json.dump(few_shots_dict, f, ensure_ascii=False)
logger.success(f"[ok] wrote {out_json}")
