# Setting

In [None]:
import  os
import pandas as pd
import numpy as np
import math
#
from utils.bgem3 import cosine_filter, batch_encode
from utils.call_llm import extract_note, create_summary
from utils.clinical_longformer import langchain_chunk_embed, plain_truncate
from utils.constants import *
#
import torch
from FlagEmbedding import BGEM3FlagModel
from concurrent.futures import ThreadPoolExecutor, as_completed
#
from sklearn.model_selection import train_test_split
from collections import defaultdict, Counter
import pickle
import h5py
#
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

# Part 3: Entity Extraction and Create Summary

In [None]:
def cols_between(_df, start_label, end_label=None):
    cols = _df.columns
    start_idx = cols.get_loc(start_label)
    end_idx = len(cols) - 1 if end_label is None else cols.get_loc(end_label)
    if start_idx > end_idx:
        raise ValueError(f"{start_label!r} comes after {end_label!r} in columns")
    return cols[start_idx:end_idx + 1]

with open(DISEASE_FEATURES, "rb") as f:
    kg = pickle.load(f)
mapping = dict(zip(kg["node_index"], kg["mondo_name"]))

with open(KG_ADJACENCY, "rb") as f:
    adj = pickle.load(f)

notes_emb = {}
notes_extract = {}
with h5py.File(NOTES_EMBEDDINGS, "r") as h5:
    for patient_id in h5.keys():  # each group is named by patient_id
        grp = h5[patient_id]
        pid = int(grp["PatientID"][()])
        notes_emb[pid] = np.array(grp["Note"])
        notes_extract[pid] = np.array(grp["Extract"]).astype(str)
# print(len(notes_emb.keys()))
# print(notes_emb[100].shape) # 768

def _to_float32_array(x):
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().float().numpy()
    if isinstance(x, np.ndarray):
        return x.astype("float32", copy=False)
    raise TypeError(f"Expected tensor/ndarray, got {type(x)}")

def store_patient(h5_path, p, ehr, target, notes, summary):
    with h5py.File(h5_path, "a") as h5:
        grp = h5.create_group(str(p))
        grp.create_dataset("PatientID", data=np.asarray(p, dtype="int64"))
        grp.create_dataset("X", data=ehr, compression="gzip")
        grp.create_dataset("Note", data=_to_float32_array(notes), compression="gzip")
        grp.create_dataset("Summary", data=_to_float32_array(summary), compression="gzip")
        grp.create_dataset("Y", data=np.asarray(target, dtype="int8"))

