In [2]:
import random
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import KFold, StratifiedKFold

from typing import Dict, Any, List
from os.path import join

random.seed(23)

In [12]:
patient_ids = pickle.load(open(join("/h/afallah/odyssey/odyssey/data/bigbird_data", 'dataset_2048_mortality_1month.pkl'), 'rb'))
patient_ids['valid']['few_shot'].keys()

dict_keys(['100_patients', '500_patients', '1000_patients', '5000_patients', '20000_patients'])

In [7]:
# patient_ids2 = pickle.load(open(join("/h/afallah/odyssey/odyssey/data/bigbird_data", 'dataset_2048_mortality_2weeks.pkl'), 'rb'))['pretrain']
#
# patient_ids1.sort()
# patient_ids2.sort()
#
# patient_ids1 == patient_ids2
# # dataset_2048.loc[dataset_2048['patient_id'].isin(patient_ids['pretrain'])]

True

In [None]:
dataset_2048 = pd.read_parquet('patient_sequences_2048.parquet')
# dataset_2048.sort_values(by='patient_id', inplace=True)
dataset_2048.drop(['event_tokens', 'type_tokens', 'age_tokens', 'time_tokens', 'visit_tokens', 'position_tokens'], axis=1, inplace=True)

dataset_2048['event_tokens_2048'] = dataset_2048['event_tokens_2048'].transform(lambda token_list: ' '.join(token_list))
dataset_2048['label_mortality_2weeks'] = ((dataset_2048['death_after_start'] >= 0) & (dataset_2048['death_after_end'] <= 15)).astype(int)
dataset_2048['label_mortality_1month'] = ((dataset_2048['death_after_start'] >= 0) & (dataset_2048['death_after_end'] <= 32)).astype(int)

dataset_2048

In [None]:
dataset_2048 = pd.read_parquet('patient_sequences_2048_labeled.parquet')
dataset_2048

In [3]:
def split_dataset_train_test_valid_datasets(
        dataset: pd.DataFrame,
        label_col: str,
        cv_size: int,
        test_size: int,
        finetune_size: List[int],
        num_splits: int,
        save_path: str) -> Dict[str, Dict[str, List[str]]]:
    """
    Splits the dataset into training and cross-validation sets using k-fold cross-validation while ensuring balanced label distribution in each fold.
    Saves the resulting dictionary to disk.

    Parameters:
        dataset (pd.DataFrame): The input dataset.
        label_col (str): The name of the column containing the labels.
        cv_size (int): The number of patients in each cross-validation split.
        test_size (int): The number of patients in the test set.
        finetune_size (List[int]): The number of patients in each fine-tune set
        num_splits (int): The number of splits to create (k value).
        save_path (str): The path to save the resulting dictionary.

    Returns:
        Dict[str, Dict[str, List[str]]]: A dictionary containing patient IDs for each split group.
    """

    # Dictionary to hold patient IDs for different sets
    patient_ids_dict = {'pretrain': [], 'valid': {'few_shot': {}, 'kfold': {}}, 'test': []}

    # Sample test patients and remove them from dataset
    if test_size > 0:
        test_patients = dataset.sample(n=test_size, random_state=23)
        dataset.drop(test_patients.index, inplace=True)
        patient_ids_dict['test'] = test_patients['patient_id'].tolist()

    # Any remaining data is used for pretraining
    patient_ids_dict['pretrain'] = dataset['patient_id'].tolist()
    random.shuffle(patient_ids_dict['pretrain'])

    # few_shot finetune dataset
    for each_finetune_size in finetune_size:
        subset_size = each_finetune_size // 2

        # Sampling positive and negative patients
        pos_patients = dataset[dataset[label_col] == True].sample(n=subset_size, random_state=23)
        neg_patients = dataset[dataset[label_col] == False].sample(n=subset_size, random_state=23)

        # Extracting patient IDs
        pos_patients_ids = pos_patients['patient_id'].tolist()
        neg_patients_ids = neg_patients['patient_id'].tolist()

        # Combining and shuffling patient IDs
        finetune_patients = pos_patients_ids + neg_patients_ids
        random.shuffle(finetune_patients)
        patient_ids_dict['valid']['few_shot'][f'{each_finetune_size}_patients'] = finetune_patients

    # Performing stratified k-fold split
    skf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=23)

    for i, (train_index, cv_index) in enumerate(skf.split(dataset, dataset[label_col])):

        dataset_cv = dataset.iloc[cv_index]
        dataset_finetune = dataset.iloc[train_index]

        # Separate positive and negative labeled patients
        pos_patients = dataset_cv[dataset_cv[label_col] == True]['patient_id'].tolist()
        neg_patients = dataset_cv[dataset_cv[label_col] == False]['patient_id'].tolist()

        # Calculate the number of positive and negative patients needed for balanced CV set
        num_pos_needed = cv_size // 2
        num_neg_needed = cv_size // 2

        # Select positive and negative patients for CV set ensuring balanced distribution
        cv_patients = pos_patients[:num_pos_needed] + neg_patients[:num_neg_needed]
        remaining_finetune_patients = pos_patients[num_pos_needed:] + neg_patients[num_neg_needed:]

        # Extract patient IDs for training set
        finetune_patients = dataset_finetune['patient_id'].tolist()
        finetune_patients += remaining_finetune_patients

        # Shuffle each list of patients
        random.shuffle(cv_patients)
        random.shuffle(finetune_patients)

        patient_ids_dict['valid']['kfold'][f'group{i+1}'] = {'finetune': finetune_patients, 'cv': cv_patients}

    # Save the dictionary to disk
    with open(save_path, 'wb') as f:
        pickle.dump(patient_ids_dict, f)

    return patient_ids_dict


patient_ids_dict = split_dataset_train_test_valid_datasets(
                                               dataset=dataset_2048,
                                               label_col='label_mortality_1month',
                                               cv_size=4000,
                                               test_size=20000,
                                               finetune_size=[100, 500, 1000, 5000, 20000],
                                               num_splits=5,
                                               save_path='dataset_2048_mortality_1month.pkl')

In [None]:
# len(patient_ids_dict['group3']['cv'])

# dataset_2048.loc[dataset_2048['patient_id'].isin(patient_ids_dict['group1']['cv'])]['label_mortality_1month']

# s = set()
# for i in range(1, 6):
#     s = s.union(set(patient_ids_dict[f'group{i}']['cv']))
#
# len(s)

In [None]:
dataset_2048.to_parquet('patient_sequences_2048_labeled.parquet')

In [None]:
# Assuming dataset.event_tokens is your DataFrame column
dataset.event_tokens.transform(len).plot(kind='hist', bins=100)
plt.xlim(1000, 8000)  # Limit x-axis to 5000
plt.ylim(0, 6000)
plt.xlabel('Length of Event Tokens')
plt.ylabel('Frequency')
plt.title('Histogram of Event Tokens Length')
plt.show()

In [None]:
sum(dataset.event_tokens.transform(len) > 2048)