In [11]:
import os
import sys
import pickle
import random
from typing import Any, Dict, List, Optional

import numpy as np
import pandas as pd

ROOT = "/h/afallah/odyssey/odyssey"
DATA_ROOT = f"{ROOT}/odyssey/data/bigbird_data"
DATASET = f"{DATA_ROOT}/patient_sequences/patient_sequences_2048.parquet"
MAX_LEN = 2048

os.chdir(ROOT)

from odyssey.utils.utils import save_object_to_disk, seed_everything
from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.data.dataset import FinetuneMultiDataset
from odyssey.data.processor import (
    filter_by_num_visit,
    filter_by_length_of_stay,
    get_last_occurence_index,
    check_readmission_label,
    get_length_of_stay,
    get_visit_cutoff_at_threshold,
    process_length_of_stay_dataset,
    process_condition_dataset,
    process_mortality_dataset,
    process_readmission_dataset,
    process_multi_dataset,
    stratified_train_test_split,
    sample_balanced_subset,
    get_pretrain_test_split,
    get_finetune_split,
)

SEED = 23
seed_everything(seed=SEED)

[rank: 0] Seed set to 23


In [2]:
# Load complete dataset
dataset_2048 = pd.read_parquet(DATASET)
dataset_2048['num_visits'] = dataset_2048['event_tokens_2048'].transform(lambda series: list(series).count('[VS]'))

print(f"Current columns: {dataset_2048.columns}")
dataset_2048.head()

Current columns: Index(['patient_id', 'num_visits', 'deceased', 'death_after_start',
       'death_after_end', 'length', 'token_length', 'event_tokens_2048',
       'type_tokens_2048', 'age_tokens_2048', 'time_tokens_2048',
       'visit_tokens_2048', 'position_tokens_2048', 'elapsed_tokens_2048',
       'common_conditions', 'rare_conditions'],
      dtype='object')


Unnamed: 0,patient_id,num_visits,deceased,death_after_start,death_after_end,length,token_length,event_tokens_2048,type_tokens_2048,age_tokens_2048,time_tokens_2048,visit_tokens_2048,position_tokens_2048,elapsed_tokens_2048,common_conditions,rare_conditions
0,35581927-9c95-5ae9-af76-7d74870a349c,1,0,,,50,54,"[[CLS], [VS], 00006473900, 00904516561, 510790...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, ...","[0, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85...","[0, 5902, 5902, 5902, 5902, 5902, 5902, 5902, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 1.97, 2.02, 2.02, 2.02, 2.02, 2.0...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
1,f5bba8dd-25c0-5336-8d3d-37424c185026,2,0,,,148,156,"[[CLS], [VS], 52135_2, 52075_2, 52074_2, 52073...","[1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...","[0, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83...","[0, 6594, 6594, 6594, 6594, 6594, 6594, 6594, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0...","[0, 0, 0, 0, 0, 0, 0, 1, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
2,f4938f91-cadb-5133-8541-a52fb0916cea,2,0,,,78,86,"[[CLS], [VS], 0RB30ZZ, 0RG10A0, 00071101441, 0...","[1, 2, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44...","[0, 8150, 8150, 8150, 8150, 8150, 8150, 8150, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.0, 0.0, 1.08, 1.08, 13.89, 13.8...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
3,6fe2371b-a6f0-5436-aade-7795005b0c66,2,0,,,86,94,"[[CLS], [VS], 63739057310, 49281041688, 005970...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72...","[0, 6093, 6093, 6093, 6093, 6093, 6093, 6093, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.75, 0.75, 0.75, 0.75, 0.75, 0.7...","[1, 0, 0, 0, 0, 0, 0, 1, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
4,6f7590ae-f3b9-50e5-9e41-d4bb1000887a,1,0,,,72,76,"[[CLS], [VS], 50813_0, 52135_0, 52075_3, 52074...","[1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...","[0, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47...","[0, 6379, 6379, 6379, 6379, 6379, 6379, 6379, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 1]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"


