# Setting

In [None]:
import  os
import pandas as pd
import numpy as np
import math
#
from codebase.utils.bgem3 import cosine_filter, batch_encode
from codebase.utils.call_llm import extract_note, create_summary
from codebase.utils.clinical_longformer import langchain_chunk_embed
from codebase.utils.train_test_split import data_split
#
import torch
from FlagEmbedding import BGEM3FlagModel
from concurrent.futures import ThreadPoolExecutor, as_completed
#
from collections import defaultdict
import pickle
import h5py
#
from tqdm import tqdm
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

In [None]:
# PRIMEKG - dataverse_files
DATAVERSE_DIR   = "D:/Lab/Research/EMERGE-REPLICATE/datasets/dataverse_files/"
PRIMEKG_KG      = DATAVERSE_DIR + "kg.csv"
PRIMEKG_DISEASE = DATAVERSE_DIR + "disease_features.csv"

GOOD_DATA = "D:/Lab/Research/EMERGE-REPLICATE/good_data/"
RECORDS      = 4  # Number of records per patient to keep
RANDOM_STATE = 42 # for reproducibility

# PART 1
NOTES_ZHU  = "D:/Lab/Research/EMERGE-REPLICATE/preprocessing-zhu/mimic-iii/data/processed/notes/notes.csv"
EHR_ZHU    = "D:/Lab/Research/EMERGE-REPLICATE/preprocessing-zhu/mimic-iii/data/processed/ehr/ehr.csv"
NOTES_BACH = GOOD_DATA + "processed/notes.csv"
EHR_BACH   = GOOD_DATA + f"processed/ehr/records_{RECORDS}.csv"

# PART 2
KG_ADJACENCY     = GOOD_DATA + "curated/kg_adjacency.pkl"
DISEASE_FEATURES = GOOD_DATA + "curated/disease_features_cleaned.pkl"
NOTES_EMBEDDINGS = GOOD_DATA + "curated/notes_embeddings.h5"

TRAIN = GOOD_DATA + f"complete/records_{RECORDS}/train.h5"
VAL   = GOOD_DATA + f"complete/records_{RECORDS}/val.h5"
TEST  = GOOD_DATA + f"complete/records_{RECORDS}/test.h5"

# Part 1: Preprocess EHR-Notes Datasets

## I. Preprocess

### Keep only first episode

In [None]:
def clean_patient_ids(INPUT_CSV, OUTPUT_CSV):
    df = pd.read_csv(INPUT_CSV, low_memory=False, encoding="utf-8")
    pid = df["PatientID"].astype(str)
    extracted = pid.str.extract(r"^(?P<base>\d+)(?:_(?P<suf>\d+))?$")

    # Determine which rows to keep:
    # - Keep if suffix is NaN (no suffix) or equals 1
    # - Drop if suffix is >= 2
    suf_num = pd.to_numeric(extracted["suf"], errors="coerce")
    keep_mask = suf_num.isna() | (suf_num == 1)
    clean = df.loc[keep_mask].copy()
    clean.loc[:, "PatientID"] = extracted.loc[keep_mask, "base"].astype(str)
    clean.to_csv(OUTPUT_CSV, index=False)

    # print("Before:")
    # print(df.head(3))
    # print("\nAfter (cleaned):")
    # print(clean.head(3))

clean_patient_ids(EHR_ZHU, EHR_BACH)
clean_patient_ids(NOTES_ZHU, NOTES_BACH)

### Feature Engineering

In [None]:
print(f"Limited to first {RECORDS} EHR rows per patient.")
notes = pd.read_csv(NOTES_BACH, dtype={"PatientID": "string"}, encoding="utf-8", low_memory=False)
ehr   = pd.read_csv(EHR_BACH,   dtype={"PatientID": "string"}, encoding="utf-8", low_memory=False)

ehr_counts = ehr.groupby("PatientID").size()
valid_ehr_ids = ehr_counts[ehr_counts >= RECORDS].index
removed_patients = set(ehr["PatientID"].unique()) - set(valid_ehr_ids)
ehr = ehr[ehr["PatientID"].isin(valid_ehr_ids)].reset_index(drop=True)

print(f"Patients removed for having < {RECORDS} EHR rows: {len(removed_patients)}")

