In [2]:
import random
import pickle

from os.path import join
from typing import Dict, Any, List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import KFold, StratifiedKFold

DATASET = 'patient_sequences/patient_sequences_2048.parquet'
MAX_LEN = 2048

random.seed(23)

In [3]:
# Load complete dataset
dataset_2048 = pd.read_parquet(DATASET)

# dataset_2048.drop(
#     ['event_tokens', 'type_tokens', 'age_tokens', 'time_tokens', 'visit_tokens', 'position_tokens'],
#     axis=1,
#     inplace=True
# )

dataset_2048

Unnamed: 0,patient_id,num_visits,deceased,death_after_start,death_after_end,length,token_length,event_tokens_2048,type_tokens_2048,age_tokens_2048,time_tokens_2048,visit_tokens_2048,position_tokens_2048
1,9b62c9f4-3fdc-5020-82b5-ae5b8292445a,3,0,,,43,55,"[[CLS], [VS], 7569, 66689036430, 00904224461, ...","[1, 2, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 3, 8, ...","[0, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28...","[0, 5963, 5963, 5963, 5963, 5963, 5963, 5963, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
2,2ca522eb-dd89-5f79-8155-9599ea46b0b2,1,1,244.0,242.0,51,55,"[[CLS], [VS], 00904629261, 00904642281, 009046...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86...","[0, 8016, 8016, 8016, 8016, 8016, 8016, 8016, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
3,54ee4964-056b-5c38-a607-b95e63176fc3,18,1,13.0,0.0,1810,1882,"[[CLS], [VS], 63323026201, 00603385521, 005970...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71...","[0, 5520, 5520, 5520, 5520, 5520, 5520, 5520, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
4,02adf8a6-8bc0-55d3-81ae-4d8582094896,8,1,20.0,11.0,640,672,"[[CLS], [VS], 51079045420, 00006494300, 177140...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, ...","[0, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65...","[0, 8002, 8002, 8002, 8002, 8002, 8002, 8002, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
5,744fe3c4-9b03-55ae-ac9f-6bc4e967cde7,2,0,,,80,88,"[[CLS], [VS], 7813, 7813, 7902, 7902, 9604, 00...","[1, 2, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29...","[0, 7582, 7582, 7582, 7582, 7582, 7582, 7582, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...
173666,cf2115d7-937e-511d-b159-dd7eb3d5d420,2,0,,,166,174,"[[CLS], [VS], 66591018442, 63323026201, 001350...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32...","[0, 5481, 5481, 5481, 5481, 5481, 5481, 5481, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
173667,31338a39-28f9-54a5-a810-2d05fbaa5166,3,0,,,283,295,"[[CLS], [VS], 5014, 5123, 00338011704, 6332302...","[1, 2, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50...","[0, 4997, 4997, 4997, 4997, 4997, 4997, 4997, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
173668,0989415d-394c-5f42-8dac-75dc7306a23c,2,0,,,470,478,"[[CLS], [VS], 51079043620, 51079088120, 492810...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 3, 8, 4, 2, 7, 7, ...","[0, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 0,...","[0, 5309, 5309, 5309, 5309, 5309, 5309, 5309, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 1, 1, 1, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, ..."
173669,26fb8fef-b976-5c55-859d-cc190261f94b,2,0,,,93,101,"[[CLS], [VS], 0SRD0J9, 60505251903, 0090422446...","[1, 2, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61...","[0, 5085, 5085, 5085, 5085, 5085, 5085, 5085, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."


In [4]:
def process_mortality_dataset(dataset: pd.DataFrame) -> pd.DataFrame:
    """Process the mortality dataset to extract required features.

    Args:
        dataset (pd.DataFrame): The input mortality dataset.

    Returns:
        pd.DataFrame: The processed mortality dataset.
    """
    dataset['event_tokens_2048'] = dataset['event_tokens_2048'].transform(lambda token_list: ' '.join(token_list))
    dataset['label_mortality_2weeks'] = ((dataset['death_after_start'] >= 0) & (dataset['death_after_end'] <= 15)).astype(int)
    dataset['label_mortality_1month'] = ((dataset['death_after_start'] >= 0) & (dataset['death_after_end'] <= 32)).astype(int)
    return dataset

