In [1]:
import os
import random
import pickle

from os.path import join
from typing import Dict, Any, List, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
from skmultilearn.model_selection import iterative_train_test_split
from sklearn.preprocessing import MultiLabelBinarizer

DATA_ROOT = '/h/afallah/odyssey/odyssey/data/bigbird_data'
DATASET = f'{DATA_ROOT}/patient_sequences/patient_sequences_2048.parquet'
MAX_LEN = 2048

SEED = 23
os.chdir(DATA_ROOT)
random.seed(SEED)
np.random.seed(SEED)

In [2]:
# Load complete dataset
dataset_2048 = pd.read_parquet(DATASET)

# dataset_2048.drop(
#     ['event_tokens', 'type_tokens', 'age_tokens', 'time_tokens', 'visit_tokens', 'position_tokens'],
#     axis=1,
#     inplace=True
# )

print(f'Current columns: {dataset_2048.columns}')
dataset_2048.head()

Current columns: Index(['patient_id', 'num_visits', 'deceased', 'death_after_start',
       'death_after_end', 'length', 'token_length', 'event_tokens_2048',
       'type_tokens_2048', 'age_tokens_2048', 'time_tokens_2048',
       'visit_tokens_2048', 'position_tokens_2048', 'elapsed_tokens_2048',
       'common_conditions', 'rare_conditions'],
      dtype='object')


Unnamed: 0,patient_id,num_visits,deceased,death_after_start,death_after_end,length,token_length,event_tokens_2048,type_tokens_2048,age_tokens_2048,time_tokens_2048,visit_tokens_2048,position_tokens_2048,elapsed_tokens_2048,common_conditions,rare_conditions
0,35581927-9c95-5ae9-af76-7d74870a349c,1,0,,,50,54,"[[CLS], [VS], 00006473900, 00904516561, 510790...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, ...","[0, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85...","[0, 5902, 5902, 5902, 5902, 5902, 5902, 5902, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 1.97, 2.02, 2.02, 2.02, 2.02, 2.0...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
1,f5bba8dd-25c0-5336-8d3d-37424c185026,2,0,,,148,156,"[[CLS], [VS], 52135_2, 52075_2, 52074_2, 52073...","[1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...","[0, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83...","[0, 6594, 6594, 6594, 6594, 6594, 6594, 6594, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0...","[0, 0, 0, 0, 0, 0, 0, 1, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
2,f4938f91-cadb-5133-8541-a52fb0916cea,2,0,,,78,86,"[[CLS], [VS], 0RB30ZZ, 0RG10A0, 00071101441, 0...","[1, 2, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44...","[0, 8150, 8150, 8150, 8150, 8150, 8150, 8150, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.0, 0.0, 1.08, 1.08, 13.89, 13.8...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
3,6fe2371b-a6f0-5436-aade-7795005b0c66,2,0,,,86,94,"[[CLS], [VS], 63739057310, 49281041688, 005970...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72...","[0, 6093, 6093, 6093, 6093, 6093, 6093, 6093, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.75, 0.75, 0.75, 0.75, 0.75, 0.7...","[1, 0, 0, 0, 0, 0, 0, 1, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
4,6f7590ae-f3b9-50e5-9e41-d4bb1000887a,1,0,,,72,76,"[[CLS], [VS], 50813_0, 52135_0, 52075_3, 52074...","[1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...","[0, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47...","[0, 6379, 6379, 6379, 6379, 6379, 6379, 6379, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 1]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"


In [3]:
# test_ids = pickle.load(open('new_data/patient_id_dict/sample_pretrain_test_patient_ids_with_conditions.pkl', 'rb'))['test']
# dataset_2048.loc[dataset_2048['patient_id'].isin(test_ids)]['rare_conditions'].transform(lambda x: x[0]).sum()

