# Setting

In [1]:
import  os
import pandas as pd
import numpy as np
import math
#
from utils.bgem3 import cosine_filter, batch_encode
from utils.call_llm import extract_note, create_summary
from utils.clinical_longformer import langchain_chunk_embed, plain_truncate
from utils.constants import *
#
import torch
from FlagEmbedding import BGEM3FlagModel
from concurrent.futures import ThreadPoolExecutor, as_completed
#
from sklearn.model_selection import train_test_split
from collections import defaultdict, Counter
import pickle
import h5py
#
from tqdm import tqdm
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

  from .autonotebook import tqdm as notebook_tqdm
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 30016.49it/s]
Some weights of LongformerModel were not initialized from the model checkpoint at yikuan8/Clinical-Longformer and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


cuda available: True
model device: cuda:0


# Part 1: Preprocess EHR-Notes Datasets

## I. Preprocess

### Keep only first episode

In [2]:
def clean_patient_ids(INPUT_CSV, OUTPUT_CSV):
    df = pd.read_csv(INPUT_CSV, low_memory=False, encoding="utf-8")
    pid = df["PatientID"].astype(str)
    extracted = pid.str.extract(r"^(?P<base>\d+)(?:_(?P<suf>\d+))?$")

    suf_num = pd.to_numeric(extracted["suf"], errors="coerce")
    keep_mask = suf_num.isna() | (suf_num == 1)
    clean = df.loc[keep_mask].copy()
    clean.loc[:, "PatientID"] = extracted.loc[keep_mask, "base"].astype(str)
    clean.to_csv(OUTPUT_CSV, index=False)

clean_patient_ids(EHR_ZHU, EHR_BACH)
clean_patient_ids(NOTES_ZHU, NOTES_BACH)

  clean.loc[:, "PatientID"] = extracted.loc[keep_mask, "base"].astype(str)


### Feature Engineering

In [3]:
notes = pd.read_csv(NOTES_BACH, dtype={"PatientID": "string"}, encoding="utf-8", low_memory=False)
ehr   = pd.read_csv(EHR_BACH,   dtype={"PatientID": "string"}, encoding="utf-8", low_memory=False)

# Check for patients with conflicting Outcome or Readmission values
conflict = ehr.groupby("PatientID")[["Outcome", "Readmission"]].nunique()
conflict_patients = conflict[(conflict["Outcome"] > 1) | (conflict["Readmission"] > 1)]

print("Patients with conflicting Outcome or Readmission values:")
print(conflict_patients)

dist = (
    ehr.groupby("PatientID").size()      # rows per patient
       .value_counts()                   # how many patients have that many rows
       .sort_index()                     # sort by row count
)

print("Distribution of records per patient:")
for rows, n_patients in dist.items():
    print(f"{rows} record(s): {n_patients} patients")
    if rows >= 10:
        break

dist_ge = dist.sort_index(ascending=False).cumsum().sort_index()
print("\nCumulative distribution (>= records):")
for rows, n_patients in dist_ge.items():
    print(f">= {rows} record(s): {n_patients} patients")
    if rows >= 10:
        break

Patients with conflicting Outcome or Readmission values:
Empty DataFrame
Columns: [Outcome, Readmission]
Index: []
Distribution of records per patient:
1 record(s): 970 patients
2 record(s): 6347 patients
3 record(s): 5820 patients
4 record(s): 4703 patients
5 record(s): 2934 patients
6 record(s): 2510 patients
7 record(s): 1598 patients
8 record(s): 1305 patients
9 record(s): 901 patients
10 record(s): 871 patients

Cumulative distribution (>= records):
>= 1 record(s): 33469 patients
>= 2 record(s): 32499 patients
>= 3 record(s): 26152 patients
>= 4 record(s): 20332 patients
>= 5 record(s): 15629 patients
>= 6 record(s): 12695 patients
>= 7 record(s): 10185 patients
>= 8 record(s): 8587 patients
>= 9 record(s): 7282 patients
>= 10 record(s): 6381 patients


In [4]:
# Combine all notes for each patient into a single string, separated by "\n"
notes = (
    notes
    .groupby("PatientID", sort=False)["Text"]
    .apply(lambda s: "\n".join(s.astype(str)))
    .reset_index(name="Text")
)

notes.head()

