In [None]:
import os
import sys
sys.path.insert(0, "/data/zeljko/projects/medgpt/")
sys.path.insert(0, "/data/zeljko/projects/MedCAT/")

os.environ['HF_DATASETS_CACHE'] = "/data/zeljko/.cache/huggingface"
os.environ['TRANSFORMERS_CACHE'] = "/data/zeljko/.cache/huggingface"

%load_ext autoreload
%autoreload 2

In [None]:
from transformers import GPT2Config, Trainer, TrainingArguments, AutoTokenizer, pipeline, GPT2Tokenizer, LlamaForCausalLM, AutoModelForCausalLM
from medgpt.tokenizers.simple_map_tokenizer import SimpleMapTokenizer
from medgpt.models.utils import add_cuis_to_model_and_tokenizer
from medgpt.tokenizers.utils import pack_text
import re
import pickle
from medcat.cat import CAT
import pandas as pd
import datasets
import random
import math
import yaml
from medgpt.config import Config

In [None]:
config = Config(yaml_path='/home/ubuntu/projects/medgpt/configs/mimic-mistral.yaml')

In [None]:
cat = CAT.load_model_pack(config.path.cat, meta_cat_config_dict={'general': {'device': config.cat.meta.device}})

## Prepare the model and tokenizer

In [None]:
model = AutoModelForCausalLM.from_pretrained(config.model.base_name)
_ = model.to('cuda:0')

In [None]:
tokenizer = AutoTokenizer.from_pretrained(config.model.base_name)

In [None]:
tokens = pickle.load(open(config.path.dataset.cuis_in_text, 'rb'))

In [None]:
add_cuis_to_model_and_tokenizer(tokens, tokenizer, cat, model, 
                                special_tokens = config.tokenizer.special_tokens.to_dict(),
                                additional_tokens = config.tokenizer.additional_tokens.to_list())

In [None]:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data

In [None]:
len(input_embeddings), len(output_embeddings)

In [None]:
# Create a map so we know what is what in the tokenizer
from collections import defaultdict
tkn2type = {}
tkn_id2type = {}
id2tkn = {}
token_type2tokens = defaultdict(set)
for tkn, id in tokenizer.vocab.items():
    id2tkn[id] = tkn
    t = 'text'
    if tkn.replace('_', '').replace('Ä ', '').isdigit() and len(tkn) < 6: # Small numbers are numbers, others are CUIs; _ is llama, G is gpt
        t = 'number'
    elif tkn.strip() in cat.cdb.cui2type_ids and cat.cdb.cui2type_ids[tkn.strip()]:
        t = list(cat.cdb.cui2type_ids[tkn.strip()])[0]
        token_type2tokens[t].add(tkn)
    tkn2type[tkn] = t
    tkn_id2type[id] = t

In [None]:
config.path.tokenizer.tkn2type

In [None]:
pickle.dump(tkn2type, open(config.path.tokenizer.tkn2type, 'wb'))
pickle.dump(tkn_id2type, open(config.path.tokenizer.tkn_id2type, 'wb'))
pickle.dump(id2tkn, open(config.path.tokenizer.id2tkn, 'wb'))
pickle.dump(token_type2tokens, open(config.path.tokenizer.token_type2tokens, 'wb'))

In [None]:
# Save model and tokenizer with the new stuff
tokenizer.save_pretrained(config.path.tokenizer.self)
model.save_pretrained(config.path.model)

In [None]:
# Load the just saved models
tokenizer = AutoTokenizer.from_pretrained(config.path.tokenizer.self)
model = AutoModelForCausalLM.from_pretrained(config.path.model)