In [1]:
import random
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import KFold, StratifiedKFold

from typing import Dict, Any, List
from os.path import join

random.seed(23)

In [3]:
dataset = pd.read_parquet("patient_sequences.parquet")
dataset

Unnamed: 0,patient_id,num_visits,deceased,death_after_start,death_after_end,length,token_length,event_tokens,type_tokens,age_tokens,time_tokens,visit_tokens,position_tokens
0,f8f3289a-057f-5fcc-a714-5f6109ca16c4,2,0,,,5,4,"[[CLS], [VS], 8938, [VE], [REG]]","[1, 2, 7, 3, 8]","[0, 18, 18, 18, 18]","[0, 8262, 8262, 8262, 8262]","[0, 2, 2, 2, 2]","[0, 2, 2, 2, 2]"
1,9b62c9f4-3fdc-5020-82b5-ae5b8292445a,4,0,,,55,52,"[[CLS], [VS], 7569, 66689036430, 00904224461, ...","[1, 2, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 3, 8, ...","[0, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28...","[0, 5963, 5963, 5963, 5963, 5963, 5963, 5963, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ..."
2,2ca522eb-dd89-5f79-8155-9599ea46b0b2,2,1,244.0,242.0,55,54,"[[CLS], [VS], 00904629261, 00904642281, 009046...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86...","[0, 8016, 8016, 8016, 8016, 8016, 8016, 8016, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ..."
3,54ee4964-056b-5c38-a607-b95e63176fc3,19,1,13.0,0.0,1882,1864,"[[CLS], [VS], 63323026201, 00603385521, 005970...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71...","[0, 5520, 5520, 5520, 5520, 5520, 5520, 5520, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ..."
4,02adf8a6-8bc0-55d3-81ae-4d8582094896,9,1,20.0,11.0,672,664,"[[CLS], [VS], 51079045420, 00006494300, 177140...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, ...","[0, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65...","[0, 8002, 8002, 8002, 8002, 8002, 8002, 8002, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...
173666,cf2115d7-937e-511d-b159-dd7eb3d5d420,3,0,,,174,172,"[[CLS], [VS], 66591018442, 63323026201, 001350...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32...","[0, 5481, 5481, 5481, 5481, 5481, 5481, 5481, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ..."
173667,31338a39-28f9-54a5-a810-2d05fbaa5166,4,0,,,295,292,"[[CLS], [VS], 5014, 5123, 00338011704, 6332302...","[1, 2, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50...","[0, 4997, 4997, 4997, 4997, 4997, 4997, 4997, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ..."
173668,0989415d-394c-5f42-8dac-75dc7306a23c,3,0,,,478,476,"[[CLS], [VS], 51079043620, 51079088120, 492810...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 3, 8, 4, 2, 7, 7, ...","[0, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 0,...","[0, 5309, 5309, 5309, 5309, 5309, 5309, 5309, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 1, 1, 1, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, ..."
173669,26fb8fef-b976-5c55-859d-cc190261f94b,3,0,,,101,99,"[[CLS], [VS], 0SRD0J9, 60505251903, 0090422446...","[1, 2, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61...","[0, 5085, 5085, 5085, 5085, 5085, 5085, 5085, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ..."


In [None]:
dataset_2048 = pd.read_parquet("patient_sequences_2048.parquet")
# dataset_2048.sort_values(by='patient_id', inplace=True)
dataset_2048.drop(
    [
        "event_tokens",
        "type_tokens",
        "age_tokens",
        "time_tokens",
        "visit_tokens",
        "position_tokens",
    ],
    axis=1,
    inplace=True,
)

dataset_2048["event_tokens_2048"] = dataset_2048["event_tokens_2048"].transform(
    lambda token_list: " ".join(token_list)
)
dataset_2048["label_mortality_2weeks"] = (
    (dataset_2048["death_after_start"] >= 0) & (dataset_2048["death_after_end"] <= 15)
).astype(int)
dataset_2048["label_mortality_1month"] = (
    (dataset_2048["death_after_start"] >= 0) & (dataset_2048["death_after_end"] <= 32)
).astype(int)

dataset_2048

In [2]:
dataset_2048 = pd.read_parquet("patient_sequences_2048_labeled.parquet")
dataset_2048