Unnamed: 0,PatientID,Text
0,10000,11:30 am chest ( portable ap ) clip # reason :...
1,10003,"admission note d : pt arrived from , sedated o..."
2,10004,respiratory care pt was admitted today from os...
3,10006,full code universal precautions allergy : hepa...
4,10007,12:24 pm chest port . line placement clip # re...


#### Method 1: Remove all patients with less than X records, and keep only X records among those qualified

In [5]:
# print(f"Limited to first {RECORDS} EHR rows per patient.")

# ehr_counts = ehr.groupby("PatientID").size()
# valid_ehr_ids = ehr_counts[ehr_counts >= RECORDS].index
# removed_patients = set(ehr["PatientID"].unique()) - set(valid_ehr_ids)
# ehr = ehr[ehr["PatientID"].isin(valid_ehr_ids)].reset_index(drop=True)

# print(f"Patients removed for having < {RECORDS} EHR rows: {len(removed_patients)}")

# ehr = (
#     ehr
#     .groupby("PatientID", group_keys=False)
#     .head(RECORDS)
#     .reset_index(drop=True)
# )

# print(f"EHR rows after truncation: {len(ehr):,}")
# print(f"EHR patients: {ehr['PatientID'].nunique():,}")

#### Method 2: Same as method 1, but we respect the RecordTime (take up to RecordTime X)

In [6]:
# # Keep a copy for reporting
# _ids_before = set(ehr["PatientID"].unique())
# _rows_before = len(ehr)

# # Coerce RecordTime to numeric and drop invalid/missing
# ehr["RecordTime"] = pd.to_numeric(ehr["RecordTime"], errors="coerce")
# ehr = ehr[ehr["RecordTime"].between(1, RECORDS, inclusive="both")].copy()

# # Keep only patients that have all RecordTime values 1..RECORDS present
# rt_counts = ehr.groupby("PatientID")["RecordTime"].nunique()
# valid_ids = rt_counts[rt_counts == RECORDS].index

# removed_patients = _ids_before - set(valid_ids)
# ehr = ehr[ehr["PatientID"].isin(valid_ids)].sort_values(["PatientID", "RecordTime"]).reset_index(drop=True)

# print(f"Rows before: {_rows_before:,} | after filtering by RecordTime: {len(ehr):,}")
# print(f"Patients removed for missing any RecordTime in 1..{RECORDS}: {len(removed_patients):,}")
# print(f"EHR rows kept: {len(ehr):,}")
# print(f"EHR patients kept: {ehr['PatientID'].nunique():,}")

### Method 3: keep all patients, add empty rows for those who don't have enough 48 records

In [7]:
id_col   = "PatientID"
time_col = "RecordTime"
keep_dup = ["Outcome", "Readmission", "Sex", "Age"]

print(f"Targeting exactly {RECORDS} rows per patient (pad with empties; {time_col}=-1).")

def pad_or_trim(g):
    g = g.sort_values(time_col, kind="stable").head(RECORDS)
    missing = RECORDS - len(g)
    if missing <= 0:
        return g

    # Build a template "empty" row
    base = {c: np.nan for c in ehr.columns}
    base[id_col] = g[id_col].iat[0]
    for c in keep_dup:
        if c in g: base[c] = g[c].iat[0]
    base[time_col] = -1  # negative so it won't look like a latest record

    filler = pd.DataFrame([base] * missing, columns=ehr.columns)
    return pd.concat([g, filler], ignore_index=True)

ehr = (
    ehr.groupby(id_col, group_keys=False)
       .apply(pad_or_trim)
       .reset_index(drop=True)
)

n_patients = ehr[id_col].nunique()
n_rows     = len(ehr)
n_pad      = int((ehr[time_col] == -1).sum())

print(f"Patients: {n_patients:,}")
print(f"Rows after pad/trim: {n_rows:,} (expected {n_patients*RECORDS:,})")
print(f"Padded rows added: {n_pad:,}")


Targeting exactly 48 rows per patient (pad with empties; RecordTime=-1).


  .apply(pad_or_trim)


Patients: 33,469
Rows after pad/trim: 1,606,512 (expected 1,606,512)
Padded rows added: 1,364,295


### Keep only patients with both ehr and notes

In [8]:
# normalize IDs a bit (handle stray spaces / empty)
notes["PatientID"] = notes["PatientID"].str.strip()
ehr["PatientID"]   = ehr["PatientID"].str.strip()

# ---- find intersection ----
ids_notes = set(notes["PatientID"].unique())
ids_ehr   = set(ehr["PatientID"].unique())
ids_both  = ids_notes & ids_ehr

