# A. Imports

In [34]:
import pandas as pd
from codebase.utils.clinical_longformer import longformerize
from tqdm import tqdm
import torch
import gc
import numpy as np

pd.set_option('display.max_columns', None)

RANDOM_STATE = 42  # for reproducibility

NOTES_ZHU = "D:/Lab/Research/EMERGE-REPLICATE/preprocessing-zhu/mimic-iii/data/processed/notes/notes.csv"
EHR_ZHU   = "D:/Lab/Research/EMERGE-REPLICATE/preprocessing-zhu/mimic-iii/data/processed/ehr/ehr.csv"

NOTES_BACH = "D:/Lab/Research/EMERGE-REPLICATE/preprocessing-bach/processed/notes.csv"
EHR_BACH   = "D:/Lab/Research/EMERGE-REPLICATE/preprocessing-bach/processed/ehr.csv"

# B. Cleaning

## Keep only first episode

In [35]:
def clean_patient_ids(INPUT_CSV, OUTPUT_CSV):
    df = pd.read_csv(INPUT_CSV)
    pid = df["PatientID"].astype(str)
    extracted = pid.str.extract(r"^(?P<base>\d+)(?:_(?P<suf>\d+))?$")

    # Determine which rows to keep:
    # - Keep if suffix is NaN (no suffix) or equals 1
    # - Drop if suffix is >= 2
    suf_num = pd.to_numeric(extracted["suf"], errors="coerce")
    keep_mask = suf_num.isna() | (suf_num == 1)
    clean = df.loc[keep_mask].copy()
    clean.loc[:, "PatientID"] = extracted.loc[keep_mask, "base"].astype(str)
    clean.to_csv(OUTPUT_CSV, index=False)

    # print("Before:")
    # print(df.head(3))
    # print("\nAfter (cleaned):")
    # print(clean.head(3))

clean_patient_ids(EHR_ZHU, EHR_BACH)
clean_patient_ids(NOTES_ZHU, NOTES_BACH)

## Feature Engineering

In [36]:
notes = pd.read_csv(NOTES_BACH, dtype={"PatientID": "string"})
ehr   = pd.read_csv(EHR_BACH,   dtype={"PatientID": "string"})

MIN_ROWS_PER_PATIENT = 4
ehr_counts = ehr.groupby("PatientID").size()
valid_ehr_ids = ehr_counts[ehr_counts >= MIN_ROWS_PER_PATIENT].index
removed_patients = set(ehr["PatientID"].unique()) - set(valid_ehr_ids)
ehr = ehr[ehr["PatientID"].isin(valid_ehr_ids)].reset_index(drop=True)

print(f"Patients removed for having < {MIN_ROWS_PER_PATIENT} EHR rows: {len(removed_patients)}")

Patients removed for having < 4 EHR rows: 13137


In [37]:
ehr = (
    ehr
    .groupby("PatientID", group_keys=False)
    .head(4)
    .reset_index(drop=True)
)

print("Limited to first 4 EHR rows per patient.")
print(f"EHR rows after truncation: {len(ehr):,}")
print(f"EHR patients: {ehr['PatientID'].nunique():,}")

Limited to first 4 EHR rows per patient.
EHR rows after truncation: 81,328
EHR patients: 20,332


In [38]:
# Check for patients with conflicting Outcome or Readmission values
conflict = ehr.groupby("PatientID")[["Outcome", "Readmission"]].nunique()
conflict_patients = conflict[(conflict["Outcome"] > 1) | (conflict["Readmission"] > 1)]

print("Patients with conflicting Outcome or Readmission values:")
print(conflict_patients)

Patients with conflicting Outcome or Readmission values:
Empty DataFrame
Columns: [Outcome, Readmission]
Index: []


