## Setup

In [None]:
%%capture

from datetime import datetime
from typing import Any, Dict, List, Optional, Union, Tuple, ClassVar
import os
import polars as pl

# PyHealth Packages
from pyhealth.datasets import MIMIC4Dataset
from pyhealth.tasks.multimodal_mimic4 import ClinicalNotesMIMIC4, ClinicalNotesICDLabsMIMIC4
from pyhealth.tasks.base_task import BaseTask

# Load MIMIC4 Files
# There's probably better ways dealing with this on the cluster, but working locally for now 
# (see: https://github.com/sunlabuiuc/PyHealth/blob/master/examples/mortality_prediction/multimodal_mimic4_minimal.py)

TASK = "ClinicalNotesICDLabsMIMIC4" # The idea here is that we want additive tasks so we can evaluate the value in adding more modalities

PYHEALTH_REPO_ROOT = '/Users/wpang/Desktop/PyHealth'

EHR_ROOT = os.path.join(PYHEALTH_REPO_ROOT, "local_data/local/data/physionet.org/files/mimiciv/2.2")
NOTE_ROOT = os.path.join(PYHEALTH_REPO_ROOT, "local_data/local/data/physionet.org/files/mimic-iv-note/2.2")
CXR_ROOT = os.path.join(PYHEALTH_REPO_ROOT,"local_data/local/data/physionet.org/files/mimic-cxr-jpg/2.0.0")
CACHE_DIR = os.path.join(PYHEALTH_REPO_ROOT,"local_data/local/data/wp/pyhealth_cache")


if TASK == "ClinicalNotesMIMIC4": # A bit janky setup at the moment and open to iteration, but conveys the point for now
    dataset = MIMIC4Dataset(
            ehr_root=EHR_ROOT,
            note_root=NOTE_ROOT,
            ehr_tables=["diagnoses_icd", "procedures_icd", "prescriptions", "labevents"],
            note_tables=["discharge", "radiology"],
            cache_dir=CACHE_DIR,
            num_workers=8,
            dev=True
        )
    
    # Apply multimodal task
    task = ClinicalNotesMIMIC4() 
    samples = dataset.set_task(task)

    # Get and print sample
    sample = samples[0]
    print(sample)

elif TASK == 'ClinicalNotesICDLabsMIMIC4':
    dataset = MIMIC4Dataset(
            ehr_root=EHR_ROOT,
            note_root=NOTE_ROOT,
            ehr_tables=["diagnoses_icd", "procedures_icd", "prescriptions", "labevents"],
            note_tables=["discharge", "radiology"],
            cache_dir=CACHE_DIR,
            num_workers=8,
            dev=True
        )

In [None]:
# dataset._unique_patient_ids[:5]
ID = "10095258"

In [None]:
# Apply multimodal task
task = ClinicalNotesICDLabsMIMIC4() 

# Single patient
patient = dataset.get_patient(ID)  
samples = task(patient)

## Radiology Notes Preview

In [None]:
TYPE = "radiology"
#TYPE = "discharge"
NOTE = "radiology_notes"
# NOTE = "discharge_notes"
HADM_ID = '22880743'

CHARACTERS_PREVIEW = 0

In [None]:
# %%capture

print("----")
print("Admission IDs (hadm_id)")
admission_ids = []
for index, content in enumerate(patient.get_events(event_type="admissions")):
    print(f"{content.attr_dict['hadm_id']} -> Admission Time: {content.timestamp}")
    admission_ids.append(content.attr_dict['hadm_id'])

In [None]:
%%capture

print("----")
print(f"Count of {TYPE} notes for hadm_id: {HADM_ID}")
print(len(patient.get_events(
                event_type=TYPE, filters=[("hadm_id", "==", HADM_ID)])))
print("----")
print(f"Note ID for {TYPE} notes for hadm_id: {HADM_ID}")
for index, content in enumerate(patient.get_events(event_type=TYPE, filters=[("hadm_id", "==", HADM_ID)])):
    print(f"{content.attr_dict['note_id']} -> Note Timestamp: {content.timestamp} -> First 100 Characters: {content.text[:100]}")
print("----")

In [None]:
for admission_id in admission_ids:
    print(f"{NOTE}: Admission ID: {admission_id}")
    print("----")
    for index, content in enumerate(patient.get_events(event_type=TYPE, filters=[("hadm_id", "==", admission_id)])):
        print(f"{content.attr_dict['note_id']} -> Note Timestamp: {content.timestamp} -> First {CHARACTERS_PREVIEW} Characters: {content.text[:CHARACTERS_PREVIEW]}")
        print("\n\n")

