In [None]:
import os
import sys
sys.path.insert(0, "/data/zeljko/projects/medgpt/")
sys.path.insert(0, "/data/zeljko/projects/MedCAT/")

os.environ['HF_DATASETS_CACHE'] = "/data/zeljko/.cache/huggingface"
os.environ['TRANSFORMERS_CACHE'] = "/data/zeljko/.cache/huggingface"

%load_ext autoreload
%autoreload 2

In [None]:
import torch
import os
import pandas as pd
import datasets
import numpy as np
from collections import defaultdict
from medcat.cat import CAT
from datasets import DatasetDict
from medgpt.datasets import patient_concept_stream
from medgpt.datasets.filters import filter_by_count, filter_by_type
from medgpt.datasets.utils import get_embeddings_for_tokens, stream_to_separate_examples, add_to_stream, \
                                  remove_parents_from_stream, bucket_concepts, cleanup_stream, \
                                  split_stream, add_age, get_all_splits, add_ttd, add_position_ids, \
                                  fix_types_for_presence, create_risk_prediction_timelines, create_risk_prediction_timelines_but_better
from medgpt.utils.cdb_utils import get_parents_map 
from medgpt.utils.stream_utils import docs2stream, get_patient_count_per_token, get_token_counts_from_dataset
from medgpt.tokenizers.simple_map_tokenizer import SimpleMapTokenizer
from medgpt.tokenizers.utils import encode_stream
from medgpt.metrics.next_concept_prediction import precision, metrics_data2df, ComputePrecisionHF
from medcat.cdb import CDB
from medgpt.utils import pickle
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments, AutoTokenizer, pipeline, GPT2Tokenizer, LlamaTokenizerFast, LlamaTokenizer
import plotly.express as px
import pickle
from medgpt.tokenizers.utils import pack_text, create_labels, pack_examples, partial_pack_for_risk, trim_ds

from medgpt.config import Config

In [None]:
config = Config(yaml_path='/home/ubuntu/projects/medgpt/configs/mimic-mistral.yaml', 
                extra_yaml_paths=['/home/ubuntu/projects/medgpt/configs/mimic-seq-len-4096.yaml'])

In [None]:
config.risk_prediction['prompts'].to_list()

In [None]:
FORCE = False # If true a lot of things will be rebuilt

In [None]:
DEVICE = torch.device(config.train.device)
# This is internal config, only for this notebook
BATCH_SIZE = 1000
NUM_PROC = 24

In [None]:
cat = CAT.load_model_pack(config.path.cat, meta_cat_config_dict={'general': {'device': config.cat.meta.device}})
cdb = cat.cdb

## Split timeline in the middle

In [None]:
# Only use the test set
dataset = datasets.load_from_disk(config.path.dataset.just_before_encoding_dataset_split)

In [None]:
# Do not run unless you are testing stuff
#import random
#from datasets import Dataset
#inds = random.sample([i for i in range(len(dataset['test']))], k=200)
#dataset = Dataset.from_dict(dataset['test'][inds])

In [None]:
config.risk_prediction.prompts, config.risk_prediction.min_past_length

In [None]:
cui_max_timeline_len = 50
dataset = dataset.map(
        lambda examples: create_risk_prediction_timelines(examples, prefixes=config.risk_prediction.prompts.to_list(), token_type='T-11', n_risk=1, min_past_length=config.risk_prediction.min_past_length,
                                                          max_timeline_len=cui_max_timeline_len), #'<TIME> '), # Requires a space at the end
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

In [None]:
dataset = dataset.map(
        lambda examples: cleanup_stream(examples, separator='... ', add_context=config.train.use_context),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(config.path.tokenizer.self)

In [None]:
dataset = dataset.map(lambda examples: encode_stream(examples, tokenizer), 
                              batched=True, 
                              num_proc=NUM_PROC, 
                              remove_columns=["stream"])

In [None]:
dataset = dataset.remove_columns(['patient_id'])

In [None]:
# Test set is only trimming
dataset['test'] = dataset['test'].map(
    lambda examples: trim_ds(examples, max_len=config.train.max_timeline_len),
    batched=True,
    batch_size=1000,
    num_proc=1,
)
dataset

In [None]:
type_names = []
for time, text in config.risk_prediction.prompts.to_list():
    type_names.append(f'risk-{time}-T-11')

In [None]:
dataset = dataset.map(
    lambda examples: create_labels(examples, config, type_names=type_names, extra_label_ids={tokenizer.eos_token_id}),
    batched=True,
    batch_size=1000,
    num_proc=8,
)

In [None]:
c = dataset['train'][0]
tkns = tokenizer.convert_ids_to_tokens(c['input_ids'])
print(len(tkns))
for i in range(len(c['input_ids'])):
    print("{:15} {:7} {:15} {:20} {} {}".format(tkns[i], c['input_ids'][i], c['time'][i], c['token_type'][i], c['labels'][i], cat.cdb.get_name(tkns[i])))

In [None]:
c = dataset['test'][4]
tkns = tokenizer.convert_ids_to_tokens(c['input_ids'])
print(len(tkns))
for i in range(len(c['input_ids'])):
    print("{:15} {:7} {:15} {:20} {}".format(tkns[i], c['input_ids'][i], c['time'][i], c['token_type'][i], c['labels'][i]))

In [None]:
dataset

In [None]:
dataset.save_to_disk(config.path.dataset.prepared_risk_dataset)

In [None]:
config.path.trained_model_risk