# Process the dataset for mortality in two weeks or one month task
dataset_2048_mortality = process_mortality_dataset(dataset_2048.copy())

In [5]:
def process_readmission_dataset(dataset: pd.DataFrame) -> pd.DataFrame:
    """Process the readmission dataset to extract required features.

    Args:
        dataset (pd.DataFrame): The input dataset.

    Returns:
        pd.DataFrame: The processed dataset.
    """
    dataset.drop(['deceased', 'death_after_start', 'death_after_end', 'length'], axis=1, inplace=True)
    dataset['last_VS_index'] = dataset['event_tokens_2048'].transform(lambda seq: get_last_index(list(seq), '[VS]'))
    dataset['label_readmission_1month'] = dataset.apply(check_readmission_label, axis=1)
    dataset['event_tokens_2048'] = dataset.apply(remove_last_visit, axis=1)
    dataset['num_visits'] -= 1
    dataset['token_length'] = dataset['event_tokens_2048'].apply(len)
    dataset = dataset.apply(truncate_and_pad, axis=1)
    dataset['event_tokens_2048'] = dataset['event_tokens_2048'].transform(lambda token_list: ' '.join(token_list))
    return dataset

def filter_by_num_visit(dataset: pd.DataFrame, minimum_num_visits: int) -> pd.DataFrame:
    """ Filter the patients based on num_visits threshold.

    Args:
        dataset (pd.DataFrame): The input dataset.
        minimum_num_visits (int): The threshold num_visits

    Returns:
        pd.DataFrame: The filtered dataset.
    """
    filtered_dataset = dataset.loc[dataset['num_visits'] >= minimum_num_visits].copy()
    filtered_dataset.reset_index(drop=True, inplace=True)
    return filtered_dataset

def get_last_index(seq: List[str], target: str) -> int:
    """Return the index of the last occurrence of target in seq.

    Args:
        seq (List[str]): The input sequence.
        target (str): The target string to find.

    Returns:
        int: The index of the last occurrence of target in seq.

    Examples:
        >>> get_last_index(['A', 'B', 'A', 'A', 'C', 'D'], 'A')
        3
    """
    return len(seq) - (seq[::-1].index(target) + 1)

def truncate_and_pad(row: pd.Series) -> Any:
    """Return a truncated and padded version of row.

    Args:
        row (pd.Series): The input row.

    Returns:
        Any: The truncated and padded row.

    Note:
        This function assumes the presence of the following columns in row:
        - 'event_tokens_2048'
        - 'type_tokens_2048'
        - 'age_tokens_2048'
        - 'time_tokens_2048'
        - 'visit_tokens_2048'
        - 'position_tokens_2048'
    """
    seq_len = len(row['event_tokens_2048'])
    row['type_tokens_2048'] = np.pad(row['type_tokens_2048'][:seq_len], (0, MAX_LEN - seq_len), mode='constant')
    row['age_tokens_2048'] = np.pad(row['age_tokens_2048'][:seq_len], (0, MAX_LEN - seq_len), mode='constant')
    row['time_tokens_2048'] = np.pad(row['time_tokens_2048'][:seq_len], (0, MAX_LEN - seq_len), mode='constant')
    row['visit_tokens_2048'] = np.pad(row['visit_tokens_2048'][:seq_len], (0, MAX_LEN - seq_len), mode='constant')
    row['position_tokens_2048'] = np.pad(row['position_tokens_2048'][:seq_len], (0, MAX_LEN - seq_len), mode='constant')
    return row

def check_readmission_label(row: pd.Series) -> int:
    """Check if the label indicates readmission within one month.

    Args:
        row (pd.Series): The input row.

    Returns:
        bool: True if readmission label is present, False otherwise.
    """
    last_vs_index = row['last_VS_index']
    return int(row['event_tokens_2048'][last_vs_index - 1] in ('[W_0]', '[W_1]', '[W_2]', '[W_3]', '[M_1]'))

