In [1]:
from dlhpcstarter.utils import load_config_and_update_args
from dlhpcstarter.command_line_arguments import read_command_line_arguments
from single_model.modelling_single import SingleCXREncoderDecoderModel, CvtWithProjectionHeadConfig
import torch
import transformers
import os
import warnings

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Paths:
ckpt_path = '/datasets/work/hb-mlaifsp-mm/work/experiments/mimic_cxr/082_any_single/trial_2/epoch=17-val_report_chexbert_f1_macro=0.348207.ckpt'
# ckpt_zoo_dir = '/datasets/work/hb-mlaifsp-mm/work/checkpoints'
ckpt_zoo_dir = '/scratch/pawsey0864/anicolson/checkpoints'


In [3]:
# Load state dict with depreciated keys:
state_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict']

In [4]:
# Encoder & decoder config:
config_decoder = transformers.BertConfig(
    vocab_size=30000,
    num_hidden_layers=6,
    type_vocab_size=2,
)  # BERT as it includes token_type_ids.
encoder_ckpt_name = 'microsoft/cvt-21-384-22k'
config_encoder = CvtWithProjectionHeadConfig.from_pretrained(
    os.path.join(ckpt_zoo_dir, encoder_ckpt_name),
    local_files_only=True,
    projection_size=config_decoder.hidden_size,
)
config = transformers.VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)

# Encoder-to-decoder instance:
SingleCXREncoderDecoderModel.register_for_auto_class("AutoModel")
encoder_decoder = SingleCXREncoderDecoderModel(config=config)

In [5]:
for key in list(state_dict.keys()):
    if 'encoder_projection' in key:
        state_dict[key.replace('encoder_projection', 'encoder.projection_head.projection')] = state_dict.pop(key)
    if 'last_hidden_state_layer_norm' in key:
        state_dict[key.replace('last_hidden_state_layer_norm', 'encoder.projection_head.layer_norm')] = state_dict.pop(key)
    if 'encoder.encoder' in key:
        state_dict[key.replace('encoder.encoder', 'encoder.cvt.encoder')] = state_dict.pop(key)
    if 'encoder_decoder.' in key:
        state_dict[key.replace('encoder_decoder.', '')] = state_dict.pop(key)
    else:
        warnings.warn(key)

In [6]:
# Load renamed state dict:
encoder_decoder.load_state_dict(state_dict)

In [7]:
# Save model:
save_path = '/datasets/work/hb-mlaifsp-mm/work/experiments/cxrmate/huggingface_single'
encoder_decoder.save_pretrained(save_path)

In [8]:
# Load tokenizer:
tokenizer_dir = os.path.join(ckpt_zoo_dir, 'mimic-cxr-tokenizers', 'bpe_prompt')
tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained(tokenizer_dir, local_files_only=True)

In [9]:
# Image processor:
image_processor = transformers.AutoFeatureExtractor.from_pretrained(os.path.join(ckpt_zoo_dir, encoder_ckpt_name))
image_processor.save_pretrained(save_path)

In [10]:
# Hub login:
from huggingface_hub import login
login(token='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')

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid.
Your token has been saved to /home/anicolson/.cache/huggingface/token
Login successful


In [11]:
# Push to hub:
encoder_decoder.push_to_hub('aehrc/mimic-cxr-report-gen-single')
tokenizer.push_to_hub('aehrc/mimic-cxr-report-gen-single')
image_processor.push_to_hub('aehrc/mimic-cxr-report-gen-single')