In [39]:
ehr.loc[:, "Capillary refill rate->0.0":"Glascow coma scale verbal response->3 Inapprop words"] = ehr.loc[:, "Capillary refill rate->0.0":"Glascow coma scale verbal response->3 Inapprop words"].fillna(0)
ehr.loc[:, "Diastolic blood pressure":] = ehr.loc[:, "Diastolic blood pressure":].apply(lambda col: col.fillna(col.mean()))

ehr.head()

Unnamed: 0,PatientID,Outcome,Readmission,Sex,Age,Capillary refill rate->0.0,Capillary refill rate->1.0,Glascow coma scale eye opening->To Pain,Glascow coma scale eye opening->3 To speech,Glascow coma scale eye opening->1 No Response,Glascow coma scale eye opening->4 Spontaneously,Glascow coma scale eye opening->None,Glascow coma scale eye opening->To Speech,Glascow coma scale eye opening->Spontaneously,Glascow coma scale eye opening->2 To pain,Glascow coma scale motor response->1 No Response,Glascow coma scale motor response->3 Abnorm flexion,Glascow coma scale motor response->Abnormal extension,Glascow coma scale motor response->No response,Glascow coma scale motor response->4 Flex-withdraws,Glascow coma scale motor response->Localizes Pain,Glascow coma scale motor response->Flex-withdraws,Glascow coma scale motor response->Obeys Commands,Glascow coma scale motor response->Abnormal Flexion,Glascow coma scale motor response->6 Obeys Commands,Glascow coma scale motor response->5 Localizes Pain,Glascow coma scale motor response->2 Abnorm extensn,Glascow coma scale total->11,Glascow coma scale total->10,Glascow coma scale total->13,Glascow coma scale total->12,Glascow coma scale total->15,Glascow coma scale total->14,Glascow coma scale total->3,Glascow coma scale total->5,Glascow coma scale total->4,Glascow coma scale total->7,Glascow coma scale total->6,Glascow coma scale total->9,Glascow coma scale total->8,Glascow coma scale verbal response->1 No Response,Glascow coma scale verbal response->No Response,Glascow coma scale verbal response->Confused,Glascow coma scale verbal response->Inappropriate Words,Glascow coma scale verbal response->Oriented,Glascow coma scale verbal response->No Response-ETT,Glascow coma scale verbal response->5 Oriented,Glascow coma scale verbal response->Incomprehensible sounds,Glascow coma scale verbal response->1.0 ET/Trach,Glascow coma scale verbal response->4 Confused,Glascow coma scale verbal response->2 Incomp sounds,Glascow coma scale verbal response->3 Inapprop words,Diastolic blood pressure,Fraction inspired oxygen,Glucose,Heart Rate,Height,Mean blood pressure,Oxygen saturation,Respiratory rate,Systolic blood pressure,Temperature,Weight,pH
0,100,0,0,0,71.990441,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,58.0,0.4,108.0,72.0,169.124771,69.0,98.0,14.0,100.0,37.3,83.409587,7.46
1,100,0,0,0,71.990441,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,71.0,0.5,98.0,92.0,169.124771,92.0,97.0,19.0,132.0,37.0,63.200001,7.37
2,100,0,0,0,71.990441,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,55.0,0.504306,120.0,108.0,169.124771,67.0,95.0,12.0,95.0,37.0,83.409587,7.38
3,100,0,0,0,71.990441,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,71.0,0.504306,147.0,76.0,169.124771,91.0,100.0,17.0,131.0,36.611112,63.5,7.36
4,1000,1,1,1,69.754941,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,58.0,0.504306,137.055896,87.0,169.124771,76.666702,97.0,20.0,114.0,36.777802,83.409587,7.219703


In [40]:
notes = (
    notes
    .groupby("PatientID", sort=False)["Text"]
    .apply(lambda s: "/n".join(s.astype(str)))
    .reset_index(name="Text")
)

notes.head()

Unnamed: 0,PatientID,Text
0,10000,11:30 am chest ( portable ap ) clip # reason :...
1,10003,"admission note d : pt arrived from , sedated o..."
2,10004,respiratory care pt was admitted today from os...
3,10006,full code universal precautions allergy : hepa...
4,10007,12:24 pm chest port . line placement clip # re...


