# Generate features: COUNT and CLMBR

## Purpose
Compute COUNT features (Age + ontology-expanded CountFeaturizer) and CLMBR representations for each split and task.

## Inputs
- FEMR database at <BASE>/<split>/extract/
- all_labels.csv at <BASE>/<split>/femr_labels/<TASK>/all_labels.csv
- CLMBR binaries available in the active environment (clmbr_create_batches, clmbr_compute_representations)
- CLMBR assets: model weights and dictionary

## Outputs
- COUNT features: <BASE>/<split>/femr_features/<TASK>/count_features.pkl
- CLMBR features: <BASE>/<split>/femr_features/<TASK>/clmbr_features.pkl
- CLMBR batches: <BASE>/<split>/femr_features/<TASK>/clmbr_batches/


In [None]:
from loguru import logger
import os, pickle
from ehrshot.labelers.core import load_labeled_patients
from femr.featurizers import AgeFeaturizer, CountFeaturizer, FeaturizerList
import pandas as pd

In [None]:
BASE = "/root/autodl-tmp/femr" 
task = "mimic_icu_phenotyping"

In [None]:
SPLITS = {
    "train":    {"db": f"{BASE}/train/extract",    "labels": f"{BASE}/train/femr_labels", "features": f"{BASE}/train/femr_features"},
    "tuning":   {"db": f"{BASE}/tuning/extract",   "labels": f"{BASE}/tuning/femr_labels", "features": f"{BASE}/tuning/femr_features"},
    "held_out": {"db": f"{BASE}/held_out/extract", "labels": f"{BASE}/held_out/femr_labels", "features": f"{BASE}/held_out/femr_features"},
}

NUM_THREADS = 15  # notebook ；
FORCE = True

def gen_count_features_one_split(db_dir, labels_dir, features_dir):
    labels_path = os.path.join(labels_dir, "all_labels.csv")
    out_path = os.path.join(features_dir, "count_features.pkl")
    if (not FORCE) and os.path.exists(out_path):
        logger.info(f"[skip] {out_path} ")
        return out_path

    logger.info(f"Loading labels: {labels_path}")
    labeled_patients = load_labeled_patients(labels_path)

    age = AgeFeaturizer()
    count = CountFeaturizer(is_ontology_expansion=True)
    feats = FeaturizerList([age, count])

    logger.info("Preprocess featurizers")
    feats.preprocess_featurizers(db_dir, labeled_patients, NUM_THREADS)

    logger.info("Featurize patients")
    results = feats.featurize(db_dir, labeled_patients, NUM_THREADS)
    feature_matrix, patient_ids, label_values, label_times = results

    os.makedirs(features_dir, exist_ok=True)
    with open(out_path, "wb") as f:
        pickle.dump(results, f)
    logger.success(f"[ok] wrote {out_path}  "
                   f"rows={feature_matrix.shape[0]}  cols={feature_matrix.shape[1]}")
    return out_path,feats

In [None]:
cfg=SPLITS['train']
db_dir = cfg["db"]
labels_dir = cfg["labels"]
features_dir = cfg["features"]
task = "mimic_icu_phenotyping"
labels_path = os.path.join(labels_dir,task,"all_labels.csv")
out_path = os.path.join(features_dir,task,"count_features.pkl")
os.makedirs(os.path.dirname(out_path), exist_ok=True)
logger.info(f"Loading labels: {labels_path}")
labeled_patients = load_labeled_patients(labels_path)

age = AgeFeaturizer()
count = CountFeaturizer(is_ontology_expansion=True)
feats = FeaturizerList([age, count])

logger.info("Preprocess featurizers")
feats.preprocess_featurizers(db_dir, labeled_patients, NUM_THREADS)

logger.info("Featurize patients")
results = feats.featurize(db_dir, labeled_patients, NUM_THREADS)
feature_matrix, patient_ids, label_values, label_times = results

os.makedirs(features_dir, exist_ok=True)
with open(out_path, "wb") as f:
    pickle.dump(results, f)
