In [None]:
import os
import sys
import pickle
import random
from typing import Any, Dict, List, Optional

import numpy as np
import pandas as pd

ROOT = "/h/afallah/odyssey/odyssey"
DATA_ROOT = f"{ROOT}/odyssey/data/bigbird_data"
DATASET = f"{DATA_ROOT}/patient_sequences/patient_sequences_2048.parquet"
MAX_LEN = 2048

os.chdir(ROOT)

from odyssey.utils.utils import seed_everything
from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.data.dataset import FinetuneMultiDataset
from odyssey.data.processor import (
    filter_by_num_visit,
    filter_by_length_of_stay,
    get_last_occurence_index,
    check_readmission_label,
    get_length_of_stay,
    get_visit_cutoff_at_threshold,
    process_length_of_stay_dataset,
    process_condition_dataset,
    process_mortality_dataset,
    process_readmission_dataset,
    process_multi_dataset,
    stratified_train_test_split,
    sample_balanced_subset,
    get_pretrain_test_split,
    get_finetune_split,
)

SEED = 23
seed_everything(seed=SEED)

In [None]:
# Load complete dataset
dataset_2048 = pd.read_parquet(DATASET)

print(f"Current columns: {dataset_2048.columns}")
dataset_2048.head()

In [None]:
# Process the dataset for length of stay prediction above a threshold
dataset_2048_los = process_length_of_stay_dataset(
    dataset_2048.copy(), threshold=7, max_len=MAX_LEN
)

In [None]:
# Process the dataset for conditions including rare and common
dataset_2048_condition = process_condition_dataset(dataset_2048.copy(), max_len=MAX_LEN)

In [None]:
# Process the dataset for mortality in two weeks or one month task
dataset_2048_mortality = process_mortality_dataset(dataset_2048.copy(), max_len=MAX_LEN)

In [None]:
# Process the dataset for hospital readmission in one month task
dataset_2048_readmission = process_readmission_dataset(
    dataset_2048.copy(), max_len=MAX_LEN
)

In [None]:
# Process the multi dataset
multi_dataset = process_multi_dataset(
    datasets={
        "original": dataset_2048,
        "mortality": dataset_2048_mortality,
        "condition": dataset_2048_condition,
        "readmission": dataset_2048_readmission,
        "los": dataset_2048_los,
    },
    max_len=MAX_LEN,
)

In [None]:
# Split data
patient_ids_dict = {
    "pretrain": [],
    "finetune": {"few_shot": {}, "kfold": {}},
    "test": [],
}

# Get train-test split
# pretrain_ids, test_ids = get_pretrain_test_split(dataset_2048_readmission, stratify_target='label_readmission_1month', test_size=0.2)
# pretrain_ids, test_ids = get_pretrain_test_split(process_condition_dataset, stratify_target='all_conditions', test_size=0.15)
# patient_ids_dict['pretrain'] = pretrain_ids
# patient_ids_dict['test'] = test_ids

# Load pretrain and test patient IDs
pid = pickle.load(open("patient_id_dict/dataset_2048_multi.pkl", "rb"))
patient_ids_dict["pretrain"] = pid["pretrain"]
patient_ids_dict["test"] = pid["test"]
set(pid["pretrain"] + pid["test"]) == set(dataset_2048["patient_id"])

In [None]:
task_config = {
    "mortality": {
        "dataset": dataset_2048_mortality,
        "label_col": "label_mortality_1month",
        "finetune_size": [250, 500, 1000, 5000, 20000],
        "save_path": "patient_id_dict/dataset_2048_mortality.pkl",
        "split_mode": "single_label_balanced",
    },
    "readmission": {
        "dataset": dataset_2048_readmission,
        "label_col": "label_readmission_1month",
        "finetune_size": [250, 1000, 5000, 20000, 60000],
        "save_path": "patient_id_dict/dataset_2048_readmission.pkl",
        "split_mode": "single_label_stratified",
    },
    "length_of_stay": {
        "dataset": dataset_2048_los,
        "label_col": "label_los_1week",
        "finetune_size": [250, 1000, 5000, 20000, 50000],
        "save_path": "patient_id_dict/dataset_2048_los.pkl",
        "split_mode": "single_label_balanced",
    },
    "condition": {
        "dataset": dataset_2048_condition,
        "label_col": "all_conditions",
        "finetune_size": [50000],
        "save_path": "patient_id_dict/dataset_2048_condition.pkl",
        "split_mode": "multi_label_stratified",
    },
}

