In [None]:
import os
import sys
import pickle
import random
import json
from typing import Any, Dict, List, Optional

import numpy as np
import pandas as pd
import polars as pl

SEED = 23
ROOT = "/h/afallah/odyssey/odyssey"
DATA_ROOT = f"{ROOT}/odyssey/data/meds_data"  # bigbird_data
DATASET = f"{DATA_ROOT}/patient_sequences/patient_sequences.parquet"  # patient_sequences_2048.parquet\
DATASET_2048 = f"{DATA_ROOT}/patient_sequences/patient_sequences_2048.parquet"
MAX_LEN = 2048

os.chdir(ROOT)

from odyssey.utils.utils import save_object_to_disk, seed_everything
from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.data.dataset import FinetuneMultiDataset
from odyssey.data.processor import *

seed_everything(seed=SEED)

In [2]:
# (lengths <= 4096 ).sum() / len(lengths)

# lengths = dataset['event_tokens'].map_elements(len)
# (lengths.filter(lengths > 2048) - 2048).sum() / 1e6
# lengths.filter((lengths > 2048) & (lengths <= 4096)).sum() / 1e6
# lengths.sum() / 1e6

# sample_id = 2130
# sample = dataset[sample_id]['event_tokens']
# print(list(sample[0]), '\n', len(list(sample[0])))

In [None]:
dataset = pl.read_parquet(DATASET)
dataset = dataset.rename({"subject_id": "patient_id", "code": "event_tokens"})
dataset = dataset.filter(pl.col("event_tokens").map_elements(len) > 5)

dataset = dataset.with_columns(
    [
        pl.col("patient_id").cast(pl.String).alias("patient_id"),
        pl.concat_list(
            [pl.col("event_tokens").list.slice(0, 2047), pl.lit(["[EOS]"])]
        ).alias("event_tokens"),
    ]
)

dataset = dataset.with_columns(
    [
        pl.col("event_tokens").map_elements(len).alias("token_length"),
    ]
)

print(dataset.head())
print(dataset.schema)

dataset.write_parquet(DATASET_2048)

dataset_saved = pl.read_parquet(DATASET_2048)
print(dataset_saved.head())
print(dataset_saved.schema)

In [None]:
# patient_ids_dict = {
#     "pretrain": [],
#     "finetune": {"few_shot": {}, "kfold": {}},
#     "test": [],
# }

# import numpy as np
# import pickle

# # Set random seed
# np.random.seed(23)

# # Get unique patient IDs
# unique_patients = dataset_saved['patient_id'].unique()

# # Randomly shuffle patient IDs
# np.random.shuffle(unique_patients)

# # Calculate split sizes
# n_patients = len(unique_patients)
# n_pretrain = int(0.65 * n_patients)
# n_finetune = int(0.25 * n_patients)

# # Split patient IDs
# patient_ids_dict["pretrain"] = unique_patients[:n_pretrain].tolist()
# patient_ids_dict["finetune"]["few_shot"] = unique_patients[n_pretrain:n_pretrain+n_finetune].tolist()
# patient_ids_dict["test"] = unique_patients[n_pretrain+n_finetune:].tolist()

# # Save the dictionary
# save_object_to_disk(patient_ids_dict, f"{DATA_ROOT}/patient_id_dict/patient_splits.pkl")
# len(patient_ids_dict["pretrain"]), len(patient_ids_dict["finetune"]["few_shot"]), len(patient_ids_dict["test"]), patient_ids_dict["pretrain"][2323]

In [None]:
# patient_ids_dict = load_object_from_disk(f"{DATA_ROOT}/patient_id_dict/patient_splits.pkl")
# len(patient_ids_dict["pretrain"]), len(patient_ids_dict["finetune"]["few_shot"]), len(patient_ids_dict["test"]), patient_ids_dict["pretrain"][2323]

In [None]:
# Load complete dataset
dataset = pd.read_parquet(DATASET)

In [None]:
dataset["num_visits"] = dataset["event_tokens_2048"].transform(
    lambda series: list(series).count("[VS]")
)

print(f"Current columns: {dataset.columns}")
dataset.head()

In [None]:
dataset["event_tokens_2048"].iloc[0]

In [None]:
# Process the dataset for length of stay prediction above a threshold
dataset_los = process_length_of_stay_dataset(
    dataset.copy(), threshold=7, max_len=MAX_LEN
)

In [None]:
# Process the dataset for conditions including rare and common
dataset_condition = process_condition_dataset(dataset.copy())

