In [None]:
import os
import random
import pickle

from os.path import join
from typing import Dict, Any, List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
from skmultilearn.model_selection import iterative_train_test_split
from sklearn.preprocessing import MultiLabelBinarizer

DATA_ROOT = '/h/afallah/odyssey/odyssey/data/bigbird_data'
DATASET = f'{DATA_ROOT}/patient_sequences/patient_sequences_2048.parquet'
MAX_LEN = 2048

os.chdir(DATA_ROOT)
random.seed(23)
np.random.seed(23)

In [93]:
# Load complete dataset
dataset_2048 = pd.read_parquet(DATASET)

# dataset_2048.drop(
#     ['event_tokens', 'type_tokens', 'age_tokens', 'time_tokens', 'visit_tokens', 'position_tokens'],
#     axis=1,
#     inplace=True
# )

dataset_2048.head()

Unnamed: 0,patient_id,num_visits,deceased,death_after_start,death_after_end,length,token_length,event_tokens_2048,type_tokens_2048,age_tokens_2048,time_tokens_2048,visit_tokens_2048,position_tokens_2048,elapsed_tokens_2048,common_conditions,rare_conditions
0,35581927-9c95-5ae9-af76-7d74870a349c,1,0,,,50,54,"[[CLS], [VS], 00006473900, 00904516561, 510790...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, ...","[0, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85...","[0, 5902, 5902, 5902, 5902, 5902, 5902, 5902, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 1.97, 2.02, 2.02, 2.02, 2.02, 2.0...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
1,f5bba8dd-25c0-5336-8d3d-37424c185026,2,0,,,148,156,"[[CLS], [VS], 52135_2, 52075_2, 52074_2, 52073...","[1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...","[0, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83...","[0, 6594, 6594, 6594, 6594, 6594, 6594, 6594, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0...","[0, 0, 0, 0, 0, 0, 0, 1, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
2,f4938f91-cadb-5133-8541-a52fb0916cea,2,0,,,78,86,"[[CLS], [VS], 0RB30ZZ, 0RG10A0, 00071101441, 0...","[1, 2, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44...","[0, 8150, 8150, 8150, 8150, 8150, 8150, 8150, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.0, 0.0, 1.08, 1.08, 13.89, 13.8...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
3,6fe2371b-a6f0-5436-aade-7795005b0c66,2,0,,,86,94,"[[CLS], [VS], 63739057310, 49281041688, 005970...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72...","[0, 6093, 6093, 6093, 6093, 6093, 6093, 6093, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.75, 0.75, 0.75, 0.75, 0.75, 0.7...","[1, 0, 0, 0, 0, 0, 0, 1, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
4,6f7590ae-f3b9-50e5-9e41-d4bb1000887a,1,0,,,72,76,"[[CLS], [VS], 50813_0, 52135_0, 52075_3, 52074...","[1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...","[0, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47...","[0, 6379, 6379, 6379, 6379, 6379, 6379, 6379, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 1]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"


In [94]:
def process_condition_dataset(dataset: pd.DataFrame) -> pd.DataFrame:
    """Process the condition dataset to extract required features.

    Args:
        dataset (pd.DataFrame): The input condition dataset.

    Returns:
        pd.DataFrame: The processed condition dataset.
    """
    dataset_2048['all_conditions'] = dataset_2048.apply(
        lambda row: np.concatenate(
            [row['common_conditions'], row['rare_conditions']], dtype=np.float64), axis=1
    )
    
    dataset['event_tokens_2048'] = dataset['event_tokens_2048'].transform(
        lambda token_list: ' '.join(token_list)
    )

    return dataset

# Process the dataset for conditions including rare and common
dataset_2048 = process_condition_dataset(dataset_2048)
dataset_2048.to_parquet(f'{DATA_ROOT}/patient_sequences/patient_sequences_2048_with_conditions.parquet')
dataset_2048.iloc[:100].to_parquet(f'{DATA_ROOT}/patient_sequences/patient_sequences_2048_with_conditions_100patients.parquet')

In [None]:
def get_stratified_data_split(dataset: pd.DataFrame, target: str, test_size: float = 0.15):
    """
    Split the given dataset into training and testing sets using iterative stratification on given multi-label target.
    """
    # Convert all_conditions into a format suitable for multi-label stratification
    Y = np.array(dataset_2048[target].values.tolist())

    # We will split based on Y but we need to keep the association with patient_id
    X = dataset_2048['patient_id'].to_numpy().reshape(-1, 1)

    # Perform iterative stratification
    X_train, y_train, X_test, y_test = iterative_train_test_split(X, Y, test_size=test_size)

    return X_train.flatten().tolist(), X_test.flatten().tolist()

