In [1]:
import pandas as pd
import numpy as np
import os
import pickle

# explore the data

import argparse
import os
import json
from typing import List, Tuple
from loguru import logger
from ehrshot.utils import LABELING_FUNCTION_2_PAPER_NAME
import pandas as pd
from tqdm import tqdm
from femr.datasets import PatientDatabase
import femr
import datetime
from collections import defaultdict
from itertools import groupby

In [2]:
# one line of code to configure the patient database
def configure_database(path_to_database: str) -> PatientDatabase:
    return PatientDatabase(path_to_database)

In [3]:
def describe_events(events, code_description):
    # Organize events by code
    grouped_events = {}
    for event in events:
        if event.code in grouped_events:
            grouped_events[event.code].append(event)
        else:
            grouped_events[event.code] = [event]

    # Mock database dictionary to translate medical codes

    descriptions = []
    for code, group in grouped_events.items():
        # Start the description for this group
        event_type = code_description.get(code, code)
        times = ', '.join(set(e.start.strftime('%B %d, %Y, %H:%M') for e in group))
        # description = f"On {times}, {len(group)} events categorized under the code '{code}' ({event_type}) occurred."

        # description = f"On {times}, {len(group)} {group[0].omop_table} medical events ({event_type}) occurred."

        omop_code = group[0].omop_code
        if omop_code:
            if '_' in omop_code:
                omop_code = omop_code.replace('_', ' ')
            description = f"{len(group)} {omop_code} events:'{event_type}' recorded"
        else:
            description = f"{len(group)} events:'{event_type}' recorded"
        # Check if any event in the group has a value
        values = set(e.value for e in group if e.value is not None)
        if values:
            description += f" with values: {', '.join(map(str, values))}."

        # description += f" These events were recorded in the '{group[0].omop_table}' table."
        
        descriptions.append(description)
    
    # Combine all descriptions into a single paragraph
    return ' \n'.join(descriptions)

def add_patient_age(prediction_time_datatime_format, database, patient_id, text):
    time_lap = prediction_time_datatime_format.date() - database.get_patient_birth_date(patient_id)
    patient_age = time_lap.days // 365
    date = prediction_time_datatime_format.date().strftime('%B %d, %Y')
    return f"The patient was {patient_age} years old at the prediction time.\n" + text

def find_closest_under_datetime(datetime_list, target_datetime):
    # Filter the list to only include datetimes that are less than the target datetime
    filtered_list = [dt for dt in datetime_list if dt <= target_datetime]
    
    # Handle case where there might be no earlier datetimes
    if not filtered_list:
        return None  # Or handle this case as appropriate in your context
    
    # Calculate the time differences for the filtered list
    time_differences = [target_datetime - dt for dt in filtered_list]
    
    # Find the index of the minimum time difference in the filtered list
    min_difference_index = time_differences.index(min(time_differences))
    
    # Get the original index from the full list
    closest_datetime = filtered_list[min_difference_index]
    original_index = datetime_list.index(closest_datetime)
    
    return original_index

def add_demographic_info(patient_events, code_description, final_description):
    person_info_events = patient_events[0:10]
    person_info_code_list = []
    person_info_description = 'The patient has the following demographic information: '
    for e in person_info_events:
        if e.omop_table == 'person':
            if e.code != 'SNOMED/3950001':
                person_info_code_list.append(e.code)
    for i in range(len(person_info_code_list)):
        c = person_info_code_list[i]
        if i != len(person_info_code_list) - 1:
            person_info_description += f'{code_description[c]},'
        else:
            person_info_description += f'{code_description[c]}.' + '\n'
    return person_info_description + final_description

In [4]:
database = configure_database("EHRSHOT_ASSETS/femr/extract")

with open('ehrshot_code_description.json', 'r') as f:
    code_description = json.load(f)

df_common = pd.read_csv('white_box_common_data/new_diagnose_blood.csv')

for column in df_common.columns:
    if df_common[column].dtype == 'bool':
        df_common[column] = df_common[column].astype(int)