In [None]:
# Process the dataset for mortality in two weeks or one month task
dataset_mortality = process_mortality_dataset(dataset.copy())

In [None]:
# Process the dataset for hospital readmission in one month task
dataset_readmission = process_readmission_dataset(dataset.copy(), max_len=MAX_LEN)

In [None]:
# Process the multi dataset
multi_dataset = process_multi_dataset(
    datasets={
        "original": dataset,
        "mortality": dataset_mortality,
        "condition": dataset_condition,
        "readmission": dataset_readmission,
        "los": dataset_los,
    },
    max_len=MAX_LEN,
)

In [None]:
# Split data
patient_ids_dict = {
    "pretrain": [],
    "finetune": {"few_shot": {}, "kfold": {}},
    "test": [],
}

# Get train-test split
# pretrain_ids, test_ids = get_pretrain_test_split(dataset_readmission, stratify_target='label_readmission_1month', test_size=0.2)
# pretrain_ids, test_ids = get_pretrain_test_split(process_condition_dataset, stratify_target='all_conditions', test_size=0.15)
# patient_ids_dict['pretrain'] = pretrain_ids
# patient_ids_dict['test'] = test_ids

# Load pretrain and test patient IDs
pid = pickle.load(open(f"{DATA_ROOT}/patient_id_dict/dataset_multi.pkl", "rb"))
patient_ids_dict["pretrain"] = pid["pretrain"]
patient_ids_dict["test"] = pid["test"]
set(pid["pretrain"] + pid["test"]) == set(dataset["patient_id"])

In [None]:
multi_dataset_pretrain = multi_dataset.loc[
    multi_dataset["patient_id"].isin(patient_ids_dict["pretrain"])
]
multi_dataset_test = multi_dataset.loc[
    multi_dataset["patient_id"].isin(patient_ids_dict["test"])
]

In [None]:
# DataFrame assumed loaded as multi_dataset_pretrain
# Define the requirements for each label
label_requirements = {
    "label_mortality_1month": 25000,
    "label_readmission_1month": 30000,
    "label_los_1week": 40000,
    # 'label_c0': 25000,
    # 'label_c1': 25000,
    # 'label_c2': 25000
}

# Prepare a dictionary to store indices for each label and category
selected_indices = {label: {"0": set(), "1": set()} for label in label_requirements}

# Initialize a dictionary to track usage of indices across labels
index_usage = {}


# Function to select indices while maximizing overlap
def select_indices(label, num_required, category):
    # Candidates are those indices matching the category requirement
    candidates = set(
        multi_dataset_pretrain[multi_dataset_pretrain[label] == category].index
    )
    # Prefer candidates that are already used elsewhere to maximize overlap
    preferred = candidates & set(index_usage.keys())
    additional = list(candidates - preferred)
    np.random.shuffle(additional)  # Shuffle to avoid any unintended order bias

    # Determine how many more are needed
    needed = num_required - len(selected_indices[label][str(category)] & candidates)
    if needed > 0:
        # Select as many as possible from preferred, then from additional
        selected = list(preferred - selected_indices[label][str(category)])[:needed]
        selected += additional[: needed - len(selected)]
        # Update the selected indices for this label and category
        selected_indices[label][str(category)].update(selected)
        # Update overall index usage
        for idx in selected:
            index_usage[idx] = index_usage.get(idx, 0) + 1


# Process each label and category
for label in label_requirements:
    num_required = label_requirements[label] // 2  # Divide by 2 for 50-50 distribution
    select_indices(label, num_required, 0)
    select_indices(label, num_required, 1)

# Combine all selected indices from both categories
all_selected_indices = set()
for indices in selected_indices.values():
    all_selected_indices.update(indices["0"])
    all_selected_indices.update(indices["1"])

# Create the balanced DataFrame
multi_dataset_finetune = multi_dataset_pretrain.loc[list(all_selected_indices)]
multi_dataset_finetune

In [None]:
for label in [
    "label_mortality_1month",
    "label_readmission_1month",
    "label_los_1week",
    "label_c0",
    "label_c1",
    "label_c2",
]:
    print(
        f"Label: {label} | Mean: {multi_dataset_finetune[label].mean()}\n{multi_dataset_finetune[label].value_counts()}\n"
    )

In [None]:
patient_ids_dict["finetune"]["few_shot"]["all"] = multi_dataset_finetune[
    "patient_id"
].tolist()

multi_dataset_pretrain = multi_dataset_pretrain.loc[
    ~multi_dataset_pretrain["patient_id"].isin(multi_dataset_finetune["patient_id"])
]

