# Build FEMR database (EHRSHOT-style) from MEDS/OMOP outputs

## Purpose
Create an EHRSHOT/FEMR-style event stream and build the FEMR patient database for downstream labeling, featurization, and modeling. This is where the missing-rate threshold is applied (code filtering).

## Inputs
- MEDS/OMOP-derived patient parquet files
- Output directory root containing train/tuning/held_out splits

## Outputs
- Per-split FEMR database under: <BASE>/<split>/extract/
- Intermediate event CSV(s) used by etl_simple_femr
- Code-frequency / missing-rate filtering artifacts (if enabled)

## Notes
To run multiple missing-rate thresholds τ, use a different <BASE> (or output folder) per τ to avoid overwriting.


## 0.Import

In [None]:
import os
import shutil
import datetime
from typing import List, Dict, Tuple
from functools import partial

import pandas as pd
import pyarrow.parquet as pq
from tqdm import tqdm
from loguru import logger
from multiprocessing import Pool, cpu_count

# !pip install femr loguru tqdm pyarrow pandas
import femr.datasets
from femr import Patient, Event

In [None]:

PATH_TO_MEDS_DIR = "/root/code/MEDS_Process/MEDS_to_OMOP/data/held_out" 

PATH_TO_OUTPUT_DIR = "/root/autodl-tmp/femr/held_out"

PATH_TO_TEMP_DIR = "/root/autodl-tmp/held_out/temp_patient_data"

NUM_PROCESSES = 15


print(f": {PATH_TO_MEDS_DIR}")
print(f": {PATH_TO_OUTPUT_DIR}")
print(f": {PATH_TO_TEMP_DIR}")
print(f": {NUM_PROCESSES}")

In [None]:
import os

def count_files_in_directory(directory_path):
    try:
        files = [f for f in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, f))]
        
        return len(files)
    except FileNotFoundError:
        print(f"： '{directory_path}'")
        return 0

directory_path = PATH_TO_TEMP_DIR
file_count = count_files_in_directory(directory_path)
print(f" '{directory_path}' : {file_count}")


import os

def list_first_10_files(directory_path):
    try:
        files = [f for f in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, f))]
        
        return files[:10]
    except FileNotFoundError:
        print(f"： '{directory_path}'")
        return []

first_10_files = list_first_10_files(directory_path)

if first_10_files:
    print(f" '{directory_path}'  10 ：")
    for file in first_10_files:
        print(file)
else:
    print(f"。")


In [None]:
directory_path = '/root/autodl-tmp/tuning/temp_patient_data'

In [None]:
import os

def check_file_in_directory(directory_path, filename):
    file_path = os.path.join(directory_path, filename)
    
    if os.path.isfile(file_path):
        return True
    else:
        return False

filename = "10040908.parquet"

if check_file_in_directory(directory_path, filename):
    print(f" '{filename}'  '{directory_path}' 。")
else:
    print(f" '{filename}'  '{directory_path}' 。")


In [None]:
import os

def count_files_in_directory(directory_path):
    try:
        files = [f for f in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, f))]
        
        return len(files)
    except FileNotFoundError:
        print(f"： '{directory_path}'")
        return 0

directory_path = PATH_TO_TEMP_DIR
file_count = count_files_in_directory(directory_path)
print(f" '{directory_path}' : {file_count}")


import os

def list_first_10_files(directory_path):
    try:
        files = [f for f in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, f))]
        
        return files[:10]
    except FileNotFoundError:
        print(f"： '{directory_path}'")
        return []

first_10_files = list_first_10_files(directory_path)

if first_10_files:
    print(f" '{directory_path}'  10 ：")
    for file in first_10_files:
        print(file)
else:
    print(f"。")


In [None]:
def process_patient_file(patient_file_path: str) -> Patient | None:
    """
    ParquetFEMR Patient。
    """
    try:
        patient_id = int(os.path.basename(patient_file_path).split('.')[0])
        
        df = pd.read_parquet(patient_file_path)
        
        rows = df.to_dict('records')
        
        events = []
        for row in rows:
            code = row.get("code")
            if code is None or code == '':
                continue
            
            start_time = row["time"]
            value = row.get("numeric_value")
            if pd.isna(value):
                value = None
            
            omop_table = row.get("ordercategorydescription", "MIMIC_MEDS")
            event = Event(start=start_time, code=str(code), value=value, omop_table=omop_table)
            events.append(event)
        
        if not events:
            return None
            
        events.sort(key=lambda x: x.start)
        return Patient(patient_id=patient_id, events=events)
    except Exception as e:
        logger.error(f" {patient_file_path} : {e}")
        return None

# 1. Step 1: create patient.parquet

In [None]:


logger.info("---  1:  ---")
if os.path.exists(PATH_TO_TEMP_DIR):
    shutil.rmtree(PATH_TO_TEMP_DIR)
os.makedirs(PATH_TO_TEMP_DIR)

parquet_files = [os.path.join(PATH_TO_MEDS_DIR, f) for f in os.listdir(PATH_TO_MEDS_DIR) if f.endswith('.parquet')]
REQUIRED_COLUMNS = ['subject_id', 'time', 'code', 'numeric_value', 'ordercategorydescription']

for file_path in tqdm(parquet_files, desc=""):
    df = pd.read_parquet(file_path, columns=REQUIRED_COLUMNS)
    for subject_id, group in df.groupby('subject_id'):
        patient_file = os.path.join(PATH_TO_TEMP_DIR, f"{subject_id}.parquet")
        if os.path.exists(patient_file):
            existing_data = pd.read_parquet(patient_file)
            combined_data = pd.concat([existing_data, group], ignore_index=True)
            combined_data.to_parquet(patient_file, index=False)
        else:
            group.to_parquet(patient_file, index=False)

logger.success("---  1:  ---")

# Step 2: Transfer all patient.parquet to one CSV

In [None]:
import csv
import pandas as pd  # pandas

def process_patient_file_to_csv_rows(patient_file_path: str) -> list:
    """
    Parquet，CSV。
     femr : [patient_id, start, code, value]
    """
    try:
        patient_id = int(os.path.basename(patient_file_path).split('.')[0])
        df = pd.read_parquet(patient_file_path)
        
        df = df.dropna(subset=['code'])
        if df.empty:
            return []

        df['patient_id'] = patient_id
        df['time'] = df['time'].apply(lambda x: x.isoformat())
        df['numeric_value'] = df['numeric_value'].where(pd.notna(df['numeric_value']), None)

        df.rename(columns={'time': 'start', 'numeric_value': 'value'}, inplace=True)
        
        output_df = df[['patient_id', 'start', 'code', 'value']]
        
        return output_df.values.tolist()
    except Exception as e:
        logger.error(f" {patient_file_path} : {e}")
        return []

In [None]:
import sys
import shutil  # shutil


logger.info("---  2:  mimiciv.csv () ---")

EVENTS_CSV_PATH = os.path.join(PATH_TO_OUTPUT_DIR, "mimiciv.csv")

if not os.path.exists(PATH_TO_OUTPUT_DIR):
    os.makedirs(PATH_TO_OUTPUT_DIR)

patient_files_to_process = [os.path.join(PATH_TO_TEMP_DIR, f) for f in os.listdir(PATH_TO_TEMP_DIR) if f.endswith('.parquet')]