In [3]:
# Process the dataset for length of stay prediction above a threshold
dataset_2048_los = process_length_of_stay_dataset(
    dataset_2048.copy(), threshold=7, max_len=MAX_LEN
)

In [4]:
# Process the dataset for conditions including rare and common
dataset_2048_condition = process_condition_dataset(dataset_2048.copy())

In [5]:
# Process the dataset for mortality in two weeks or one month task
dataset_2048_mortality = process_mortality_dataset(dataset_2048.copy())

In [6]:
# Process the dataset for hospital readmission in one month task
dataset_2048_readmission = process_readmission_dataset(
    dataset_2048.copy(), max_len=MAX_LEN
)

In [7]:
# Process the multi dataset
multi_dataset = process_multi_dataset(
    datasets={
        "original": dataset_2048,
        "mortality": dataset_2048_mortality,
        "condition": dataset_2048_condition,
        "readmission": dataset_2048_readmission,
        "los": dataset_2048_los,
    },
    max_len=MAX_LEN,
)

In [8]:
# Split data
patient_ids_dict = {
    "pretrain": [],
    "finetune": {"few_shot": {}, "kfold": {}},
    "test": [],
}

# Get train-test split
# pretrain_ids, test_ids = get_pretrain_test_split(dataset_2048_readmission, stratify_target='label_readmission_1month', test_size=0.2)
# pretrain_ids, test_ids = get_pretrain_test_split(process_condition_dataset, stratify_target='all_conditions', test_size=0.15)
# patient_ids_dict['pretrain'] = pretrain_ids
# patient_ids_dict['test'] = test_ids

# Load pretrain and test patient IDs
pid = pickle.load(open(f"{DATA_ROOT}/patient_id_dict/dataset_2048_multi.pkl", "rb"))
patient_ids_dict["pretrain"] = pid["pretrain"]
patient_ids_dict["test"] = pid["test"]
set(pid["pretrain"] + pid["test"]) == set(dataset_2048["patient_id"])

True

In [9]:
multi_dataset_pretrain = multi_dataset.loc[multi_dataset['patient_id'].isin(patient_ids_dict["pretrain"])]
multi_dataset_test = multi_dataset.loc[multi_dataset['patient_id'].isin(patient_ids_dict["test"])]

In [10]:
# DataFrame assumed loaded as multi_dataset_pretrain
# Define the requirements for each label
label_requirements = {
    'label_mortality_1month': 25000,
    'label_readmission_1month': 30000,
    'label_los_1week': 40000,
    # 'label_c0': 25000,
    # 'label_c1': 25000,
    # 'label_c2': 25000
}

# Prepare a dictionary to store indices for each label and category
selected_indices = {label: {'0': set(), '1': set()} for label in label_requirements}

# Initialize a dictionary to track usage of indices across labels
index_usage = {}

# Function to select indices while maximizing overlap
def select_indices(label, num_required, category):
    # Candidates are those indices matching the category requirement
    candidates = set(multi_dataset_pretrain[multi_dataset_pretrain[label] == category].index)
    # Prefer candidates that are already used elsewhere to maximize overlap
    preferred = candidates & set(index_usage.keys())
    additional = list(candidates - preferred)
    np.random.shuffle(additional)  # Shuffle to avoid any unintended order bias

    # Determine how many more are needed
    needed = num_required - len(selected_indices[label][str(category)] & candidates)
    if needed > 0:
        # Select as many as possible from preferred, then from additional
        selected = list(preferred - selected_indices[label][str(category)])[:needed]
        selected += additional[:needed - len(selected)]
        # Update the selected indices for this label and category
        selected_indices[label][str(category)].update(selected)
        # Update overall index usage
        for idx in selected:
            index_usage[idx] = index_usage.get(idx, 0) + 1