logger.success(f"[ok] wrote {out_path}  "
               f"rows={feature_matrix.shape[0]}  cols={feature_matrix.shape[1]}")

In [None]:
logger.info(f"=== Count features | {'train'} ===")
cfg=SPLITS['train']
_, feats = gen_count_features_one_split(cfg["db"], cfg["labels"], cfg["features"])

In [None]:
def gen_count_features_one_split_rest(feats, task, db_dir, labels_dir, features_dir):
    labels_path = os.path.join(labels_dir, task, "all_labels.csv")
    out_path = os.path.join(features_dir, task, "count_features.pkl")
    if (not FORCE) and os.path.exists(out_path):
        logger.info(f"[skip] {out_path} ")
        return out_path

    logger.info(f"Loading labels: {labels_path}")
    labeled_patients = load_labeled_patients(labels_path)

    logger.info("Featurize patients")
    results = feats.featurize(db_dir, labeled_patients, NUM_THREADS)
    feature_matrix, patient_ids, label_values, label_times = results

    os.makedirs(os.path.join(features_dir, task), exist_ok=True)
    with open(out_path, "wb") as f:
        pickle.dump(results, f)
    logger.success(f"[ok] wrote {out_path}  "
                   f"rows={feature_matrix.shape[0]}  cols={feature_matrix.shape[1]}")
    return out_path,feats

In [None]:
logger.info(f"=== Count features | {'tuning'} ===")
cfg=SPLITS['tuning']
gen_count_features_one_split_rest(feats, task, cfg["db"], cfg["labels"], cfg["features"])

In [None]:
logger.info(f"=== Count features | {'held_out'} ===")
cfg=SPLITS['held_out']
gen_count_features_one_split_rest(feats, task, cfg["db"], cfg["labels"], cfg["features"])

In [None]:
import os, pickle
import numpy as np
import pandas as pd

FEATURE_DIR   = "/root/autodl-tmp/femr/tuning/femr_features/mimic_icu_phenotyping"
COUNT_PKL    = os.path.join(FEATURE_DIR, "count_features.pkl")  # 4

with open(COUNT_PKL, "rb") as f:
    results = pickle.load(f)

X         = results[0]
patient_ids   = np.asarray(results[1])
label_values  = np.asarray(results[2])
label_times   = np.asarray(results[3])

print("X shape:", X.shape)  # [, ]
print("#patients:", len(np.unique(patient_ids)))
print("labels: pos =", int(label_values.sum()), "/", len(label_values))

import pandas as pd
pd.DataFrame(X[:5, :20].toarray()).head()
pd.DataFrame(X[:5, :20].toarray()).style.format("{:.4f}")

# 5 clmbr features

In [None]:
import sys
python_executable_path = sys.executable
path_to_bin = os.path.dirname(python_executable_path)
path_to_clmbr_create_batches = os.path.join(path_to_bin, "clmbr_create_batches")

if not os.path.exists(path_to_clmbr_create_batches):
    raise FileNotFoundError(f" {path_to_clmbr_create_batches}  'clmbr_create_batches'。")
    

path_to_clmbr_compute_representations = os.path.join(path_to_bin, "clmbr_compute_representations")

if not os.path.exists(path_to_clmbr_compute_representations):
    raise FileNotFoundError(f" {path_to_clmbr_compute_representations}  'clmbr_compute_representations'。")

In [None]:
import os
from loguru import logger
import subprocess, shlex

MODELS_DIR = "/root/code/MEDS_Process/ehrshot-benchmark/EHRSHOT_ASSETS/models"  # clmbr/（）
MODEL = "clmbr"  # "motor"
FORCE = True

def run(cmd: str):
    logger.info("RUN: " + cmd)
    rc = subprocess.call(cmd, shell=True)
    if rc != 0:
        raise RuntimeError(f"Command failed (rc={rc}):\n{cmd}")