with open(EVENTS_CSV_PATH, 'w', newline='') as f:
    writer = csv.writer(f)
    
    writer.writerow(['patient_id', 'start', 'code', 'value'])
    
    with Pool(processes=NUM_PROCESSES) as pool:
        results_iterator = pool.imap_unordered(process_patient_file_to_csv_rows, patient_files_to_process)
        for rows in tqdm(results_iterator, total=len(patient_files_to_process), desc="CSV"):
            if rows:
                writer.writerows(rows)

logger.success(f"---  2:  mimiciv.csv ，: {EVENTS_CSV_PATH} ---")


# Step 3: Data cleaning on the CSV

In [None]:
import csv
import os
import pandas as pd
from loguru import logger
from tqdm import tqdm

EVENTS_CSV_PATH = os.path.join(PATH_TO_OUTPUT_DIR, "mimiciv.csv")
TEMP_CSV_PATH = os.path.join(PATH_TO_OUTPUT_DIR, "mimiciv_final_fixed.csv")

OMOP_BIRTH_CODE = 'SNOMED/3950001'
OMOP_DEATH_CODE = 'SNOMED/419620001'
DEFAULT_PREFIX = "MIMIC"

logger.info(f" {EVENTS_CSV_PATH} ...")

try:
    with open(EVENTS_CSV_PATH, 'r') as f:
        total_lines = sum(1 for line in f)
except Exception as e:
    logger.error(f": {e}")
    total_lines = 0

try:
    with open(EVENTS_CSV_PATH, 'r') as fin, open(TEMP_CSV_PATH, 'w', newline='') as fout:
        reader = csv.reader(fin)
        writer = csv.writer(fout)
        
        header = next(reader)
        writer.writerow(header)
        
        code_col_idx = header.index('code')
        
        for row in tqdm(reader, total=total_lines - 1 if total_lines > 0 else None, desc=""):
            if len(row) > code_col_idx:
                code_value = str(row[code_col_idx]).strip()
                
                if 'MEDS_BIRTH' in code_value: 
                    code_value = OMOP_BIRTH_CODE 
                elif 'MEDS_DEATH' in code_value:
                     code_value = OMOP_DEATH_CODE 
                        
                if code_value.startswith("['") and code_value.endswith("']"): 
                    code_value = code_value[2:-2]

                is_standard_omop = (code_value.count('/') == 1 and '//' not in code_value)
                
                if not is_standard_omop:
                    code_value = f"{DEFAULT_PREFIX}/{code_value}"
                
                row[code_col_idx] = code_value
            
            writer.writerow(row)

    os.replace(TEMP_CSV_PATH, EVENTS_CSV_PATH)
    logger.success(f" mimiciv.csv ！")

except Exception as e:
    logger.error(f": {e}")
    if os.path.exists(TEMP_CSV_PATH):
        os.remove(TEMP_CSV_PATH)

In [None]:
def fix_nat_start_times_in_directory(directory_path: str):
    """
    CSV，
    'start'NaT， 'code' == 'SNOMED/3950001' 'start'。

    Args:
        directory_path (str): CSV。
    """
    
    directory_path = os.path.expanduser(directory_path)
    
    logger.info(f": {directory_path}")

    try:
        all_csv_files = [f for f in os.listdir(directory_path) if f.endswith('.csv')]
        if not all_csv_files:
            logger.warning(f" '{directory_path}'  .csv 。")
            return
    except FileNotFoundError:
        logger.error(f"： '{directory_path}'。。")
        return

    logger.info(f" {len(all_csv_files)} CSV。")
    
    files_processed = 0
    files_with_changes = 0

    for filename in tqdm(all_csv_files, desc="CSV"):
        file_path = os.path.join(directory_path, filename)
        
        try:
            df = pd.read_csv(file_path, parse_dates=['start'])

            birth_event = df[df['code'] == 'SNOMED/3950001']
            
            nat_rows = df['start'].isna()
            
            if not birth_event.empty and nat_rows.any():
                birth_time = birth_event['start'].iloc[0]
                
                df.loc[nat_rows, 'start'] = birth_time
                
                df.to_csv(file_path, index=False)
                
                logger.success(f" '{filename}':  {nat_rows.sum()} NaT。")
                files_with_changes += 1
            
            files_processed += 1

        except Exception as e:
            logger.error(f" '{filename}' : {e}")
            
    logger.info("---  ---")
    logger.info(f" {files_processed} 。")
    logger.info(f" {files_with_changes} 。")

fix_nat_start_times_in_directory(PATH_TO_OUTPUT_DIR)

logger.success("！")

## Step 3.1: Missing rate thredshold for Experiment

In [None]:
DEFAULT_THRESHOLD = 0.9
PROTECTED_CODES = {"SNOMED/3950001", "SNOMED/419620001"}  # /
ICU_ADMIT_PREFIX = "MIMIC/ICU_ADMISSION"
ICU_DISCHARGE_PREFIX = "MIMIC/ICU_DISCHARGE"

PAT_COL = "patient_id"
TIME_COL = "start"
CODE_COL = "code"

LOOKBACK_HOURS = 24  # ：ICU24h
LOOKAHEAD_HOURS = 24  # ：ICU24h


In [None]:
import os, json
from typing import Dict, List, Tuple, Iterable
from collections import defaultdict, Counter
import pandas as pd
from tqdm import tqdm

def _count_lines_fast(path: str, has_header: bool = True) -> int:
    total = 0
    with open(path, "rb") as f:
        while True:
            chunk = f.read(1024 * 1024)
            if not chunk:
                break
            total += chunk.count(b"\n")
    if has_header and total > 0:
        total -= 1
    return max(total, 0)

def _is_protected(code: str) -> bool:
    return (
        code in PROTECTED_CODES
        or code.startswith(ICU_ADMIT_PREFIX)
        or code.startswith(ICU_DISCHARGE_PREFIX)
    )

from collections import defaultdict
import pandas as pd
from tqdm import tqdm

def first_pass_index_patients(
    csv_path: str,
    sep: str = ",",
    chunksize: int = 1_000_000,
    encoding: str | None = None,
) -> dict[str, dict[str, list[pd.Timestamp]]]:
    """
     ICU //（，）。
     .iat/.iloc ，，。
    """
    info = defaultdict(lambda: {"admits": [], "discharges": [], "deaths": []})

    total_rows = _count_lines_fast(csv_path, has_header=True)
    with tqdm(total=total_rows or None, unit="row", desc="Indexing ICU markers") as pbar:
        for chunk in pd.read_csv(
            csv_path,
            chunksize=chunksize,
            sep=sep,
            encoding=encoding,
            usecols=[PAT_COL, TIME_COL, CODE_COL],
            dtype={PAT_COL: "string", CODE_COL: "string"},
            parse_dates=[TIME_COL],
            infer_datetime_format=True,
        ):
            pid  = chunk[PAT_COL].astype("string")
            t    = chunk[TIME_COL]  # datetime64[ns]
            code = chunk[CODE_COL].astype("string")

            m_admit = code.str.startswith(ICU_ADMIT_PREFIX, na=False)
            m_dis   = code.str.startswith(ICU_DISCHARGE_PREFIX, na=False)
            m_death = code.eq("SNOMED/419620001")

            def _accumulate(mask, key: str):
                if not mask.any():
                    return
                sub = pd.DataFrame({PAT_COL: pid[mask].to_numpy(), TIME_COL: t[mask].to_numpy()})
                grouped = sub.groupby(PAT_COL, sort=False)[TIME_COL].agg(list)
                for p, ts_list in grouped.items():
                    if ts_list:
                        info[str(p)][key].extend([ts for ts in ts_list if pd.notna(ts)])

            _accumulate(m_admit, "admits")
            _accumulate(m_dis,   "discharges")
            _accumulate(m_death, "deaths")

            pbar.update(len(chunk))

    for p, d in info.items():
        d["admits"]     = sorted(set(d["admits"]))
        d["discharges"] = sorted(set(d["discharges"]))
        d["deaths"]     = sorted(set(d["deaths"]))
    return info