# ---- filter to the same set ----
notes_f = notes[notes["PatientID"].isin(ids_both)].copy()
ehr_f   = ehr[ehr["PatientID"].isin(ids_both)].copy()

# ---- sort for readability ----
notes_f = notes_f.sort_values(["PatientID"]).reset_index(drop=True)
ehr_f   = ehr_f.sort_values(["PatientID", "RecordTime"]).reset_index(drop=True)

ehr_f.drop(columns=["RecordTime"], inplace=True)

# ---- save ----
notes_f.to_csv(NOTES_BACH, index=False)
ehr_f.to_csv(EHR_BACH,   index=False)

# ---- report ----
print("=== BEFORE ===")
print(f"Notes: {len(notes):,} rows | {len(ids_notes):,} unique patients")
print(f"EHR  : {len(ehr):,} rows | {len(ids_ehr):,} unique patients")
print("\n=== AFTER (kept only patients present in BOTH) ===")
print(f"Common patients kept: {len(ids_both):,}")
print(f"Notes.filtered.csv: {len(notes_f):,} rows | {notes_f['PatientID'].nunique():,} patients")
print(f"ehr.filtered.csv  : {len(ehr_f):,} rows | {ehr_f['PatientID'].nunique():,} patients")

=== BEFORE ===
Notes: 31,074 rows | 31,074 unique patients
EHR  : 1,606,512 rows | 33,469 unique patients

=== AFTER (kept only patients present in BOTH) ===
Common patients kept: 31,027
Notes.filtered.csv: 31,027 rows | 31,027 patients
ehr.filtered.csv  : 1,489,296 rows | 31,027 patients


### Splitting Train/Test/Val sets

In [9]:
def cols_between(_df, start_label, end_label=None):
    cols = _df.columns
    start_idx = cols.get_loc(start_label)
    end_idx = len(cols) - 1 if end_label is None else cols.get_loc(end_label)
    if start_idx > end_idx:
        raise ValueError(f"{start_label!r} comes after {end_label!r} in columns")
    return cols[start_idx:end_idx + 1]

def data_split_and_impute():
    df = pd.read_csv(EHR_BACH, encoding='utf-8', low_memory=False)

    pat = (
        df.groupby('PatientID', as_index=False)
            .agg(Outcome=('Outcome','first'),
                Readmission=('Readmission','first'))
    )
    pat['joint'] = pat['Outcome'].astype(int)*2 + pat['Readmission'].astype(int)
    print("Patient counts per joint class (O*2+R):", Counter(pat['joint']))

    # 1) TEST = 20% (stratified on joint)
    pat_trainval, pat_test = train_test_split(
        pat, test_size=0.20, stratify=pat['joint'], random_state=RANDOM_STATE
    )

    # 2) VAL = 12.5% of remaining (i.e., ~10% overall)
    pat_train, pat_val = train_test_split(
        pat_trainval, test_size=0.125, stratify=pat_trainval['joint'], random_state=RANDOM_STATE
    )

    print("Train/Val/Test patients:", len(pat_train), len(pat_val), len(pat_test))

    train_ids = set(pat_train['PatientID'])
    val_ids   = set(pat_val['PatientID'])
    test_ids  = set(pat_test['PatientID'])

    train_df = df[df['PatientID'].isin(train_ids)].copy()
    val_df   = df[df['PatientID'].isin(val_ids)].copy()
    test_df  = df[df['PatientID'].isin(test_ids)].copy()

    for name, d in [('Train', train_df), ('Val', val_df), ('Test', test_df)]:
        print(f"{name}: rows={len(d):,}, patients={d['PatientID'].nunique():,}")

    # ====== DEFINE FEATURE BLOCKS ======
    cat_cols = list(cols_between(
        df,
        "Capillary refill rate->0.0",
        "Glascow coma scale verbal response->3 Inapprop words"
    ))
    num_cols = list(cols_between(df, "Diastolic blood pressure", None))

    # ====== COMPUTE IMPUTATION VALUES ON TRAIN ONLY ======
    # Categorical: all NaNs -> 0 (no stats needed)
    # Numeric: per-column mean from TRAIN
    # Ensure numeric dtype for means; if some numeric cols are object due to bad parsing, coerce safely
    train_num = train_df[num_cols].apply(pd.to_numeric, errors='coerce')
    num_impute = train_num.mean()  # pandas Series indexed by column name

    # ====== APPLY IMPUTATION (USING TRAIN STATS) ======
    def apply_impute(d):
        d = d.copy()
        if cat_cols:
            d.loc[:, cat_cols] = d.loc[:, cat_cols].fillna(0)
        if num_cols:
            d.loc[:, num_cols] = d.loc[:, num_cols].apply(pd.to_numeric, errors='coerce')
            d.loc[:, num_cols] = d.loc[:, num_cols].fillna(num_impute)
        return d

    train_df_i = apply_impute(train_df)
    val_df_i   = apply_impute(val_df)
    test_df_i  = apply_impute(test_df)

    # ====== QUICK CHECKS AFTER IMPUTATION ======
    def summarize_split(name, df_rows, df_pat):
        o_row = df_rows['Outcome'].mean()
        r_row = df_rows['Readmission'].mean()
        o_pat = df_pat['Outcome'].mean()
        r_pat = df_pat['Readmission'].mean()
        print(
            f"{name} — Outcome: rows={o_row:.3%}, patients={o_pat:.3%} | "
            f"Readmission: rows={r_row:.3%}, patients={r_pat:.3%}"
            f"NaN: {(df_rows.isna().sum().sum())}"
        )

    summarize_split("Train", train_df_i, pat_train)
    summarize_split("Val",   val_df_i,   pat_val)
    summarize_split("Test",  test_df_i,  pat_test)

    train_df_i.to_csv(TRAIN_DRAFT, index=False)
    val_df_i.to_csv(VAL_DRAFT, index=False)
    test_df_i.to_csv(TEST_DRAFT, index=False)

