In [1]:
# 1, covert the old patient database and labels to the new femr model format
# 2, structure of the dataset is: 
# {patient_id, sequence_meidcal_tokens, sequence_times_tokens, {task_name_1: {time_point_1, time_point_2, ...}, task_name_2: {time_point_1, time_point_2, ...}, ...}}

In [2]:
import pandas as pd
import numpy as np
import os
import json
from femr.datasets import PatientDatabase
import datetime
from tqdm import tqdm
from bisect import bisect_right, bisect_left
from ast import literal_eval

In [3]:
task_name_list =[
    "guo_los",
    "guo_readmission",
    "guo_icu",
    "new_hypertension",
    "new_hyperlipidemia",
    "new_pancan",
    "new_celiac",
    "new_lupus",
    "new_acutemi",
    "lab_thrombocytopenia",
    "lab_hyperkalemia",
    "lab_hyponatremia",
    "lab_anemia",
    "lab_hypoglycemia"
]
# 1, load the femr medical tokens, only keep the tokens that are with in the femr tokenizer
def load_medical_tokens(femr_token_dict_path = 'femr_vocab.json'):
    with open(femr_token_dict_path, 'r') as f:
        femr_dict = json.load(f)
    return femr_dict

# 2, load the patient database
def configure_database(path_to_database = "EHRSHOT_ASSETS/femr/extract"):
    return PatientDatabase(path_to_database)

# 3, load the patient labels and time
def load_patient_labels(path_to_labels_dir = "EHRSHOT_ASSETS/benchmark", task_dir = 'guo_los', label_file = 'labeled_patients.csv'):
    path_to_labels = os.path.join(path_to_labels_dir, task_dir, label_file)
    df_labels = pd.read_csv(path_to_labels)
    return df_labels

# 4, get the unique patient ids from a single label files. 
def get_unique_patient_ids(path_to_labels_dir = 'EHRSHOT_ASSETS/benchmark', task_dir = 'guo_los', label_file = 'labeled_patients.csv'):
    path_to_labels = os.path.join(path_to_labels_dir, task_dir, label_file)
    df_labels = pd.read_csv(path_to_labels)
    return list(df_labels['patient_id'].unique())

# 5, get the unique patient ids from all the label files.
def get_unique_patient_ids_all_tasks(path_to_labels_dir = 'EHRSHOT_ASSETS/benchmark', task_name_list = task_name_list, label_file = 'labeled_patients.csv'):
    patient_ids = []
    for task_name in task_name_list:
        current_patient_ids = get_unique_patient_ids(path_to_labels_dir, task_name, label_file)
        patient_ids.extend(current_patient_ids)
    return list(set(patient_ids))

# 6, given a patient id, extract the events, filter all the medical sequence that are in the femr tokenizer
def filter_medical_events(patient_id, database, femr_dict):
    patient_info = database[patient_id]
    assert patient_info.patient_id == patient_id
    event_list = database[patient_id].events
    patient_birthdate = database.get_patient_birth_date(patient_id)
    patient_birthdate = patient_birthdate.isoformat()

    medical_code_list = []
    time_list = []
    omop_person_info_list = []
    omop_time_info_list = []

    for i in range(len(event_list)):
        current_event = event_list[i]
        start_date = current_event.start
        start_date = start_date.isoformat()
        medical_code = current_event.code
        omop_table = current_event.omop_table
        if medical_code in femr_dict:
            medical_code_list.append(medical_code)
            time_list.append(start_date)
            assert start_date is not None
            assert medical_code is not None
        if omop_table == 'person':
            omop_person_info_list.append(medical_code)
            omop_time_info_list.append(start_date)
    
    patient = {}
    # patient['patient_id'] = patient_id
    patient['birth_date'] = patient_birthdate
    patient['medical_tokens'] = medical_code_list
    patient['time_tokens'] = time_list
    patient['person_info_tokens'] = omop_person_info_list
    patient['person_info_time_tokens'] = omop_time_info_list
    
    return patient