# Process each label and category
for label in label_requirements:
    num_required = label_requirements[label] // 2  # Divide by 2 for 50-50 distribution
    select_indices(label, num_required, 0)
    select_indices(label, num_required, 1)

# Combine all selected indices from both categories
all_selected_indices = set()
for indices in selected_indices.values():
    all_selected_indices.update(indices['0'])
    all_selected_indices.update(indices['1'])

# Create the balanced DataFrame
multi_dataset_finetune = multi_dataset_pretrain.loc[list(all_selected_indices)]
multi_dataset_finetune

Unnamed: 0,patient_id,num_visits,event_tokens_2048,type_tokens_2048,age_tokens_2048,time_tokens_2048,visit_tokens_2048,position_tokens_2048,elapsed_tokens_2048,cutoff_los_1week,...,label_c10,label_c11,label_c12,label_c13,label_c14,label_c15,label_c16,label_c17,label_c18,label_c19
2,f4938f91-cadb-5133-8541-a52fb0916cea,2,"[[CLS], [VS], 0RB30ZZ, 0RG10A0, 00071101441, 0...","[1, 2, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44...","[0, 8150, 8150, 8150, 8150, 8150, 8150, 8150, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.0, 0.0, 1.08, 1.08, 13.89, 13.8...",86,...,0,0,0,0,0,0,0,0,0,0
131078,778ba11a-0549-5d3c-af88-506284980a37,2,"[[CLS], [VS], 58177032304, 00182844789, 50910_...","[1, 2, 6, 6, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, ...","[0, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35...","[0, 7251, 7251, 7251, 7251, 7251, 7251, 7251, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 2.74, 2.74, 9.73, 9.73, 9.73, 9.7...",108,...,0,0,0,0,0,0,0,0,0,0
7,e0bb4cea-1ae3-5716-baa0-b93d56001be8,4,"[[CLS], [VS], 51006_0, 50983_0, 50971_1, 50946...","[1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...","[0, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28...","[0, 8342, 8342, 8342, 8342, 8342, 8342, 8342, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0...",175,...,0,0,0,0,0,0,0,0,0,0
131080,eb87674f-7491-504c-87a2-0bb00855dfdb,5,"[[CLS], [VS], 596, 7078, 7050, 7095, 581600873...","[1, 2, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85...","[0, 8188, 8188, 8188, 8188, 8188, 8188, 8188, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 5.16, 5.32, 5...",344,...,0,0,0,0,0,0,0,0,0,0
131083,a5f91b16-c05f-5f4d-97f4-c86a4854ff5f,7,"[[CLS], [VS], 0F798DZ, BF10YZZ, 51768_1, 51765...","[1, 2, 7, 7, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...","[0, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78...","[0, 8600, 8600, 8600, 8600, 8600, 8600, 8600, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.0, 0.0, 0.27, 0.27, 0.27, 0.27,...",1054,...,0,1,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
131064,5d00ae3c-6ea4-5836-bb15-95e59c694392,1,"[[CLS], [VS], 5185, 5188, 00172567310, 0078157...","[1, 2, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80...","[0, 7040, 7040, 7040, 7040, 7040, 7040, 7040, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.0, 0.0, 3.01, 3.01, 3.01, 3.01,...",16,...,0,0,0,0,0,0,0,0,0,0
131065,d6e81808-dccb-5c9d-a120-a693230da4a4,3,"[[CLS], [VS], 5187, 50813_3, 51146_2, 51200_1,...","[1, 2, 7, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...","[0, 66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 66...","[0, 7687, 7687, 7687, 7687, 7687, 7687, 7687, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.0, 4.63, 11.8, 11.8, 11.8, 11.8...",639,...,0,0,0,0,0,0,0,0,0,0
131066,55f5672c-229b-588e-b67c-476ebe518263,4,"[[CLS], [VS], 00006473900, 33332001201, 633230...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, ...","[0, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91...","[0, 8734, 8734, 8734, 8734, 8734, 8734, 8734, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 1.47, 1.47, 1.51, 1.51, 1.51, 1.6...",-1,...,0,0,0,0,0,0,0,0,0,0
131068,8c62d596-9658-54be-bece-20adc700a80f,5,"[[CLS], [VS], 63323026201, 49281001350, 000080...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 5, ...","[0, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63...","[0, 8469, 8469, 8469, 8469, 8469, 8469, 8469, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-2.0, -1.0, 0.21, 0.21, 0.21, 1.36, 6.06, 6.0...",284,...,0,0,0,0,0,0,0,0,0,0