def build_blocks_and_select_patients(
    icu_info: Dict[str, Dict[str, List[pd.Timestamp]]],
    drop_if_any_short_block: bool = True,
    min_duration_hours: int = 24
) -> Tuple[Dict[str, List[Tuple[pd.Timestamp, pd.Timestamp | None]]],
           Dict[str, List[Tuple[pd.Timestamp, pd.Timestamp]]],
           set, set]:
    """
    ：
      blocks[pid] = [(admit, end), ...]， end = min(/, )，None 
      pre_windows[pid] = [(admit-24h, admit), ...]
      keep_patients: 
      drop_patients: 
    ：
      1)  ICU  -> 
      2)  (admit -> /) < 24 ，（ drop_if_any_short_block ）
    """
    lookback = pd.Timedelta(hours=LOOKBACK_HOURS)
    min_dur  = pd.Timedelta(hours=min_duration_hours)

    blocks = {}
    pre_windows = {}
    keep_patients, drop_patients = set(), set()

    for pid, d in icu_info.items():
        admits = d["admits"]
        if not admits:
            drop_patients.add(pid)
            continue

        discharges = d["discharges"]
        deaths     = d["deaths"]

        segs = []
        for i, a in enumerate(admits):
            next_admit = admits[i+1] if i+1 < len(admits) else pd.Timestamp.max
            closers = []
            death_after = [dt for dt in deaths if dt > a]
            if death_after:
                closers.append(min(death_after))
            dis_after = [dt for dt in discharges if dt > a]
            if dis_after:
                closers.append(min(dis_after))
            closers.append(next_admit)
            end = min(closers) if closers else pd.Timestamp.max

            segs.append((a, end))

        any_short = any(( (end - a) < min_dur ) for a, end in segs if end != pd.Timestamp.max)
        if drop_if_any_short_block and any_short:
            drop_patients.add(pid)
            continue

        keep_patients.add(pid)
        blocks[pid] = segs
        pre_windows[pid] = [(a - lookback, a) for a, _ in segs]

    return blocks, pre_windows, keep_patients, drop_patients

def build_codes_to_keep_by_row_fraction_with_progress(
    csv_path: str,
    out_dir: str,
    threshold: float = DEFAULT_THRESHOLD,
    sep: str = ",",
    chunksize: int = 1_000_000,
    encoding: str | None = None,
) -> str:
    """
     codes_to_keep：【ICU  LOOKAHEAD_HOURS 】，
    （）。
    """
    os.makedirs(out_dir, exist_ok=True)

    icu_info = first_pass_index_patients(csv_path, sep=sep, chunksize=chunksize, encoding=encoding)

    blocks, _pre_windows_unused, keep_patients, drop_patients = build_blocks_and_select_patients(icu_info)

    with open(os.path.join(out_dir, "kept_patients.json"), "w") as f:
        json.dump(sorted(list(keep_patients)), f)
    with open(os.path.join(out_dir, "dropped_patients.json"), "w") as f:
        json.dump(sorted(list(drop_patients)), f)

    obs_delta = pd.Timedelta(hours=LOOKAHEAD_HOURS)
    post_windows = {}
    for pid, segs in blocks.items():
        wins = []
        for (a, end) in segs:
            wins.append((a, min(a + obs_delta, end)))
        post_windows[pid] = wins

    total_rows_considered = 0
    code_counts = Counter()
    total_rows = _count_lines_fast(csv_path, has_header=True)

    with tqdm(total=total_rows or None, unit="row", desc=f"Counting (post {LOOKAHEAD_HOURS}h)") as pbar:
        for chunk in pd.read_csv(csv_path, chunksize=chunksize, sep=sep, encoding=encoding,
                                 usecols=[PAT_COL, TIME_COL, CODE_COL]):
            pid  = chunk[PAT_COL].astype(str)
            t    = pd.to_datetime(chunk[TIME_COL], errors="coerce")
            code = chunk[CODE_COL].astype(str)

            m_keep_pid = pid.isin(keep_patients)
            if not m_keep_pid.any():
                pbar.update(len(chunk))
                continue

            sub  = chunk.loc[m_keep_pid].copy()
            pid  = pid.loc[m_keep_pid]
            t    = t.loc[m_keep_pid]
            code = code.loc[m_keep_pid]

            in_post = []
            for idx in sub.index:
                p = pid.at[idx]; ti = t.at[idx]
                if pd.isna(ti):
                    in_post.append(False)
                    continue
                hit = False
                for lo, hi in post_windows.get(p, []):
                    if ti >= lo and ti < hi:
                        hit = True
                        break
                in_post.append(hit)
            in_post = pd.Series(in_post, index=sub.index)

            c = code.loc[in_post].astype(str)
            m_not_prot = ~c.apply(_is_protected)
            c = c.loc[m_not_prot]

            total_rows_considered += len(c)
            vc = c.value_counts()
            for k, v in vc.items():
                code_counts[k] += int(v)

            pbar.update(len(chunk))

    denom = max(total_rows_considered, 1)
    rows = []
    for cd, n in code_counts.items():
        row_fraction = n / denom
        missing_rate = 1.0 - row_fraction
        keep = (missing_rate < threshold)
        rows.append({
            "code": cd,
            "n_rows": n,
            "total_rows": denom,
            "row_fraction": row_fraction,
            "missing_rate": missing_rate,
            "is_protected": _is_protected(cd),
            "keep": keep,
        })
    stats_df = pd.DataFrame(rows).sort_values(by=["keep", "row_fraction"], ascending=[False, False])

    if not stats_df.empty:
        prot_mask = stats_df["code"].apply(_is_protected)
        stats_df.loc[prot_mask, "keep"] = True

    keep_csv = os.path.join(out_dir, "codes_to_keep.csv")
    stats_df.to_csv(keep_csv, index=False)

    keep_list = set(stats_df.loc[stats_df["keep"], "code"].astype(str).tolist())
    keep_list |= PROTECTED_CODES  # _is_protected
    with open(os.path.join(out_dir, "codes_to_keep.json"), "w", encoding="utf-8") as f:
        json.dump(sorted(list(keep_list)), f, ensure_ascii=False, indent=2)

    return keep_csv