## Keep only patients with both ehr and notes

In [41]:
# normalize IDs a bit (handle stray spaces / empty)
notes["PatientID"] = notes["PatientID"].str.strip()
ehr["PatientID"]   = ehr["PatientID"].str.strip()
notes = notes.dropna(subset=["PatientID"])
ehr   = ehr.dropna(subset=["PatientID"])

# ---- find intersection ----
ids_notes = set(notes["PatientID"].unique())
ids_ehr   = set(ehr["PatientID"].unique())
ids_both  = ids_notes & ids_ehr

# ---- filter to the same set ----
notes_f = notes[notes["PatientID"].isin(ids_both)].copy()
ehr_f   = ehr[ehr["PatientID"].isin(ids_both)].copy()

# ---- (optional) sort for readability ----
notes_f = notes_f.sort_values(["PatientID"]).reset_index(drop=True)
ehr_f   = ehr_f.sort_values(["PatientID"]).reset_index(drop=True)

# ---- save ----
notes_f.to_csv(NOTES_BACH, index=False)
ehr_f.to_csv(EHR_BACH,   index=False)

# ---- report ----
print("=== BEFORE ===")
print(f"Notes: {len(notes):,} rows | {len(ids_notes):,} unique patients")
print(f"EHR  : {len(ehr):,} rows | {len(ids_ehr):,} unique patients")
print("\n=== AFTER (kept only patients present in BOTH) ===")
print(f"Common patients kept: {len(ids_both):,}")
print(f"Notes.filtered.csv: {len(notes_f):,} rows | {notes_f['PatientID'].nunique():,} patients")
print(f"ehr.filtered.csv  : {len(ehr_f):,} rows | {ehr_f['PatientID'].nunique():,} patients")

=== BEFORE ===
Notes: 31,027 rows | 31,027 unique patients
EHR  : 81,328 rows | 20,332 unique patients

=== AFTER (kept only patients present in BOTH) ===
Common patients kept: 19,307
Notes.filtered.csv: 19,307 rows | 19,307 patients
ehr.filtered.csv  : 77,228 rows | 19,307 patients


# C. Check

## Check Notes

In [42]:
# --- Load the CSV ---
df = pd.read_csv(NOTES_BACH)

# --- Functions to count ---
def char_count(text):
    return len(text)

def word_count(text):
    return len(text.split())

# --- Apply counts ---
df["char_count"] = df["Text"].astype(str).apply(char_count)
df["word_count"] = df["Text"].astype(str).apply(word_count)

# --- Overall stats ---
overall = {
    "max_chars": df["char_count"].max(),
    "avg_chars": df["char_count"].mean(),
    "max_words": df["word_count"].max(),
    "avg_words": df["word_count"].mean(),
}

print("\n🌍 Overall stats:")
print(overall)


🌍 Overall stats:
{'max_chars': 52597, 'avg_chars': np.float64(9399.021960946808), 'max_words': 9494, 'avg_words': np.float64(1778.0428859998965)}


## Check EHR

In [43]:
df = pd.read_csv(EHR_BACH)
print("Shape:", df.shape)

print("Outcome positive ratio (each episode):", df["Outcome"].mean())
print("Readmission positive ratio (each episode):", df["Readmission"].mean())

print("===")
pat_any = df.groupby("PatientID")[["Outcome","Readmission"]].max()

print("Outcome positive ratio (each patient):", pat_any["Outcome"].mean())
print("Readmission positive ratio (each patient):", pat_any["Readmission"].mean())

Shape: (77228, 64)
Outcome positive ratio (each episode): 0.11384471953177604
Readmission positive ratio (each episode): 0.14973843683638058
===
Outcome positive ratio (each patient): 0.11384471953177604
Readmission positive ratio (each patient): 0.14973843683638058
