In [9]:
from pyhealth.datasets import MIMIC3Dataset, MIMIC4Dataset
import csv
import pickle
from copy import deepcopy
from collections import defaultdict
import json
import random
from pyhealth.datasets import SampleBaseDataset

In [2]:
import pandas as pd

noteevents_df = pd.read_csv("/mimic-iii-clinical-database-1.4/physionet.org/files/mimiciii/1.4/NOTEEVENTS.csv", low_memory=False)
physician_note = noteevents_df[noteevents_df["CATEGORY"] == "Physician "]


In [3]:
noteevents_df['CATEGORY'].unique()

array(['Discharge summary', 'Echo', 'ECG', 'Nursing', 'Physician ',
       'Rehab Services', 'Case Management ', 'Respiratory ', 'Nutrition',
       'General', 'Social Work', 'Pharmacy', 'Consult', 'Radiology',
       'Nursing/other'], dtype=object)

In [4]:
physician_note

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,CHARTDATE,CHARTTIME,STORETIME,CATEGORY,DESCRIPTION,CGID,ISERROR,TEXT
308694,316237,16605,109285.0,2138-03-21,2138-03-21 15:02:00,2138-03-21 15:03:05,Physician,Physician Resident Progress Note,21203.0,,Chief Complaint:\n 24 Hour Events:\n Continu...
308695,316238,29075,179159.0,2116-02-07,2116-02-07 15:37:00,2116-02-07 15:37:10,Physician,Physician Resident Progress Note,21203.0,,Chief Complaint:\n 24 Hour Events:\n EGD d...
308698,316241,29075,179159.0,2116-02-07,2116-02-07 15:37:00,2116-02-07 16:05:26,Physician,Physician Resident Progress Note,21203.0,,24 Hour Events:\n EGD demonstrated no eviden...
308699,316242,29075,179159.0,2116-02-07,2116-02-07 15:37:00,2116-02-07 16:08:06,Physician,Physician Resident Progress Note,21203.0,,24 Hour Events:\n EGD demonstrated no eviden...
308700,316243,31608,152365.0,2133-01-16,2133-01-16 16:12:00,2133-01-16 16:12:47,Physician,Physician Resident Progress Note,21203.0,,Chief Complaint:\n 24 Hour Events:\n Recei...
...,...,...,...,...,...,...,...,...,...,...,...
2066675,701664,77163,120851.0,2175-10-10,2175-10-10 06:30:00,2175-10-10 06:56:15,Physician,Physician Resident Progress Note,19796.0,,TITLE:\n Chief Complaint:\n 24 Hour Events...
2066676,701670,59113,169374.0,2194-11-04,2194-11-04 07:06:00,2194-11-04 07:06:25,Physician,Physician Resident Progress Note,20449.0,,Chief Complaint:\n 24 Hour Events:\n -fax ...
2066677,701673,72678,134826.0,2169-09-26,2169-09-26 07:14:00,2169-09-26 07:14:26,Physician,Physician Resident Progress Note,16654.0,,TITLE:\n Chief Complaint:\n 24 Hour Events...
2066678,701674,72678,134826.0,2169-09-26,2169-09-26 07:14:00,2169-09-26 07:15:51,Physician,Physician Resident Progress Note,16654.0,,TITLE:\n Chief Complaint:\n 24 Hour Events...


In [5]:
notes = physician_note[physician_note["SUBJECT_ID"] == int(3866)]

if notes.empty:
    print('')

notes = notes.sort_values("CHARTDATE")
combined_note = "\n\n".join(notes["TEXT"].dropna().tolist())




In [7]:
import openai

client = openai.OpenAI(api_key="_") #add api key here

def summary(text, model="gpt-4o-mini", seed=44):
    prompt = f"Summarize the following physician note in 2 concise lines:\n\n{text}"

    response = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": "You are a medical assistant that summarizes discharge notes."},
            {"role": "user", "content": prompt}
        ],
        max_tokens=150,
        temperature=0.5,
        seed=seed
    )
    
    return response.choices[0].message.content.strip()