In [17]:
def truncate_and_pad(row: pd.Series) -> Any:
    """Return a truncated and padded version of row.

    Args:
        row (pd.Series): The input row.

    Returns:
        Any: The truncated and padded row.

    Note:
        This function assumes the presence of the following columns in row:
        - 'event_tokens_2048'
        - 'type_tokens_2048'
        - 'age_tokens_2048'
        - 'time_tokens_2048'
        - 'visit_tokens_2048'
        - 'position_tokens_2048'
        - 'elapsed_tokens_2048'
    """
    seq_len = len(row['event_tokens_2048'])
    row['type_tokens_2048'] = np.pad(row['type_tokens_2048'][:seq_len], (0, MAX_LEN - seq_len), mode='constant')
    row['age_tokens_2048'] = np.pad(row['age_tokens_2048'][:seq_len], (0, MAX_LEN - seq_len), mode='constant')
    row['time_tokens_2048'] = np.pad(row['time_tokens_2048'][:seq_len], (0, MAX_LEN - seq_len), mode='constant')
    row['visit_tokens_2048'] = np.pad(row['visit_tokens_2048'][:seq_len], (0, MAX_LEN - seq_len), mode='constant')
    row['position_tokens_2048'] = np.pad(row['position_tokens_2048'][:seq_len], (0, MAX_LEN - seq_len), mode='constant')
    row['elapsed_tokens_2048'] = np.pad(row['elapsed_tokens_2048'][:seq_len], (0, MAX_LEN - seq_len), mode='constant')
    return row


def filter_by_num_visit(dataset: pd.DataFrame, minimum_num_visits: int) -> pd.DataFrame:
    """ Filter the patients based on num_visits threshold.

    Args:
        dataset (pd.DataFrame): The input dataset.
        minimum_num_visits (int): The threshold num_visits

    Returns:
        pd.DataFrame: The filtered dataset.
    """
    filtered_dataset = dataset.loc[dataset['num_visits'] >= minimum_num_visits]
    filtered_dataset.reset_index(drop=True, inplace=True)
    return filtered_dataset


def filter_by_length_of_stay(dataset: pd.DataFrame, threshold: int = 1) -> pd.DataFrame:
    """ Filter the patients based on length of stay threshold.

    Args:
        dataset (pd.DataFrame): The input dataset.
        minimum_num_visits (int): The threshold length of stay

    Returns:
        pd.DataFrame: The filtered dataset.
    """
    filtered_dataset = dataset.loc[dataset['length_of_stay'] >= threshold]

    # Only keep the patients that their first event happens within threshold
    filtered_dataset = filtered_dataset[
        filtered_dataset.apply(
        lambda row: row['elapsed_tokens_2048'][row['last_VS_index'] + 1] < threshold*24,
        axis=1)]

    filtered_dataset.reset_index(drop=True, inplace=True)
    return filtered_dataset


def get_last_index(seq: List[str], target: str) -> int:
    """Return the index of the last occurrence of target in seq.

    Args:
        seq (List[str]): The input sequence.
        target (str): The target string to find.

    Returns:
        int: The index of the last occurrence of target in seq.

    Examples:
        >>> get_last_index(['A', 'B', 'A', 'A', 'C', 'D'], 'A')
        3
    """
    return len(seq) - (seq[::-1].index(target) + 1)


def check_readmission_label(row: pd.Series) -> int:
    """Check if the label indicates readmission within one month.

    Args:
        row (pd.Series): The input row.

    Returns:
        bool: True if readmission label is present, False otherwise.
    """
    last_vs_index = row['last_VS_index']
    return int(row['event_tokens_2048'][last_vs_index - 1] in ('[W_0]', '[W_1]', '[W_2]', '[W_3]', '[M_1]'))


def remove_last_visit(row: pd.Series) -> pd.Series:
    """ Remove the event tokens of last visit in the row

    Args:
        row (pd.Series): The input row.

    Returns:
        pd.Series: The preprocessed row.
    """
    last_vs_index = row['last_VS_index']
    return row['event_tokens_2048'][:last_vs_index - 1]


def get_length_of_stay(row: pd.Series) -> pd.Series:
    """ Determine the length of a given visit. 
    
    Args:
        row (pd.Series): The input row.

    Returns:
        pd.Series: The preprocessed row.    
    """
    admission_time = row['last_VS_index'] + 1
    discharge_time = row['last_VE_index'] - 1
    return (discharge_time - admission_time) / 24