In [5]:
final_description_list = []
num_events = 300
for i in tqdm(range(len(df_common))):
# for i in range(1):
    patient_id = df_common.loc[i].patient_id
    prediction_time = df_common.loc[i].prediction_time
    prediction_time = datetime.datetime.fromisoformat(prediction_time)
    patient_femr_object = database[patient_id]
    assert patient_id == patient_femr_object.patient_id
    patient_events = list(patient_femr_object.events)

    all_event = []
    #     time_start = e.start
    #     if time_start == prediction_time + datetime.timedelta(minutes=1):
    #         all_event.append(e)
    # if len(all_event) == 0:
    #     print(i)
    time_list = [e.start for e in patient_events]
    time_idx = find_closest_under_datetime(time_list, prediction_time)
    if time_idx > num_events:
        start_index = time_idx - num_events
    else:
        start_index = 3
    
    all_event = patient_events[start_index:time_idx+1]
    all_time = time_list[start_index:time_idx+1]

    paired_list = list(zip(all_time, all_event))
    grouped = [(date, [event for _, event in group]) for date, group in groupby(paired_list, key=lambda x: x[0].date())]
    final_description = ''
    for temp_date, temp_group in grouped:
        current_date = temp_date.strftime('%B %d, %Y')
        description = describe_events(temp_group, code_description)
        final_description += f'At {current_date}:\n' + description + '\n'
    final_description = add_demographic_info(patient_events, code_description, final_description)
    final_description = add_patient_age(prediction_time, database, patient_id, final_description)

    final_description_list.append(final_description)

  0%|          | 3/2987 [00:00<01:50, 26.98it/s]

100%|██████████| 2987/2987 [01:06<00:00, 44.79it/s]


In [6]:
df_common['description'] = final_description_list

df_common.to_csv('new_diagnose_blood_patient_description_full.csv', index=False)

In [8]:
columns_to_check = ['value_new_hypertension','value_new_hyperlipidemia', 'value_new_pancan', 'value_new_celiac', 'value_new_lupus', 'value_new_acutemi']
# df = df_common.loc[:, ['value_new_hypertension','value_new_hyperlipidemia', 'value_new_pancan', 'value_new_celiac', 'value_new_lupus', 'value_new_acutemi']]
df = df_common.loc[:, ['value_new_hypertension','value_new_hyperlipidemia', 'value_new_acutemi']]

In [71]:
pos_idx_hypertension = np.where(df_common['value_new_hypertension'] == 1)[0]
pos_idx_hyperlipidemia = np.where(df_common['value_new_hyperlipidemia'] == 1)[0]
# pos_idx_pancan = np.where(df_common['value_new_pancan'] == 1)[0]
# pos_idx_celiac = np.where(df_common['value_new_celiac'] == 1)[0]
# pos_idx_lupus = np.where(df_common['value_new_lupus'] == 1)[0]
pos_idx_acutemi = np.where(df_common['value_new_acutemi'] == 1)[0]

def sample_from_array(arr, n=12):
    np.random.seed(42)  # For reproducibility
    return np.random.choice(arr, n, replace=False)  # Set replace=False to sample without replacement

# Sampling from each array
samples_hypertension = sample_from_array(pos_idx_hypertension)
samples_hyperlipidemia = sample_from_array(pos_idx_hyperlipidemia)
samples_pancan = sample_from_array(pos_idx_pancan)
samples_celiac = sample_from_array(pos_idx_celiac)
samples_lupus = sample_from_array(pos_idx_lupus)
samples_acutemi = sample_from_array(pos_idx_acutemi)

# Concatenate all samples into a single array
all_samples_idx = np.concatenate([
    samples_hypertension,
    samples_hyperlipidemia,
    samples_pancan,
    samples_celiac,
    samples_lupus,
    samples_acutemi
])
all_samples_idx.sort()
all_samples_idx = np.unique(all_samples_idx)
all_samples_idx.shape
all_index = 

(69,)

In [73]:
excluded_array = np.setdiff1d(all_index, all_samples_idx)
sampled_reset_values = np.random.choice(excluded_array, (100 - len(all_samples_idx)), replace=False)

NameError: name 'all_index' is not defined

RangeIndex(start=0, stop=2794, step=1)