# Setting

In [None]:
import  os
import pandas as pd
import numpy as np
import math
#
from utils.bgem3 import cosine_filter, batch_encode
from utils.call_llm import extract_note, create_summary
from utils.clinical_longformer import langchain_chunk_embed, plain_truncate
from utils.constants import *
#
import torch
from FlagEmbedding import BGEM3FlagModel
from concurrent.futures import ThreadPoolExecutor, as_completed
#
from sklearn.model_selection import train_test_split
from collections import defaultdict, Counter
import pickle
import h5py
from contextlib import ExitStack
#
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

def cols_between(_df, start_label, end_label=None):
    cols = _df.columns
    start_idx = cols.get_loc(start_label)
    end_idx = len(cols) - 1 if end_label is None else cols.get_loc(end_label)
    if start_idx > end_idx:
        raise ValueError(f"{start_label!r} comes after {end_label!r} in columns")
    return cols[start_idx:end_idx + 1]

def _to_float32_array(x):
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().float().numpy()
    if isinstance(x, np.ndarray):
        return x.astype("float32", copy=False)
    raise TypeError(f"Expected tensor/ndarray, got {type(x)}")

# Part 3: Entity Extraction and Create Summary

In [None]:
with open(DISEASE_FEATURES, "rb") as f:
    kg = pickle.load(f)
mapping_disease = dict(zip(kg["node_index"], kg["mondo_name"]))

with open(KG_ADJACENCY, "rb") as f:
    adj = pickle.load(f)

with open(NOTES_EXTRACTS, "rb") as f:
    extracts = pickle.load(f)
notes_extract = dict(zip(extracts["PatientID"], extracts["Extract"]))

notes_emb = {}
with h5py.File(NOTES_EMBEDDINGS, "r") as h5:
    for patient_id in h5.keys():  # each group is named by patient_id
        grp = h5[patient_id]
        pid = int(grp["PatientID"][()])
        notes_emb[pid] = np.array(grp["Note"])
# print(len(notes_emb.keys()))
# print(notes_emb[100].shape) # 768

def store_patient(h5_path, p, ehr, target, notes, summary):
    with h5py.File(h5_path, "a") as h5:
        grp = h5.create_group(str(p))
        grp.create_dataset("PatientID", data=np.asarray(p, dtype="int64"))
        grp.create_dataset("X", data=ehr, compression="gzip")
        grp.create_dataset("Note", data=_to_float32_array(notes), compression="gzip")
        grp.create_dataset("Summary", data=_to_float32_array(summary), compression="gzip")
        grp.create_dataset("Y", data=np.asarray(target, dtype="int8"))

In [None]:
entities = defaultdict(list)
summary_entities = defaultdict(str)
summary_nodes = defaultdict(str)
summary_edges = defaultdict(str)

def offload_computations(mode):
    df = pd.read_csv(TRAIN_DRAFT, encoding="utf-8", low_memory=False) if mode == "train" else (
         pd.read_csv(VAL_DRAFT, encoding="utf-8", low_memory=False) if mode == "val" else
         pd.read_csv(TEST_DRAFT, encoding="utf-8", low_memory=False)
    )
    
    cat_cols = list(cols_between(
        df,
        "Capillary refill rate->0.0",
        "Glascow coma scale verbal response->3 Inapprop words"
    ))
    num_cols = list(cols_between(df, "Diastolic blood pressure", None))

    patients = list(df["PatientID"].unique())
    col_mean = df[num_cols].mean()
    col_std = df[num_cols].std()

    # Define a worker function to process a patient's data
    def process_patient(patient_id):
        patient_entities = []
        patient_rows = df[df["PatientID"] == patient_id]
        
        for _, row in patient_rows.iterrows():
            # Process categorical columns
            for c in cat_cols:
                if row[c] == 1:
                    cat = c
                    if "Glascow coma scale total" not in cat:
                        for i in range(0, 30, 1):
                            cat = cat.replace(f"->{i}.0", " : ")
                            cat = cat.replace(f"->{i}", " : ")
                    cat = cat.replace("->", " : ")
                    patient_entities.append(cat)
            
            # Process numerical columns
            for c in num_cols:
                if math.isnan(row[c]):
                    continue
                z_score = (row[c] - col_mean[c]) / col_std[c]
                if z_score > 2:
                    patient_entities.append(f"{c} too high")
                elif z_score < -2:
                    patient_entities.append(f"{c} too low")
        
        # Remove duplicates
        patient_entities = list(set(patient_entities))
        
        # Create summary strings and find nodes
        nodes = []
        summary_entity = ""
        for e in patient_entities:
            summary_entity += e + ", "
            idx = cosine_filter(query=e, threshold=0.6, top_k=3)
            nodes.extend(idx)
        summary_entity = summary_entity[:-2] if summary_entity else ""
        
        # Remove duplicate nodes
        nodes = list(set(nodes))
        
        # Process nodes and edges
        summary_node = ""
        summary_edge = ""
        for n in nodes:
            summary_node += kg.iloc[n]["Diseases"] + ", "
            node_x = kg.iloc[n]["node_index"]
            for connect_to in adj[n]:
                rela = connect_to[1]
                node_y = connect_to[0]
                if node_y not in kg["node_index"].values:
                    continue
                e = "(" + mapping_disease[node_x] + ", " + str(rela) + ", " + mapping_disease[node_y] + ")"
                summary_edge += e + ", "
        
        summary_node = summary_node[:-2] if summary_node else ""
        summary_edge = summary_edge[:-2] if summary_edge else ""
        
        return patient_id, {
            "entities": patient_entities,
            "summary_entity": summary_entity,
            "summary_node": summary_node,
            "summary_edge": summary_edge
        }
    
    with ThreadPoolExecutor(max_workers=10) as executor:
        results = list(tqdm(
            executor.map(process_patient, patients),
            total=len(patients),
            desc=f"Processing {mode} data"
        ))

    # Collect results
    for patient_id, result in results:
        entities[patient_id] = result["entities"]
        summary_entities[patient_id] = result["summary_entity"]
        summary_nodes[patient_id] = result["summary_node"] 
        summary_edges[patient_id] = result["summary_edge"]

    return patients