patient_ids_dict["pretrain"] = multi_dataset_pretrain["patient_id"].tolist()

save_object_to_disk(
    patient_ids_dict, f"{DATA_ROOT}/patient_id_dict/dataset_multi_v2.pkl"
)

# "mortality_1month=0.5, los_1week=0.5, readmission_1month=0.5, c0=0.5, c1=0.5, c2=0.5"

In [None]:
"""
Current Approach:
    - Pretrain: 141234 Patients
    - Test: 24924 Patients, 132682 Datapoints
    - Finetune: 139514 Unique Patients, 434270 Datapoints
        - Mortality: 26962 Patients
        - Readmission: 48898 Patients
        - Length of Stay: 72686 Patients
        - Condition 0: 122722 Patients
        - Condition 1: 94048 Patients
        - Condition 2: 68954 Patients
"""

In [None]:
task_config = {
    "mortality": {
        "dataset": dataset_mortality,
        "label_col": "label_mortality_1month",
        "finetune_size": [250, 500, 1000, 5000, 20000],
        "save_path": "patient_id_dict/dataset_mortality.pkl",
        "split_mode": "single_label_balanced",
    },
    "readmission": {
        "dataset": dataset_readmission,
        "label_col": "label_readmission_1month",
        "finetune_size": [250, 1000, 5000, 20000, 60000],
        "save_path": "patient_id_dict/dataset_readmission.pkl",
        "split_mode": "single_label_stratified",
    },
    "length_of_stay": {
        "dataset": dataset_los,
        "label_col": "label_los_1week",
        "finetune_size": [250, 1000, 5000, 20000, 50000],
        "save_path": "patient_id_dict/dataset_los.pkl",
        "split_mode": "single_label_balanced",
    },
    "condition": {
        "dataset": dataset_condition,
        "label_col": "all_conditions",
        "finetune_size": [50000],
        "save_path": "patient_id_dict/dataset_condition.pkl",
        "split_mode": "multi_label_stratified",
    },
}

In [None]:
# Get finetune split
for task in task_config.keys():
    patient_ids_dict = get_finetune_split(
        task_config=task_config,
        task=task,
        patient_ids_dict=patient_ids_dict,
    )

In [None]:
# dataset_mortality.to_parquet(
#     "patient_sequences/patient_sequences_2048_mortality.parquet",
# )
# dataset_readmission.to_parquet(
#     "patient_sequences/patient_sequences_2048_readmission.parquet",
# )
# dataset_los.to_parquet("patient_sequences/patient_sequences_2048_los.parquet")
# dataset_condition.to_parquet(
#     "patient_sequences/patient_sequences_2048_condition.parquet",
# )
multi_dataset.to_parquet(
    f"{DATA_ROOT}/patient_sequences/patient_sequences_2048_multi_v2.parquet"
)

In [None]:
# # Load data
# multi_dataset = pd.read_parquet('patient_sequences/patient_sequences_2048_multi.parquet')
# pid = pickle.load(open('patient_id_dict/dataset_multi.pkl', 'rb'))
# multi_dataset = multi_dataset[multi_dataset['patient_id'].isin(pid['finetune']['few_shot']['all'])]

# # Train Tokenizer
# tokenizer = ConceptTokenizer(data_dir='/h/afallah/odyssey/odyssey/odyssey/data/vocab')
# tokenizer.fit_on_vocab(with_tasks=True)

# # Load datasets
# tasks = ['mortality_1month', 'los_1week'] + [f'c{i}' for i in range(5)]

# train_dataset = FinetuneMultiDataset(
#     data=multi_dataset,
#     tokenizer=tokenizer,
#     tasks=tasks,
#     balance_guide={'mortality_1month': 0.5, 'los_1week': 0.5},
#     max_len=2048,
# )

In [None]:
# dataset_condition = pd.read_parquet('patient_sequences/patient_sequences_2048_condition.parquet')
# pid = pickle.load(open('patient_id_dict/dataset_condition.pkl', 'rb'))
# condition_finetune = dataset_condition.loc[dataset_condition['patient_id'].isin(pid['finetune']['few_shot']['50000'])]
# condition_finetune

In [None]:
# freq = np.array(condition_finetune['all_conditions'].tolist()).sum(axis=0)
# weights = np.clip(0, 50, sum(freq) / freq)
# np.max(np.sqrt(freq)) / np.sqrt(freq)

In [None]:
# sorted(patient_ids_dict['pretrain']) == sorted(pickle.load(open('new_data/patient_id_dict/sample_pretrain_test_patient_ids_with_conditions.pkl', 'rb'))['pretrain'])