ehr = (
    ehr
    .groupby("PatientID", group_keys=False)
    .head(4)
    .reset_index(drop=True)
)

print(f"EHR rows after truncation: {len(ehr):,}")
print(f"EHR patients: {ehr['PatientID'].nunique():,}")

Patients removed for having < 4 EHR rows: 13137


In [38]:
# Check for patients with conflicting Outcome or Readmission values
conflict = ehr.groupby("PatientID")[["Outcome", "Readmission"]].nunique()
conflict_patients = conflict[(conflict["Outcome"] > 1) | (conflict["Readmission"] > 1)]

print("Patients with conflicting Outcome or Readmission values:")
print(conflict_patients)

Patients with conflicting Outcome or Readmission values:
Empty DataFrame
Columns: [Outcome, Readmission]
Index: []


In [None]:
# Data Imputation: Fill missing values with 0 or column mean
ehr.loc[:, "Capillary refill rate->0.0":"Glascow coma scale verbal response->3 Inapprop words"] = ehr.loc[:, "Capillary refill rate->0.0":"Glascow coma scale verbal response->3 Inapprop words"].fillna(0)
ehr.loc[:, "Diastolic blood pressure":] = ehr.loc[:, "Diastolic blood pressure":].apply(lambda col: col.fillna(col.mean()))

ehr.head()

Unnamed: 0,PatientID,Outcome,Readmission,Sex,Age,Capillary refill rate->0.0,Capillary refill rate->1.0,Glascow coma scale eye opening->To Pain,Glascow coma scale eye opening->3 To speech,Glascow coma scale eye opening->1 No Response,Glascow coma scale eye opening->4 Spontaneously,Glascow coma scale eye opening->None,Glascow coma scale eye opening->To Speech,Glascow coma scale eye opening->Spontaneously,Glascow coma scale eye opening->2 To pain,Glascow coma scale motor response->1 No Response,Glascow coma scale motor response->3 Abnorm flexion,Glascow coma scale motor response->Abnormal extension,Glascow coma scale motor response->No response,Glascow coma scale motor response->4 Flex-withdraws,Glascow coma scale motor response->Localizes Pain,Glascow coma scale motor response->Flex-withdraws,Glascow coma scale motor response->Obeys Commands,Glascow coma scale motor response->Abnormal Flexion,Glascow coma scale motor response->6 Obeys Commands,Glascow coma scale motor response->5 Localizes Pain,Glascow coma scale motor response->2 Abnorm extensn,Glascow coma scale total->11,Glascow coma scale total->10,Glascow coma scale total->13,Glascow coma scale total->12,Glascow coma scale total->15,Glascow coma scale total->14,Glascow coma scale total->3,Glascow coma scale total->5,Glascow coma scale total->4,Glascow coma scale total->7,Glascow coma scale total->6,Glascow coma scale total->9,Glascow coma scale total->8,Glascow coma scale verbal response->1 No Response,Glascow coma scale verbal response->No Response,Glascow coma scale verbal response->Confused,Glascow coma scale verbal response->Inappropriate Words,Glascow coma scale verbal response->Oriented,Glascow coma scale verbal response->No Response-ETT,Glascow coma scale verbal response->5 Oriented,Glascow coma scale verbal response->Incomprehensible sounds,Glascow coma scale verbal response->1.0 ET/Trach,Glascow coma scale verbal response->4 Confused,Glascow coma scale verbal response->2 Incomp sounds,Glascow coma scale verbal response->3 Inapprop words,Diastolic blood pressure,Fraction inspired oxygen,Glucose,Heart Rate,Height,Mean blood pressure,Oxygen saturation,Respiratory rate,Systolic blood pressure,Temperature,Weight,pH
0,100,0,0,0,71.990441,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,58.0,0.4,108.0,72.0,169.124771,69.0,98.0,14.0,100.0,37.3,83.409587,7.46
1,100,0,0,0,71.990441,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,71.0,0.5,98.0,92.0,169.124771,92.0,97.0,19.0,132.0,37.0,63.200001,7.37
2,100,0,0,0,71.990441,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,55.0,0.504306,120.0,108.0,169.124771,67.0,95.0,12.0,95.0,37.0,83.409587,7.38
3,100,0,0,0,71.990441,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,71.0,0.504306,147.0,76.0,169.124771,91.0,100.0,17.0,131.0,36.611112,63.5,7.36
4,1000,1,1,1,69.754941,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,58.0,0.504306,137.055896,87.0,169.124771,76.666702,97.0,20.0,114.0,36.777802,83.409587,7.219703