In [12]:
for label in ['label_mortality_1month', 'label_readmission_1month', 'label_los_1week', 'label_c0', 'label_c1', 'label_c2']:
    print(f'Label: {label} | Mean: {multi_dataset_finetune[label].mean()}\n{multi_dataset_finetune[label].value_counts()}\n')

Label: label_mortality_1month | Mean: 0.2708074272345986
label_mortality_1month
0    34598
1    12849
Name: count, dtype: int64

Label: label_readmission_1month | Mean: -0.0013699496280059856
label_readmission_1month
 0    16168
-1    15672
 1    15607
Name: count, dtype: int64

Label: label_los_1week | Mean: 0.3155099374038401
label_los_1week
 0    22417
 1    20000
-1     5030
Name: count, dtype: int64

Label: label_c0 | Mean: 0.49739709570678864
label_c0
0    23847
1    23600
Name: count, dtype: int64

Label: label_c1 | Mean: 0.49037873838177337
label_c1
0    24180
1    23267
Name: count, dtype: int64

Label: label_c2 | Mean: 0.3007144814213754
label_c2
0    33179
1    14268
Name: count, dtype: int64



In [28]:
patient_ids_dict['finetune']['few_shot']['all'] = multi_dataset_finetune['patient_id'].tolist()

multi_dataset_pretrain = multi_dataset_pretrain.loc[
    ~multi_dataset_pretrain['patient_id'].isin(multi_dataset_finetune['patient_id'])]

patient_ids_dict['pretrain'] = multi_dataset_pretrain['patient_id'].tolist()

save_object_to_disk(patient_ids_dict, f"{DATA_ROOT}/patient_id_dict/dataset_2048_multi_v2.pkl")

# "mortality_1month=0.5, los_1week=0.5, readmission_1month=0.5, c0=0.5, c1=0.5, c2=0.5"

File saved to disk: /h/afallah/odyssey/odyssey/odyssey/data/bigbird_data/patient_id_dict/dataset_2048_multi_v2.pkl


In [None]:
"""
Current Approach:
    - Pretrain: 141234 Patients
    - Test: 24924 Patients, 132682 Datapoints
    - Finetune: 139514 Unique Patients, 434270 Datapoints
        - Mortality: 26962 Patients
        - Readmission: 48898 Patients
        - Length of Stay: 72686 Patients
        - Condition 0: 122722 Patients
        - Condition 1: 94048 Patients
        - Condition 2: 68954 Patients
"""

In [None]:
task_config = {
    "mortality": {
        "dataset": dataset_2048_mortality,
        "label_col": "label_mortality_1month",
        "finetune_size": [250, 500, 1000, 5000, 20000],
        "save_path": "patient_id_dict/dataset_2048_mortality.pkl",
        "split_mode": "single_label_balanced",
    },
    "readmission": {
        "dataset": dataset_2048_readmission,
        "label_col": "label_readmission_1month",
        "finetune_size": [250, 1000, 5000, 20000, 60000],
        "save_path": "patient_id_dict/dataset_2048_readmission.pkl",
        "split_mode": "single_label_stratified",
    },
    "length_of_stay": {
        "dataset": dataset_2048_los,
        "label_col": "label_los_1week",
        "finetune_size": [250, 1000, 5000, 20000, 50000],
        "save_path": "patient_id_dict/dataset_2048_los.pkl",
        "split_mode": "single_label_balanced",
    },
    "condition": {
        "dataset": dataset_2048_condition,
        "label_col": "all_conditions",
        "finetune_size": [50000],
        "save_path": "patient_id_dict/dataset_2048_condition.pkl",
        "split_mode": "multi_label_stratified",
    },
}

