In [14]:
from dlhpcstarter.utils import load_config_and_update_args
from dlhpcstarter.command_line_arguments import read_command_line_arguments
from longitudinal_model.modelling_longitudinal import LongitudinalPromptMultiCXREncoderDecoderModel, CvtWithProjectionHeadConfig
import torch
import transformers
import os
import warnings

In [15]:
transformers.__version__, torch.__version__

('4.31.0', '2.0.1+cu117')

In [16]:
# Hub checkpoint name:
hub_ckpt_name = 'aehrc/cxrmate'

In [17]:
# Paths:
ckpt_path = '/datasets/work/hb-mlaifsp-mm/work/experiments/mimic_cxr/098_gen_prompt_cxr_bert/trial_1/epoch=0-step=1567-val_report_chexbert_f1_macro=0.413190.ckpt'
ckpt_zoo_dir = '/datasets/work/hb-mlaifsp-mm/work/checkpoints'

In [18]:
# Load state dict with depreciated keys:
state_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict']

In [19]:
# Encoder & decoder config:
config_decoder = transformers.BertConfig(
    vocab_size=30000,
    num_hidden_layers=6,
    type_vocab_size=2,
)  # BERT as it includes token_type_ids.
encoder_ckpt_name = 'microsoft/cvt-21-384-22k'
config_encoder = CvtWithProjectionHeadConfig.from_pretrained(
    os.path.join(ckpt_zoo_dir, encoder_ckpt_name),
    local_files_only=True,
    projection_size=config_decoder.hidden_size,
)
config = transformers.VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)

# Encoder-to-decoder instance:
LongitudinalPromptMultiCXREncoderDecoderModel.register_for_auto_class("AutoModel")
encoder_decoder = LongitudinalPromptMultiCXREncoderDecoderModel(config=config)

trainable params: 147456 || all params: 80916528 || trainable%: 0.18223223814051934


In [20]:
for key in list(state_dict.keys()):
    if 'encoder_projection' in key:
        state_dict[key.replace('encoder_projection', 'encoder.projection_head.projection')] = state_dict.pop(key)
    elif 'last_hidden_state_layer_norm' in key:
        state_dict[key.replace('last_hidden_state_layer_norm', 'encoder.projection_head.layer_norm')] = state_dict.pop(key)
    elif 'encoder.encoder' in key:
        state_dict[key.replace('encoder.encoder', 'encoder.cvt.encoder')] = state_dict.pop(key)
    elif 'encoder_decoder.' in key:
        state_dict[key.replace('encoder_decoder.', '')] = state_dict.pop(key)
    else:
        warnings.warn(f'Key not found: {key}')

In [21]:
# Load renamed state dict:
encoder_decoder.load_state_dict(state_dict)

<All keys matched successfully>

In [22]:
# Save model:
save_path = '/scratch2/nic261/checkpoints/cxrmate/huggingface_cxrmate'
encoder_decoder.save_pretrained(save_path)

In [23]:
# Load tokenizer:
tokenizer_dir = os.path.join(ckpt_zoo_dir, 'mimic-cxr-tokenizers', 'bpe_prompt')
tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained(tokenizer_dir, local_files_only=True)

In [24]:
# Image processor:
image_processor = transformers.AutoFeatureExtractor.from_pretrained(os.path.join(ckpt_zoo_dir, encoder_ckpt_name))
image_processor.save_pretrained(save_path)

['/scratch2/nic261/checkpoints/cxrmate/huggingface_cxrmate/preprocessor_config.json']

In [25]:
# Hub login:
from huggingface_hub import login

with open('/home/nic261/hf_token.txt', 'r') as f:
    token = f.readline()
login(token=token[:-1])
del token

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid.
Your token has been saved to /home/nic261/.cache/huggingface/token
Login successful


In [26]:
# Push to hub:
encoder_decoder.push_to_hub(hub_ckpt_name)
tokenizer.push_to_hub(hub_ckpt_name)
image_processor.push_to_hub(hub_ckpt_name)

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

pytorch_model.bin:   0%|          | 0.00/450M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/aehrc/cxrmate/commit/1f014633b98564f21316b32e167b5796381690d8', commit_message='Upload feature extractor', commit_description='', oid='1f014633b98564f21316b32e167b5796381690d8', pr_url=None, pr_revision=None, pr_num=None)