# Generate labels: ICU mortality (binary)

## Purpose
Generate ICU mortality labels using the EHRSHOT/FEMR labeler for each split (train/tuning/held_out).

## Inputs
- FEMR database at <BASE>/<split>/extract/
- Label output directory <BASE>/<split>/femr_labels/

## Outputs
- Labels at: <BASE>/<split>/femr_labels/mimic_icu_mortality/*.csv (including all_labels.csv depending on your flow)


In [None]:
import os
import json
import pandas as pd
import random
from multiprocessing import Pool, Manager
from loguru import logger
from typing import List, Dict

# Core femr and ehrshot imports
import femr.datasets
from femr.labelers import LabeledPatients, Label
from ehrshot.labelers.mimic import Mimic_ICUEventStreamMortalityLabeler

In [None]:
# --- Please modify your configuration here ---

# 1. INPUT: Path to your successfully created FEMR database extract
PATH_TO_FEMR_DATABASE = "/root/autodl-tmp/femr/held_out/extract" 

# 2. OUTPUT: A directory where the generated label files will be saved
PATH_TO_OUTPUT_DIR = "/root/autodl-tmp/femr/held_out/femr_labels/"

# 3. TASK NAME: A subdirectory will be created with this name
TASK_NAME = "mimic_icu_mortality"

# 4. PARAMETERS
# Number of CPU cores to use for parallel processing
NUM_PROCESSES = 15
# Set to True if you want to randomly sample only one ICU stay per patient.
# Set to False to generate a label for every ICU stay.
IS_SAMPLE_ONE_LABEL_PER_PATIENT = False

# --- End of configuration ---

In [None]:
# Automatically create the full path for the output file
PATH_TO_TASK_OUTPUT_DIR = os.path.join(PATH_TO_OUTPUT_DIR, TASK_NAME)
PATH_TO_OUTPUT_FILE = os.path.join(PATH_TO_TASK_OUTPUT_DIR, "labeled_patients.csv")

# Create directories if they don't exist
os.makedirs(PATH_TO_TASK_OUTPUT_DIR, exist_ok=True)

# Setup logging
path_to_log_file = os.path.join(PATH_TO_TASK_OUTPUT_DIR, 'info.log')
if os.path.exists(path_to_log_file):
    os.remove(path_to_log_file)
logger.add(path_to_log_file, level="INFO")

logger.info(f"Task: {TASK_NAME}")
logger.info(f"FEMR Database Path: {PATH_TO_FEMR_DATABASE}")
logger.info(f"Output Directory: {PATH_TO_TASK_OUTPUT_DIR}")
logger.info(f"Sample one label per patient: {IS_SAMPLE_ONE_LABEL_PER_PATIENT}")
logger.info(f"Number of threads: {NUM_PROCESSES}")

In [None]:
def save_labeled_patients_to_csv(labeled_patients: LabeledPatients, path_to_csv: str):
    """Converts a LabeledPatients object to a pandas DataFrame and saves it as a CSV."""
    rows = []
    for patient_id, labels in labeled_patients.items():
        for l in labels:
            rows.append((patient_id, l.time, l.value, labeled_patients.labeler_type))
    df = pd.DataFrame(rows, columns=['patient_id', 'prediction_time', 'value', 'label_type'])
    df = df.sort_values(['patient_id', 'prediction_time', 'value'])
    df.to_csv(path_to_csv, index=False)
    logger.success(f"Successfully saved {len(df)} labels to {path_to_csv}")

# This function is needed for the IS_SAMPLE_ONE_LABEL_PER_PATIENT logic
def process_patient_ids_for_sampling(args):
    """Processes a subset of patient IDs to sample one label per patient."""
    pid_subset, labeled_patients_dict, path_to_database = args
    local_results = {}
    database = femr.datasets.PatientDatabase(path_to_database)
    labeled_patients = LabeledPatients.from_dict(labeled_patients_dict)

    for pid in pid_subset:
        random.seed(int(pid))
        labels = labeled_patients.get_labels_from_patient_idx(pid)
        
        # Filter out labels that occur for patients <= 18 yrs of age
        if not database[pid].events:
            continue
        birth_year = database[pid].events[0].start.year
        
        valid_labels = [l for l in labels if (l.time.year - birth_year) >= 18]
        
        if not valid_labels:
            local_results[pid] = []
        elif len(valid_labels) == 1:
            local_results[pid] = valid_labels
        else:
            local_results[pid] = [random.choice(valid_labels)]
            
    return local_results

In [None]:
# Load PatientDatabase and Ontology
logger.info("Start | Loading PatientDatabase and Ontology")
database = femr.datasets.PatientDatabase(PATH_TO_FEMR_DATABASE)
ontology = database.get_ontology()
logger.info("Finish | Loading PatientDatabase and Ontology")

# Initialize the labeler for ICU mortality
labeler = Mimic_ICUEventStreamMortalityLabeler(ontology)

# Apply the labeler to the database
logger.info("Start | Applying labeler to all patients")
labeled_patients = labeler.apply(
    path_to_patient_database=PATH_TO_FEMR_DATABASE,
    num_threads=NUM_PROCESSES,
)
logger.info("Finish | Applying labeler")

In [None]:
# Optional: Randomly sample one label per patient
if IS_SAMPLE_ONE_LABEL_PER_PATIENT:
    logger.info("Start | Sampling one label per patient")
    pids = list(labeled_patients.keys())
    pid_subsets = [pids[i::NUM_PROCESSES] for i in range(NUM_PROCESSES)]
    
    # We pass a dictionary representation of labeled_patients to avoid pickling issues
    labeled_patients_dict = labeled_patients.to_dict()

    with Pool(NUM_PROCESSES) as pool:
        results_list = list(tqdm(
            pool.imap(
                process_patient_ids_for_sampling, 
                [(subset, labeled_patients_dict, PATH_TO_FEMR_DATABASE) for subset in pid_subsets]
            ), 
            total=len(pid_subsets),
            desc="Sampling labels"
        ))
    
    # Combine results from all processes
    combined_results = {k: v for d in results_list for k, v in d.items()}
    labeled_patients = LabeledPatients(combined_results, labeler_type=labeler.get_labeler_type())
    logger.info("Finish | Sampling one label per patient")

# Force labels to be minute-level resolution for FEMR compatibility
logger.info("Start | Adjusting label timestamps to minute-level resolution")
for patient_id, labels in labeled_patients.items():
    new_labels = [Label(time=l.time.replace(second=0, microsecond=0), value=l.value) for l in labels]
    labeled_patients[patient_id] = new_labels
logger.info("Finish | Adjusting label timestamps")

# Save the final labeled patients object
logger.info(f"Saving final labeled patients to CSV format at {PATH_TO_OUTPUT_FILE}")
save_labeled_patients_to_csv(labeled_patients, PATH_TO_OUTPUT_FILE)

# Final logging of statistics
logger.info("--- Final Label Statistics ---")
num_patients_total = labeled_patients.get_num_patients(is_include_empty_labels=True)
num_patients_with_labels = labeled_patients.get_num_patients(is_include_empty_labels=False)
num_labels = labeled_patients.get_num_labels()
_, label_values,_ = labeled_patients.as_numpy_arrays()
num_positive_labels = int(label_values.sum())

logger.info(f"Total # of patients in database: {num_patients_total}")
logger.info(f"Total # of patients with at least one label: {num_patients_with_labels}")
logger.info(f"Total # of labels (ICU stays): {num_labels}")
logger.info(f"Total # of positive labels (deaths): {num_positive_labels}")
if num_labels > 0:
    logger.info(f"Mortality Rate: {num_positive_labels / num_labels:.2%}")

logger.success("ðŸŽ‰ Done! ðŸŽ‰")

In [None]:
labeled_patients.as_numpy_arrays()

# My Labeller

In [None]:
from datetime import timedelta
from femr.datasets import PatientDatabase

ICU_ADMIT_PREFIX = "MIMIC/ICU_ADMISSION"
ICU_DISCHARGE_PREFIX = "MIMIC/ICU_DISCHARGE"
DEATH_CODES = {"SNOMED/419620001"}

try:
    from tqdm import tqdm
except Exception:
    def tqdm(x, **kw): return x

def summarize_icu_from_femr(path_to_db: str):
    db = PatientDatabase(path_to_db)

    try:
        pids = list(db)  # patient_idï¼ˆMIMIC  subject_idï¼‰
    except TypeError:
        pids = list(range(len(db)))

    n_patients_total = len(pids)
    n_patients_any_icu = 0
    n_icu_episodes = 0
    n_episodes_ge24_no_early_death = 0

    for pid in tqdm(pids, desc="Scanning ICU", mininterval=0.5):
        patient = db[pid]
        evs = sorted(patient.events, key=lambda e: e.start)
        n = len(evs)

        any_icu = any(
            (isinstance(getattr(e, "code", None), str) and e.code.startswith(ICU_ADMIT_PREFIX))
            for e in evs
        )
        if any_icu:
            n_patients_any_icu += 1

        i = 0
        while i < n:
            code = getattr(evs[i], "code", None)
            if isinstance(code, str) and code.startswith(ICU_ADMIT_PREFIX):
                start_t = evs[i].start

                j = i + 1
                end_t = None
                while j < n:
                    cj = getattr(evs[j], "code", None)
                    if isinstance(cj, str) and cj.startswith(ICU_DISCHARGE_PREFIX):
                        end_t = evs[j].start
                        break
                    j += 1

                if end_t is None:
                    i += 1
                    continue  # ï¼š stay

                n_icu_episodes += 1

                if (end_t - start_t).total_seconds() >= 24 * 3600:
                    t_pred = start_t + timedelta(hours=24)
                    early_death = False
                    k = i + 1
                    while k < j:
                        ck = getattr(evs[k], "code", None)
                        tk = evs[k].start
                        if ck in DEATH_CODES and start_t < tk <= t_pred:
                            early_death = True
                            break
                        k += 1
                    if not early_death:
                        n_episodes_ge24_no_early_death += 1

                i = j + 1
            else:
                i += 1

    return {
        "n_patients_total": n_patients_total,
        "n_patients_any_icu": n_patients_any_icu,  # ICU ï¼ˆï¼‰
        "n_icu_episodes": n_icu_episodes,  # â†’  ICU
        "n_episodes_ge24_no_early_death": n_episodes_ge24_no_early_death,
    }


In [None]:
stats = summarize_icu_from_femr(PATH_TO_FEMR_DATABASE)
print(stats)