In [None]:
# Get finetune split
for task in task_config.keys():
    patient_ids_dict = get_finetune_split(
        task_config=task_config,
        task=task,
        patient_ids_dict=patient_ids_dict,
    )

In [32]:
# dataset_2048_mortality.to_parquet(
#     "patient_sequences/patient_sequences_2048_mortality.parquet",
# )
# dataset_2048_readmission.to_parquet(
#     "patient_sequences/patient_sequences_2048_readmission.parquet",
# )
# dataset_2048_los.to_parquet("patient_sequences/patient_sequences_2048_los.parquet")
# dataset_2048_condition.to_parquet(
#     "patient_sequences/patient_sequences_2048_condition.parquet",
# )
multi_dataset.to_parquet(f"{DATA_ROOT}/patient_sequences/patient_sequences_2048_multi_v2.parquet")

In [None]:
# # Load data
# multi_dataset = pd.read_parquet('patient_sequences/patient_sequences_2048_multi.parquet')
# pid = pickle.load(open('patient_id_dict/dataset_2048_multi.pkl', 'rb'))
# multi_dataset = multi_dataset[multi_dataset['patient_id'].isin(pid['finetune']['few_shot']['all'])]

# # Train Tokenizer
# tokenizer = ConceptTokenizer(data_dir='/h/afallah/odyssey/odyssey/odyssey/data/vocab')
# tokenizer.fit_on_vocab(with_tasks=True)

# # Load datasets
# tasks = ['mortality_1month', 'los_1week'] + [f'c{i}' for i in range(5)]

# train_dataset = FinetuneMultiDataset(
#     data=multi_dataset,
#     tokenizer=tokenizer,
#     tasks=tasks,
#     balance_guide={'mortality_1month': 0.5, 'los_1week': 0.5},
#     max_len=2048,
# )

In [None]:
# dataset_2048_condition = pd.read_parquet('patient_sequences/patient_sequences_2048_condition.parquet')
# pid = pickle.load(open('patient_id_dict/dataset_2048_condition.pkl', 'rb'))
# condition_finetune = dataset_2048_condition.loc[dataset_2048_condition['patient_id'].isin(pid['finetune']['few_shot']['50000'])]
# condition_finetune

In [None]:
# freq = np.array(condition_finetune['all_conditions'].tolist()).sum(axis=0)
# weights = np.clip(0, 50, sum(freq) / freq)
# np.max(np.sqrt(freq)) / np.sqrt(freq)

In [None]:
# sorted(patient_ids_dict['pretrain']) == sorted(pickle.load(open('new_data/patient_id_dict/sample_pretrain_test_patient_ids_with_conditions.pkl', 'rb'))['pretrain'])

In [None]:
# merged_df = pd.merge(dataset_2048_mortality, dataset_2048_readmission, how='outer', on='patient_id')
# final_merged_df = pd.merge(merged_df, dataset_2048_condition, how='outer', on='patient_id')
# final_merged_df

In [None]:
# Performing stratified k-fold split
# skf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=SEED)

# for i, (train_index, cv_index) in enumerate(skf.split(dataset, dataset[label_col])):

#     dataset_cv = dataset.iloc[cv_index]
#     dataset_finetune = dataset.iloc[train_index]

#     # Separate positive and negative labeled patients
#     pos_patients = dataset_cv[dataset_cv[label_col] == True]['patient_id'].tolist()
#     neg_patients = dataset_cv[dataset_cv[label_col] == False]['patient_id'].tolist()