## ICD Codes

In [None]:
EVENT_TYPE = 'procedures_icd'

In [None]:
for admission_id in admission_ids:
    print(f"{NOTE}: Admission ID: {admission_id}")
    print("----")
    for index, content in enumerate(patient.get_events(event_type=EVENT_TYPE, filters=[("hadm_id", "==", admission_id)])):
        # print(f"{content.attr_dict['note_id']} -> Note Timestamp: {content.timestamp} -> First {CHARACTERS_PREVIEW} Characters: {content.text[:CHARACTERS_PREVIEW]}")
        # print("\n\n")
        print(content)
        print("\n\n")

## Lab Events

In [None]:
LAB_CATEGORIES: ClassVar[Dict[str, List[str]]] = {
        "Sodium": ["50824", "52455", "50983", "52623"],
        "Potassium": ["50822", "52452", "50971", "52610"],
        "Chloride": ["50806", "52434", "50902", "52535"],
        "Bicarbonate": ["50803", "50804"],
        "Glucose": ["50809", "52027", "50931", "52569"],
        "Calcium": ["50808", "51624"],
        "Magnesium": ["50960"],
        "Anion Gap": ["50868", "52500"],
        "Osmolality": ["52031", "50964", "51701"],
        "Phosphate": ["50970"],
    }

LAB_CATEGORY_NAMES: ClassVar[List[str]] = [
    "Sodium", "Potassium", "Chloride", "Bicarbonate", "Glucose",
    "Calcium", "Magnesium", "Anion Gap", "Osmolality", "Phosphate",
]

LABITEMS: ClassVar[List[str]] = [
    item for itemids in LAB_CATEGORIES.values() for item in itemids
]

all_lab_times = []
lab_vector = []
all_lab_values = []

TOKEN_REPRESENTING_MISSING_FLOAT = float("nan")

In [None]:
admission_infos = []

for index, content in enumerate(patient.get_events(event_type="admissions")):
    admission_infos.append((
        content.attr_dict['hadm_id'],
        content['timestamp'],
        content.attr_dict['dischtime']
    ))

# admission_infos is a 3-tuple: (admission_id, admission_timestamp, discharge_timestamp)

In [None]:
for admission_info in admission_infos:
    print(f"{NOTE}: Admission ID: {admission_info[0]}")
    print("----")
    admission_time = admission_info[1]
    admission_dischtime = datetime.strptime(
                    admission_info[2], "%Y-%m-%d %H:%M:%S"
                )
    
    labevents_df = patient.get_events(
                event_type="labevents",
                start=admission_time,
                end=admission_dischtime,
                return_df=True,
            )
    labevents_df = labevents_df.filter(
                pl.col("labevents/itemid").is_in(LABITEMS)
            )
    display(labevents_df)
    
    if labevents_df.height > 0:
        labevents_df = labevents_df.with_columns(
            pl.col("labevents/storetime").str.strptime(pl.Datetime, "%Y-%m-%d %H:%M:%S")
        )
        labevents_df = labevents_df.filter(
            pl.col("labevents/storetime") <= admission_dischtime
        )
        if labevents_df.height > 0:
            print(f"Lab Events Height: {labevents_df.height}")
            labevents_df = labevents_df.select(
                pl.col("timestamp"),
                pl.col("labevents/itemid"),
                pl.col("labevents/valuenum").cast(pl.Float64),
            )
            for lab_ts in sorted(labevents_df["timestamp"].unique().to_list()):
                ts_labs = labevents_df.filter(pl.col("timestamp") == lab_ts)
                lab_vector: List[Any] = []
                for category_name in LAB_CATEGORY_NAMES:
                    category_value = TOKEN_REPRESENTING_MISSING_FLOAT
                    for itemid in LAB_CATEGORIES[category_name]:
                        matching = ts_labs.filter(pl.col("labevents/itemid") == itemid)
                        if matching.height > 0:
                            category_value = matching["labevents/valuenum"][0]
                            break
                    lab_vector.append(category_value)
                all_lab_values.append(lab_vector)
                all_lab_times.append((lab_ts - admission_time).total_seconds() / 3600.0)
    break

In [None]:
all_lab_times

In [None]:
lab_vector