In [2]:
from dlhpcstarter.utils import load_config_and_update_args
from dlhpcstarter.command_line_arguments import read_command_line_arguments
from multi_model.modelling_multi import MultiCXREncoderDecoderModel, CvtWithProjectionHeadConfig
import torch
import transformers
import os
import warnings

In [3]:
transformers.__version__, torch.__version__

('4.41.2', '2.1.1')

In [4]:
# Hub checkpoint name:
hub_ckpt_name = 'aehrc/cxrmate-multi-tf'

In [5]:
# Paths:

# Note: multi-image CXR report generation was named variable-image CXR report generation during development, hence, 083_any_variable:
ckpt_path = '/datasets/work/hb-mlaifsp-mm/work/experiments/mimic_cxr/083_any_variable/trial_1/epoch=28-val_report_chexbert_f1_macro=0.383505.ckpt'
ckpt_zoo_dir = '/datasets/work/hb-mlaifsp-mm/work/checkpoints'

In [6]:
# Load state dict with depreciated keys:
state_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict']

In [7]:
# Encoder & decoder config:
config_decoder = transformers.BertConfig(
    vocab_size=30000,
    num_hidden_layers=6,
    type_vocab_size=2,
)  # BERT as it includes token_type_ids.
encoder_ckpt_name = 'microsoft/cvt-21-384-22k'
config_encoder = CvtWithProjectionHeadConfig.from_pretrained(
    os.path.join(ckpt_zoo_dir, encoder_ckpt_name),
    local_files_only=True,
    projection_size=config_decoder.hidden_size,
)
config = transformers.VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)

# Encoder-to-decoder instance:
MultiCXREncoderDecoderModel.register_for_auto_class("AutoModel")
encoder_decoder = MultiCXREncoderDecoderModel(config=config)

In [8]:
for key in list(state_dict.keys()):
    if 'encoder_projection' in key:
        state_dict[key.replace('encoder_projection', 'encoder.projection_head.projection')] = state_dict.pop(key)
    elif 'last_hidden_state_layer_norm' in key:
        state_dict[key.replace('last_hidden_state_layer_norm', 'encoder.projection_head.layer_norm')] = state_dict.pop(key)
    elif 'encoder.encoder' in key:
        state_dict[key.replace('encoder.encoder', 'encoder.cvt.encoder')] = state_dict.pop(key)
    elif 'encoder_decoder.' in key:
        state_dict[key.replace('encoder_decoder.', '')] = state_dict.pop(key)
    else:
        warnings.warn(f'Key not found: {key}')

In [9]:
# Load renamed state dict:
encoder_decoder.load_state_dict(state_dict)

<All keys matched successfully>

In [10]:
# Save model:
save_path = '/scratch3/nic261/checkpoints/cxrmate/huggingface_multi'
encoder_decoder.save_pretrained(save_path)

[2024-06-28 15:58:16,764] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)



* 'schema_extra' has been renamed to 'json_schema_extra'


In [11]:
# Load tokenizer:
tokenizer_dir = os.path.join(ckpt_zoo_dir, 'mimic-cxr-tokenizers', 'bpe_prompt')
tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained(tokenizer_dir, local_files_only=True)

In [14]:
# Image processor:
image_processor = transformers.AutoFeatureExtractor.from_pretrained(os.path.join(ckpt_zoo_dir, encoder_ckpt_name))
image_processor.feature_extractor_type = "ConvNextFeatureExtractor"



In [15]:
image_processor

ConvNextFeatureExtractor {
  "_valid_processor_keys": [
    "images",
    "do_resize",
    "size",
    "crop_pct",
    "resample",
    "do_rescale",
    "rescale_factor",
    "do_normalize",
    "image_mean",
    "image_std",
    "return_tensors",
    "data_format",
    "input_data_format"
  ],
  "crop_pct": 0.875,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "feature_extractor_type": "ConvNextFeatureExtractor",
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "ConvNextFeatureExtractor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 384
  }
}

In [16]:
image_processor.save_pretrained(save_path)

['/scratch3/nic261/checkpoints/cxrmate/huggingface_multi/preprocessor_config.json']

In [17]:
# Hub login:
from huggingface_hub import login

with open('/home/nic261/hf_token.txt', 'r') as f:
    token = f.readline()
login(token=token[:-1])
del token

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /scratch3/nic261/.cache/token
Login successful


In [18]:
# Push to hub:
encoder_decoder.push_to_hub(hub_ckpt_name)
tokenizer.push_to_hub(hub_ckpt_name)
image_processor.push_to_hub(hub_ckpt_name)

CommitInfo(commit_url='https://huggingface.co/aehrc/cxrmate-multi-tf/commit/77dc7a7ae05bc3b6f65109a58afbc8bb886a0039', commit_message='Upload feature extractor', commit_description='', oid='77dc7a7ae05bc3b6f65109a58afbc8bb886a0039', pr_url=None, pr_revision=None, pr_num=None)