def remove_last_visit(row: pd.Series) -> pd.Series:
    """ Remove the event tokens of last visit in the row

    Args:
        row (pd.Series): The input row.

    Returns:
        pd.Series: The preprocessed row.
    """
    last_vs_index = row['last_VS_index']
    return row['event_tokens_2048'][:last_vs_index - 1]

# Process the dataset for hospital readmission in one month task
dataset_2048_readmission = filter_by_num_visit(dataset_2048, minimum_num_visits=2)
dataset_2048_readmission = process_readmission_dataset(dataset_2048_readmission)

In [7]:
def split_dataset_train_test_finetune_datasets(
        dataset: pd.DataFrame,
        label_col: str,
        cv_size: int,
        test_size: int,
        finetune_size: List[int],
        num_splits: int,
        save_path: str
) -> Dict[str, Dict[str, List[str]]]:
    """
    Splits the dataset into training and cross-finetuneation sets using k-fold cross-finetuneation
    while ensuring balanced label distribution in each fold. Saves the resulting dictionary to disk.

    Args:
        dataset (pd.DataFrame): The input dataset.
        label_col (str): The name of the column containing the labels.
        cv_size (int): The number of patients in each cross-finetuneation split.
        test_size (int): The number of patients in the test set.
        finetune_size (List[int]): The number of patients in each fine-tune set
        num_splits (int): The number of splits to create (k value).
        save_path (str): The path to save the resulting dictionary.

    Returns:
        Dict[str, Dict[str, List[str]]]: A dictionary containing patient IDs for each split group.
    """

    # Dictionary to hold patient IDs for different sets
    patient_ids_dict = {'pretrain': [], 'finetune': {'few_shot': {}, 'kfold': {}}, 'test': []}

    # Sample test patients and remove them from dataset
    if test_size > 0:
        test_patients = dataset.sample(n=test_size, random_state=23)
        dataset.drop(test_patients.index, inplace=True)
        patient_ids_dict['test'] = test_patients['patient_id'].tolist()

    # Any remaining data is used for pretraining
    patient_ids_dict['pretrain'] = dataset['patient_id'].tolist()
    random.shuffle(patient_ids_dict['pretrain'])

    # few_shot finetune dataset
    for each_finetune_size in finetune_size:
        subset_size = each_finetune_size // 2

        # Sampling positive and negative patients
        pos_patients = dataset[dataset[label_col] == True].sample(n=subset_size, random_state=23)
        neg_patients = dataset[dataset[label_col] == False].sample(n=subset_size, random_state=23)

        # Extracting patient IDs
        pos_patients_ids = pos_patients['patient_id'].tolist()
        neg_patients_ids = neg_patients['patient_id'].tolist()

        # Combining and shuffling patient IDs
        finetune_patients = pos_patients_ids + neg_patients_ids
        random.shuffle(finetune_patients)
        patient_ids_dict['finetune']['few_shot'][f'{each_finetune_size}'] = finetune_patients

    # Performing stratified k-fold split
    skf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=23)

    for i, (train_index, cv_index) in enumerate(skf.split(dataset, dataset[label_col])):

        dataset_cv = dataset.iloc[cv_index]
        dataset_finetune = dataset.iloc[train_index]

        # Separate positive and negative labeled patients
        pos_patients = dataset_cv[dataset_cv[label_col] == True]['patient_id'].tolist()
        neg_patients = dataset_cv[dataset_cv[label_col] == False]['patient_id'].tolist()

        # Calculate the number of positive and negative patients needed for balanced CV set
        num_pos_needed = cv_size // 2
        num_neg_needed = cv_size // 2

        # Select positive and negative patients for CV set ensuring balanced distribution
        cv_patients = pos_patients[:num_pos_needed] + neg_patients[:num_neg_needed]
        remaining_finetune_patients = pos_patients[num_pos_needed:] + neg_patients[num_neg_needed:]

        # Extract patient IDs for training set
        finetune_patients = dataset_finetune['patient_id'].tolist()
        finetune_patients += remaining_finetune_patients

        # Shuffle each list of patients
        random.shuffle(cv_patients)
        random.shuffle(finetune_patients)

        patient_ids_dict['finetune']['kfold'][f'group{i+1}'] = {'finetune': finetune_patients, 'cv': cv_patients}

    # Save the dictionary to disk
    with open(save_path, 'wb') as f:
        pickle.dump(patient_ids_dict, f)

    return patient_ids_dict