In [None]:
def create_summary_embeddings(mode = "train"):
    if mode not in ["train", "val", "test"]:
        raise ValueError("mode must be 'train', 'val', or 'test'")
    if mode == "train":
        df = pd.read_csv(TRAIN_DRAFT, encoding="utf-8", low_memory=False)
        h5_path = SUMEMB_TRAIN
    elif mode == "val":
        df = pd.read_csv(VAL_DRAFT, encoding="utf-8", low_memory=False)
        h5_path = SUMEMB_VAL
    else:
        df = pd.read_csv(TEST_DRAFT, encoding="utf-8", low_memory=False)
        h5_path = SUMEMB_TEST

    with h5py.File(h5_path, "w") as f:
        pass
    
    cat_cols = list(cols_between(
        df,
        "Capillary refill rate->0.0",
        "Glascow coma scale verbal response->3 Inapprop words"
    ))
    num_cols = list(cols_between(df, "Diastolic blood pressure", None))

    patients = list(df["PatientID"].unique())
    col_mean = df[num_cols].mean()
    col_std = df[num_cols].std()

    preprocess_patients = {}
    for idx, row in tqdm(df.iterrows(), total=len(df), desc=f"Processing {mode} data"):
        PatientID = row["PatientID"]
        entities = []

        # record = ""
        # if row["Sex"] == 1:
        #     record += "Gender: Male\n"
        # else:
        #     record += "Gender: Female\n"
        # record += f"Age: {row['Age']}\n"
        
        for c in cat_cols:
            if row[c] == 1:
                cat = c
                if "Glascow coma scale total" not in cat:
                    for i in range(0, 30, 1):
                        cat = cat.replace(f"->{i}.0", " : ")
                        cat = cat.replace(f"->{i}", " : ")
                cat = cat.replace("->", " : ")
                entities.append(cat)
        
        for c in num_cols:
            if math.isnan(row[c]):
                continue
            z_score = (row[c] - col_mean[c]) / col_std[c]
            if z_score > 2:
                entities.append(f"{c} too high")
            elif z_score < -2:
                entities.append(f"{c} too low")

        entities = list(set(entities))
        summary_entities = ""
        summary_nodes = ""
        summary_edges = ""
        nodes = []
        for e in entities:
            summary_entities += e + ", "
            idx = cosine_filter(None, e, threshold=0.6, top_k=3)
            nodes.extend(idx)
        summary_entities = summary_entities[:-2]

        nodes = list(set(nodes))
        
        for n in nodes:
            summary_nodes += kg.iloc[n]["Diseases"] + ", "
            node_x = kg.iloc[n]["node_index"]
            for connect_to in adj[n]:
                rela = connect_to[1]
                node_y = connect_to[0]
                if node_y not in kg["node_index"].values:
                    continue
                e = "(" + mapping[node_x] + ", " + str(rela) + ", " + mapping[node_y] + ")"
                # print(e)
                summary_edges += e + ", "

        summary_edges = summary_edges[:-2]
        summary_nodes = summary_nodes[:-2]
        summary_notes = notes_extract.get(PatientID, "")

        preprocess_patients[PatientID] = {
            "notes": summary_notes,
            "ehr": summary_entities,
            "nodes": summary_nodes,
            "edges": summary_edges,
        }
    
    def get_summary(p):
        return create_summary(notes=preprocess_patients[p]["notes"],
                              ehr=preprocess_patients[p]["ehr"],
                              nodes=preprocess_patients[p]["nodes"],
                              edges=preprocess_patients[p]["edges"],
                              )

    summaries = {}
    def worker_summary(pid):
        return pid, get_summary(pid)
    progress_bar = tqdm(total=len(patients), desc="Generating summaries", unit="pt")
    with ThreadPoolExecutor(max_workers=8) as executor:
        future_to_pid = {executor.submit(worker_summary, pid): pid for pid in patients}
        for future in as_completed(future_to_pid):
            pid, summary = future.result()
            if summary is not None:
                summaries[pid] = summary
            progress_bar.update(1)
    progress_bar.close()

    text_dtype = h5py.string_dtype(encoding="utf-8")
    with h5py.File(h5_path, "a") as h5:
        for pid, summary in tqdm(summaries.items(), desc="Writing HDF5", unit="pt"):
            gname = str(pid)
            if gname in h5:
                del h5[gname]
            grp = h5.create_group(gname)
            grp.create_dataset("PatientID", data=np.asarray(pid, dtype="int64"))
            grp.create_dataset("SummaryText", data=np.asarray(summary, dtype=object), dtype=text_dtype)
            grp.create_dataset("SummaryEmbedding", data=langchain_chunk_embed(summary) if USE_CHUNKING else plain_truncate(text=summary, max_length=256), compression="gzip")

In [None]:
MAX_WORKERS = 1