# 7, given a list of patient id, construct the dataset, return a dataframe with
def construct_dataset(patient_id_list, database, femr_dict, return_dataframe = True):
    dataset = {}
    for patient_id in tqdm(patient_id_list):
        patient_info = filter_medical_events(patient_id, database, femr_dict)
        dataset[patient_id] = patient_info
    if return_dataframe:
        dataset = pd.DataFrame(dataset).T
        dataset['patient_id'] = dataset.index
        dataset = dataset.reset_index(drop=True)
        return dataset
    else:
        return dataset

# 8, save the dataset with only patient id and 
def save_patient_dataset(patient_dataset, save_dir = 'new_femr_dataset/patient_info/'):
    for patient_id in tqdm(patient_dataset.keys()):
        patient_info = patient_dataset[patient_id]
        save_path = os.path.join(save_dir, str(patient_id) + '.json')
        with open(save_path, 'w') as f:
            json.dump(patient_info, f)

# 9, given a datatime string with format 
def convert_datetime_string_to_timestamp(datetime_string):
    date_format = '%Y-%m-%d %H:%M:%S'
    return datetime.datetime.strptime(datetime_string, date_format)

# 10, given patient's dataset, add the labels and the corresponding time points of a single medical task
def add_labels_to_dataset(dataset, path_to_labels_dir = 'EHRSHOT_ASSETS/benchmark', task_dir = 'guo_los', label_file = 'labeled_patients.csv'):
    path_to_labels = os.path.join(path_to_labels_dir, task_dir, label_file)
    df_labels = pd.read_csv(path_to_labels)
    for patient_id in tqdm(dataset.keys()):
        patient_info = dataset[patient_id]
        patient_info[task_dir] = {}

# 11, for a single task of a patient, given the input label dataframe of the patient, return the labe


In [4]:
femr_dict = load_medical_tokens()
database = configure_database("EHRSHOT_ASSETS/femr/extract")

# label_file = load_patient_labels(task_dir=task_name_list[0])
# patient_id_list = get_unique_patient_ids_all_tasks()
# patient_dataset = construct_dataset(patient_id_list, database, femr_dict, return_dataframe=False)
# save_patient_dataset(patient_dataset)

# df = pd.read_csv('patient_dataset.csv', converters={'medical_tokens': literal_eval, 'time_tokens': literal_eval, 'person_info_tokens': literal_eval, 'person_info_time_tokens': literal_eval})

In [6]:
len(os.listdir('new_femr_dataset/patient_info/'))
database[115967095].events