if not os.path.exists(TRAIN_DRAFT) or not os.path.exists(VAL_DRAFT) or not os.path.exists(TEST_DRAFT):
    data_split_and_impute()

Patient counts per joint class (O*2+R): Counter({0: 26775, 3: 3186, 1: 1015, 2: 51})
Train/Val/Test patients: 21718 3103 6206
Train: rows=1,042,464, patients=21,718
Val: rows=148,944, patients=3,103
Test: rows=297,888, patients=6,206
Train — Outcome: rows=10.434%, patients=10.434% | Readmission: rows=13.542%, patients=13.542%NaN: 0
Val — Outcome: rows=10.442%, patients=10.442% | Readmission: rows=13.535%, patients=13.535%NaN: 0
Test — Outcome: rows=10.425%, patients=10.425% | Readmission: rows=13.535%, patients=13.535%NaN: 0


## II. Check

### Check Notes

In [10]:
df = pd.read_csv(NOTES_BACH, encoding="utf-8", low_memory=False)

def char_count(text):
    return len(text)
def word_count(text):
    return len(text.split())

df["char_count"] = df["Text"].astype(str).apply(char_count)
df["word_count"] = df["Text"].astype(str).apply(word_count)

overall = {
    "max_chars": df["char_count"].max(),
    "avg_chars": df["char_count"].mean(),
    "max_words": df["word_count"].max(),
    "avg_words": df["word_count"].mean(),
}

print("\n🌍 Overall stats:")
print(overall)


🌍 Overall stats:
{'max_chars': 52570, 'avg_chars': np.float64(7748.050439939408), 'max_words': 9521, 'avg_words': np.float64(1474.1109034067103)}


### Check EHR

In [11]:
df = pd.read_csv(EHR_BACH, encoding="utf-8", low_memory=False)
print("Shape:", df.shape)

print("Outcome positive ratio (each episode):", df["Outcome"].mean())
print("Readmission positive ratio (each episode):", df["Readmission"].mean())

print("===")
pat_any = df.groupby("PatientID")[["Outcome","Readmission"]].max()

print("Outcome positive ratio (each patient):", pat_any["Outcome"].mean())
print("Readmission positive ratio (each patient):", pat_any["Readmission"].mean())

Shape: (1489296, 64)
Outcome positive ratio (each episode): 0.1043284880910175
Readmission positive ratio (each episode): 0.13539820156637766
===
Outcome positive ratio (each patient): 0.1043284880910175
Readmission positive ratio (each patient): 0.13539820156637766


# Part 2: Preprocess PrimeKG and Extract Entities from Datasets

## PrimeKG

In [12]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = BGEM3FlagModel("BAAI/bge-m3", use_fp16=True, device=device, trust_remote_code=True)