In [None]:
# Get finetune split
for task in task_config.keys():
    patient_ids_dict = get_finetune_split(
        task_config=task_config,
        task=task,
        patient_ids_dict=patient_ids_dict,
    )

In [None]:
dataset_2048_mortality.to_parquet(
    "patient_sequences/patient_sequences_2048_mortality.parquet",
)
dataset_2048_readmission.to_parquet(
    "patient_sequences/patient_sequences_2048_readmission.parquet",
)
dataset_2048_los.to_parquet("patient_sequences/patient_sequences_2048_los.parquet")
dataset_2048_condition.to_parquet(
    "patient_sequences/patient_sequences_2048_condition.parquet",
)
multi_dataset.to_parquet("patient_sequences/patient_sequences_2048_multi.parquet")

In [None]:
# # Load data
# multi_dataset = pd.read_parquet('patient_sequences/patient_sequences_2048_multi.parquet')
# pid = pickle.load(open('patient_id_dict/dataset_2048_multi.pkl', 'rb'))
# multi_dataset = multi_dataset[multi_dataset['patient_id'].isin(pid['finetune']['few_shot']['all'])]

# # Train Tokenizer
# tokenizer = ConceptTokenizer(data_dir='/h/afallah/odyssey/odyssey/odyssey/data/vocab')
# tokenizer.fit_on_vocab(with_tasks=True)

# # Load datasets
# tasks = ['mortality_1month', 'los_1week'] + [f'c{i}' for i in range(5)]

# train_dataset = FinetuneMultiDataset(
#     data=multi_dataset,
#     tokenizer=tokenizer,
#     tasks=tasks,
#     balance_guide={'mortality_1month': 0.5, 'los_1week': 0.5},
#     max_len=2048,
# )

In [None]:
# dataset_2048_condition = pd.read_parquet('patient_sequences/patient_sequences_2048_condition.parquet')
# pid = pickle.load(open('patient_id_dict/dataset_2048_condition.pkl', 'rb'))
# condition_finetune = dataset_2048_condition.loc[dataset_2048_condition['patient_id'].isin(pid['finetune']['few_shot']['50000'])]
# condition_finetune

In [None]:
# freq = np.array(condition_finetune['all_conditions'].tolist()).sum(axis=0)
# weights = np.clip(0, 50, sum(freq) / freq)
# np.max(np.sqrt(freq)) / np.sqrt(freq)

In [None]:
# sorted(patient_ids_dict['pretrain']) == sorted(pickle.load(open('new_data/patient_id_dict/sample_pretrain_test_patient_ids_with_conditions.pkl', 'rb'))['pretrain'])

In [None]:
# merged_df = pd.merge(dataset_2048_mortality, dataset_2048_readmission, how='outer', on='patient_id')
# final_merged_df = pd.merge(merged_df, dataset_2048_condition, how='outer', on='patient_id')
# final_merged_df

In [None]:
# Performing stratified k-fold split
# skf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=SEED)

# for i, (train_index, cv_index) in enumerate(skf.split(dataset, dataset[label_col])):

#     dataset_cv = dataset.iloc[cv_index]
#     dataset_finetune = dataset.iloc[train_index]

#     # Separate positive and negative labeled patients
#     pos_patients = dataset_cv[dataset_cv[label_col] == True]['patient_id'].tolist()
#     neg_patients = dataset_cv[dataset_cv[label_col] == False]['patient_id'].tolist()

