In [None]:
import os
import sys
sys.path.insert(0, "/data/zeljko/projects/medgpt/")
sys.path.insert(0, "/data/zeljko/projects/MedCAT/")

os.environ['HF_DATASETS_CACHE'] = "/data/zeljko/.cache/huggingface"
os.environ['TRANSFORMERS_CACHE'] = "/data/zeljko/.cache/huggingface"

%load_ext autoreload
%autoreload 2

In [None]:
import torch
import os
import pandas as pd
import datasets
import numpy as np
from collections import defaultdict
from medcat.cat import CAT
from datasets import DatasetDict
from medgpt.datasets import patient_concept_stream
from medgpt.datasets.filters import filter_by_count, filter_by_type
from medgpt.datasets.utils import get_embeddings_for_tokens, stream_to_separate_examples, add_to_stream, \
                                  remove_parents_from_stream, bucket_concepts, cleanup_stream, \
                                  split_stream, add_age, get_all_splits, add_ttd, add_position_ids, \
                                  fix_types_for_presence
from medgpt.utils.cdb_utils import get_parents_map 
from medgpt.utils.stream_utils import docs2stream, get_patient_count_per_token, get_token_counts_from_dataset
from medgpt.tokenizers.simple_map_tokenizer import SimpleMapTokenizer
from medgpt.tokenizers.utils import encode_stream
from medgpt.metrics.next_concept_prediction import precision, metrics_data2df, ComputePrecisionHF
from medcat.cdb import CDB
from medgpt.utils import pickle
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments, AutoTokenizer, pipeline, GPT2Tokenizer, LlamaTokenizerFast, LlamaTokenizer
import plotly.express as px
import pickle

from medgpt.config import Config

In [None]:
config = Config(yaml_path='/home/ubuntu/projects/medgpt/configs/mimic-mistral.yaml', 
                extra_yaml_paths=['/home/ubuntu/projects/medgpt/configs/mimic-seq-len-4096.yaml'])

In [None]:
config.path.dataset.hf_output_folder

In [None]:
FORCE = False # If true a lot of things will be rebuilt

In [None]:
DEVICE = torch.device(config.train.device)
# This is internal config, only for this notebook
BATCH_SIZE = 1000
NUM_PROC = 16

In [None]:
cat = CAT.load_model_pack(config.path.cat, meta_cat_config_dict={'general': {'device': config.cat.meta.device}})
cdb = cat.cdb

In [None]:
doc2info = pickle.load(open(config.path.dataset.doc2info, 'rb'))

### Get counts

In [None]:
doc_paths = [os.path.join(config.path.dataset.annotated_documents, path) for path in os.listdir(config.path.dataset.annotated_documents) 
              if path.startswith("part_")]

In [None]:
pt2cui2cnt = get_token_counts_from_dataset(
                 doc_paths=doc_paths,
                 doc2info=doc2info,
                 meta_requirements={'Subject': 'Patient'}, 
                 save_path=config.path.dataset.pt2cui2cnt,
                 force=False)
len(pt2cui2cnt)

### Get pt2stream

In [None]:
doc2text = pickle.load(open(config.path.dataset.doc2text, 'rb'))

In [None]:
from tokenizers.pre_tokenizers import WhitespaceSplit, Split, Sequence
from tokenizers import Regex

#pu = Split(Regex(r'[.;:!?\n]+'), behavior='isolated')
ws = Split(Regex(r'[$ ]*[^ \n]+[\n]*'), behavior='isolated')
pre_tokenizer = Sequence([ws]) # Only space, ignore everything else

In [None]:
pre_tokenizer.pre_tokenize_str("I was. - \n\nrunning")

In [None]:
pt2stream = None

In [None]:
pt2stream = docs2stream(doc_paths,
                        doc2info=doc2info,
                        pt2cui2cnt=pt2cui2cnt,
                        entity_type_column='type_ids',
                        meta_requirements={'Subject': 'Patient'}, # Presence will be an option to filter by later
                        historical_meta=None,
                        skip_cuis={'S-418023006', '17971005'},
                        require_time=True,
                        save_path=config.path.dataset.self,
                        tokenizer=pre_tokenizer.pre_tokenize_str,
                        doc2text=doc2text,
                        force=False,
                        cntx_size=config.train.cntx_size,
                        sentence_limits=tuple(config.train.sentence_limits) if 'sentence_limits' in config.train and config.train.sentence_limits else None)

In [None]:
cui_by_pt = get_patient_count_per_token(pt2stream, force=False, save_path=config.path.dataset.cui_by_pt)

### Load datasets

In [None]:
dataset = datasets.load_dataset(os.path.abspath(patient_concept_stream.__file__), data_files=[config.path.dataset.self])['train']

In [None]:
# Do not run unless you are testing stuff
import random
#from datasets import Dataset
#inds = random.sample([i for i in range(len(dataset))], k=200)
#dataset = Dataset.from_dict(dataset[inds])

### Filter by count, split and checkpoint

In [None]:
patient_ids_test_set = set([str(x) for x in pd.read_csv(config.path.dataset.test_df).subject_id.values])

In [None]:
dataset = filter_by_count(dataset, 
                          min_count=config.train.min_count, 
                          min_count_global=config.train.min_global_count, 
                          min_length=config.train.min_length, 
                          max_length=-1, 
                          num_proc=NUM_PROC, 
                          token_cnt=cui_by_pt)
#dataset = dataset.train_test_split(test_size = 0.05)
train_ds = dataset.filter(lambda example: example['patient_id'] not in patient_ids_test_set,
                          num_proc=NUM_PROC)
test_ds = dataset.filter(lambda example: example['patient_id'] in patient_ids_test_set,
                         num_proc=NUM_PROC)