pretrain_ids, test_ids = get_stratified_data_split(dataset_2048, 'all_conditions', test_size=0.15)

patient_ids_dict = {'pretrain': pretrain_ids, 'finetune': {'few_shot': {}, 'kfold': {}}, 'test': test_ids}

with open('sample_pretrain_test_patient_ids_with_conditions.pkl', 'wb') as f:
        pickle.dump(patient_ids_dict, f)


In [None]:
def process_mortality_dataset(dataset: pd.DataFrame) -> pd.DataFrame:
    """Process the mortality dataset to extract required features.

    Args:
        dataset (pd.DataFrame): The input mortality dataset.

    Returns:
        pd.DataFrame: The processed mortality dataset.
    """
    dataset['event_tokens_2048'] = dataset['event_tokens_2048'].transform(lambda token_list: ' '.join(token_list))
    dataset['label_mortality_2weeks'] = ((dataset['death_after_start'] >= 0) & (dataset['death_after_end'] <= 15)).astype(int)
    dataset['label_mortality_1month'] = ((dataset['death_after_start'] >= 0) & (dataset['death_after_end'] <= 32)).astype(int)
    return dataset

# Process the dataset for mortality in two weeks or one month task
dataset_2048_mortality = process_mortality_dataset(dataset_2048.copy())

In [None]:
def process_readmission_dataset(dataset: pd.DataFrame) -> pd.DataFrame:
    """Process the readmission dataset to extract required features.

    Args:
        dataset (pd.DataFrame): The input dataset.

    Returns:
        pd.DataFrame: The processed dataset.
    """
    dataset.drop(['deceased', 'death_after_start', 'death_after_end', 'length'], axis=1, inplace=True)
    dataset['last_VS_index'] = dataset['event_tokens_2048'].transform(lambda seq: get_last_index(list(seq), '[VS]'))
    dataset['label_readmission_1month'] = dataset.apply(check_readmission_label, axis=1)
    dataset['event_tokens_2048'] = dataset.apply(remove_last_visit, axis=1)
    dataset['num_visits'] -= 1
    dataset['token_length'] = dataset['event_tokens_2048'].apply(len)
    dataset = dataset.apply(truncate_and_pad, axis=1)
    dataset['event_tokens_2048'] = dataset['event_tokens_2048'].transform(lambda token_list: ' '.join(token_list))
    return dataset

def filter_by_num_visit(dataset: pd.DataFrame, minimum_num_visits: int) -> pd.DataFrame:
    """ Filter the patients based on num_visits threshold.

    Args:
        dataset (pd.DataFrame): The input dataset.
        minimum_num_visits (int): The threshold num_visits

    Returns:
        pd.DataFrame: The filtered dataset.
    """
    filtered_dataset = dataset.loc[dataset['num_visits'] >= minimum_num_visits].copy()
    filtered_dataset.reset_index(drop=True, inplace=True)
    return filtered_dataset

def get_last_index(seq: List[str], target: str) -> int:
    """Return the index of the last occurrence of target in seq.

    Args:
        seq (List[str]): The input sequence.
        target (str): The target string to find.

    Returns:
        int: The index of the last occurrence of target in seq.

    Examples:
        >>> get_last_index(['A', 'B', 'A', 'A', 'C', 'D'], 'A')
        3
    """
    return len(seq) - (seq[::-1].index(target) + 1)

def truncate_and_pad(row: pd.Series) -> Any:
    """Return a truncated and padded version of row.

    Args:
        row (pd.Series): The input row.

    Returns:
        Any: The truncated and padded row.

    Note:
        This function assumes the presence of the following columns in row:
        - 'event_tokens_2048'
        - 'type_tokens_2048'
        - 'age_tokens_2048'
        - 'time_tokens_2048'
        - 'visit_tokens_2048'
        - 'position_tokens_2048'
    """
    seq_len = len(row['event_tokens_2048'])
    row['type_tokens_2048'] = np.pad(row['type_tokens_2048'][:seq_len], (0, MAX_LEN - seq_len), mode='constant')
    row['age_tokens_2048'] = np.pad(row['age_tokens_2048'][:seq_len], (0, MAX_LEN - seq_len), mode='constant')
    row['time_tokens_2048'] = np.pad(row['time_tokens_2048'][:seq_len], (0, MAX_LEN - seq_len), mode='constant')
    row['visit_tokens_2048'] = np.pad(row['visit_tokens_2048'][:seq_len], (0, MAX_LEN - seq_len), mode='constant')
    row['position_tokens_2048'] = np.pad(row['position_tokens_2048'][:seq_len], (0, MAX_LEN - seq_len), mode='constant')
    return row