def gen_clmbr_features_one_split(db_dir, labels_dir, features_dir, task):
    model_dir = os.path.join(MODELS_DIR, MODEL)
    path_clmbr_model = os.path.join(model_dir, "clmbr_model")
    path_dictionary = "/root/code/MEDS_Process/ehrshot-benchmark/EHRSHOT_ASSETS/models/clmbr/dictionary"

    assert os.path.exists(path_clmbr_model), f"Missing model weights @ {path_clmbr_model}"
    assert os.path.exists(path_dictionary),  f"Missing dictionary @ {path_dictionary}"

    labels_path = os.path.join(labels_dir, task, "all_labels.csv")
    batches_dir = os.path.join(features_dir, task, f"{MODEL}_batches")
    repr_path = os.path.join(features_dir, task, f"{MODEL}_features.pkl")

    if FORCE:
        os.system(f"rm -rf {shlex.quote(batches_dir)}")
        os.system(f"rm -rf {shlex.quote(repr_path)}")

    hier_flag = "--is_hierarchical " if MODEL == "motor" else ""
    cmd_batches = (
        f"{path_to_clmbr_create_batches} {shlex.quote(batches_dir)}"
        f" --data_path {shlex.quote(db_dir)}"
        f" --dictionary {shlex.quote(path_dictionary)}"
        f" --task labeled_patients"
        f" --batch_size 131072"
        f" --val_start 100"
        f" --test_start 100"
        f" {hier_flag}"
        f" --labeled_patients_path {shlex.quote(labels_path)}"
    )
    run(cmd_batches)

    # 2) compute_representations
    cmd_repr = (
        f"{path_to_clmbr_compute_representations} {shlex.quote(repr_path)}"
        f" --data_path {shlex.quote(db_dir)}"
        f" --batches_path {shlex.quote(batches_dir)}"
        f" --model_dir {shlex.quote(path_clmbr_model)}"
    )
    run(cmd_repr)
    logger.success(f"[ok] wrote {repr_path}")

for split, cfg in SPLITS.items():
    logger.info(f"=== CLMBR features | {split} ===")
    gen_clmbr_features_one_split(cfg["db"], cfg["labels"], cfg["features"], task)


In [None]:
import os, pprint
from femr.extension import dataloader

DATA_PATH = "/root/autodl-tmp/femr/train/extract"
BATCH_INFO = "/root/autodl-tmp/femr/train/femr_features/mimic_icu_phenotyping/clmbr_batches/batch_info.msgpack"

loader = dataloader.BatchLoader(DATA_PATH, BATCH_INFO)

print("num train/dev/test batches:",
      loader.get_number_of_batches("train"),
      loader.get_number_of_batches("dev"),
      loader.get_number_of_batches("test"))

batch = loader.get_batch("train", 0)

print("\nTOP KEYS:", list(batch.keys()))
if "task" in batch:
    print("TASK TYPE:", type(batch["task"]))
    if isinstance(batch["task"], dict):
        print("TASK KEYS:", list(batch["task"].keys()))
        for k in ["labels", "label_values", "values", "y", "label_ages"]:
            if k in batch["task"]:
                v = batch["task"][k]
                try:
                    shape = (len(v), len(v[0])) if hasattr(v, "__getitem__") else None
                except Exception:
                    shape = None
                print(f"task['{k}'] present; example type:", type(v), "example:", (v[0] if len(v)>0 else None), "shape:", shape)
    else:
        print("TASK (non-dict):", batch["task"])

print("\nTRANSFORMER KEYS:", list(batch["transformer"].keys()))
print("num_indices:", batch.get("num_indices"))
print("patient_ids len:", len(batch.get("patient_ids", [])))


In [None]:
batch["task"]

In [None]:
import msgpack, os