def apply_codes_to_keep_on_csv_with_progress(
    csv_path: str,
    keep_file: str,  # codes_to_keep.csv  .json
    out_path: str,
    sep: str = ",",
    chunksize: int = 1_000_000,
    encoding: str | None = None,
) -> Tuple[str, int]:
    """
    ：
      -  kept_patients（ kept_patients.json）
      - code  keep_set 
      -  ICU block(admit -> end):  (admit+24h, end) ；**** /（）
    """
    out_dir = os.path.dirname(out_path) or "."
    kept_pat_json = os.path.join(out_dir, "kept_patients.json")
    if not os.path.exists(kept_pat_json):
        raise FileNotFoundError(" kept_patients.json，（build_codes_to_keep_...）")

    if keep_file.endswith(".json"):
        keep_set = set(json.load(open(keep_file, "r", encoding="utf-8")))
    else:
        kdf = pd.read_csv(keep_file, usecols=["code", "keep"])
        keep_set = set(kdf.loc[kdf["keep"], "code"].astype(str).tolist())
    keep_set |= PROTECTED_CODES | {ICU_ADMIT_PREFIX, ICU_DISCHARGE_PREFIX}

    kept_patients = set(json.load(open(kept_pat_json, "r")))
    icu_info = first_pass_index_patients(csv_path, sep=sep, chunksize=chunksize, encoding=encoding)
    blocks, _, _, _ = build_blocks_and_select_patients(icu_info)

    drop_intervals = {}
    for pid, segs in blocks.items():
        lst = []
        for a, end in segs:
            lo = a + pd.Timedelta(hours=LOOKAHEAD_HOURS)
            hi = end
            lst.append((lo, hi))
        drop_intervals[pid] = lst

    total_rows = _count_lines_fast(csv_path, has_header=True)
    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
    first = True
    written = 0

    with tqdm(total=total_rows or None, unit="row", desc="Filtering") as pbar:
        for chunk in pd.read_csv(csv_path, chunksize=chunksize, sep=sep, encoding=encoding):
            if CODE_COL not in chunk.columns or PAT_COL not in chunk.columns or TIME_COL not in chunk.columns:
                raise KeyError(f"CSV ：{[PAT_COL, TIME_COL, CODE_COL]}")
            pid  = chunk[PAT_COL].astype(str)
            t    = pd.to_datetime(chunk[TIME_COL], errors="coerce")
            code = chunk[CODE_COL].astype(str)

            m_keep_pid = pid.isin(kept_patients)
            if not m_keep_pid.any():
                pbar.update(len(chunk))
                continue
            sub = chunk.loc[m_keep_pid].copy()
            pid = pid.loc[m_keep_pid]; t = t.loc[m_keep_pid]; code = code.loc[m_keep_pid]

            m_code = code.apply(lambda c: (c in keep_set) or _is_protected(c))
            sub = sub.loc[m_code]; pid = pid.loc[m_code]; t = t.loc[m_code]; code = code.loc[m_code]
            if sub.empty:
                pbar.update(len(chunk))
                continue

            rows_keep = []
            for idx in sub.index:
                p = pid.at[idx]; ti = t.at[idx]; ci = code.at[idx]
                if pd.isna(ti):
                    rows_keep.append(True)
                    continue
                if _is_protected(ci):
                    rows_keep.append(True)
                    continue
                drop = False
                for (lo, hi) in drop_intervals.get(p, []):
                    if ti > lo and ti < hi:
                        drop = True
                        break
                rows_keep.append(not drop)

            out_chunk = sub.loc[pd.Series(rows_keep, index=sub.index)]
            out_chunk.to_csv(out_path, mode="w" if first else "a", header=first, index=False)
            first = False
            written += len(out_chunk)

            pbar.update(len(chunk))

    return out_path, written

def build_keep_only(
    src_csv: str,
    out_dir: str,
    threshold: float = DEFAULT_THRESHOLD,
    sep: str = ",",
    chunksize: int = 1_000_000,
    encoding: str | None = None,
) -> str:
    """
    “+ keep  +  kept_patients.json / dropped_patients.json”，。
    """
    return build_codes_to_keep_by_row_fraction_with_progress(
        src_csv, out_dir, threshold, sep, chunksize, encoding
    )

def filter_only(
    src_csv: str,
    out_dir: str,
    keep_file: str,  # build_keep_only  codes_to_keep.csv/json
    sep: str = ",",
    chunksize: int = 1_000_000,
    encoding: str | None = None,
) -> Tuple[str, int]:
    """
    ， keep_file + kept_patients.json。
    ：out_dir/filtered.csv
    """
    filtered_csv = os.path.join(out_dir, "filtered.csv")
    return apply_codes_to_keep_on_csv_with_progress(
        src_csv, keep_file, filtered_csv, sep, chunksize, encoding
    )

In [None]:

keep_csv = build_keep_only(
    src_csv="/root/autodl-tmp/femr/train/mimiciv.csv",
    out_dir="/root/autodl-tmp/femr_filtered/train",  
    threshold=0.9,
    sep=",",
    chunksize=1_000_000,
)
print("codes_to_keep ->", keep_csv)



In [None]:
# save as filter_icu_24h_then_stats.py
import os, json
from typing import Dict, List, Tuple
from collections import defaultdict, Counter
import numpy as np
import pandas as pd
from tqdm import tqdm

PAT_COL  = "patient_id"
TIME_COL = "start"  # （CSV  pandas ）
CODE_COL = "code"  # code

PROTECTED_CODES = {"SNOMED/3950001", "SNOMED/419620001"}  # /
ICU_ADMIT_PREFIX     = "MIMIC/ICU_ADMISSION"
ICU_DISCHARGE_PREFIX = "MIMIC/ICU_DISCHARGE"

WINDOW_HOURS = 24  # 24h

def _count_lines_fast(path: str, has_header: bool = True) -> int:
    total = 0
    with open(path, "rb") as f:
        while True:
            chunk = f.read(1024 * 1024)
            if not chunk:
                break
            total += chunk.count(b"\n")
    if has_header and total > 0:
        total -= 1
    return max(total, 0)

def _is_birth(c: pd.Series) -> pd.Series:
    return c == "SNOMED/3950001"

def _is_death(c: pd.Series) -> pd.Series:
    return c == "SNOMED/419620001"

def _is_icu_admit(c: pd.Series) -> pd.Series:
    return c.str.startswith(ICU_ADMIT_PREFIX, na=False)

def _is_icu_discharge(c: pd.Series) -> pd.Series:
    return c.str.startswith(ICU_DISCHARGE_PREFIX, na=False)

def _is_protected_any(c: pd.Series) -> pd.Series:
    return _is_birth(c) | _is_death(c) | _is_icu_admit(c) | _is_icu_discharge(c)

def _is_allowed_after_24h(c: pd.Series) -> pd.Series:
    return _is_death(c) | _is_icu_discharge(c)