In [None]:
# Combine all notes for each patient into a single string, separated by "\n"
notes = (
    notes
    .groupby("PatientID", sort=False)["Text"]
    .apply(lambda s: "\n".join(s.astype(str)))
    .reset_index(name="Text")
)

notes.head()

Unnamed: 0,PatientID,Text
0,10000,11:30 am chest ( portable ap ) clip # reason :...
1,10003,"admission note d : pt arrived from , sedated o..."
2,10004,respiratory care pt was admitted today from os...
3,10006,full code universal precautions allergy : hepa...
4,10007,12:24 pm chest port . line placement clip # re...


### Keep only patients with both ehr and notes

In [41]:
# normalize IDs a bit (handle stray spaces / empty)
notes["PatientID"] = notes["PatientID"].str.strip()
ehr["PatientID"]   = ehr["PatientID"].str.strip()
notes = notes.dropna(subset=["PatientID"])
ehr   = ehr.dropna(subset=["PatientID"])

# ---- find intersection ----
ids_notes = set(notes["PatientID"].unique())
ids_ehr   = set(ehr["PatientID"].unique())
ids_both  = ids_notes & ids_ehr

# ---- filter to the same set ----
notes_f = notes[notes["PatientID"].isin(ids_both)].copy()
ehr_f   = ehr[ehr["PatientID"].isin(ids_both)].copy()

# ---- (optional) sort for readability ----
notes_f = notes_f.sort_values(["PatientID"]).reset_index(drop=True)
ehr_f   = ehr_f.sort_values(["PatientID"]).reset_index(drop=True)

# ---- save ----
notes_f.to_csv(NOTES_BACH, index=False)
ehr_f.to_csv(EHR_BACH,   index=False)

# ---- report ----
print("=== BEFORE ===")
print(f"Notes: {len(notes):,} rows | {len(ids_notes):,} unique patients")
print(f"EHR  : {len(ehr):,} rows | {len(ids_ehr):,} unique patients")
print("\n=== AFTER (kept only patients present in BOTH) ===")
print(f"Common patients kept: {len(ids_both):,}")
print(f"Notes.filtered.csv: {len(notes_f):,} rows | {notes_f['PatientID'].nunique():,} patients")
print(f"ehr.filtered.csv  : {len(ehr_f):,} rows | {ehr_f['PatientID'].nunique():,} patients")

=== BEFORE ===
Notes: 31,027 rows | 31,027 unique patients
EHR  : 81,328 rows | 20,332 unique patients

=== AFTER (kept only patients present in BOTH) ===
Common patients kept: 19,307
Notes.filtered.csv: 19,307 rows | 19,307 patients
ehr.filtered.csv  : 77,228 rows | 19,307 patients


## II. Check

### Check Notes

In [None]:
# --- Load the CSV ---
df = pd.read_csv(NOTES_BACH, encoding="utf-8", low_memory=False)

# --- Functions to count ---
def char_count(text):
    return len(text)

def word_count(text):
    return len(text.split())

# --- Apply counts ---
df["char_count"] = df["Text"].astype(str).apply(char_count)
df["word_count"] = df["Text"].astype(str).apply(word_count)

# --- Overall stats ---
overall = {
    "max_chars": df["char_count"].max(),
    "avg_chars": df["char_count"].mean(),
    "max_words": df["word_count"].max(),
    "avg_words": df["word_count"].mean(),
}

print("\n🌍 Overall stats:")
print(overall)


🌍 Overall stats:
{'max_chars': 52597, 'avg_chars': np.float64(9399.021960946808), 'max_words': 9494, 'avg_words': np.float64(1778.0428859998965)}


### Check EHR

In [None]:
df = pd.read_csv(EHR_BACH, encoding="utf-8", low_memory=False)
print("Shape:", df.shape)

print("Outcome positive ratio (each episode):", df["Outcome"].mean())
print("Readmission positive ratio (each episode):", df["Readmission"].mean())