def create_summary_embeddings(mode = "train"):
    if mode not in ["train", "val", "test"]:
        raise ValueError("mode must be 'train', 'val', or 'test'")
    if mode == "train":
        df = pd.read_csv(TRAIN_DRAFT, encoding="utf-8", low_memory=False)
        h5_path = SUMEMB_TRAIN
    elif mode == "val":
        df = pd.read_csv(VAL_DRAFT, encoding="utf-8", low_memory=False)
        h5_path = SUMEMB_VAL
    else:
        df = pd.read_csv(TEST_DRAFT, encoding="utf-8", low_memory=False)
        h5_path = SUMEMB_TEST

    with h5py.File(h5_path, "w") as f:
        pass
    
    cat_cols = list(cols_between(
        df,
        "Capillary refill rate->0.0",
        "Glascow coma scale verbal response->3 Inapprop words"
    ))
    num_cols = list(cols_between(df, "Diastolic blood pressure", None))

    col_mean = df[num_cols].mean()
    col_std = df[num_cols].std()

    entities = defaultdict(list)
    for idx, row in tqdm(df.iterrows(), total=len(df), desc=f"Processing {mode} data"):
        PatientID = row["PatientID"]

        # record = ""
        # if row["Sex"] == 1:
        #     record += "Gender: Male\n"
        # else:
        #     record += "Gender: Female\n"
        # record += f"Age: {row['Age']}\n"
        
        for c in cat_cols:
            if row[c] == 1:
                cat = c
                if "Glascow coma scale total" not in cat:
                    for i in range(0, 30, 1):
                        cat = cat.replace(f"->{i}.0", " : ")
                        cat = cat.replace(f"->{i}", " : ")
                cat = cat.replace("->", " : ")
                entities[PatientID].append(cat)
        
        for c in num_cols:
            if math.isnan(row[c]):
                continue
            z_score = (row[c] - col_mean[c]) / col_std[c]
            if z_score > 2:
                entities[PatientID].append(f"{c} too high")
            elif z_score < -2:
                entities[PatientID].append(f"{c} too low")

    # Match entities to knowledge graph
    patients = list(df["PatientID"].unique())

    def get_summary(p):
        entities[p] = list(set(entities[p]))
        summary_entities = ""
        summary_nodes = ""
        summary_edges = ""
        nodes = []
        for e in entities[p]:
            summary_entities += e + ", "
            idx = cosine_filter(None, e, threshold=0.6, top_k=3)
            nodes.extend(idx)
        summary_entities = summary_entities[:-2]

        nodes = list(set(nodes))
        
        for n in nodes:
            summary_nodes += kg.iloc[n]["Diseases"] + ", "
            node_x = kg.iloc[n]["node_index"]
            for connect_to in adj[n]:
                rela = connect_to[1]
                node_y = connect_to[0]
                if node_y not in kg["node_index"].values:
                    continue
                e = "(" + mapping[node_x] + ", " + str(rela) + ", " + mapping[node_y] + ")"
                # print(e)
                summary_edges += e + ", "

        summary_edges = summary_edges[:-2]
        summary_nodes = summary_nodes[:-2]
        summary_notes = notes_extract.get(p, "")

        return create_summary(notes=summary_notes,
                              ehr=summary_entities,
                              nodes=summary_nodes,
                              edges=summary_edges,
                              )

    summaries = {}
    def worker_summary(pid):
        return pid, get_summary(pid)
    progress_bar = tqdm(total=len(patients), desc="Generating summaries", unit="pt")
    with ThreadPoolExecutor(max_workers=8) as executor:
        future_to_pid = {executor.submit(worker_summary, pid): pid for pid in patients}
        for future in as_completed(future_to_pid):
            pid, summary = future.result()
            if summary is not None:
                summaries[pid] = summary
            progress_bar.update(1)
    progress_bar.close()

    text_dtype = h5py.string_dtype(encoding="utf-8")
    with h5py.File(h5_path, "a") as h5:
        for pid, summary in tqdm(summaries.items(), desc="Writing HDF5", unit="pt"):
            gname = str(pid)
            if gname in h5:
                del h5[gname]
            grp = h5.create_group(gname)
            grp.create_dataset("PatientID", data=np.asarray(pid, dtype="int64"))
            grp.create_dataset("SummaryText", data=np.asarray(summary, dtype=object), dtype=text_dtype)
            grp.create_dataset("SummaryEmbedding", data=langchain_chunk_embed(summary) if USE_CHUNKING else plain_truncate(text=summary, max_length=256), compression="gzip")

In [None]:
if not os.path.exists(SUMEMB_TRAIN) or not os.path.exists(SUMEMB_VAL) or not os.path.exists(SUMEMB_TEST):
    create_summary_embeddings("train")
    create_summary_embeddings("val")
    create_summary_embeddings("test")

In [None]:
summaries = {}
def load_summary_embeddings(h5_path):
    with h5py.File(h5_path, "r") as h5:
        for gname, grp in h5.items():
            summaries[int(gname)] = {
                "PatientID": int(grp["PatientID"][()]),
                "SummaryText": grp["SummaryText"].asstr()[()],
                "SummaryEmbedding": grp["SummaryEmbedding"][()]
            }

def create_dataset(mode = "train"):
    if mode not in ["train", "val", "test"]:
        raise ValueError("mode must be 'train', 'val', or 'test'")
    if mode == "train":
        df = pd.read_csv(TRAIN_DRAFT, encoding="utf-8", low_memory=False)
        h5_path = TRAIN
        load_summary_embeddings(SUMEMB_TRAIN)
    elif mode == "val":
        df = pd.read_csv(VAL_DRAFT, encoding="utf-8", low_memory=False)
        h5_path = VAL
        load_summary_embeddings(SUMEMB_VAL)
    else:
        df = pd.read_csv(TEST_DRAFT, encoding="utf-8", low_memory=False)
        h5_path = TEST
        load_summary_embeddings(SUMEMB_TEST)

    with h5py.File(h5_path, "w") as f:
        pass

    patients = list(df["PatientID"].unique())
    feature_cols = [c for c in df.columns if c not in ["PatientID","Outcome","Readmission"]]
    target_map = df.groupby("PatientID")[["Outcome","Readmission"]].first()

    for p in tqdm(patients, total=len(patients), desc=f"Storing {mode} data to HDF5"):
        data_ehr = df.loc[df["PatientID"] == p, feature_cols].to_numpy()
        data_notes = notes_emb[p]
        data_summary = data_notes # fallback to notes embedding if summary missing
        # data_summary = summaries[p]["SummaryEmbedding"]
        outcome, readm = target_map.loc[p].astype(int)
        data_target = (int(outcome), int(readm))
        store_patient(h5_path, p, data_ehr, data_target, data_notes, data_summary)

In [None]:
if not os.path.exists(TRAIN) or not os.path.exists(VAL) or not os.path.exists(TEST):
    create_dataset(mode="train")
    create_dataset(mode="val")
    create_dataset(mode="test")