In [1]:
from cxrmate_ed.modelling_cxrmate_ed import MIMICIVEDCXRMultimodalModel
from cxrmate_ed.modelling_uniformer import MultiUniFormerWithProjectionHead
from cxrmate_ed.configuration_uniformer import UniFormerWithProjectionHeadConfig
from cxrmate_ed.records import EDCXRSubjectRecords
from cxrmate_ed.tables import NUM_ED_CXR_TOKEN_TYPE_IDS
import torch
import transformers
import os
import warnings
import math
import datetime

In [2]:
transformers.__version__, torch.__version__

('4.40.2', '2.1.1')

In [3]:
# Hub checkpoint name:
hub_ckpt_name = 'aehrc/cxrmate-ed'
ckpt_zoo_dir = '/datasets/work/hb-mlaifsp-mm/work/checkpoints'
database_path = '/scratch3/nic261/database/mimic_iv_duckdb_rev_d.db'

In [4]:
# Paths:
ckpt_path = '/datasets/work/hb-mlaifsp-mm/work/repositories/transmodal/cxrmate2/experiments/cxrmate2/cxrmate2/005_scst_cxrbert_bertscore/trial_3/epoch=22-step=54924-val_findings_bertscore_f1=0.443515.ckpt'

In [5]:
# Load state dict with depreciated keys:
state_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict']

In [6]:
# Encoder-to-decoder instance:
MIMICIVEDCXRMultimodalModel.register_for_auto_class("AutoModel")

In [7]:
records = EDCXRSubjectRecords(database_path=database_path, time_delta_map=lambda x: 1 / math.sqrt(x + 1))

records.ed_module_tables = {k: records.ed_module_tables[k] for k in ['edstays', 'triage', 'vitalsign']}
records.mimic_cxr_tables = {k: records.mimic_cxr_tables[k] for k in ['mimic_cxr_sectioned']}
records.mimic_cxr_tables['mimic_cxr_sectioned'].text_columns = ['indication', 'history']

index_value_encoder_config = {}
for k, v in (records.ed_module_tables | records.mimic_cxr_tables).items():
    if v.load and (v.value_columns or v.index_columns):
        index_value_encoder_config[k] = v.total_indices

# Decoder tokenizer:
encoder_decoder_ckpt_name = f'{ckpt_zoo_dir}/mimic_iv_tokenizers/bpe_cxr_findings_impression_indication_history_ed_medrecon_vitalsign_triage'
tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained(encoder_decoder_ckpt_name)
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

# Print the special tokens:
print('Description, Special token, Index')
for k, v in tokenizer.special_tokens_map.items():
    if k != 'additional_special_tokens':
        print(f'{k}, {v}, {getattr(tokenizer, k + "_id")}')
    else:
        for i, j in zip(tokenizer.additional_special_tokens, tokenizer.additional_special_tokens_ids):
            print(f'additional_special_token, {i}, {j}')

# Decoder config:
config_decoder = transformers.LlamaConfig(
    vocab_size=len(tokenizer),
    hidden_size=768,
    intermediate_size=3072,
    num_attention_heads=12,
    num_hidden_layers=6,
    max_position_embeddings=2048,
)
config_decoder.is_decoder = True
config_decoder.index_value_encoder_config = index_value_encoder_config
config_decoder.index_value_encoder_intermediate_size = 2048
config_decoder.ed_module_columns = [f'{k}_{i}' for k, v in records.ed_module_tables.items() for i in v.text_columns]
config_decoder.mimic_cxr_columns = [i for _, v in records.mimic_cxr_tables.items() for i in v.text_columns]
config_decoder.token_type_to_token_type_id = records.token_type_to_token_type_id
config_decoder.num_token_types = NUM_ED_CXR_TOKEN_TYPE_IDS
config_decoder.include_time_delta = True
config_decoder.time_delta_monotonic_inversion = True
config_decoder.zero_time_delta_value = records.compute_time_delta(
    datetime.datetime.fromtimestamp(0),
    datetime.datetime.fromtimestamp(0), 
    to_tensor=False,
)
config_decoder.add_time_deltas = True

# Section embedding identifiers (for report):
config_decoder.section_ids = [
    records.token_type_to_token_type_id['findings'], 
    records.token_type_to_token_type_id['impression'], 
]

# Add set token identifiers in decoder's config:
config_decoder.pad_token_id = tokenizer.pad_token_id

# Encoder config:
config_encoder = UniFormerWithProjectionHeadConfig(
    projection_size=config_decoder.hidden_size,
)
encoder_ckpt_name = 'uniformer_base_tl_384'

# Encoder-to-decoder model:
config = transformers.VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
config.decoder.add_cross_attention = False
encoder_decoder = MIMICIVEDCXRMultimodalModel(
    config=config, 
    DefaultEncoderClass=MultiUniFormerWithProjectionHead,
    DefaultDecoderClass=transformers.LlamaForCausalLM,
)

Description, Special token, Index
bos_token, [BOS], 1
eos_token, [EOS], 2
unk_token, [UNK], 0
sep_token, [SEP], 3
pad_token, [PAD], 4
cls_token, [BOS], 1
mask_token, [MASK], 5


In [8]:
for key in list(state_dict.keys()):
    if 'encoder_decoder.' in key:
        state_dict[key.replace('encoder_decoder.', '')] = state_dict.pop(key)
    else:
        warnings.warn(f'Key not found: {key}')

In [9]:
# Load renamed state dict:
encoder_decoder.load_state_dict(state_dict)

<All keys matched successfully>

In [10]:
# Save model:
save_path = '/scratch3/nic261/checkpoints/cxrmate_ed'
encoder_decoder.save_pretrained(save_path)

In [11]:
# Hub login:
from huggingface_hub import login

with open('/home/nic261/hf_token.txt', 'r') as f:
    token = f.readline()
login(token=token[:-1])
del token

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /scratch3/nic261/.cache/token
Login successful


In [12]:
# Push to hub:
encoder_decoder.push_to_hub(hub_ckpt_name)
tokenizer.push_to_hub(hub_ckpt_name)

CommitInfo(commit_url='https://huggingface.co/aehrc/cxrmate-ed/commit/d7831418c6fed23bc77f78eade260f758fb61d56', commit_message='Upload tokenizer', commit_description='', oid='d7831418c6fed23bc77f78eade260f758fb61d56', pr_url=None, pr_revision=None, pr_num=None)