def create_adj():
    # Create adjacency list
    df = pd.read_csv(PRIMEKG_KG, low_memory=False, encoding="utf-8")
    df = df[["relation", "x_index", "y_index"]]

    adj_list = defaultdict(list)
    for u, v, r in tqdm(zip(df["x_index"].values, df["y_index"].values, df["relation"].values), total=len(df), desc="Creating adjacency list"):
        adj_list[int(u)].append((int(v), str(r)))

    with open(KG_ADJACENCY, "wb") as f:
        pickle.dump(adj_list, f, protocol=5)

def create_disease_features():
    # Create disease feature
    df = pd.read_csv(PRIMEKG_DISEASE, low_memory=False, encoding="utf-8")
    df = df.sort_values("node_index").reset_index(drop=True)

    df["Diseases"] = (
        "[disease name]" + df["mondo_name"].fillna("") + " " +
        "[definition]" + df["mondo_definition"].combine_first(df["orphanet_definition"]).fillna("") + " " +
        "[description]" + df["umls_description"].fillna("")
    )
    df["embed"] = list(batch_encode(df["Diseases"].tolist(), batch_size=64, max_length=8192).cpu().numpy())
    df = df[["node_index", "mondo_name", "Diseases", "embed"]]

    with open(DISEASE_FEATURES, "wb") as f:
        pickle.dump(df, f, protocol=5)

def create_notes_embeddings():
    # Embed notes and save to HDF5
    with h5py.File(NOTES_EMBEDDINGS, "w") as f: # reset file
        pass

    notes_df = pd.read_csv(NOTES_BACH, encoding="utf-8", low_memory=False)
    for idx, row in tqdm(notes_df.iterrows(), total=len(notes_df), desc="Embedding notes and saving to HDF5"):
        patient_id = row["PatientID"]
        text = row["Text"]
        
        with h5py.File(NOTES_EMBEDDINGS, "a") as h5:
            grp = h5.create_group(str(patient_id))
            grp.create_dataset("PatientID", data=np.asarray(patient_id, dtype="int64"))
            grp.create_dataset("Note", data=langchain_chunk_embed(text), compression="gzip")

Fetching 30 files: 100%|██████████| 30/30 [00:00<?, ?it/s]


In [13]:
if not os.path.exists(KG_ADJACENCY):
    create_adj()

if not os.path.exists(DISEASE_FEATURES):
    create_disease_features()

if not os.path.exists(NOTES_EMBEDDINGS):
    create_notes_embeddings()

## Extract Entities

In [14]:
with open(DISEASE_FEATURES, "rb") as f:
    kg = pickle.load(f)
mapping = dict(zip(kg["node_index"], kg["mondo_name"]))

with open(KG_ADJACENCY, "rb") as f:
    adj = pickle.load(f)

notes_emb = {}
with h5py.File(NOTES_EMBEDDINGS, "r") as h5:
    for patient_id in h5.keys():  # each group is named by patient_id
        grp = h5[patient_id]
        pid = int(grp["PatientID"][()])
        embedding = np.array(grp["Note"])
        notes_emb[pid] = embedding
print(len(notes_emb.keys()))
# print(notes_emb[91199].shape) # 768

def _to_float32_array(x):
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().float().numpy()
    if isinstance(x, np.ndarray):
        return x.astype("float32", copy=False)
    raise TypeError(f"Expected tensor/ndarray, got {type(x)}")

def store_patient(h5_path, p, ehr, target, notes, summary):
    with h5py.File(h5_path, "a") as h5:
        grp = h5.create_group(str(p))
        grp.create_dataset("PatientID", data=np.asarray(p, dtype="int64"))
        grp.create_dataset("X", data=ehr, compression="gzip")
        grp.create_dataset("Note", data=_to_float32_array(notes), compression="gzip")
        grp.create_dataset("Summary", data=_to_float32_array(summary), compression="gzip")
        grp.create_dataset("Y", data=np.asarray(target, dtype="int8"))