print("===")
pat_any = df.groupby("PatientID")[["Outcome","Readmission"]].max()

print("Outcome positive ratio (each patient):", pat_any["Outcome"].mean())
print("Readmission positive ratio (each patient):", pat_any["Readmission"].mean())

Shape: (77228, 64)
Outcome positive ratio (each episode): 0.11384471953177604
Readmission positive ratio (each episode): 0.14973843683638058
===
Outcome positive ratio (each patient): 0.11384471953177604
Readmission positive ratio (each patient): 0.14973843683638058


# Part 2: Preprocess PrimeKG and Extract Entities from Datasets

## PrimeKG

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = BGEM3FlagModel("BAAI/bge-m3", use_fp16=True, device=device, trust_remote_code=True)

def create_adj():
    # Create adjacency list
    df = pd.read_csv(PRIMEKG_KG, low_memory=False, encoding="utf-8")
    df = df[["relation", "x_index", "y_index"]]

    adj_list = defaultdict(list)
    for u, v, r in tqdm(zip(df["x_index"].values, df["y_index"].values, df["relation"].values), total=len(df), desc="Creating adjacency list"):
        adj_list[int(u)].append((int(v), str(r)))

    with open(KG_ADJACENCY, "wb") as f:
        pickle.dump(adj_list, f, protocol=5)

def create_disease_features():
    # Create disease feature
    df = pd.read_csv(PRIMEKG_DISEASE, low_memory=False, encoding="utf-8")
    df = df.sort_values("node_index").reset_index(drop=True)

    df["Diseases"] = (
        "[disease name]" + df["mondo_name"].fillna("") + " " +
        "[definition]" + df["mondo_definition"].combine_first(df["orphanet_definition"]).fillna("") + " " +
        "[description]" + df["umls_description"].fillna("")
    )
    df["embed"] = list(batch_encode(df["Diseases"].tolist(), batch_size=64, max_length=8192).cpu().numpy())
    df = df[["node_index", "mondo_name", "Diseases", "embed"]]

    with open(DISEASE_FEATURES, "wb") as f:
        pickle.dump(df, f, protocol=5)

def create_notes_embeddings():
    # Embed notes and save to HDF5
    with h5py.File(NOTES_EMBEDDINGS, "w") as f: # reset file
        pass

    notes_df = pd.read_csv(NOTES_BACH, encoding="utf-8", low_memory=False)
    for idx, row in tqdm(notes_df.iterrows(), total=len(notes_df), desc="Embedding notes and saving to HDF5"):
        patient_id = row["PatientID"]
        text = row["Text"]
        
        with h5py.File(NOTES_EMBEDDINGS, "a") as h5:
            grp = h5.create_group(str(patient_id))
            grp.create_dataset("PatientID", data=np.asarray(patient_id, dtype="int64"))
            grp.create_dataset("Note", data=langchain_chunk_embed(text), compression="gzip")

In [None]:
if not os.path.exists(KG_ADJACENCY):
    create_adj()

if not os.path.exists(DISEASE_FEATURES):
    create_disease_features()

if not os.path.exists(NOTES_EMBEDDINGS):
    create_notes_embeddings()

## Extract Entities

In [None]:
def _to_float32_array(x):
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().float().numpy()
    if isinstance(x, np.ndarray):
        return x.astype("float32", copy=False)
    raise TypeError(f"Expected tensor/ndarray, got {type(x)}")

def store_patient(h5_path, p, ehr, target, notes, summary):
    with h5py.File(h5_path, "a") as h5:
        grp = h5.create_group(str(p))
        grp.create_dataset("PatientID", data=np.asarray(p, dtype="int64"))
        grp.create_dataset("X", data=ehr, compression="gzip")
        grp.create_dataset("Note", data=_to_float32_array(notes), compression="gzip")
        grp.create_dataset("Summary", data=_to_float32_array(summary), compression="gzip")
        grp.create_dataset("Y", data=np.asarray(target, dtype="int8"))