def truncate_visit_after_threshold(row: pd.Series, threshold: int = 24) -> pd.Series:
    """ Remove the event tokens of last visit that occur after threshold hours.

    Args:
        row (pd.Series): The input row.
        threshold (int): The cut of threshold.

    Returns:
        pd.Series: The preprocessed row.
    """
    last_vs_index = row['last_VS_index']
    last_ve_index = row['last_VE_index']

    for i in range(last_vs_index+1, last_ve_index):
        if row['elapsed_tokens_2048'][i] > threshold:
            return row['event_tokens_2048'][:i]
    
    return row['event_tokens_2048']

In [18]:
def process_length_of_stay_dataset(dataset: pd.DataFrame, threshold: int = 7) -> pd.DataFrame:
    """Process the length of stay dataset to extract required features.

    Args:
        dataset (pd.DataFrame): The input dataset.
        threshold (int): The threshold length of stay.

    Returns:
        pd.DataFrame: The processed dataset.
    """
    dataset['last_VS_index'] = dataset['event_tokens_2048'].transform(lambda seq: get_last_index(list(seq), '[VS]'))
    dataset['last_VE_index'] = dataset['event_tokens_2048'].transform(lambda seq: get_last_index(list(seq), '[VE]'))
    dataset['length_of_stay'] = dataset.apply(get_length_of_stay, axis=1)

    dataset = filter_by_length_of_stay(dataset, threshold=1)
    dataset['label_length_of_stay_1week'] = dataset['length_of_stay'] >= threshold
    dataset['event_tokens_2048'] = dataset.apply(lambda row: truncate_visit_after_threshold(row, threshold=24), axis=1)

    dataset['token_length'] = dataset['event_tokens_2048'].apply(len)    
    dataset = dataset.apply(truncate_and_pad, axis=1)
    dataset['event_tokens_2048'] = dataset['event_tokens_2048'].transform(lambda token_list: ' '.join(token_list))

    return dataset

# Process the dataset for hospital readmission in one month task
dataset_2048_los = process_length_of_stay_dataset(dataset_2048.copy(), threshold=7)
dataset_2048_los

Unnamed: 0,patient_id,num_visits,deceased,death_after_start,death_after_end,length,token_length,event_tokens_2048,type_tokens_2048,age_tokens_2048,time_tokens_2048,visit_tokens_2048,position_tokens_2048,elapsed_tokens_2048,common_conditions,rare_conditions,last_VS_index,last_VE_index,length_of_stay,label_length_of_stay_1week
0,35581927-9c95-5ae9-af76-7d74870a349c,1,0,,,50,40,[CLS] [VS] 00006473900 00904516561 51079000220...,"[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, ...","[0, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85...","[0, 5902, 5902, 5902, 5902, 5902, 5902, 5902, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 1.97, 2.02, 2.02, 2.02, 2.02, 2.0...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",1,52,2.041667,False
1,f5bba8dd-25c0-5336-8d3d-37424c185026,2,0,,,148,81,[CLS] [VS] 52135_2 52075_2 52074_2 52073_3 520...,"[1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...","[0, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83...","[0, 6594, 6594, 6594, 6594, 6594, 6594, 6594, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0...","[0, 0, 0, 0, 0, 0, 0, 1, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",68,154,3.500000,False
2,f4938f91-cadb-5133-8541-a52fb0916cea,2,0,,,78,86,[CLS] [VS] 0RB30ZZ 0RG10A0 00071101441 0090419...,"[1, 2, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44...","[0, 8150, 8150, 8150, 8150, 8150, 8150, 8150, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.0, 0.0, 1.08, 1.08, 13.89, 13.8...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",49,84,1.375000,False
3,6fe2371b-a6f0-5436-aade-7795005b0c66,2,0,,,86,91,[CLS] [VS] 63739057310 49281041688 00597026010...,"[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72...","[0, 6093, 6093, 6093, 6093, 6093, 6093, 6093, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.75, 0.75, 0.75, 0.75, 0.75, 0.7...","[1, 0, 0, 0, 0, 0, 0, 1, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",55,92,1.458333,False
4,6f7590ae-f3b9-50e5-9e41-d4bb1000887a,1,0,,,72,56,[CLS] [VS] 50813_0 52135_0 52075_3 52074_3 520...,"[1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...","[0, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47...","[0, 6379, 6379, 6379, 6379, 6379, 6379, 6379, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 1]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",1,74,2.958333,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
143474,3f300d4e-4554-5f1f-9dff-f209a4916cbc,7,0,,,536,564,[CLS] [VS] 51484_0 51146_3 51200_1 51221_4 512...,"[1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...","[0, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64...","[0, 6921, 6921, 6921, 6921, 6921, 6921, 6921, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0...","[1, 1, 1, 1, 0, 0, 0, 0, 1, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",520,562,1.666667,False
143475,cf2115d7-937e-511d-b159-dd7eb3d5d420,2,0,,,166,142,[CLS] [VS] 33332001001 00781305714 10019017644...,"[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 6, 5, ...","[0, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32...","[0, 5481, 5481, 5481, 5481, 5481, 5481, 5481, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 1.16, 1.25, 1.3, 1.3, 1.31, 1.53,...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",109,172,2.541667,False
143476,31338a39-28f9-54a5-a810-2d05fbaa5166,3,0,,,283,221,[CLS] [VS] 00338011704 00409128331 63323026201...,"[1, 2, 6, 6, 6, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...","[0, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50...","[0, 4997, 4997, 4997, 4997, 4997, 4997, 4997, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 1.48, 1.49, 1.52, 1.52, 1.6, 1.6,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",192,293,4.125000,False
143477,0989415d-394c-5f42-8dac-75dc7306a23c,2,0,,,470,140,[CLS] [VS] 49281041550 51079043620 51079088120...,"[1, 2, 6, 6, 6, 6, 6, 6, 6, 3, 8, 4, 2, 7, 6, ...","[0, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 0,...","[0, 5309, 5309, 5309, 5309, 5309, 5309, 5309, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 1, 1, 1, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, ...","[-2.0, -1.0, 1.3, 1.78, 1.78, 1.78, 1.84, 1.94...","[0, 0, 0, 0, 1, 0, 0, 0, 1, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",12,476,19.250000,True