def create_dataset(mode = "train"):
    if mode not in ["train", "val", "test"]:
        raise ValueError("mode must be 'train', 'val', or 'test'")
    if mode == "train":
        df = pd.read_csv(TRAIN_DRAFT, encoding="utf-8", low_memory=False)
        h5_path = TRAIN
        with h5py.File(TRAIN, "w") as f:
            pass
    elif mode == "val":
        df = pd.read_csv(VAL_DRAFT, encoding="utf-8", low_memory=False)
        h5_path = VAL
        with h5py.File(VAL, "w") as f:
            pass
    else:
        df = pd.read_csv(TEST_DRAFT, encoding="utf-8", low_memory=False)
        h5_path = TEST
        with h5py.File(TEST, "w") as f:
            pass
    
    cat_cols = list(cols_between(
        df,
        "Capillary refill rate->0.0",
        "Glascow coma scale verbal response->3 Inapprop words"
    ))
    num_cols = list(cols_between(df, "Diastolic blood pressure", None))

    col_mean = df[num_cols].mean()
    col_std = df[num_cols].std()

    entities = defaultdict(list)
    for idx, row in tqdm(df.iterrows(), total=len(df), desc=f"Processing {mode} data"):
        record = ""
        PatientID = row["PatientID"]

        if row["Sex"] == 1:
            record += "Gender: Male\n"
        else:
            record += "Gender: Female\n"
        record += f"Age: {row['Age']}\n"
        
        for c in cat_cols:
            if row[c] == 1:
                cat = c
                if "Glascow coma scale total" not in cat:
                    for i in range(0, 30, 1):
                        cat = cat.replace(f"->{i}.0", " : ")
                        cat = cat.replace(f"->{i}", " : ")
                cat = cat.replace("->", " : ")
                entities[PatientID].append(cat)
        
        for c in num_cols:
            if math.isnan(row[c]):
                continue
            z_score = (row[c] - col_mean[c]) / col_std[c]
            if z_score > 2:
                entities[PatientID].append(f"{c} too high")
            elif z_score < -2:
                entities[PatientID].append(f"{c} too low")

    # Match entities to knowledge graph
    patients = list(df["PatientID"].unique())

    def get_summary(p):
        entities[p] = list(set(entities[p]))
        summary_entities = ""
        summary_nodes = ""
        summary_edges = ""
        nodes = []
        for e in entities[p]:
            summary_entities += e + ", "
            idx = cosine_filter(None, e, threshold=0.6, top_k=3)
            nodes.extend(idx)
        summary_entities = summary_entities[:-2]

        nodes = list(set(nodes))
        
        for n in nodes:
            summary_nodes += kg.iloc[n]["Diseases"] + ", "
            node_x = kg.iloc[n]["node_index"]
            for connect_to in adj[n]:
                rela = connect_to[1]
                node_y = connect_to[0]
                if node_y not in kg["node_index"].values:
                    continue
                e = "(" + mapping[node_x] + ", " + str(rela) + ", " + mapping[node_y] + ")"
                # print(e)
                summary_edges += e + ", "
        summary_edges = summary_edges[:-2]
        summary_nodes = summary_nodes[:-2]
        summary_notes = extract_note(notes=notes_emb[p])
        summary = create_summary(summary_entities, summary_notes, summary_nodes, summary_edges)
        return langchain_chunk_embed(summary)

    # summaries = defaultdict(list)
    # for p in tqdm(patients, total=len(patients), desc="Generating summaries"):
    #     summaries[p] = get_summary(p)

    feature_cols = [c for c in df.columns if c not in ["PatientID","Outcome","Readmission"]]
    target_map = df.groupby("PatientID")[["Outcome","Readmission"]].first()

    for p in tqdm(patients, total=len(patients), desc=f"Storing {mode} data to HDF5"):
        data_ehr = df.loc[df["PatientID"] == p, feature_cols].to_numpy()
        data_notes = notes_emb[p]
        data_summary = data_notes  # default to notes if summary fails
        # data_summary = summaries[p]
        outcome, readm = target_map.loc[p].astype(int)
        data_target = (int(outcome), int(readm))
        store_patient(h5_path, p, data_ehr, data_target, data_notes, data_summary)

31027


In [15]:
if not os.path.exists(TRAIN) or not os.path.exists(VAL) or not os.path.exists(TEST):
    create_dataset(mode="train")
    create_dataset(mode="val")
    create_dataset(mode="test")

Processing train data: 100%|██████████| 1042464/1042464 [01:45<00:00, 9883.10it/s] 
Storing train data to HDF5: 100%|██████████| 21718/21718 [46:59<00:00,  7.70it/s]  
Processing val data: 100%|██████████| 148944/148944 [00:17<00:00, 8502.45it/s]
Storing val data to HDF5: 100%|██████████| 3103/3103 [01:55<00:00, 26.98it/s]
Processing test data: 100%|██████████| 297888/297888 [00:45<00:00, 6506.24it/s]
Storing test data to HDF5: 100%|██████████| 6206/6206 [05:12<00:00, 19.88it/s]