def prefilter_drop_mimic_except_icu(
    src_csv: str,
    out_csv: str,
    sep: str = ",",
    chunksize: int = 2_000_000,
    encoding: str | None = None,
):
    """
     'MIMIC/' ，**** ICU /；
    。
    """
    os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)

    import csv as _csv
    with open(src_csv, "r", encoding=encoding or "utf-8") as f:
        original_cols = next(_csv.reader(f))

    total_rows = _count_lines_fast(src_csv, has_header=True)
    first = True
    written = 0

    with tqdm(total=total_rows or None, unit="row", desc="Pre-filter: drop MIMIC/* (keep ICU markers)") as pbar:
        reader = pd.read_csv(src_csv, sep=sep, chunksize=chunksize, encoding=encoding)
        for chunk in reader:
            code = chunk[CODE_COL].astype("string")
            is_mimic   = code.str.startswith("MIMIC/", na=False)
            keep_icu   = _is_icu_admit(code) | _is_icu_discharge(code)
            to_drop    = is_mimic & (~keep_icu)

            out_chunk = chunk.loc[~to_drop].reindex(columns=original_cols)
            out_chunk.to_csv(out_csv, mode="w" if first else "a", header=first, index=False)
            first = False
            written += len(out_chunk)
            pbar.update(len(chunk))
    return out_csv, written



def index_icu_markers(csv_path: str, sep=",", chunksize=2_000_000, encoding=None):
    info = defaultdict(lambda: {"admits": [], "discharges": [], "deaths": []})
    total_rows = _count_lines_fast(csv_path, has_header=True)
    with tqdm(total=total_rows or None, unit="row", desc="Indexing ICU markers") as pbar:
        reader = pd.read_csv(
            csv_path, sep=sep, chunksize=chunksize, encoding=encoding,
            usecols=[PAT_COL, TIME_COL, CODE_COL],
            dtype={PAT_COL:"string", CODE_COL:"string"},
            parse_dates=[TIME_COL],
        )
        for chunk in reader:
            pid  = chunk[PAT_COL].astype("string")
            code = chunk[CODE_COL].astype("string")

            m_admit = _is_icu_admit(code)
            m_dis   = _is_icu_discharge(code)
            m_death = _is_death(code)

            def _acc(mask, key: str):
                if not mask.any():
                    return
                sub = chunk.loc[mask, [PAT_COL, TIME_COL]].copy()
                if not np.issubdtype(sub[TIME_COL].dtype, np.datetime64):
                    sub[TIME_COL] = pd.to_datetime(sub[TIME_COL], errors="coerce")
                sub = sub.dropna(subset=[TIME_COL])
                if sub.empty: return
                g = sub.groupby(PAT_COL, sort=False)[TIME_COL].agg(list)
                for p, ts_list in g.items():
                    info[str(p)][key].extend(ts_list)

            _acc(m_admit, "admits")
            _acc(m_dis,   "discharges")
            _acc(m_death, "deaths")
            pbar.update(len(chunk))

    for p, d in info.items():
        d["admits"]     = sorted(set(d["admits"]))
        d["discharges"] = sorted(set(d["discharges"]))
        d["deaths"]     = sorted(set(d["deaths"]))
    return info

def save_icu_index(icu_info: dict, path: str):
    """ ICU  JSON（ISO ）， .gz 。"""
    rec = {
        p: {
            "admits":     [ts.isoformat() for ts in d["admits"]],
            "discharges": [ts.isoformat() for ts in d["discharges"]],
            "deaths":     [ts.isoformat() for ts in d["deaths"]],
        } for p, d in icu_info.items()
    }
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    if path.endswith(".gz"):
        import gzip
        with gzip.open(path, "wt", encoding="utf-8") as f:
            json.dump(rec, f)
    else:
        with open(path, "w", encoding="utf-8") as f:
            json.dump(rec, f)
    return path

def load_icu_index(path: str) -> dict:
    """ Step 1  ICU  JSON（ .gz）， Timestamp。"""
    if path.endswith(".gz"):
        import gzip
        with gzip.open(path, "rt", encoding="utf-8") as f:
            rec = json.load(f)
    else:
        rec = json.load(open(path, "r", encoding="utf-8"))
    out = {}
    for p, d in rec.items():
        out[p] = {
            "admits":     [pd.Timestamp(x) for x in d["admits"]],
            "discharges": [pd.Timestamp(x) for x in d["discharges"]],
            "deaths":     [pd.Timestamp(x) for x in d["deaths"]],
        }
    return out

def build_stays(icu_info: Dict[str, Dict[str, List[pd.Timestamp]]], window_hours=24):
    kept_stays, dropped_stays = {}, {}
    delta = pd.Timedelta(hours=window_hours)
    for pid, d in icu_info.items():
        admits = d["admits"]; 
        if not admits: continue
        dis, deaths = d["discharges"], d["deaths"]
        kept, dropped = [], []
        for i, A in enumerate(admits):
            next_A = admits[i+1] if i+1 < len(admits) else pd.Timestamp.max
            dis_after   = [x for x in dis    if x > A]
            death_after = [x for x in deaths if x > A]
            candidates = [next_A]
            if dis_after:   candidates.append(min(dis_after))
            if death_after: candidates.append(min(death_after))
            E = min(candidates) if candidates else pd.Timestamp.max
            too_short = (E - A) < delta
            death_within_24h = any((x - A) <= delta for x in death_after)
            if too_short or death_within_24h:
                dropped.append((A, E))
            else:
                cutoff = min(A + delta, E)
                kept.append((A, cutoff, E))
        if kept:   kept_stays[pid] = kept
        if dropped:dropped_stays[pid] = dropped
    return kept_stays, dropped_stays

def save_stays(kept_stays: dict, dropped_stays: dict, kept_path: str, dropped_path: str):
    def _ser(obj):
        return {p: [(str(a), str(c), str(e)) for (a, c, e) in lst] for p, lst in obj.items()}
    os.makedirs(os.path.dirname(kept_path) or ".", exist_ok=True)
    json.dump(_ser(kept_stays), open(kept_path, "w"), ensure_ascii=False)
    json.dump({p: [(str(a), str(e)) for (a, e) in lst] for p, lst in dropped_stays.items()},
              open(dropped_path, "w"), ensure_ascii=False)
    return kept_path, dropped_path

def load_stays(kept_path: str, dropped_path: str):
    ks = json.load(open(kept_path, "r")); ds = json.load(open(dropped_path, "r"))
    kept = {p: [(pd.Timestamp(a), pd.Timestamp(c), pd.Timestamp(e)) for (a, c, e) in lst] for p, lst in ks.items()}
    drop = {p: [(pd.Timestamp(a), pd.Timestamp(e)) for (a, e) in lst] for p, lst in ds.items()}
    return kept, drop



def _membership_via_searchsorted(times: np.ndarray, starts: np.ndarray, ends: np.ndarray, inclusive_end=False):
    """
     [starts[i], ends[i])， times 。
    ：(in_interval, idx)  idx （<0 ）。
    """
    if len(starts) == 0:
        return np.zeros_like(times, dtype=bool), np.full(times.shape, -1, dtype=int)
    idx = np.searchsorted(starts, times, side="right") - 1  # -1 =  start<=t
    valid = idx >= 0
    e_sel = np.empty_like(times)
    e_sel[~valid] = np.datetime64("NaT")
    e_sel[valid]  = ends[idx[valid]]
    if inclusive_end:
        in_itv = valid & (times <= e_sel)
    else:
        in_itv = valid & (times < e_sel)
    return in_itv, idx