In [8]:
def get_discharge_notes(visit, physician_note=physician_note):
    hadm_id = visit.visit_id  # This must be present in your visit object

    notes = physician_note[physician_note["HADM_ID"] == int(hadm_id)]

    if notes.empty:
        return ""

    notes = notes.sort_values("CHARTDATE")  # or CHARTTIME if available
    combined_note = "\n\n".join(notes["TEXT"].dropna().tolist())
    summarized_note= summary(combined_note[:128000])
    print(summarized_note)
    
    
    return summarized_note

In [11]:
def load_mappings():
    condition_mapping_file = "./resources/CCSCM.csv"
    procedure_mapping_file = "./resources/CCSPROC.csv"
    drug_file = "./resources/ATC.csv"

    condition_dict = {}
    with open(condition_mapping_file, newline='') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            condition_dict[row['code']] = row['name'].lower()

    procedure_dict = {}
    with open(procedure_mapping_file, newline='') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            procedure_dict[row['code']] = row['name'].lower()

    drug_dict = {}
    with open(drug_file, newline='') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            if row['level'] == '3.0':
                drug_dict[row['code']] = row['name'].lower()

    return condition_dict, procedure_dict, drug_dict


def load_dataset(dataset="mimic3"):
    if dataset == "mimic3":
        ds = MIMIC3Dataset(
        root="/scratch/hht9zt/mimic-iii-clinical-database-1.4/physionet.org/files/mimiciii/1.4/", 
        tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],      
        code_mapping={
            "NDC": ("ATC", {"target_kwargs": {"level": 3}}),
            "ICD9CM": "CCSCM",
            "ICD9PROC": "CCSPROC"
            },
        dev=False,
        refresh_cache=False,
        )
        
    return ds
        

def assign_task(dataset, ds, task="readmission30"):

   
    if task == "readmission30":
        if dataset == "mimic3":
            sample_dataset = ds.set_task(readmission_prediction_mimic3_fn)
        
    return sample_dataset

def expand_and_map(l, dict_):
    if type(l[0]) == list:
        return [dict_[item] for sublist in l for item in sublist]
    if type(l[0]) == str:
        return [dict_[item] for item in l]
    


In [14]:
from collections import defaultdict

def process_dataset(sample_dataset, condition_dict, procedure_dict, drug_dict):
    patient_data = defaultdict(dict)
    patient_data_no_label = defaultdict(dict)
    patient_to_index = sample_dataset.patient_to_index
    
    for patient, idxs in patient_to_index.items():
        for i in range(len(idxs)):
            label = sample_dataset.samples[idxs[i]]['label']
            patient_id = patient + f"_{i}"
            patient_data[patient_id]['label'] = label
            
            for j in range(i+1):
                idx = idxs[j]
                data = sample_dataset.samples[idx]
                
                conditions = expand_and_map(data['conditions'], condition_dict)
                procedures = expand_and_map(data['procedures'], procedure_dict)
                drugs = expand_and_map(data['drugs'], drug_dict)
                
                # Grab discharge notes as string, fallback to empty string if missing
                discharge_notes = data.get('physician_notes', "") or ""
                
                patient_data[patient_id][f'visit {j}'] = {
                    'conditions': conditions,
                    'procedures': procedures,
                    'drugs': drugs,
                    'physician_notes': discharge_notes
                }
                patient_data_no_label[patient_id][f'visit {j}'] = {
                    'conditions': conditions,
                    'procedures': procedures,
                    'drugs': drugs,
                    'physician_notes': discharge_notes
                }
            
    return patient_data, patient_data_no_label