def check_readmission_label(row: pd.Series) -> int:
    """Check if the label indicates readmission within one month.

    Args:
        row (pd.Series): The input row.

    Returns:
        bool: True if readmission label is present, False otherwise.
    """
    last_vs_index = row['last_VS_index']
    return int(row['event_tokens_2048'][last_vs_index - 1] in ('[W_0]', '[W_1]', '[W_2]', '[W_3]', '[M_1]'))

def remove_last_visit(row: pd.Series) -> pd.Series:
    """ Remove the event tokens of last visit in the row

    Args:
        row (pd.Series): The input row.

    Returns:
        pd.Series: The preprocessed row.
    """
    last_vs_index = row['last_VS_index']
    return row['event_tokens_2048'][:last_vs_index - 1]

# Process the dataset for hospital readmission in one month task
dataset_2048_readmission = filter_by_num_visit(dataset_2048, minimum_num_visits=2)
dataset_2048_readmission = process_readmission_dataset(dataset_2048_readmission)

In [None]:
def split_dataset_train_test_finetune_datasets(
        dataset: pd.DataFrame,
        label_col: str,
        cv_size: int,
        test_size: int,
        finetune_size: List[int],
        num_splits: int,
        save_path: str
) -> Dict[str, Dict[str, List[str]]]:
    """
    Splits the dataset into training and cross-finetuneation sets using k-fold cross-finetuneation
    while ensuring balanced label distribution in each fold. Saves the resulting dictionary to disk.

    Args:
        dataset (pd.DataFrame): The input dataset.
        label_col (str): The name of the column containing the labels.
        cv_size (int): The number of patients in each cross-finetuneation split.
        test_size (int): The number of patients in the test set.
        finetune_size (List[int]): The number of patients in each fine-tune set
        num_splits (int): The number of splits to create (k value).
        save_path (str): The path to save the resulting dictionary.

    Returns:
        Dict[str, Dict[str, List[str]]]: A dictionary containing patient IDs for each split group.
    """

    # Dictionary to hold patient IDs for different sets
    patient_ids_dict = {'pretrain': [], 'finetune': {'few_shot': {}, 'kfold': {}}, 'test': []}

    # Sample test patients and remove them from dataset
    if test_size > 0:
        test_patients = dataset.sample(n=test_size, random_state=23)
        dataset.drop(test_patients.index, inplace=True)
        patient_ids_dict['test'] = test_patients['patient_id'].tolist()

    # Any remaining data is used for pretraining
    patient_ids_dict['pretrain'] = dataset['patient_id'].tolist()
    random.shuffle(patient_ids_dict['pretrain'])

    # few_shot finetune dataset
    for each_finetune_size in finetune_size:

        # For the time being
        test_patients = dataset.sample(n=each_finetune_size, random_state=23)
        test_patients = test_patients['patient_id'].tolist()
        random.shuffle(test_patients)
        patient_ids_dict['finetune']['few_shot'][f'{each_finetune_size}'] = test_patients

        subset_size = each_finetune_size // 2

        # Sampling positive and negative patients
        pos_patients = dataset[dataset[label_col] == True].sample(n=subset_size, random_state=23)
        neg_patients = dataset[dataset[label_col] == False].sample(n=subset_size, random_state=23)

        # Extracting patient IDs
        pos_patients_ids = pos_patients['patient_id'].tolist()
        neg_patients_ids = neg_patients['patient_id'].tolist()

        # Combining and shuffling patient IDs
        finetune_patients = pos_patients_ids + neg_patients_ids
        random.shuffle(finetune_patients)
        patient_ids_dict['finetune']['few_shot'][f'{each_finetune_size}'] = finetune_patients

    # Performing stratified k-fold split
    # skf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=23)

    # for i, (train_index, cv_index) in enumerate(skf.split(dataset, dataset[label_col])):

    #     dataset_cv = dataset.iloc[cv_index]
    #     dataset_finetune = dataset.iloc[train_index]

    #     # Separate positive and negative labeled patients
    #     pos_patients = dataset_cv[dataset_cv[label_col] == True]['patient_id'].tolist()
    #     neg_patients = dataset_cv[dataset_cv[label_col] == False]['patient_id'].tolist()

    #     # Calculate the number of positive and negative patients needed for balanced CV set
    #     num_pos_needed = cv_size // 2
    #     num_neg_needed = cv_size // 2

    #     # Select positive and negative patients for CV set ensuring balanced distribution
    #     cv_patients = pos_patients[:num_pos_needed] + neg_patients[:num_neg_needed]
    #     remaining_finetune_patients = pos_patients[num_pos_needed:] + neg_patients[num_neg_needed:]

    #     # Extract patient IDs for training set
    #     finetune_patients = dataset_finetune['patient_id'].tolist()
    #     finetune_patients += remaining_finetune_patients

    #     # Shuffle each list of patients
    #     random.shuffle(cv_patients)
    #     random.shuffle(finetune_patients)

    #     patient_ids_dict['finetune']['kfold'][f'group{i+1}'] = {'finetune': finetune_patients, 'cv': cv_patients}

    # Save the dictionary to disk
    with open(save_path, 'wb') as f:
        pickle.dump(patient_ids_dict, f)

    return patient_ids_dict