In [None]:
def process_condition_dataset(dataset: pd.DataFrame) -> pd.DataFrame:
    """Process the condition dataset to extract required features.

    Args:
        dataset (pd.DataFrame): The input condition dataset.

    Returns:
        pd.DataFrame: The processed condition dataset.
    """
    dataset['all_conditions'] = dataset.apply(
        lambda row: np.concatenate(
            [row['common_conditions'], row['rare_conditions']], dtype=np.float64), axis=1
    )
    
    dataset = dataset.apply(truncate_and_pad, axis=1)
    dataset['event_tokens_2048'] = dataset['event_tokens_2048'].transform(lambda token_list: ' '.join(token_list))

    return dataset

# Process the dataset for conditions including rare and common
dataset_2048_condition = process_condition_dataset(dataset_2048.copy())

In [None]:
def process_mortality_dataset(dataset: pd.DataFrame) -> pd.DataFrame:
    """Process the mortality dataset to extract required features.

    Args:
        dataset (pd.DataFrame): The input mortality dataset.

    Returns:
        pd.DataFrame: The processed mortality dataset.
    """
    dataset['label_mortality_2weeks'] = ((dataset['death_after_start'] >= 0) & (dataset['death_after_end'] <= 15)).astype(int)
    dataset['label_mortality_1month'] = ((dataset['death_after_start'] >= 0) & (dataset['death_after_end'] <= 32)).astype(int)
    
    dataset = dataset.apply(truncate_and_pad, axis=1)
    dataset['event_tokens_2048'] = dataset['event_tokens_2048'].transform(lambda token_list: ' '.join(token_list))
    return dataset

# Process the dataset for mortality in two weeks or one month task
dataset_2048_mortality = process_mortality_dataset(dataset_2048.copy())

In [None]:
def process_readmission_dataset(dataset: pd.DataFrame) -> pd.DataFrame:
    """Process the readmission dataset to extract required features.

    Args:
        dataset (pd.DataFrame): The input dataset.

    Returns:
        pd.DataFrame: The processed dataset.
    """
    dataset['last_VS_index'] = dataset['event_tokens_2048'].transform(lambda seq: get_last_index(list(seq), '[VS]'))
    dataset['label_readmission_1month'] = dataset.apply(check_readmission_label, axis=1)
    dataset['event_tokens_2048'] = dataset.apply(remove_last_visit, axis=1)
    dataset['num_visits'] -= 1
    dataset['token_length'] = dataset['event_tokens_2048'].apply(len)
    
    dataset = dataset.apply(truncate_and_pad, axis=1)
    dataset['event_tokens_2048'] = dataset['event_tokens_2048'].transform(lambda token_list: ' '.join(token_list))
    return dataset