In [13]:
def readmission_prediction_mimic3_fn(patient, time_window=30):
    samples = []

    # we will drop the last visit
    pat_id = patient.patient_id 
    patient = sorted(patient, key=lambda visit: visit.encounter_time)
    
    for i in range(len(patient) - 1):
        visit: Visit = patient[i]
        next_visit: Visit = patient[i + 1]

        # get time difference between current visit and next visit
        time_diff = (next_visit.encounter_time - visit.encounter_time).days
        readmission_label = 1 if time_diff < time_window else 0

        conditions = visit.get_code_list(table="DIAGNOSES_ICD")
        procedures = visit.get_code_list(table="PROCEDURES_ICD")
        drugs = visit.get_code_list(table="PRESCRIPTIONS")
        # exclude: visits without condition, procedure, or drug code
        if len(conditions) * len(procedures) * len(drugs) == 0:
            continue
            
        try:
            discharge_notes = get_discharge_notes(visit)
        except AttributeError:
            discharge_notes = None  # Fallback in case method doesn't exist

        # TODO: should also exclude visit with age < 18
        samples.append(
            {
                "visit_id": visit.visit_id,
                "patient_id": pat_id,
                "conditions": [conditions],
                "procedures": [procedures],
                "drugs": [drugs],
                "physician_notes": discharge_notes,
                "label": readmission_label,
            }
        )
        
    if len(samples) < 1:
        return []
    # no cohort selection
    return samples


In [100]:
def mortality_prediction_mimic3_fn(patient):
    samples = []
    
    pat_id = patient.patient_id 
    patient = sorted(patient, key=lambda visit: visit.encounter_time)
    

    for i in range(len(patient) - 1):
        visit = patient[i]
        next_visit = patient[i + 1]

        if next_visit.discharge_status not in [0, 1]:
            mortality_label = 0
        else:
            mortality_label = int(next_visit.discharge_status)

        conditions = visit.get_code_list(table="DIAGNOSES_ICD")
        procedures = visit.get_code_list(table="PROCEDURES_ICD")
        drugs = visit.get_code_list(table="PRESCRIPTIONS")
        
        try:
            discharge_notes = get_discharge_notes(visit)
        except AttributeError:
            discharge_notes = None  # Fallback in case method doesn't exist

        if len(conditions) * len(procedures) * len(drugs) == 0:
            continue
        
        samples.append(
            {
                "visit_id": visit.visit_id,
                "patient_id": pat_id,
                "conditions": conditions,
                "procedures": procedures,
                "drugs": drugs,
                "physician_notes": discharge_notes,
                "label": mortality_label,
            }
        )
    # print(samples)

    return samples

In [20]:
task = ['readmission30', 'mortality']
dataset = 'mimic3'
base_ds = load_dataset(dataset)

In [24]:
sample_dataset = assign_task(dataset, base_ds, task)


In [23]:
out_dir = "/ehr_prepare"
condition_dict, procedure_dict, drug_dict = load_mappings()
    
for dataset in datasets:
    print(f"Loading dataset: {dataset}")
    ds = load_dataset(dataset)
    for task in tasks:
        base_ds = deepcopy(ds)
        print(f"Assigning task: {task}")
        sample_dataset = assign_task(dataset, base_ds, task)
print(f"Dataset: {dataset}, Task: {task}, Number of samples: {len(sample_dataset)}")
print(f"Saving dataset to {out_dir}/{dataset}_{task}_physician.pkl")
if dataset == "mimic3":
with open(f"{out_dir}/{dataset}_{task}_physician_summary.pkl", "wb") as f:
    pickle.dump(sample_dataset, f)

sample_dataset_path = f"{out_dir}/{dataset}_{task}_physician_summary.pkl"
sample_dataset = pickle.load(open(sample_dataset_path, "rb"))
patient_data, patient_data_no_label = process_dataset(sample_dataset, condition_dict, procedure_dict, drug_dict)
if dataset == "mimic3":
    with open(f"{out_dir}/patient_{dataset}_{task}_physician_summary.json", "w") as f:
        json.dump(patient_data, f, indent=4)

print("Done!")

Done!


In [109]:
sample_dataset.samples[7612]['physician_notes']

'55-year-old male with a history of hypertension and COPD was transferred for worsening gallstone pancreatitis, complicated by respiratory failure, acute renal failure, and hypotension requiring pressor support. He remains critically ill, intubated, on high levels of oxygen and PEEP, with ongoing management for septic shock and renal failure, including CVVH and monitoring for abdominal compartment syndrome.'