In [13]:
from dlhpcstarter.utils import load_config_and_update_args
from dlhpcstarter.command_line_arguments import read_command_line_arguments
from longitudinal_model.modelling_longitudinal import LongitudinalPromptVariableCXREncoderDecoderModel, CvtWithProjectionHeadConfig
import torch
import transformers
import os
import warnings

In [14]:
# Hub checkpoint name:
hub_ckpt_name = 'aehrc/cxrmate'

In [15]:
# Paths:
ckpt_path = '/scratch/pawsey0864/anicolson/experiments/mimic_cxr/098_gen_prompt_cxr_bert/trial_0/epoch=0-step=3917-val_report_chexbert_f1_macro=0.425015.ckpt'
ckpt_zoo_dir = '/scratch/pawsey0864/anicolson/checkpoints'

In [16]:
# Load state dict with depreciated keys:
state_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict']

In [17]:
# Encoder & decoder config:
config_decoder = transformers.BertConfig(
    vocab_size=30000,
    num_hidden_layers=6,
    type_vocab_size=2,
)  # BERT as it includes token_type_ids.
encoder_ckpt_name = 'microsoft/cvt-21-384-22k'
config_encoder = CvtWithProjectionHeadConfig.from_pretrained(
    os.path.join(ckpt_zoo_dir, encoder_ckpt_name),
    local_files_only=True,
    projection_size=config_decoder.hidden_size,
)
config = transformers.VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)

# Encoder-to-decoder instance:
LongitudinalPromptVariableCXREncoderDecoderModel.register_for_auto_class("AutoModel")
encoder_decoder = LongitudinalPromptVariableCXREncoderDecoderModel(config=config)

trainable params: 147456 || all params: 80916528 || trainable%: 0.18223223814051934


In [18]:
for key in list(state_dict.keys()):
    if 'encoder_projection' in key:
        state_dict[key.replace('encoder_projection', 'encoder.projection_head.projection')] = state_dict.pop(key)
    elif 'last_hidden_state_layer_norm' in key:
        state_dict[key.replace('last_hidden_state_layer_norm', 'encoder.projection_head.layer_norm')] = state_dict.pop(key)
    elif 'encoder.encoder' in key:
        state_dict[key.replace('encoder.encoder', 'encoder.cvt.encoder')] = state_dict.pop(key)
    elif 'encoder_decoder.' in key:
        state_dict[key.replace('encoder_decoder.', '')] = state_dict.pop(key)
    else:
        warnings.warn(f'Key not found: {key}')

In [19]:
# Load renamed state dict:
encoder_decoder.load_state_dict(state_dict)

<All keys matched successfully>

In [20]:
# Save model:
save_path = '/scratch/pawsey0864/anicolson/experiments/cxrmate/huggingface_variable'
encoder_decoder.save_pretrained(save_path)

In [26]:
# Load tokenizer:
tokenizer_dir = os.path.join(ckpt_zoo_dir, 'mimic-cxr-tokenizers', 'bpe_prompt')
tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained(tokenizer_dir, local_files_only=True)

In [27]:
# Image processor:
image_processor = transformers.AutoFeatureExtractor.from_pretrained(os.path.join(ckpt_zoo_dir, encoder_ckpt_name))
image_processor.save_pretrained(save_path)

['/scratch/pawsey0864/anicolson/experiments/cxrmate/huggingface_variable/preprocessor_config.json']

In [30]:
# Hub login:
from huggingface_hub import login

with open('/home/anicolson/hf_token.txt', 'r') as f:
    token = f.readline()
login(token=token[:-1])

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid.
Your token has been saved to /scratch/pawsey0864/anicolson/checkpoints/token
Login successful


In [31]:
# Push to hub:
encoder_decoder.push_to_hub(hub_ckpt_name)
tokenizer.push_to_hub(hub_ckpt_name)
image_processor.push_to_hub(hub_ckpt_name)

pytorch_model.bin: 100%|██████████| 450M/450M [00:49<00:00, 9.13MB/s]
Upload 1 LFS files: 100%|██████████| 1/1 [00:50<00:00, 50.02s/it]


CommitInfo(commit_url='https://huggingface.co/aehrc/cxrmate/commit/f468c8a98b73a47e1fcaace298224f589bc6f503', commit_message='Upload feature extractor', commit_description='', oid='f468c8a98b73a47e1fcaace298224f589bc6f503', pr_url=None, pr_revision=None, pr_num=None)