Unnamed: 0,patient_id,num_visits,deceased,death_after_start,death_after_end,length,token_length,event_tokens_2048,type_tokens_2048,age_tokens_2048,time_tokens_2048,visit_tokens_2048,position_tokens_2048,label_mortality_2weeks,label_mortality_1month
1,9b62c9f4-3fdc-5020-82b5-ae5b8292445a,4,0,,,55,52,[CLS] [VS] 7569 66689036430 00904224461 665530...,"[1, 2, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 3, 8, ...","[0, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28...","[0, 5963, 5963, 5963, 5963, 5963, 5963, 5963, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",0,0
2,2ca522eb-dd89-5f79-8155-9599ea46b0b2,2,1,244.0,242.0,55,54,[CLS] [VS] 00904629261 00904642281 00904652261...,"[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86...","[0, 8016, 8016, 8016, 8016, 8016, 8016, 8016, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",0,0
3,54ee4964-056b-5c38-a607-b95e63176fc3,19,1,13.0,0.0,1882,1864,[CLS] [VS] 63323026201 00603385521 00597007575...,"[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71...","[0, 5520, 5520, 5520, 5520, 5520, 5520, 5520, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",1,1
4,02adf8a6-8bc0-55d3-81ae-4d8582094896,9,1,20.0,11.0,672,664,[CLS] [VS] 51079045420 00006494300 17714001110...,"[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, ...","[0, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65...","[0, 8002, 8002, 8002, 8002, 8002, 8002, 8002, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",1,1
5,744fe3c4-9b03-55ae-ac9f-6bc4e967cde7,3,0,,,88,86,[CLS] [VS] 7813 7813 7902 7902 9604 0053633810...,"[1, 2, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29...","[0, 7582, 7582, 7582, 7582, 7582, 7582, 7582, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
173666,cf2115d7-937e-511d-b159-dd7eb3d5d420,3,0,,,174,172,[CLS] [VS] 66591018442 63323026201 00135019502...,"[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32...","[0, 5481, 5481, 5481, 5481, 5481, 5481, 5481, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",0,0
173667,31338a39-28f9-54a5-a810-2d05fbaa5166,4,0,,,295,292,[CLS] [VS] 5014 5123 00338011704 63323026201 0...,"[1, 2, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50...","[0, 4997, 4997, 4997, 4997, 4997, 4997, 4997, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",0,0
173668,0989415d-394c-5f42-8dac-75dc7306a23c,3,0,,,478,476,[CLS] [VS] 51079043620 51079088120 49281041550...,"[1, 2, 6, 6, 6, 6, 6, 6, 6, 3, 8, 4, 2, 7, 7, ...","[0, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 0,...","[0, 5309, 5309, 5309, 5309, 5309, 5309, 5309, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 1, 1, 1, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, ...",0,0
173669,26fb8fef-b976-5c55-859d-cc190261f94b,3,0,,,101,99,[CLS] [VS] 0SRD0J9 60505251903 00904224461 006...,"[1, 2, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61...","[0, 5085, 5085, 5085, 5085, 5085, 5085, 5085, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",0,0


In [13]:
set(dataset_2048["num_visits"])

{2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 59,
 60,
 61,
 63,
 64,
 65,
 66,
 67,
 69,
 70,
 75,
 76,
 77,
 78,
 80,
 89,
 90,
 91,
 92}