def filter_csv_by_icu_rules(
    src_csv: str,
    out_csv: str,
    kept_stays: Dict[str, List[Tuple[pd.Timestamp, pd.Timestamp, pd.Timestamp]]],
    dropped_stays: Dict[str, List[Tuple[pd.Timestamp, pd.Timestamp]]],
    sep: str = ",",
    chunksize: int = 2_000_000,
    encoding: str | None = None,
):
    """
    ：
      1) stay ：  code（//ICU /）
      2)  stay：（ code）
      3)  stay：
          - [admit, admit+24h) ：
          - [admit+24h, end] ： {ICU , }
    """
    total_rows = _count_lines_fast(src_csv, has_header=True)
    os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)

    import csv as _csv
    with open(src_csv, "r", encoding=encoding or "utf-8") as f:
        original_cols = next(_csv.reader(f))

    first = True
    written = 0

    kept_pids    = set(kept_stays.keys())
    dropped_pids = set(dropped_stays.keys())

    with tqdm(total=total_rows or None, unit="row", desc="Filtering by ICU rules") as pbar:
        reader = pd.read_csv(src_csv, sep=sep, chunksize=chunksize, encoding=encoding, parse_dates=[TIME_COL])
        for chunk in reader:
            if not np.issubdtype(chunk[TIME_COL].dtype, np.datetime64):
                chunk[TIME_COL] = pd.to_datetime(chunk[TIME_COL], errors="coerce")

            pid  = chunk[PAT_COL].astype("string")
            time = chunk[TIME_COL].to_numpy(dtype="datetime64[ns]", copy=False)
            code = chunk[CODE_COL].astype("string")

            keep_mask = np.zeros(len(chunk), dtype=bool)

            idx_all = np.arange(len(chunk))
            m_kept_pid    = pid.isin(kept_pids).to_numpy()
            m_dropped_pid = pid.isin(dropped_pids).to_numpy()
            m_no_stay_pid = ~(m_kept_pid | m_dropped_pid)

            if m_no_stay_pid.any():
                sub = code[m_no_stay_pid]
                keep_mask[m_no_stay_pid] = _is_protected_any(sub).to_numpy()

            if m_dropped_pid.any():
                ids = idx_all[m_dropped_pid]
                for p in np.unique(pid.iloc[ids]):
                    mask_p = (pid.iloc[ids] == p).to_numpy()
                    ids_p  = ids[mask_p]
                    times_p= time[ids_p]

                    dlist = dropped_stays.get(str(p), [])
                    d_starts = np.array([np.datetime64(a, "ns") for a, _ in dlist])
                    d_ends   = np.array([np.datetime64(e, "ns") for _, e in dlist])
                    in_dropped, _ = _membership_via_searchsorted(times_p, d_starts, d_ends, inclusive_end=True)

                    keep_outside = _is_protected_any(code.iloc[ids_p]).to_numpy()
                    keep_mask[ids_p] = (~in_dropped) & keep_outside

            if m_kept_pid.any():
                ids = idx_all[m_kept_pid]
                for p in np.unique(pid.iloc[ids]):
                    mask_p = (pid.iloc[ids] == p).to_numpy()
                    ids_p  = ids[mask_p]
                    times_p= time[ids_p]
                    codes_p= code.iloc[ids_p]

                    kept_list = kept_stays.get(str(p), [])
                    k_starts = np.array([np.datetime64(a, "ns") for a, _, _ in kept_list])
                    k_cuts   = np.array([np.datetime64(c, "ns") for _, c, _ in kept_list])
                    k_ends   = np.array([np.datetime64(e, "ns") for _, _, e in kept_list])

                    in_kept, idxk = _membership_via_searchsorted(times_p, k_starts, k_ends, inclusive_end=False)

                    validk = idxk >= 0
                    cut_sel = np.empty_like(times_p)
                    end_sel = np.empty_like(times_p)
                    cut_sel[validk] = k_cuts[idxk[validk]]
                    end_sel[validk] = k_ends[idxk[validk]]
                    keep_0_24 = in_kept & (times_p < cut_sel)

                    allow_after = _is_allowed_after_24h(codes_p).to_numpy()
                    in_after = in_kept & (times_p >= cut_sel) & (times_p <= end_sel)
                    keep_24_E = in_after & allow_after

                    outside_kept = ~in_kept
                    keep_outside = _is_protected_any(codes_p).to_numpy()

                    keep_mask[ids_p] = keep_0_24 | keep_24_E | (outside_kept & keep_outside)

            out_chunk = chunk.loc[keep_mask].reindex(columns=original_cols)
            out_chunk.to_csv(out_csv, mode="w" if first else "a", header=first, index=False)
            first = False
            written += len(out_chunk)
            pbar.update(len(chunk))

    return out_csv, written

def stats_row_fraction(csv_path: str, out_csv: str, sep=",", chunksize=2_000_000, encoding=None, exclude_protected=True):
    """
    （）：missing_rate(code) = 1 - n_rows(code)/total_rows
    ：exclude_protected=True ， code /。
    """
    total = 0
    counts = Counter()
    total_rows = _count_lines_fast(csv_path, has_header=True)
    with tqdm(total=total_rows or None, unit="row", desc="Counting codes (filtered CSV)") as pbar:
        reader = pd.read_csv(csv_path, sep=sep, chunksize=chunksize, encoding=encoding, usecols=[CODE_COL])
        for chunk in reader:
            c = chunk[CODE_COL].astype("string")
            if exclude_protected:
                m = ~_is_protected_any(c)
                c = c[m]
                inc_total = int(m.sum())
            else:
                inc_total = len(c)
            vc = c.value_counts()
            for k, v in vc.items():
                counts[str(k)] += int(v)
            total += inc_total
            pbar.update(len(chunk))
    total = max(total, 1)
    rows = []
    for code, n in counts.items():
        row_fraction = n / total
        rows.append({
            "code": code,
            "n_rows": n,
            "total_rows": total,
            "row_fraction": row_fraction,
            "missing_rate": 1.0 - row_fraction,
        })
    df = pd.DataFrame(rows).sort_values(by=["row_fraction"], ascending=False)
    df.to_csv(out_csv, index=False)
    return out_csv

def run_filter_then_stats(src_csv: str, out_dir: str, sep=",", chunksize=2_000_000):
    os.makedirs(out_dir, exist_ok=True)

    icu_info = index_icu_markers(src_csv, sep=sep, chunksize=chunksize)
    kept_stays, dropped_stays = build_stays(icu_info, window_hours=WINDOW_HOURS)

    filtered_csv = os.path.join(out_dir, "filtered.csv")
    filter_csv_by_icu_rules(src_csv, filtered_csv, kept_stays, dropped_stays, sep=sep, chunksize=chunksize)

    stats_csv = os.path.join(out_dir, "code_row_fraction_stats.csv")
    stats_row_fraction(filtered_csv, stats_csv, sep=sep, chunksize=chunksize, exclude_protected=True)

    with open(os.path.join(out_dir, "kept_stays.json"), "w") as f:
        json.dump({p: [(str(a), str(c), str(e)) for (a, c, e) in lst] for p, lst in kept_stays.items()}, f)
    with open(os.path.join(out_dir, "dropped_stays.json"), "w") as f:
        json.dump({p: [(str(a), str(e)) for (a, e) in lst] for p, lst in dropped_stays.items()}, f)

    return filtered_csv, stats_csv