#     # Calculate the number of positive and negative patients needed for balanced CV set
#     num_pos_needed = cv_size // 2
#     num_neg_needed = cv_size // 2

#     # Select positive and negative patients for CV set ensuring balanced distribution
#     cv_patients = pos_patients[:num_pos_needed] + neg_patients[:num_neg_needed]
#     remaining_finetune_patients = pos_patients[num_pos_needed:] + neg_patients[num_neg_needed:]

#     # Extract patient IDs for training set
#     finetune_patients = dataset_finetune['patient_id'].tolist()
#     finetune_patients += remaining_finetune_patients

#     # Shuffle each list of patients
#     random.shuffle(cv_patients)
#     random.shuffle(finetune_patients)

#     patient_ids_dict['finetune']['kfold'][f'group{i+1}'] = {'finetune': finetune_patients, 'cv': cv_patients}

In [None]:
# Assuming dataset.event_tokens is your DataFrame column
# dataset.event_tokens.transform(len).plot(kind='hist', bins=100)
# plt.xlim(1000, 8000)  # Limit x-axis to 5000
# plt.ylim(0, 6000)
# plt.xlabel('Length of Event Tokens')
# plt.ylabel('Frequency')
# plt.title('Histogram of Event Tokens Length')
# plt.show()

In [None]:
# len(patient_ids_dict['group3']['cv'])

# dataset_2048.loc[dataset_2048['patient_id'].isin(patient_ids_dict['group1']['cv'])]['label_mortality_1month']

# s = set()
# for i in range(1, 6):
#     s = s.union(set(patient_ids_dict[f'group{i}']['cv']))
#
# len(s)

In [None]:
##### DEAD ZONE | DO NOT ENTER #####

# patient_ids = pickle.load(open(join("/h/afallah/odyssey/odyssey/data/bigbird_data", 'dataset_2048_mortality_1month.pkl'), 'rb'))
# patient_ids['finetune']['few_shot'].keys()

# patient_ids2 = pickle.load(open(join("/h/afallah/odyssey/odyssey/data/bigbird_data", 'dataset_2048_mortality_2weeks.pkl'), 'rb'))['pretrain']
#
# patient_ids1.sort()
# patient_ids2.sort()
#
# patient_ids1 == patient_ids2
# # dataset_2048.loc[dataset_2048['patient_id'].isin(patient_ids['pretrain'])]

In [None]:
# dataset_2048_readmission = dataset_2048.loc[dataset_2048['num_visits'] > 1]
# dataset_2048_readmission.reset_index(drop=True, inplace=True)
#
# dataset_2048_readmission['last_VS_index'] = dataset_2048_readmission['event_tokens_2048'].transform(lambda seq: get_last_occurence_index(list(seq), '[VS]'))
#
# dataset_2048_readmission['label_readmission_1month'] = dataset_2048_readmission.apply(
#     lambda row: row['event_tokens_2048'][row['last_VS_index'] - 1] in ('[W_0]', '[W_1]', '[W_2]', '[W_3]', '[M_1]'), axis=1
# )
# dataset_2048_readmission['event_tokens_2048'] = dataset_2048_readmission.apply(
#     lambda row: row['event_tokens_2048'][:row['last_VS_index'] - 1], axis=1
# )
# dataset_2048_readmission.drop(['deceased', 'death_after_start', 'death_after_end', 'length'], axis=1, inplace=True)
# dataset_2048_readmission['num_visits'] -= 1
# dataset_2048_readmission['token_length'] = dataset_2048_readmission['event_tokens_2048'].apply(len)
# dataset_2048_readmission = dataset_2048_readmission.apply(lambda row: truncate_and_pad(row), axis=1)
# dataset_2048_readmission['event_tokens_2048'] = dataset_2048_readmission['event_tokens_2048'].transform(
#     lambda token_list: ' '.join(token_list)
# )
#
# dataset_2048_readmission