In [3]:
def split_dataset_train_test_valid_datasets(
    dataset: pd.DataFrame,
    label_col: str,
    cv_size: int,
    test_size: int,
    finetune_size: List[int],
    num_splits: int,
    save_path: str,
) -> Dict[str, Dict[str, List[str]]]:
    """
    Splits the dataset into training and cross-validation sets using k-fold cross-validation while ensuring balanced label distribution in each fold.
    Saves the resulting dictionary to disk.

    Parameters:
        dataset (pd.DataFrame): The input dataset.
        label_col (str): The name of the column containing the labels.
        cv_size (int): The number of patients in each cross-validation split.
        test_size (int): The number of patients in the test set.
        finetune_size (List[int]): The number of patients in each fine-tune set
        num_splits (int): The number of splits to create (k value).
        save_path (str): The path to save the resulting dictionary.

    Returns:
        Dict[str, Dict[str, List[str]]]: A dictionary containing patient IDs for each split group.
    """

    # Dictionary to hold patient IDs for different sets
    patient_ids_dict = {
        "pretrain": [],
        "valid": {"few_shot": {}, "kfold": {}},
        "test": [],
    }

    # Sample test patients and remove them from dataset
    if test_size > 0:
        test_patients = dataset.sample(n=test_size, random_state=23)
        dataset.drop(test_patients.index, inplace=True)
        patient_ids_dict["test"] = test_patients["patient_id"].tolist()

    # Any remaining data is used for pretraining
    patient_ids_dict["pretrain"] = dataset["patient_id"].tolist()
    random.shuffle(patient_ids_dict["pretrain"])

    # few_shot finetune dataset
    for each_finetune_size in finetune_size:
        subset_size = each_finetune_size // 2

        # Sampling positive and negative patients
        pos_patients = dataset[dataset[label_col] == True].sample(
            n=subset_size, random_state=23
        )
        neg_patients = dataset[dataset[label_col] == False].sample(
            n=subset_size, random_state=23
        )

        # Extracting patient IDs
        pos_patients_ids = pos_patients["patient_id"].tolist()
        neg_patients_ids = neg_patients["patient_id"].tolist()

        # Combining and shuffling patient IDs
        finetune_patients = pos_patients_ids + neg_patients_ids
        random.shuffle(finetune_patients)
        patient_ids_dict["valid"]["few_shot"][f"{each_finetune_size}_patients"] = (
            finetune_patients
        )

    # Performing stratified k-fold split
    skf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=23)

    for i, (train_index, cv_index) in enumerate(skf.split(dataset, dataset[label_col])):
        dataset_cv = dataset.iloc[cv_index]
        dataset_finetune = dataset.iloc[train_index]

        # Separate positive and negative labeled patients
        pos_patients = dataset_cv[dataset_cv[label_col] == True]["patient_id"].tolist()
        neg_patients = dataset_cv[dataset_cv[label_col] == False]["patient_id"].tolist()

        # Calculate the number of positive and negative patients needed for balanced CV set
        num_pos_needed = cv_size // 2
        num_neg_needed = cv_size // 2

        # Select positive and negative patients for CV set ensuring balanced distribution
        cv_patients = pos_patients[:num_pos_needed] + neg_patients[:num_neg_needed]
        remaining_finetune_patients = (
            pos_patients[num_pos_needed:] + neg_patients[num_neg_needed:]
        )

        # Extract patient IDs for training set
        finetune_patients = dataset_finetune["patient_id"].tolist()
        finetune_patients += remaining_finetune_patients

        # Shuffle each list of patients
        random.shuffle(cv_patients)
        random.shuffle(finetune_patients)

        patient_ids_dict["valid"]["kfold"][f"group{i+1}"] = {
            "finetune": finetune_patients,
            "cv": cv_patients,
        }

    # Save the dictionary to disk
    with open(save_path, "wb") as f:
        pickle.dump(patient_ids_dict, f)

    return patient_ids_dict


patient_ids_dict = split_dataset_train_test_valid_datasets(
    dataset=dataset_2048,
    label_col="label_mortality_1month",
    cv_size=4000,
    test_size=20000,
    finetune_size=[100, 500, 1000, 5000, 20000],
    num_splits=5,
    save_path="dataset_2048_mortality_1month.pkl",
)

In [None]:
# len(patient_ids_dict['group3']['cv'])

# dataset_2048.loc[dataset_2048['patient_id'].isin(patient_ids_dict['group1']['cv'])]['label_mortality_1month']

# s = set()
# for i in range(1, 6):
#     s = s.union(set(patient_ids_dict[f'group{i}']['cv']))
#
# len(s)

In [None]:
dataset_2048.to_parquet("patient_sequences_2048_labeled.parquet")

In [None]:
# Assuming dataset.event_tokens is your DataFrame column
dataset.event_tokens.transform(len).plot(kind="hist", bins=100)
plt.xlim(1000, 8000)  # Limit x-axis to 5000
plt.ylim(0, 6000)
plt.xlabel("Length of Event Tokens")
plt.ylabel("Frequency")
plt.title("Histogram of Event Tokens Length")
plt.show()

In [None]:
sum(dataset.event_tokens.transform(len) > 2048)

In [None]:
##### DEAD ZONE | DO NOT ENTER #####

# patient_ids = pickle.load(open(join("/h/afallah/odyssey/odyssey/data/bigbird_data", 'dataset_2048_mortality_1month.pkl'), 'rb'))
# patient_ids['valid']['few_shot'].keys()

# patient_ids2 = pickle.load(open(join("/h/afallah/odyssey/odyssey/data/bigbird_data", 'dataset_2048_mortality_2weeks.pkl'), 'rb'))['pretrain']
#
# patient_ids1.sort()
# patient_ids2.sort()
#
# patient_ids1 == patient_ids2
# # dataset_2048.loc[dataset_2048['patient_id'].isin(patient_ids['pretrain'])]