In [None]:
src_csv  = "/root/autodl-tmp/femr/tuning/mimiciv.csv"
out_dir  = "/root/autodl-tmp/femr_filtered/tuning"
os.makedirs(out_dir, exist_ok=True)

pre_csv = os.path.join(out_dir, "pre_filtered.csv")
prefilter_drop_mimic_except_icu(src_csv, pre_csv, chunksize=2_000_000)

In [None]:
src_csv  = "/root/autodl-tmp/femr/tuning/mimiciv.csv"
out_dir  = "/root/autodl-tmp/femr_filtered/tuning"
pre_csv = os.path.join(out_dir, "pre_filtered.csv")

icu_index_path = os.path.join(out_dir, "icu_index.json.gz")
icu_info = index_icu_markers(pre_csv, chunksize=2_000_000)
save_icu_index(icu_info, icu_index_path)



In [None]:
kept_path    = os.path.join(out_dir, "kept_stays.json")
dropped_path = os.path.join(out_dir, "dropped_stays.json")
kept_stays, dropped_stays = build_stays(icu_info, window_hours=24)
save_stays(kept_stays, dropped_stays, kept_path, dropped_path)



In [None]:
filtered_csv = os.path.join(out_dir, "filtered.csv")
filter_csv_by_icu_rules(pre_csv, filtered_csv, kept_stays, dropped_stays, chunksize=2_000_000)

In [None]:
stats_csv = os.path.join(out_dir, "code_row_fraction_stats.csv")
stats_row_fraction(filtered_csv, stats_csv, exclude_protected=True, chunksize=2_000_000)

delete events by missing rate

In [None]:
import os, csv
from typing import Set
import pandas as pd
from tqdm import tqdm

PROTECTED_CODES = {"SNOMED/3950001", "SNOMED/419620001"}  # /
ICU_ADMIT_PREFIX     = "MIMIC/ICU_ADMISSION"
ICU_DISCHARGE_PREFIX = "MIMIC/ICU_DISCHARGE"

def _count_lines_fast(path: str, has_header: bool = True) -> int:
    total = 0
    with open(path, "rb") as f:
        while True:
            chunk = f.read(1024 * 1024)
            if not chunk:
                break
            total += chunk.count(b"\n")
    if has_header and total > 0:
        total -= 1
    return max(total, 0)

def _load_top_codes(stats_csv: str, top_n: int = 144) -> Set[str]:
    """ stats_csv  code （ row_fraction ）。"""
    df = pd.read_csv(stats_csv, dtype={"code": "string"})
    if "row_fraction" in df.columns:
        df = df.sort_values("row_fraction", ascending=False)
    top = df["code"].dropna().astype(str).head(top_n).tolist()
    return set(top)

def filter_to_top_codes(
    filtered_csv: str,
    stats_csv: str,
    out_csv: str,
    top_n: int = 144,
    sep: str = ",",
    chunksize: int = 1_000_000,
    encoding: str | None = None,
):
    """
    ：stats  top_n  code +  code（//ICU/），。
    ；；。
    """
    os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)

    keep_codes = _load_top_codes(stats_csv, top_n=top_n)
    keep_codes |= PROTECTED_CODES  # /

    with open(filtered_csv, "r", encoding=encoding or "utf-8") as f:
        original_cols = next(csv.reader(f))

    total_rows = _count_lines_fast(filtered_csv, has_header=True)
    first = True
    written = 0

    desc = f"Filtering to top {top_n} codes (+protected)"
    with tqdm(total=total_rows or None, unit="row", desc=desc) as pbar:
        reader = pd.read_csv(
            filtered_csv, sep=sep, chunksize=chunksize, encoding=encoding,
            dtype={"code": "string"}  # 'SNOMED/2.26E+14'
        )
        for chunk in reader:
            codes = chunk["code"].astype("string")
            mask = (
                codes.isin(keep_codes)
                | codes.str.startswith(ICU_ADMIT_PREFIX, na=False)
                | codes.str.startswith(ICU_DISCHARGE_PREFIX, na=False)
            )
            out_chunk = chunk.loc[mask].reindex(columns=original_cols)
            out_chunk.to_csv(out_csv, mode="w" if first else "a",
                             header=first, index=False)
            first = False
            written += len(out_chunk)
            pbar.update(len(chunk))

    return out_csv, written


In [None]:
out_dir = "/root/autodl-tmp/femr_filtered/held_out"
stats_csv = os.path.join('/root/autodl-tmp/femr_filtered/train', "code_row_fraction_stats.csv")
src      = os.path.join(out_dir, "filtered.csv")  # filtered.csv
dst      = os.path.join(out_dir, "filtered_top144.csv")

filter_to_top_codes(
    filtered_csv=src,
    stats_csv=stats_csv,
    out_csv=dst,
    top_n=144,
    sep=",",
    chunksize=1_000_000,
)

# Step 4: Split CSV into multiple shards

In [None]:
PATH_TO_OUTPUT_DIR = "/root/autodl-tmp/femr_filtered/held_out"
EVENTS_CSV_PATH = os.path.join(PATH_TO_OUTPUT_DIR, "filtered_top144.csv")
NUM_PROCESSES = 8

In [None]:
import pandas as pd
from pathlib import Path


outdir_main = Path(PATH_TO_OUTPUT_DIR) / "shards"

if os.path.exists(outdir_main): shutil.rmtree(outdir_main)
outdir_main.mkdir(parents=True, exist_ok=True)

NUM_SPLITS = 30

# ==========================================