In [None]:
# merged_df = pd.merge(dataset_mortality, dataset_readmission, how='outer', on='patient_id')
# final_merged_df = pd.merge(merged_df, dataset_condition, how='outer', on='patient_id')
# final_merged_df

In [None]:
# Performing stratified k-fold split
# skf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=SEED)

# for i, (train_index, cv_index) in enumerate(skf.split(dataset, dataset[label_col])):

#     dataset_cv = dataset.iloc[cv_index]
#     dataset_finetune = dataset.iloc[train_index]

#     # Separate positive and negative labeled patients
#     pos_patients = dataset_cv[dataset_cv[label_col] == True]['patient_id'].tolist()
#     neg_patients = dataset_cv[dataset_cv[label_col] == False]['patient_id'].tolist()

#     # Calculate the number of positive and negative patients needed for balanced CV set
#     num_pos_needed = cv_size // 2
#     num_neg_needed = cv_size // 2

#     # Select positive and negative patients for CV set ensuring balanced distribution
#     cv_patients = pos_patients[:num_pos_needed] + neg_patients[:num_neg_needed]
#     remaining_finetune_patients = pos_patients[num_pos_needed:] + neg_patients[num_neg_needed:]

#     # Extract patient IDs for training set
#     finetune_patients = dataset_finetune['patient_id'].tolist()
#     finetune_patients += remaining_finetune_patients

#     # Shuffle each list of patients
#     random.shuffle(cv_patients)
#     random.shuffle(finetune_patients)

#     patient_ids_dict['finetune']['kfold'][f'group{i+1}'] = {'finetune': finetune_patients, 'cv': cv_patients}

In [None]:
# Assuming dataset.event_tokens is your DataFrame column
# dataset.event_tokens.transform(len).plot(kind='hist', bins=100)
# plt.xlim(1000, 8000)  # Limit x-axis to 5000
# plt.ylim(0, 6000)
# plt.xlabel('Length of Event Tokens')
# plt.ylabel('Frequency')
# plt.title('Histogram of Event Tokens Length')
# plt.show()

In [None]:
# len(patient_ids_dict['group3']['cv'])

# dataset.loc[dataset['patient_id'].isin(patient_ids_dict['group1']['cv'])]['label_mortality_1month']

# s = set()
# for i in range(1, 6):
#     s = s.union(set(patient_ids_dict[f'group{i}']['cv']))
#
# len(s)

In [None]:
##### DEAD ZONE | DO NOT ENTER #####

# patient_ids = pickle.load(open(join("/h/afallah/odyssey/odyssey/data/bigbird_data", 'dataset_mortality_1month.pkl'), 'rb'))
# patient_ids['finetune']['few_shot'].keys()

# patient_ids2 = pickle.load(open(join("/h/afallah/odyssey/odyssey/data/bigbird_data", 'dataset_mortality_2weeks.pkl'), 'rb'))['pretrain']
#
# patient_ids1.sort()
# patient_ids2.sort()
#
# patient_ids1 == patient_ids2
# # dataset.loc[dataset['patient_id'].isin(patient_ids['pretrain'])]

In [None]:
# dataset_readmission = dataset.loc[dataset['num_visits'] > 1]
# dataset_readmission.reset_index(drop=True, inplace=True)
#
# dataset_readmission['last_VS_index'] = dataset_readmission['event_tokens_2048'].transform(lambda seq: get_last_occurence_index(list(seq), '[VS]'))
#
# dataset_readmission['label_readmission_1month'] = dataset_readmission.apply(
#     lambda row: row['event_tokens_2048'][row['last_VS_index'] - 1] in ('[W_0]', '[W_1]', '[W_2]', '[W_3]', '[M_1]'), axis=1
# )
# dataset_readmission['event_tokens_2048'] = dataset_readmission.apply(
#     lambda row: row['event_tokens_2048'][:row['last_VS_index'] - 1], axis=1
# )
# dataset_readmission.drop(['deceased', 'death_after_start', 'death_after_end', 'length'], axis=1, inplace=True)
# dataset_readmission['num_visits'] -= 1
# dataset_readmission['token_length'] = dataset_readmission['event_tokens_2048'].apply(len)
# dataset_readmission = dataset_readmission.apply(lambda row: truncate_and_pad(row), axis=1)
# dataset_readmission['event_tokens_2048'] = dataset_readmission['event_tokens_2048'].transform(
#     lambda token_list: ' '.join(token_list)
# )
#
# dataset_readmission