def create_summary_embeddings():
    patients_train = offload_computations("train")
    patients_val = offload_computations("val")
    patients_test = offload_computations("test")
    patients = patients_train + patients_val + patients_test

    def summary_for_pid(pid):
        return pid, create_summary(
            notes=notes_extract[pid],
            ehr=summary_entities[pid],
            nodes=summary_nodes[pid],
            edges=summary_edges[pid],
        )
    with ThreadPoolExecutor(max_workers=10) as ex:
        summaries = dict(
            tqdm(ex.map(summary_for_pid, patients),
                total=len(patients),
                desc="Generating summaries",
                unit="pt")
        )

    with h5py.File(SUMMARIES_EMBEDDINGS, "w") as f:
        pass

    text_dtype = h5py.string_dtype(encoding="utf-8")
    with h5py.File(SUMMARIES_EMBEDDINGS, "a") as h5:
        for pid, summary in tqdm(summaries.items(), desc="Writing HDF5", unit="pt"):
            gname = str(pid)
            if gname in h5:
                del h5[gname]
            grp = h5.create_group(gname)
            grp.create_dataset("PatientID", data=np.asarray(pid, dtype="int64"))
            grp.create_dataset("SummaryText", data=summary, dtype=text_dtype)
            embedding = (
                langchain_chunk_embed(summary) if USE_CHUNKING
                else plain_truncate(text=summary, max_length=256)
            )
            grp.create_dataset("SummaryEmbedding", data=embedding, compression="gzip")

In [None]:
if not os.path.exists(SUMMARIES_EMBEDDINGS):
    create_summary_embeddings()

In [None]:
summaries = {}
with h5py.File(SUMMARIES_EMBEDDINGS, "r") as h5:
    for gname, grp in h5.items():
        summaries[int(gname)] = {
            "PatientID": int(grp["PatientID"][()]),
            "SummaryText": grp["SummaryText"].asstr()[()],
            "SummaryEmbedding": grp["SummaryEmbedding"][()]
        }

def create_dataset(mode = "train"):
    if mode not in ["train", "val", "test"]:
        raise ValueError("mode must be 'train', 'val', or 'test'")
    if mode == "train":
        df = pd.read_csv(TRAIN_DRAFT, encoding="utf-8", low_memory=False)
        h5_path = TRAIN
    elif mode == "val":
        df = pd.read_csv(VAL_DRAFT, encoding="utf-8", low_memory=False)
        h5_path = VAL
    else:
        df = pd.read_csv(TEST_DRAFT, encoding="utf-8", low_memory=False)
        h5_path = TEST

    with h5py.File(h5_path, "w") as f:
        pass

    patients = list(df["PatientID"].unique())
    feature_cols = [c for c in df.columns if c not in ["PatientID","Outcome","Readmission"]]
    target_map = df.groupby("PatientID")[["Outcome","Readmission"]].first()

    for p in tqdm(patients, total=len(patients), desc=f"Storing {mode} data to HDF5"):
        data_ehr = df.loc[df["PatientID"] == p, feature_cols].to_numpy()
        data_notes = notes_emb[p]
        data_summary = summaries[p]["SummaryEmbedding"]
        outcome, readm = target_map.loc[p].astype(int)
        data_target = (int(outcome), int(readm))
        store_patient(h5_path, p, data_ehr, data_target, data_notes, data_summary)

In [None]:
if not os.path.exists(TRAIN) or not os.path.exists(VAL) or not os.path.exists(TEST):
    create_dataset(mode="train")
    create_dataset(mode="val")
    create_dataset(mode="test")