(Event(start=1933-12-22 00:00:00, code=SNOMED/3950001, value=None, =110267, omop_table=person),
 Event(start=1933-12-22 23:59:00, code=Race/5, value=None, =80, omop_table=person),
 Event(start=1933-12-22 23:59:00, code=Gender/F, value=None, =149, omop_table=person),
 Event(start=1933-12-22 23:59:00, code=Ethnicity/Hispanic, value=None, =161558, omop_table=person),
 Event(start=2008-10-07 23:40:00, code=LOINC/8480-6, value=116.0, =427208, omop_table=measurement),
 Event(start=2008-10-07 23:40:00, code=LOINC/8462-4, value=58.0, =299869, omop_table=measurement),
 Event(start=2008-10-07 23:40:00, code=LOINC/8310-5, value=98.5999984741211, =510293, omop_table=measurement),
 Event(start=2008-10-07 23:40:00, code=LOINC/9279-1, value=16.0, =247313, omop_table=measurement),
 Event(start=2008-10-07 23:40:00, code=LOINC/8327-9, value=1.0, =200894, omop_table=observation),
 Event(start=2008-10-07 23:40:00, code=LOINC/8867-4, value=88.0, =362967, omop_table=measurement),
 Event(start=2008-10-07 23:

In [4]:
a = patient_dataset[115967095]

In [33]:
d = patient_dataset[115967096]['person_info_time_tokens'][0]
print(d)
d_i = d.isoformat()
print(d_i)
d_j = datetime.datetime.fromisoformat(d_i)
assert d == d_j

1940-06-24 00:00:00
1940-06-24T00:00:00


In [5]:
with open('new_femr_dataset/temp.json', 'w') as f:
    json.dump(a, f)

In [6]:
patient_dataset.head(5)

Unnamed: 0,birth_date,medical_tokens,time_tokens,person_info_tokens,person_info_time_tokens,patient_id
0,1933-12-22,"['Race/5', 'Gender/F', 'Ethnicity/Hispanic', '...","[datetime.datetime(1933, 12, 22, 23, 59), date...","['SNOMED/3950001', 'Race/5', 'Gender/F', 'Ethn...","[datetime.datetime(1933, 12, 22, 0, 0), dateti...",115967095
1,1940-06-24,"['Race/5', 'Gender/F', 'Ethnicity/Not Hispanic...","[datetime.datetime(1940, 6, 24, 23, 59), datet...","['SNOMED/3950001', 'Race/5', 'Gender/F', 'Ethn...","[datetime.datetime(1940, 6, 24, 0, 0), datetim...",115967096
2,1957-08-10,"['Race/5', 'Gender/F', 'Ethnicity/Not Hispanic...","[datetime.datetime(1957, 8, 10, 23, 59), datet...","['SNOMED/3950001', 'Race/5', 'Gender/F', 'Ethn...","[datetime.datetime(1957, 8, 10, 0, 0), datetim...",115967097
3,1943-12-15,"['Race/5', 'Gender/F', 'Ethnicity/Not Hispanic...","[datetime.datetime(1943, 12, 15, 23, 59), date...","['SNOMED/3950001', 'Race/5', 'Gender/F', 'Ethn...","[datetime.datetime(1943, 12, 15, 0, 0), dateti...",115967098
4,1952-03-03,"['Race/5', 'Gender/F', 'Ethnicity/Not Hispanic...","[datetime.datetime(1952, 3, 3, 23, 59), dateti...","['SNOMED/3950001', 'Race/5', 'Gender/F', 'Ethn...","[datetime.datetime(1952, 3, 3, 0, 0), datetime...",115967099


In [19]:
patient_dataset.loc[0]['time_tokens'], type(patient_dataset.loc[0]['time_tokens'][0])

('[datetime.datetime(1933, 12, 22, 23, 59), datetime.datetime(1933, 12, 22, 23, 59), datetime.datetime(1933, 12, 22, 23, 59), datetime.datetime(2008, 10, 7, 23, 40), datetime.datetime(2008, 10, 7, 23, 40), datetime.datetime(2008, 10, 7, 23, 40), datetime.datetime(2008, 10, 7, 23, 40), datetime.datetime(2008, 10, 7, 23, 40), datetime.datetime(2008, 10, 7, 23, 40), datetime.datetime(2008, 10, 7, 23, 40), datetime.datetime(2008, 10, 7, 23, 40), datetime.datetime(2008, 10, 7, 23, 50), datetime.datetime(2008, 10, 7, 23, 56), datetime.datetime(2008, 10, 7, 23, 56), datetime.datetime(2008, 10, 7, 23, 56), datetime.datetime(2008, 10, 7, 23, 59), datetime.datetime(2008, 10, 7, 23, 59), datetime.datetime(2008, 10, 7, 23, 59), datetime.datetime(2008, 10, 7, 23, 59), datetime.datetime(2008, 10, 7, 23, 59), datetime.datetime(2008, 10, 7, 23, 59), datetime.datetime(2008, 10, 7, 23, 59), datetime.datetime(2008, 10, 7, 23, 59), datetime.datetime(2008, 10, 7, 23, 59), datetime.datetime(2008, 10, 8, 0, 

In [25]:
path_to_database = 'EHRSHOT_ASSETS/femr/extract'
path_to_dataset = 'patient_dataset.csv'
path_to_medical_tokens = 'femr_vocab.json'
path_to_labels_dir = 'EHRSHOT_ASSETS/benchmark'

femr_dict = load_medical_tokens(femr_token_dict_path=path_to_medical_tokens)
database = configure_database(path_to_database)
patient_id_list = get_unique_patient_ids_all_tasks(path_to_labels_dir=path_to_labels_dir)
patient_dataset = construct_dataset(patient_id_list, database, femr_dict)

100%|██████████| 6275/6275 [05:01<00:00, 20.84it/s]


In [28]:
patient_dataset.loc[0]['time_tokens']

[datetime.datetime(1933, 12, 22, 23, 59),
 datetime.datetime(1933, 12, 22, 23, 59),
 datetime.datetime(1933, 12, 22, 23, 59),
 datetime.datetime(2008, 10, 7, 23, 40),
 datetime.datetime(2008, 10, 7, 23, 40),
 datetime.datetime(2008, 10, 7, 23, 40),
 datetime.datetime(2008, 10, 7, 23, 40),
 datetime.datetime(2008, 10, 7, 23, 40),
 datetime.datetime(2008, 10, 7, 23, 40),
 datetime.datetime(2008, 10, 7, 23, 40),
 datetime.datetime(2008, 10, 7, 23, 40),
 datetime.datetime(2008, 10, 7, 23, 50),
 datetime.datetime(2008, 10, 7, 23, 56),
 datetime.datetime(2008, 10, 7, 23, 56),
 datetime.datetime(2008, 10, 7, 23, 56),
 datetime.datetime(2008, 10, 7, 23, 59),
 datetime.datetime(2008, 10, 7, 23, 59),
 datetime.datetime(2008, 10, 7, 23, 59),
 datetime.datetime(2008, 10, 7, 23, 59),
 datetime.datetime(2008, 10, 7, 23, 59),
 datetime.datetime(2008, 10, 7, 23, 59),
 datetime.datetime(2008, 10, 7, 23, 59),
 datetime.datetime(2008, 10, 7, 23, 59),
 datetime.datetime(2008, 10, 7, 23, 59),
 datetime.dat

In [8]:
df_labels = pd.read_csv('EHRSHOT_ASSETS/benchmark/lab_anemia/labeled_patients.csv')
df_labels

Unnamed: 0,patient_id,prediction_time,value,label_type
0,115973769,2011-10-01 11:19:00,0,categorical
1,115973769,2011-10-05 14:24:00,0,categorical
2,115973769,2011-10-06 02:59:00,1,categorical
3,115973769,2011-10-08 04:49:00,0,categorical
4,115973769,2012-08-24 14:44:00,0,categorical
...,...,...,...,...
184875,115967121,2023-02-10 09:23:00,2,categorical
184876,115967121,2023-02-12 08:44:00,2,categorical
184877,115967121,2023-02-16 10:45:00,2,categorical
184878,115967121,2023-02-23 09:17:00,2,categorical


In [9]:
for i in range(df_labels.shape[0]):
    current_row = df_labels.iloc[i]
    patient_id = current_row['patient_id']
    label = current_row['value']
    time = current_row['prediction_time']
    print(patient_id, label, time)
    break

115973769 0 2011-10-01 11:19:00


In [10]:
# extract the unique patient ids, 
patient_id_list = df_labels['patient_id'].unique()
for patient_id in patient_id_list:
    
    df_current_patient = df_labels[df_labels['patient_id'] == patient_id]
    prediction_time_list = df_current_patient['prediction_time'].values
    value_list = df_current_patient['value'].values
    prediction_time_list = [convert_datetime_string_to_timestamp(time) for time in prediction_time_list]
    

In [11]:
df_labels

Unnamed: 0,patient_id,prediction_time,value,label_type
0,115973769,2011-10-01 11:19:00,0,categorical
1,115973769,2011-10-05 14:24:00,0,categorical
2,115973769,2011-10-06 02:59:00,1,categorical
3,115973769,2011-10-08 04:49:00,0,categorical
4,115973769,2012-08-24 14:44:00,0,categorical
...,...,...,...,...
184875,115967121,2023-02-10 09:23:00,2,categorical
184876,115967121,2023-02-12 08:44:00,2,categorical
184877,115967121,2023-02-16 10:45:00,2,categorical
184878,115967121,2023-02-23 09:17:00,2,categorical


In [12]:
# for event in patient_events:
#     start_time = event.start
#     medical_code = event.code
#     if start_time in prediction_time_list:
patient_events = database[patient_id].events
patient_time_list = [event.start for event in patient_events]
patient_code_list = [event.code for event in patient_events]

In [13]:
prediction_time = prediction_time_list[1]
current_prediction_end_time = bisect_right(patient_time_list, prediction_time)
current_prediction_start_time = bisect_left(patient_time_list, prediction_time)

In [14]:
current_prediction_start_time, current_prediction_end_time


(2605, 2612)

In [15]:
def find_max_indices(patient_time_list, prediction_time_list):
    indices = []
    for prediction_time in prediction_time_list:
        index = bisect_right(patient_time_list, prediction_time) - 1
        indices.append(index if index >= 0 else None)  # Return None if prediction_time is before the first patient_time
    return indices

In [16]:
print(patient_time_list[current_prediction_start_time:current_prediction_end_time])
patient_code_list[current_prediction_start_time:current_prediction_end_time]

[datetime.datetime(2023, 1, 16, 11, 26), datetime.datetime(2023, 1, 16, 11, 26), datetime.datetime(2023, 1, 16, 11, 26), datetime.datetime(2023, 1, 16, 11, 26), datetime.datetime(2023, 1, 16, 11, 26), datetime.datetime(2023, 1, 16, 11, 26), datetime.datetime(2023, 1, 16, 11, 26)]


['LOINC/8480-6',
 'LOINC/3151-8',
 'LOINC/20112-9',
 'LOINC/8462-4',
 'LOINC/9279-1',
 'LOINC/8867-4',
 'LOINC/8478-0']

In [17]:
# import argparse
# import os
# import json
# from typing import List
# from loguru import logger
# from ehrshot.utils import LABELING_FUNCTION_2_PAPER_NAME
# import pandas as pd

# from femr.datasets import PatientDatabase
# from femr.labelers.core import LabeledPatients, Label
# from femr.labelers.benchmarks import (
#     Guo_LongLOSLabeler,
#     Guo_30DayReadmissionLabeler,
#     Guo_ICUAdmissionLabeler,
#     PancreaticCancerCodeLabeler,
#     CeliacDiseaseCodeLabeler,
#     LupusCodeLabeler,
#     AcuteMyocardialInfarctionCodeLabeler,
#     EssentialHypertensionCodeLabeler,
#     HyperlipidemiaCodeLabeler,
#     HyponatremiaInstantLabValueLabeler,
#     ThrombocytopeniaInstantLabValueLabeler,
#     HyperkalemiaInstantLabValueLabeler,
#     HypoglycemiaInstantLabValueLabeler,
#     AnemiaInstantLabValueLabeler,
# )
# from femr.labelers.omop import (
#     ChexpertLabeler,
# )

# database = PatientDatabase("EHRSHOT_ASSETS/femr/extract")
# ontology = database.get_ontology()
# labeler = AnemiaInstantLabValueLabeler(ontology)

# labeled_patients = labeler.apply(
#         path_to_patient_database="EHRSHOT_ASSETS/femr/extract",
#         num_threads=20,
#     )

# for patient, labels in labeled_patients.items():
#     new_labels: List[Label] = [ Label(time=l.time.replace(second=0, microsecond=0), value=l.value) for l in labels ]
#     # labeled_patients[patient] = new_labels
#     # if new_labels > 1:
#     #         print(patient, new_labels)
#     for e in new_labels:
#         print(patient, e.time, e.value)