def create_complete_dataset():
    with h5py.File(TRAIN, "w") as f:
        pass
    with h5py.File(VAL, "w") as f:
        pass
    with h5py.File(TEST, "w") as f:
        pass

    # Preparing datasets
    df = pd.read_csv(EHR_BACH, encoding="utf-8", low_memory=False)

    with open(DISEASE_FEATURES, "rb") as f:
        kg = pickle.load(f)
    mapping = dict(zip(kg["node_index"], kg["mondo_name"]))

    with open(KG_ADJACENCY, "rb") as f:
        adj = pickle.load(f)

    notes_emb = {}
    with h5py.File(NOTES_EMBEDDINGS, "r") as h5:
        for patient_id in h5.keys():  # each group is named by patient_id
            grp = h5[patient_id]
            pid = int(grp["PatientID"][()])
            embedding = np.array(grp["Note"])
            notes_emb[pid] = embedding
    # print(notes_emb.keys())
    # print(notes_emb[91199].shape) # 768

    # Preprocess EHR data to extract entities
    cat_col = df.columns[5:-12]
    num_col = df.columns[-12:]

    col_mean = df[num_col].mean()
    col_std = df[num_col].std()

    entities = defaultdict(list)
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        record = ""
        PatientID = row["PatientID"]

        if row["Sex"] == 1:
            record += "Gender: Male\n"
        else:
            record += "Gender: Female\n"
        record += f"Age: {row['Age']}\n"
        
        for c in cat_col:
            if row[c] == 1:
                cat = c
                if "Glascow coma scale total" not in cat:
                    for i in range(0, 30, 1):
                        cat = cat.replace(f"->{i}.0", " : ")
                        cat = cat.replace(f"->{i}", " : ")
                cat = cat.replace("->", " : ")
                entities[PatientID].append(cat)
        
        for c in num_col:
            if math.isnan(row[c]):
                continue
            z_score = (row[c] - col_mean[c]) / col_std[c]
            if z_score > 2:
                entities[PatientID].append(f"{c} too high")
            elif z_score < -2:
                entities[PatientID].append(f"{c} too low")

    # Match entities to knowledge graph
    patients = list(df["PatientID"].unique())

    train_ids, val_ids, test_ids = data_split()

    def get_summary(p):
        entities[p] = list(set(entities[p]))
        summary_entities = ""
        summary_nodes = ""
        summary_edges = ""
        nodes = []
        for e in entities[p]:
            summary_entities += e + ", "
            idx = cosine_filter(None, e, threshold=0.6, top_k=3)
            nodes.extend(idx)
        summary_entities = summary_entities[:-2]

        nodes = list(set(nodes))
        
        for n in nodes:
            summary_nodes += kg.iloc[n]["Diseases"] + ", "
            node_x = kg.iloc[n]["node_index"]
            for connect_to in adj[n]:
                rela = connect_to[1]
                node_y = connect_to[0]
                if node_y not in kg["node_index"].values:
                    continue
                e = "(" + mapping[node_x] + ", " + str(rela) + ", " + mapping[node_y] + ")"
                # print(e)
                summary_edges += e + ", "
        summary_edges = summary_edges[:-2]
        summary_nodes = summary_nodes[:-2]
        summary_notes = extract_note(notes=notes_emb[p])

        summary = create_summary(summary_entities, summary_notes, summary_nodes, summary_edges)
        return langchain_chunk_embed(summary)

    # summaries = defaultdict(list)
    # for p in tqdm(patients, total=len(patients), desc="Generating summaries"):
    #     summaries[p] = get_summary(p)

    feature_cols = [c for c in df.columns if c not in ["PatientID","Outcome","Readmission"]]
    target_map = df.groupby("PatientID")[["Outcome","Readmission"]].first()

    for p in tqdm(patients):
        data_ehr = df.loc[df["PatientID"] == p, feature_cols].to_numpy()
        data_notes = notes_emb[p]
        data_summary = data_notes  # default to notes if summary fails
        # data_summary = summaries[p]
        outcome, readm = target_map.loc[p].astype(int)
        data_target = (int(outcome), int(readm))

        if p in train_ids:
            h5_path = TRAIN
        elif p in val_ids:
            h5_path = VAL
        else:
            h5_path = TEST
        store_patient(h5_path, p, data_ehr, data_target, data_notes, data_summary)

In [None]:
if not os.path.exists(TRAIN) or not os.path.exists(VAL) or not os.path.exists(TEST):
    create_complete_dataset()