# Process the dataset for hospital readmission in one month task
dataset_2048_readmission = filter_by_num_visit(dataset_2048.copy(), minimum_num_visits=2)
dataset_2048_readmission = process_readmission_dataset(dataset_2048_readmission)

In [28]:
def stratified_train_test_split(dataset: pd.DataFrame, target: str, test_size: float, return_test: Optional[bool] = False):
    """
    Split the given dataset into training and testing sets using iterative stratification on given multi-label target.
    """
    # Convert all_conditions into a format suitable for multi-label stratification
    Y = np.array(dataset[target].values.tolist())
    X = dataset['patient_id'].to_numpy().reshape(-1, 1)
    is_single_label = type(dataset.iloc[0][target]) == np.int64

    # Perform stratified split
    if is_single_label:
        X_train, X_test, y_train, y_test = train_test_split(X, Y, stratify=Y, test_size=test_size, random_state=SEED)

    else:
        X_train, y_train, X_test, y_test = iterative_train_test_split(X, Y, test_size=test_size)
    
    X_train = X_train.flatten().tolist()
    X_test = X_test.flatten().tolist()

    if return_test:
        return X_test
    else:
        return X_train, X_test


def sample_balanced_subset(dataset: pd.DataFrame, target: str, sample_size: int):
    """
    Sample a subset of dataset with balanced target labels.
    """
    # Sampling positive and negative patients
    pos_patients = dataset[dataset[target] == True].sample(n=sample_size // 2, random_state=SEED)
    neg_patients = dataset[dataset[target] == False].sample(n=sample_size // 2, random_state=SEED)

    # Combining and shuffling patient IDs
    sample_patients = pos_patients['patient_id'].tolist() + neg_patients['patient_id'].tolist()
    random.shuffle(sample_patients)

    return sample_patients


def get_pretrain_test_split(dataset: pd.DataFrame, stratify_target: Optional[str] = None, test_size: float = 0.15):
    """ Split dataset into pretrain and test set. Stratify on a given target column if needed. """

    if stratify_target:
        pretrain_ids, test_ids = stratified_train_test_split(dataset, target=stratify_target, test_size=test_size)
    
    else:
        test_patients = dataset.sample(n=test_size, random_state=SEED)
        test_ids = test_patients['patient_id'].tolist()
        pretrain_ids = dataset[~dataset['patient_id'].isin(test_patients)]['patient_id'].tolist()
    
    random.shuffle(pretrain_ids)

    return pretrain_ids, test_ids

In [29]:
# Split data
patient_ids_dict = {'pretrain': [], 'finetune': {'few_shot': {}, 'kfold': {}}, 'test': []}

# Get train-test split
# pretrain_ids, test_ids = get_pretrain_test_split(dataset_2048_readmission, stratify_target='label_readmission_1month', test_size=0.2)
# pretrain_ids, test_ids = get_pretrain_test_split(process_condition_dataset, stratify_target='all_conditions', test_size=0.15)
# patient_ids_dict['pretrain'] = pretrain_ids
# patient_ids_dict['test'] = test_ids

pid = pickle.load(open('patient_id_dict/dataset_2048_condition.pkl', 'rb'))
patient_ids_dict['pretrain'] = pid['pretrain']
patient_ids_dict['test'] = pid['test']

In [31]:
class config:

    task_splits = {

        # 'mortality': {
        #     'dataset': dataset_2048_mortality,
        #     'label_col': 'label_mortality_1month',
        #     'finetune_size': [250, 500, 1000, 5000, 20000],
        #     'save_path': 'patient_id_dict/dataset_2048_mortality.pkl',
        #     'split_mode': 'single_label_balanced'
        # },

        # 'readmission': {
        #     'dataset': dataset_2048_readmission,
        #     'label_col': 'label_readmission_1month',
        #     'finetune_size': [250, 1000, 5000, 20000, 60000],
        #     'save_path': 'patient_id_dict/dataset_2048_readmission.pkl',
        #     'split_mode': 'single_label_stratified'
        # },

        'length_of_stay': {
            'dataset': dataset_2048_los,
            'label_col': 'label_length_of_stay_1week',
            'finetune_size': [250, 1000, 5000, 20000, 50000],
            'save_path': 'patient_id_dict/dataset_2048_los.pkl',
            'split_mode': 'single_label_balanced'
        },

        # 'condition': {
        #     'dataset': dataset_2048_condition,
        #     'label_col': 'all_conditions',
        #     'finetune_size': [50000],
        #     'save_path': 'patient_id_dict/dataset_2048_condition.pkl',
        #     'split_mode': 'multi_label_stratified'
        # }
    }

    all_tasks = list(task_splits.keys())

In [32]:
def get_finetune_split(
        config: config,
        patient_ids_dict: Dict[str, Any],
) -> Dict[str, Dict[str, List[str]]]:
    """
    Splits the dataset into training and cross-finetuneation sets using k-fold cross-finetuneation
    while ensuring balanced label distribution in each fold. Saves the resulting dictionary to disk.
    """
    # Extract task-specific configuration
    task_config = config.task_splits[task]
    dataset = task_config['dataset']
    label_col = task_config['label_col']
    finetune_sizes = task_config['finetune_size']
    save_path = task_config['save_path']
    split_mode = task_config['split_mode']

    # Get pretrain dataset
    pretrain_ids = patient_ids_dict['pretrain']
    dataset = dataset[dataset['patient_id'].isin(pretrain_ids)]

    # Few-shot finetune patient ids
    for finetune_num in finetune_sizes:

        if split_mode == 'single_label_balanced':
            finetune_ids = sample_balanced_subset(dataset, target=label_col, sample_size=finetune_num)
        
        elif split_mode == 'single_label_stratified':
            finetune_ids = stratified_train_test_split(dataset, target=label_col, test_size=finetune_num / len(dataset), return_test=True)
        
        elif split_mode == 'multi_label_stratified':
            finetune_ids = stratified_train_test_split(dataset, target=label_col, test_size=finetune_num / len(dataset), return_test=True)

        patient_ids_dict['finetune']['few_shot'][f'{finetune_num}'] = finetune_ids
    
    # Save the dictionary to disk
    with open(save_path, 'wb') as f:
        pickle.dump(patient_ids_dict, f)
        print(f'File saved to disk: {save_path}')

    return patient_ids_dict


for task in config.all_tasks:
    patient_ids_dict = get_finetune_split(
        config=config,
        patient_ids_dict=patient_ids_dict
    )

File saved to disk: patient_id_dict/dataset_2048_los.pkl


In [38]:
# dataset_2048_mortality.to_parquet('patient_sequences/patient_sequences_2048_mortality.parquet')
# dataset_2048_readmission.to_parquet('patient_sequences/patient_sequences_2048_readmission.parquet')
dataset_2048_los.to_parquet('patient_sequences/patient_sequences_2048_los.parquet')
# dataset_2048_condition.to_parquet('patient_sequences/patient_sequences_2048_condition.parquet')

In [None]:
dataset_2048_condition = pd.read_parquet('patient_sequences/patient_sequences_2048_condition.parquet')
pid = pickle.load(open('patient_id_dict/dataset_2048_condition.pkl', 'rb'))
condition_finetune = dataset_2048_condition.loc[dataset_2048_condition['patient_id'].isin(pid['finetune']['few_shot']['50000'])]
condition_finetune

In [None]:
freq = np.array(condition_finetune['all_conditions'].tolist()).sum(axis=0)
weights = np.clip(0, 50, sum(freq) / freq)
np.max(np.sqrt(freq)) / np.sqrt(freq)

In [None]:
# sorted(patient_ids_dict['pretrain']) == sorted(pickle.load(open('new_data/patient_id_dict/sample_pretrain_test_patient_ids_with_conditions.pkl', 'rb'))['pretrain'])

In [None]:
# merged_df = pd.merge(dataset_2048_mortality, dataset_2048_readmission, how='outer', on='patient_id')
# final_merged_df = pd.merge(merged_df, dataset_2048_condition, how='outer', on='patient_id')
# final_merged_df

In [None]:
# Performing stratified k-fold split
    # skf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=SEED)

    # for i, (train_index, cv_index) in enumerate(skf.split(dataset, dataset[label_col])):

    #     dataset_cv = dataset.iloc[cv_index]
    #     dataset_finetune = dataset.iloc[train_index]

    #     # Separate positive and negative labeled patients
    #     pos_patients = dataset_cv[dataset_cv[label_col] == True]['patient_id'].tolist()
    #     neg_patients = dataset_cv[dataset_cv[label_col] == False]['patient_id'].tolist()

    #     # Calculate the number of positive and negative patients needed for balanced CV set
    #     num_pos_needed = cv_size // 2
    #     num_neg_needed = cv_size // 2

    #     # Select positive and negative patients for CV set ensuring balanced distribution
    #     cv_patients = pos_patients[:num_pos_needed] + neg_patients[:num_neg_needed]
    #     remaining_finetune_patients = pos_patients[num_pos_needed:] + neg_patients[num_neg_needed:]

    #     # Extract patient IDs for training set
    #     finetune_patients = dataset_finetune['patient_id'].tolist()
    #     finetune_patients += remaining_finetune_patients

    #     # Shuffle each list of patients
    #     random.shuffle(cv_patients)
    #     random.shuffle(finetune_patients)

    #     patient_ids_dict['finetune']['kfold'][f'group{i+1}'] = {'finetune': finetune_patients, 'cv': cv_patients}

In [None]:
# Assuming dataset.event_tokens is your DataFrame column
dataset.event_tokens.transform(len).plot(kind='hist', bins=100)
plt.xlim(1000, 8000)  # Limit x-axis to 5000
plt.ylim(0, 6000)
plt.xlabel('Length of Event Tokens')
plt.ylabel('Frequency')
plt.title('Histogram of Event Tokens Length')
plt.show()

In [None]:
# len(patient_ids_dict['group3']['cv'])

# dataset_2048.loc[dataset_2048['patient_id'].isin(patient_ids_dict['group1']['cv'])]['label_mortality_1month']

# s = set()
# for i in range(1, 6):
#     s = s.union(set(patient_ids_dict[f'group{i}']['cv']))
#
# len(s)

In [None]:
##### DEAD ZONE | DO NOT ENTER #####

# patient_ids = pickle.load(open(join("/h/afallah/odyssey/odyssey/data/bigbird_data", 'dataset_2048_mortality_1month.pkl'), 'rb'))
# patient_ids['finetune']['few_shot'].keys()

# patient_ids2 = pickle.load(open(join("/h/afallah/odyssey/odyssey/data/bigbird_data", 'dataset_2048_mortality_2weeks.pkl'), 'rb'))['pretrain']
#
# patient_ids1.sort()
# patient_ids2.sort()
#
# patient_ids1 == patient_ids2
# # dataset_2048.loc[dataset_2048['patient_id'].isin(patient_ids['pretrain'])]

In [None]:
# dataset_2048_readmission = dataset_2048.loc[dataset_2048['num_visits'] > 1]
# dataset_2048_readmission.reset_index(drop=True, inplace=True)
#
# dataset_2048_readmission['last_VS_index'] = dataset_2048_readmission['event_tokens_2048'].transform(lambda seq: get_last_index(list(seq), '[VS]'))
#
# dataset_2048_readmission['label_readmission_1month'] = dataset_2048_readmission.apply(
#     lambda row: row['event_tokens_2048'][row['last_VS_index'] - 1] in ('[W_0]', '[W_1]', '[W_2]', '[W_3]', '[M_1]'), axis=1
# )
# dataset_2048_readmission['event_tokens_2048'] = dataset_2048_readmission.apply(
#     lambda row: row['event_tokens_2048'][:row['last_VS_index'] - 1], axis=1
# )
# dataset_2048_readmission.drop(['deceased', 'death_after_start', 'death_after_end', 'length'], axis=1, inplace=True)
# dataset_2048_readmission['num_visits'] -= 1
# dataset_2048_readmission['token_length'] = dataset_2048_readmission['event_tokens_2048'].apply(len)
# dataset_2048_readmission = dataset_2048_readmission.apply(lambda row: truncate_and_pad(row), axis=1)
# dataset_2048_readmission['event_tokens_2048'] = dataset_2048_readmission['event_tokens_2048'].transform(
#     lambda token_list: ' '.join(token_list)
# )
#
# dataset_2048_readmission