dataset = DatasetDict({'train': train_ds, 'test': test_ds})

In [None]:
dataset = DatasetDict({'train': train_ds, 'test': test_ds})

In [None]:
dataset.save_to_disk(config.path.dataset.splits_data)

### Bucket examples and remove parents

In [None]:
dataset = datasets.load_from_disk(config.path.dataset.splits_data)

In [None]:
dataset['train'][0]['stream']

In [None]:
# We need to remove parents early on, because it can messup other things like temporality 
cuis = pickle.load(open(config.path.dataset.cuis_in_text, 'rb'))
ch2parents = get_parents_map(cuis, cdb.addl_info['pt2ch'], depth=2)
dataset = dataset.map(
        lambda examples: remove_parents_from_stream(examples, ch2parents=ch2parents, separator=None),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

In [None]:
dataset = dataset.map(
        lambda examples: bucket_concepts(examples, bucket_size_seconds=config.train.days*24*60*60, time_prefix=''), #'<TIME> '), # Requires a space at the end
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

In [None]:
pt2stream = None
cui_by_pt = get_patient_count_per_token(pt2stream, force=False, save_path=config.path.dataset.cui_by_pt)

In [None]:
# Trim timelines longer than MAX_LEN
dataset = filter_by_count(dataset, min_count=0, min_count_global=0, 
                          min_length=config.train.min_length, 
                          max_length=8*config.train.max_timeline_len, # This is just to prevent some timelines from being ultra long, also this is timelines in concepts, never happens that they are this long
                          num_proc=NUM_PROC, 
                          token_cnt=cui_by_pt)

### Change token type to match presence

In [None]:
dataset = dataset.map(
        lambda examples: fix_types_for_presence(examples, config.train.token_type_prefix),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

### Add demographics 

In [None]:
pt2info = pickle.load(open(config.path.dataset.pt2info, 'rb'))

In [None]:
# Add Sex
dataset = dataset.map(
        lambda examples: add_to_stream(examples, pt2info, last=False, prefix=None, key='Sex', token_type='sex', lowercase=False),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

In [None]:
# Add Eth
dataset = dataset.map(
        lambda examples: add_to_stream(examples, pt2info, last=False, prefix=None, key='eth', token_type='ethnicity', lowercase=True),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

In [None]:
dataset = dataset.map(
        lambda examples: add_age(examples, pt2info=pt2info),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

### Add start and end tokens `<s> </s>`

In [None]:
# Add <s>
dataset = dataset.map(
        lambda examples: add_to_stream(examples, one_token=config.tokenizer.special_tokens.bos_token, 
                                       token_type='bos_token', add_space=False),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

In [None]:
# Add </s>
dataset = dataset.map(
        lambda examples: add_to_stream(examples, one_token=config.tokenizer.special_tokens.eos_token, 
                                       token_type='eos_token', last=True),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

In [None]:
# Just in case
dataset.save_to_disk(config.path.dataset.just_before_encoding_dataset_split)

### Cleanup

In [None]:
dataset = datasets.load_from_disk(config.path.dataset.just_before_encoding_dataset_split)

In [None]:
#config.train.use_context = False

In [None]:
ends = list(config.train.sentence_limits if config.train.sentence_limits else ['.', '!', '?', ';', '_'])
dataset = dataset.map(
        lambda examples: cleanup_stream(examples, separator='... ', add_context=config.train.use_context, ends=ends),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

In [None]:
dataset['train'][15]

### Encode

In [None]:
tokenizer = AutoTokenizer.from_pretrained(config.path.tokenizer.self)

In [None]:
encoded_dataset = dataset.map(lambda examples: encode_stream(examples, tokenizer), 
                              batched=True, 
                              num_proc=NUM_PROC, 
                              remove_columns=["stream"])

In [None]:
encoded_dataset.save_to_disk(config.path.dataset.prepared_dataset_split)

In [None]:
config.path.dataset.prepared_dataset_split

# Tests

In [None]:
encoded_dataset = datasets.load_from_disk(config.path.dataset.prepared_dataset_split)

In [None]:
config.path.dataset.prepared_dataset_split

In [None]:
encoded_dataset

In [None]:
c = encoded_dataset['train'][39]
tkns = tokenizer.convert_ids_to_tokens(c['input_ids'])
for i in range(len(c['input_ids'])):
    #if c['token_type'][i] in ['T-11', 'time_sep']:
     print("{:15} {:7} {:15} {:10}".format(tkns[i], c['input_ids'][i], c['time'][i], c['token_type'][i]))

In [None]:
print(tokenizer.decode(c['input_ids']))

## Prepare the DS for the test folder

In [None]:
from medgpt.tokenizers.utils import pack_text, create_labels, pack_examples
tokenizer = AutoTokenizer.from_pretrained(config.path.tokenizer.self)

In [None]:
dataset = encoded_dataset['test']
dataset = dataset.remove_columns(['patient_id', 'token_type', 'time'])

for max_len in [512, 1024, 2048]:
    config.train.max_timeline_len = max_len
    
    # Do test if needed
    _dataset = dataset.map(
        lambda examples: pack_text(examples, max_len=config.train.max_timeline_len),
        batched=True,
        batch_size=1000,
        num_proc=1,
    )
    # Create labels for supervised training
    cuis = pickle.load(open(config.path.dataset.cuis_in_text, 'rb'))
    cui_ids = set(tokenizer.convert_tokens_to_ids([c for c in cuis]))
    _dataset = _dataset.map(
        lambda examples: create_labels(examples, config, cui_ids),
        batched=True,
        batch_size=1000,
        num_proc=8,
    )

    name = config.path.dataset.metrics_folder.split("/")[-2][:-7] + 'test_set.hf'
    _dataset.save_to_disk(config.path.dataset.test_sets_folder + name)