def split_for_multicore():
    """ NUM_SPLITS ，"""
    with open(EVENTS_CSV_PATH) as f:
        total_lines = sum(1 for _ in f) - 1
    rows_per_split = (total_lines // NUM_SPLITS) + 1
    print(f"[]  {total_lines},  {rows_per_split} ")

    reader = pd.read_csv(EVENTS_CSV_PATH, chunksize=rows_per_split)
    for idx, chunk in enumerate(reader):
        shard_path = outdir_main / f"shard_{idx:02d}.csv"
        chunk.to_csv(shard_path, index=False)
        print(f"[]  {shard_path} ({len(chunk)} )")
              
split_for_multicore()
print("\n✅ ！")
print(f": {outdir_main}")

# Step 5: Build FEMR databse

In [None]:
FEMR_INPUT_DIR_CLEAN = outdir_main

In [None]:
import sys

logger.info("---  3:  etl_simple_femr  ---")
python_executable_path = sys.executable
path_to_bin = os.path.dirname(python_executable_path)
path_to_etl_simple_femr = os.path.join(path_to_bin, "etl_simple_femr")

if not os.path.exists(path_to_etl_simple_femr):
    raise FileNotFoundError(f" {path_to_etl_simple_femr}  'etl_simple_femr'。")

FEMR_EXTRACT_DIR = os.path.join(PATH_TO_OUTPUT_DIR, "extract")
FEMR_LOGS_DIR = os.path.join(PATH_TO_OUTPUT_DIR, "logs")
if os.path.exists(FEMR_EXTRACT_DIR): shutil.rmtree(FEMR_EXTRACT_DIR)
if os.path.exists(FEMR_LOGS_DIR): shutil.rmtree(FEMR_LOGS_DIR)

command = (
    f"{path_to_etl_simple_femr} {FEMR_INPUT_DIR_CLEAN} {FEMR_EXTRACT_DIR} {FEMR_LOGS_DIR} --num_threads {NUM_PROCESSES}"
)

logger.info(f": {command}")
exit_code = os.system(command)

if exit_code == 0:
    logger.success("---  3: FEMR ！---")
else:
    logger.error(f"---  3: etl_simple_femr ，: {exit_code} ---")
    raise ChildProcessError("etl_simple_femr failed. Check logs in " + FEMR_LOGS_DIR + " for details.")

In [None]:
# FEMR_EXTRACT_DIR = os.path.join(PATH_TO_OUTPUT_DIR, "extract")
# FEMR_LOGS_DIR = os.path.join(PATH_TO_OUTPUT_DIR, "logs")

logger.info("...")
database = femr.datasets.PatientDatabase(FEMR_EXTRACT_DIR)
logger.info(f": {len(database)}")
logger.info("！")



# Step 5: Vefify

In [None]:
# Logging
database = femr.datasets.PatientDatabase(FEMR_EXTRACT_DIR)
all_patient_ids = list(database)
patient_id: int = all_patient_ids[0]
patient = database[patient_id]
events = patient.events
logger.info(f"FEMR database saved to: {PATH_TO_OUTPUT_DIR}")
logger.info(f"Num patients: {len(database)}")
logger.info(f"Number of events in patient '{patient_id}': {len(events)}")
logger.info(f"First event of patient '{patient_id}': {events[0]}")
logger.success("Done!")

In [None]:
for i in range(len(events)):
    logger.info(f"First event of patient '{patient_id}': {events[i]}")

In [None]:
from pathlib import Path

def human_bytes(n):
    for u in ["B","KB","MB","GB","TB","PB"]:
        if n < 1024: return f"{n:.1f}{u}"
        n /= 1024
    return f"{n:.1f}EB"

def dir_overview(path):
    p = Path(path)
    print("Listing:", p.resolve())
    if not p.exists():
        print("❌ "); return
    for item in sorted(p.iterdir()):
        if item.is_dir():
            files = [f for f in item.rglob("*") if f.is_file()]
            size = sum(f.stat().st_size for f in files)
            print(f"[DIR]  {item.name:30} files={len(files):6d}  size={human_bytes(size)}")
        else:
            print(f"[FILE] {item.name:30} {human_bytes(item.stat().st_size)}")

# dir_overview(FEMR_INPUT_DIR_CLEAN)
# dir_overview("data/mimiciv_meds/data")


In [None]:
dir_overview('/root/autodl-tmp')

In [None]:
import csv, os
from pathlib import Path
import pandas as pd

def human_bytes(n):
    for u in ["B","KB","MB","GB","TB","PB"]:
        if n < 1024: return f"{n:.1f}{u}"
        n /= 1024
    return f"{n:.1f}EB"

def quick_csv_check(csv_path, sample_rows=5, required_cols=None):
    path = Path(csv_path)
    assert path.exists(), f"：{path}"
    size = path.stat().st_size
    print(f"File: {path.resolve()}\nSize: {human_bytes(size)}")

    raw = path.open("rb").read(4096)
    try:
        dialect = csv.Sniffer().sniff(raw.decode("utf-8", "ignore"))
        delim = dialect.delimiter
    except Exception:
        delim = ","
    print("Delimiter:", repr(delim))

    df_head = pd.read_csv(path, sep=delim, nrows=sample_rows, low_memory=False)
    print("\nColumns:", list(df_head.columns))
    print(f"\nHead ({sample_rows} rows):")
    print(df_head)

    try:
        df_probe = pd.read_csv(path, sep=delim, nrows=2000, low_memory=False)
        print("\nInferred dtypes (sample of 2k rows):")
        print(df_probe.dtypes)
        print("\nNA count (sample of 2k rows):")
        print(df_probe.isna().sum())
    except Exception as e:
        print("\n(/) ：", e)

    if required_cols:
        missing = [c for c in required_cols if c not in df_head.columns]
        if missing:
            print("❌ ：", missing)
        else:
            print("✅ ：", required_cols)

    try:
        with path.open("rb") as f:
            n_lines = sum(1 for _ in f)
        print(f"\nLine count (): {n_lines}")
    except Exception as e:
        print("\n()：", e)

# quick_csv_check(Path(FEMR_INPUT_DIR_CLEAN) / "a.csv",
#                 sample_rows=5,


In [None]:
quick_csv_check(Path(FEMR_INPUT_DIR_CLEAN) / "a.csv",
                sample_rows=50,
                required_cols=None)  # ， ['patient_id','time','concept_id']


# Test

In [None]:
import pandas as pd
from pathlib import Path

csv_path = Path(FEMR_INPUT_DIR_CLEAN) / "a.csv"
out_path = Path(FEMR_INPUT_DIR_CLEAN) / "a_small.csv"
patient_col = "patient_id"  # ← ，
n_patients = 5  # ，

uids = []
seen = set()
for chunk in pd.read_csv(csv_path, chunksize=200_000, dtype={patient_col: str}):
    for uid in chunk[patient_col].astype(str).tolist():
        if uid not in seen:
            seen.add(uid)
            uids.append(uid)
            if len(uids) >= n_patients:
                break
    if len(uids) >= n_patients:
        break

print("Selected patients:", uids)

written_header = False
with pd.read_csv(csv_path, chunksize=200_000, dtype={patient_col: str}) as reader:
    for chunk in reader:
        sub = chunk[chunk[patient_col].astype(str).isin(uids)]
        if not sub.empty:
            sub.to_csv(out_path, mode="a", index=False, header=not written_header)
            written_header = True

print("Wrote:", out_path, "size:", out_path.stat().st_size, "bytes")


In [None]:

logger.info("---  3:  etl_simple_femr  ---")

# ##################################################################
# ##################################################################

EVENTS_CSV_PATH_t = os.path.join(FEMR_INPUT_DIR_CLEAN, "a_small.csv")
FEMR_INPUT_DIR_CLEAN_t = os.path.join(PATH_TO_OUTPUT_DIR, "femr_input_temp_test")
CLEAN_CSV_PATH_t = os.path.join(FEMR_INPUT_DIR_CLEAN_t, "a_small.csv")

FEMR_EXTRACT_DIR = os.path.join(PATH_TO_OUTPUT_DIR, "extract")
FEMR_LOGS_DIR = os.path.join(PATH_TO_OUTPUT_DIR, "logs")

if os.path.exists(FEMR_EXTRACT_DIR): shutil.rmtree(FEMR_EXTRACT_DIR)
if os.path.exists(FEMR_LOGS_DIR): shutil.rmtree(FEMR_LOGS_DIR)
# if os.path.exists(FEMR_INPUT_DIR_CLEAN): shutil.rmtree(FEMR_INPUT_DIR_CLEAN)
# os.makedirs(FEMR_INPUT_DIR_CLEAN)

logger.info(f" a_small.csv : {FEMR_INPUT_DIR_CLEAN_t}")
shutil.move(EVENTS_CSV_PATH_t, CLEAN_CSV_PATH_t)