In [7]:
import pandas as pd
import numpy as np
import os
import json
from femr.datasets import PatientDatabase
import datetime
from tqdm import tqdm

In [8]:
task_name_list =[
    "guo_los",
    "guo_readmission",
    "guo_icu",
    "new_hypertension",
    "new_hyperlipidemia",
    "new_pancan",
    "new_celiac",
    "new_lupus",
    "new_acutemi",
    "lab_thrombocytopenia",
    "lab_hyperkalemia",
    "lab_hyponatremia",
    "lab_anemia",
    "lab_hypoglycemia"
]

In [9]:
# df_labels = pd.read_csv('EHRSHOT_ASSETS/benchmark/lab_anemia/labeled_patients.csv')
# df_labels

In [10]:
def find_time_indices(patient_label_time_list, patient_all_time_list):
    results = []
    all_index = 0
    for label_time in patient_label_time_list:
        while all_index < len(patient_all_time_list) and patient_all_time_list[all_index] <= label_time:
            all_index += 1
        if all_index > 0:
            results.append(all_index - 1)
        else:
            raise ValueError("No time found")
        
    return results

In [11]:
def extract_patient_labels_index_by_task(patient_info_dir = 'new_femr_dataset/patient_info/', task_dir = 'EHRSHOT_ASSETS/benchmark/lab_anemia/', task_label_file = 'labeled_patients.csv'):
    df_labels = pd.read_csv(os.path.join(task_dir, task_label_file))
    patient_id_list = df_labels['patient_id'].unique()

    label_index_list = []

    for patient_id in tqdm(patient_id_list):
        patient_json_name = str(patient_id) + '.json'
        with open(patient_info_dir + patient_json_name) as f:
            patient_info = json.load(f)

        df_label_patient = df_labels[df_labels['patient_id'] == patient_id]

        patient_label_time_list = [datetime.datetime.fromisoformat(date) for date in df_label_patient['prediction_time'].values]

        patient_label_value_list = df_label_patient['value'].values

        assert len(patient_info['medical_tokens']) == len(patient_info['time_tokens'])

        patient_all_time_list = [datetime.datetime.fromisoformat(date) for date in patient_info['time_tokens']]

        label_index = find_time_indices(patient_label_time_list, patient_all_time_list)

        assert len(label_index) == len(patient_label_value_list)

        label_index_list.extend(label_index)

    df_labels_extracted = df_labels.copy()
    df_labels_extracted['index'] = label_index_list

    return df_labels_extracted

In [12]:
labeled_dir = 'EHRSHOT_ASSETS/benchmark/'
out_dir = 'new_femr_dataset/patient_label/'
task_label_file = 'labeled_patients.csv'
for task in tqdm(task_name_list):
    task_dir = os.path.join(labeled_dir, task)
    df_task = extract_patient_labels_index_by_task(patient_info_dir = 'new_femr_dataset/patient_info/', task_dir = task_dir, task_label_file = task_label_file)
    df_task.to_csv(os.path.join(out_dir, task + '.csv'), index = False)

100%|██████████| 3855/3855 [00:14<00:00, 266.69it/s]
100%|██████████| 3718/3718 [00:14<00:00, 252.82it/s]
100%|██████████| 3617/3617 [00:13<00:00, 269.48it/s]
100%|██████████| 2328/2328 [00:07<00:00, 308.87it/s]
100%|██████████| 2650/2650 [00:09<00:00, 290.88it/s]
100%|██████████| 3864/3864 [00:14<00:00, 260.64it/s]
100%|██████████| 3899/3899 [00:14<00:00, 261.34it/s]
100%|██████████| 3864/3864 [00:14<00:00, 264.44it/s]
100%|██████████| 3834/3834 [00:14<00:00, 267.02it/s]
100%|██████████| 6063/6063 [00:18<00:00, 320.63it/s]
100%|██████████| 5931/5931 [00:19<00:00, 310.84it/s]
100%|██████████| 5921/5921 [00:19<00:00, 311.34it/s]
100%|██████████| 6086/6086 [00:19<00:00, 318.72it/s]
100%|██████████| 5974/5974 [00:19<00:00, 303.98it/s]
100%|██████████| 14/14 [03:36<00:00, 15.47s/it]


In [13]:
task_dir

'EHRSHOT_ASSETS/benchmark/lab_hypoglycemia'

In [None]:
patient_info['time_tokens'][0]
datetime.datetime.fromisoformat(patient_info['time_tokens'][0])

datetime.datetime(1951, 2, 21, 23, 59)