patient_ids_dict = split_dataset_train_test_finetune_datasets(
    dataset=dataset_2048_mortality.copy(),
    label_col='label_mortality_1month',
    cv_size=4000,
    test_size=20000,    # 20000
    finetune_size=[250, 500, 1000, 5000, 20000],  # [250, 500, 1000, 5000, 20000] [250, 1000, 5000, 20000, 50000]
    num_splits=5,
    save_path='patient_id_dict/dataset_2048_mortality_1month.pkl'
)

In [8]:
pickle.load(open('patient_id_dict/dataset_2048_mortality_1month.pkl', 'rb'))['finetune']['few_shot'].keys()

dict_keys(['250', '500', '1000', '5000', '20000'])

In [54]:
# dataset_2048_mortality.to_parquet('patient_sequences_2048_mortality.parquet')
dataset_2048_readmission.to_parquet('patient_sequences/patient_sequences_2048_readmission.parquet')

In [None]:
# Assuming dataset.event_tokens is your DataFrame column
dataset.event_tokens.transform(len).plot(kind='hist', bins=100)
plt.xlim(1000, 8000)  # Limit x-axis to 5000
plt.ylim(0, 6000)
plt.xlabel('Length of Event Tokens')
plt.ylabel('Frequency')
plt.title('Histogram of Event Tokens Length')
plt.show()

In [None]:
# len(patient_ids_dict['group3']['cv'])

# dataset_2048.loc[dataset_2048['patient_id'].isin(patient_ids_dict['group1']['cv'])]['label_mortality_1month']

# s = set()
# for i in range(1, 6):
#     s = s.union(set(patient_ids_dict[f'group{i}']['cv']))
#
# len(s)

In [None]:
##### DEAD ZONE | DO NOT ENTER #####

# patient_ids = pickle.load(open(join("/h/afallah/odyssey/odyssey/data/bigbird_data", 'dataset_2048_mortality_1month.pkl'), 'rb'))
# patient_ids['finetune']['few_shot'].keys()

# patient_ids2 = pickle.load(open(join("/h/afallah/odyssey/odyssey/data/bigbird_data", 'dataset_2048_mortality_2weeks.pkl'), 'rb'))['pretrain']
#
# patient_ids1.sort()
# patient_ids2.sort()
#
# patient_ids1 == patient_ids2
# # dataset_2048.loc[dataset_2048['patient_id'].isin(patient_ids['pretrain'])]

In [None]:
# dataset_2048_readmission = dataset_2048.loc[dataset_2048['num_visits'] > 1]
# dataset_2048_readmission.reset_index(drop=True, inplace=True)
#
# dataset_2048_readmission['last_VS_index'] = dataset_2048_readmission['event_tokens_2048'].transform(lambda seq: get_last_index(list(seq), '[VS]'))
#
# dataset_2048_readmission['label_readmission_1month'] = dataset_2048_readmission.apply(
#     lambda row: row['event_tokens_2048'][row['last_VS_index'] - 1] in ('[W_0]', '[W_1]', '[W_2]', '[W_3]', '[M_1]'), axis=1
# )
# dataset_2048_readmission['event_tokens_2048'] = dataset_2048_readmission.apply(
#     lambda row: row['event_tokens_2048'][:row['last_VS_index'] - 1], axis=1
# )
# dataset_2048_readmission.drop(['deceased', 'death_after_start', 'death_after_end', 'length'], axis=1, inplace=True)
# dataset_2048_readmission['num_visits'] -= 1
# dataset_2048_readmission['token_length'] = dataset_2048_readmission['event_tokens_2048'].apply(len)
# dataset_2048_readmission = dataset_2048_readmission.apply(lambda row: truncate_and_pad(row), axis=1)
# dataset_2048_readmission['event_tokens_2048'] = dataset_2048_readmission['event_tokens_2048'].transform(
#     lambda token_list: ' '.join(token_list)
# )
#
# dataset_2048_readmission