def load_regular_map(dict_path):
    with open(dict_path, "rb") as f:
        unpacker = msgpack.Unpacker(f, raw=False)
        regular = None
        payloads = []
        for obj in unpacker:
            payloads.append(type(obj).__name__)
            if isinstance(obj, dict) and ("regular" in obj or "ontology_rollup" in obj):
                regular = obj.get("regular")
                meta = {k: type(v).__name__ for k, v in obj.items() if k != "regular"}
                print("Found packed dict keys:", list(obj.keys()))
                print("Meta value types:", meta)
                break
        print("Stream objects seen (types):", payloads[:5], "...")
        return regular

REG = load_regular_map("/root/code/MEDS_Process/ehrshot-benchmark/EHRSHOT_ASSETS/models/clmbr/dictionary")
print("regular size:", len(REG) if REG else None)
some = list(REG.items())[:5]
print("sample:", some)

In [None]:
import os, io, json, msgpack

ASSETS_DIR = "/root/code/MEDS_Process/ehrshot-benchmark/EHRSHOT_ASSETS/models/clmbr"
PATH_T2C   = os.path.join(ASSETS_DIR, "token_2_code.json")
OUT_DICT   = os.path.join(ASSETS_DIR, "dictionary.vocab.msgpack")

AGE_STATS = None  # : {"mean": 45.3, "std": 18.7}

def _load_any_jsons(path: str):
    """ JSON （ JSON /  JSON / JSONL ）。"""
    with open(path, "rb") as f:
        data = f.read()
    objs = []

    try:
        objs.append(json.loads(data.decode("utf-8-sig")))
        return objs
    except Exception:
        pass

    try:
        for line in data.splitlines():
            line = line.strip()
            if not line:
                continue
            try:
                objs.append(json.loads(line))
            except Exception:
                pass
        if objs:
            return objs
    except Exception:
        pass

    s = data.decode("utf-8", errors="ignore")
    dec = json.JSONDecoder()
    idx = 0
    N = len(s)
    while idx < N:
        while idx < N and s[idx].isspace():
            idx += 1
        if idx >= N:
            break
        try:
            obj, end = dec.raw_decode(s, idx)
            objs.append(obj)
            idx = end
        except Exception:
            idx += 1
    if not objs:
        raise ValueError(" JSON ")
    return objs

def build_dictionary_from_token2code(path_t2c: str, out_path: str, age_stats=None):
    """ token_2_code（ JSON、、 JSONL） CLMBR  msgpack。"""
    objs = _load_any_jsons(path_t2c)

    code2token = {}
    for obj in objs:
        if isinstance(obj, dict):
            keys = list(obj.keys())
            is_a = keys and all(isinstance(k, str) and k.isdigit() for k in keys[: min(5, len(keys))])
            if is_a:
                for tok_str, code in obj.items():
                    if not isinstance(code, str):
                        raise ValueError(" A  value  code_string(str)")
                    code2token[code] = int(tok_str)
            else:
                for code, tok in obj.items():
                    if not isinstance(tok, int):
                        raise ValueError(" B  value  token_id(int)")
                    code2token[code] = tok
        elif isinstance(obj, list):
            for it in obj:
                if isinstance(it, (list, tuple)) and len(it) == 2:
                    tok, code = it
                    code2token[str(code)] = int(tok) if not isinstance(tok, int) else tok
                elif isinstance(it, dict):
                    if "token" in it and "code" in it:
                        tok, code = it["token"], it["code"]
                        code2token[str(code)] = int(tok) if not isinstance(tok, int) else tok

    if not code2token:
        raise ValueError(" code->token_id ， token_2_code.json ")

    dictionary_obj = {"regular": code2token}
    if age_stats is not None:
        dictionary_obj["age_stats"] = {"mean": float(age_stats["mean"]), "std": float(age_stats["std"])}

    with open(out_path, "wb") as f:
        msgpack.pack(dictionary_obj, f, use_bin_type=True)

    print(f"[ok] wrote dictionary: {out_path}")
    print("vocab size =", len(code2token),
          "| min_id =", min(code2token.values()),
          "| max_id =", max(code2token.values()))
    return out_path

build_dictionary_from_token2code(PATH_T2C, OUT_DICT, age_stats=AGE_STATS)