#     # Calculate the number of positive and negative patients needed for balanced CV set
#     num_pos_needed = cv_size // 2
#     num_neg_needed = cv_size // 2

#     # Select positive and negative patients for CV set ensuring balanced distribution
#     cv_patients = pos_patients[:num_pos_needed] + neg_patients[:num_neg_needed]
#     remaining_finetune_patients = pos_patients[num_pos_needed:] + neg_patients[num_neg_needed:]

#     # Extract patient IDs for training set
#     finetune_patients = dataset_finetune['patient_id'].tolist()
#     finetune_patients += remaining_finetune_patients

#     # Shuffle each list of patients
#     random.shuffle(cv_patients)
#     random.shuffle(finetune_patients)

#     patient_ids_dict['finetune']['kfold'][f'group{i+1}'] = {'finetune': finetune_patients, 'cv': cv_patients}

In [None]:
# Assuming dataset.event_tokens is your DataFrame column
# dataset.event_tokens.transform(len).plot(kind='hist', bins=100)
# plt.xlim(1000, 8000)  # Limit x-axis to 5000
# plt.ylim(0, 6000)
# plt.xlabel('Length of Event Tokens')
# plt.ylabel('Frequency')
# plt.title('Histogram of Event Tokens Length')
# plt.show()

In [None]:
# len(patient_ids_dict['group3']['cv'])

# dataset_2048.loc[dataset_2048['patient_id'].isin(patient_ids_dict['group1']['cv'])]['label_mortality_1month']

# s = set()
# for i in range(1, 6):
#     s = s.union(set(patient_ids_dict[f'group{i}']['cv']))
#
# len(s)

In [None]:
##### DEAD ZONE | DO NOT ENTER #####

# patient_ids = pickle.load(open(join("/h/afallah/odyssey/odyssey/data/bigbird_data", 'dataset_2048_mortality_1month.pkl'), 'rb'))
# patient_ids['finetune']['few_shot'].keys()

# patient_ids2 = pickle.load(open(join("/h/afallah/odyssey/odyssey/data/bigbird_data", 'dataset_2048_mortality_2weeks.pkl'), 'rb'))['pretrain']
#
# patient_ids1.sort()
# patient_ids2.sort()
#
# patient_ids1 == patient_ids2
# # dataset_2048.loc[dataset_2048['patient_id'].isin(patient_ids['pretrain'])]

In [None]:
# dataset_2048_readmission = dataset_2048.loc[dataset_2048['num_visits'] > 1]
# dataset_2048_readmission.reset_index(drop=True, inplace=True)
#
# dataset_2048_readmission['last_VS_index'] = dataset_2048_readmission['event_tokens_2048'].transform(lambda seq: get_last_occurence_index(list(seq), '[VS]'))
#
# dataset_2048_readmission['label_readmission_1month'] = dataset_2048_readmission.apply(
#     lambda row: row['event_tokens_2048'][row['last_VS_index'] - 1] in ('[W_0]', '[W_1]', '[W_2]', '[W_3]', '[M_1]'), axis=1
# )
# dataset_2048_readmission['event_tokens_2048'] = dataset_2048_readmission.apply(
#     lambda row: row['event_tokens_2048'][:row['last_VS_index'] - 1], axis=1
# )
# dataset_2048_readmission.drop(['deceased', 'death_after_start', 'death_after_end', 'length'], axis=1, inplace=True)
# dataset_2048_readmission['num_visits'] -= 1
# dataset_2048_readmission['token_length'] = dataset_2048_readmission['event_tokens_2048'].apply(len)
# dataset_2048_readmission = dataset_2048_readmission.apply(lambda row: truncate_and_pad(row), axis=1)
# dataset_2048_readmission['event_tokens_2048'] = dataset_2048_readmission['event_tokens_2048'].transform(
#     lambda token_list: ' '.join(token_list)
# )
#
# dataset_2048_readmission