patient_ids_dict = split_dataset_train_test_finetune_datasets(
    dataset=dataset_2048_condition.copy(),
    label_col='all_conditions',
    cv_size=4000,
    test_size=20000,    # for mortality used to be 20000, then 15000, then 20000 | readmission 10000 then, now 8000
    finetune_size=[250, 1000, 5000, 20000, 50000],  # [250, 500, 1000, 5000, 20000] [250, 1000, 5000, 20000, 50000]
    num_splits=5,
    save_path='patient_id_dict/dataset_2048_condition.pkl'
)

In [None]:
# dataset_2048_mortality.to_parquet('patient_sequences/patient_sequences_2048_mortality.parquet')
# dataset_2048_readmission.to_parquet('patient_sequences/patient_sequences_2048_readmission.parquet')
dataset_2048_condition.to_parquet('patient_sequences/patient_sequences_2048_condition.parquet')

In [None]:
# Assuming dataset.event_tokens is your DataFrame column
dataset.event_tokens.transform(len).plot(kind='hist', bins=100)
plt.xlim(1000, 8000)  # Limit x-axis to 5000
plt.ylim(0, 6000)
plt.xlabel('Length of Event Tokens')
plt.ylabel('Frequency')
plt.title('Histogram of Event Tokens Length')
plt.show()

In [None]:
# len(patient_ids_dict['group3']['cv'])

# dataset_2048.loc[dataset_2048['patient_id'].isin(patient_ids_dict['group1']['cv'])]['label_mortality_1month']

# s = set()
# for i in range(1, 6):
#     s = s.union(set(patient_ids_dict[f'group{i}']['cv']))
#
# len(s)

In [None]:
##### DEAD ZONE | DO NOT ENTER #####

# patient_ids = pickle.load(open(join("/h/afallah/odyssey/odyssey/data/bigbird_data", 'dataset_2048_mortality_1month.pkl'), 'rb'))
# patient_ids['finetune']['few_shot'].keys()

# patient_ids2 = pickle.load(open(join("/h/afallah/odyssey/odyssey/data/bigbird_data", 'dataset_2048_mortality_2weeks.pkl'), 'rb'))['pretrain']
#
# patient_ids1.sort()
# patient_ids2.sort()
#
# patient_ids1 == patient_ids2
# # dataset_2048.loc[dataset_2048['patient_id'].isin(patient_ids['pretrain'])]

In [None]:
# dataset_2048_readmission = dataset_2048.loc[dataset_2048['num_visits'] > 1]
# dataset_2048_readmission.reset_index(drop=True, inplace=True)
#
# dataset_2048_readmission['last_VS_index'] = dataset_2048_readmission['event_tokens_2048'].transform(lambda seq: get_last_index(list(seq), '[VS]'))
#
# dataset_2048_readmission['label_readmission_1month'] = dataset_2048_readmission.apply(
#     lambda row: row['event_tokens_2048'][row['last_VS_index'] - 1] in ('[W_0]', '[W_1]', '[W_2]', '[W_3]', '[M_1]'), axis=1
# )
# dataset_2048_readmission['event_tokens_2048'] = dataset_2048_readmission.apply(
#     lambda row: row['event_tokens_2048'][:row['last_VS_index'] - 1], axis=1
# )
# dataset_2048_readmission.drop(['deceased', 'death_after_start', 'death_after_end', 'length'], axis=1, inplace=True)
# dataset_2048_readmission['num_visits'] -= 1
# dataset_2048_readmission['token_length'] = dataset_2048_readmission['event_tokens_2048'].apply(len)
# dataset_2048_readmission = dataset_2048_readmission.apply(lambda row: truncate_and_pad(row), axis=1)
# dataset_2048_readmission['event_tokens_2048'] = dataset_2048_readmission['event_tokens_2048'].transform(
#     lambda token_list: ' '.join(token_list)
# )
#
# dataset_2048_readmission