In [57]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [58]:
from medicap.modelling_medicap import MedICapEncoderDecoderModel, CvtWithProjectionHeadConfig
from medicap.configuration_medicap import MedICapConfig
import torch
import transformers
import os
import warnings

In [59]:
transformers.__version__, torch.__version__

('4.31.0', '2.0.1+cu117')

In [60]:
# Hub checkpoint name:
hub_ckpt_name = 'aehrc/medicap'

In [61]:
# Paths:
ckpt_path = '/datasets/work/hb-mlaifsp-mm/work/experiments/imageclefmed_caption_2023/007_no_ca_scst/trial_2/epoch=2-step=5712-val_bertscore_f1=0.645100.ckpt'
ckpt_zoo_dir = '/datasets/work/hb-mlaifsp-mm/work/checkpoints'

In [62]:
# Load state dict with depreciated keys:
state_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict']

In [63]:
# Encoder & decoder config:
decoder_ckpt_name = 'distilgpt2'
config_decoder = transformers.GPT2Config.from_pretrained(
    os.path.join(ckpt_zoo_dir, decoder_ckpt_name),
    local_files_only=True,
)
config_decoder.add_cross_attention = False  # No cross attention.
config_decoder.is_decoder = True
            
# Resize GPT2 token embedding to include the padding and beginning of sentence tokens:
config_decoder.vocab_size += 2

encoder_ckpt_name = 'microsoft/cvt-21-384-22k'
config_encoder = CvtWithProjectionHeadConfig.from_pretrained(
    os.path.join(ckpt_zoo_dir, encoder_ckpt_name),
    local_files_only=True,
    projection_size=config_decoder.hidden_size,
)
config = MedICapConfig.from_encoder_decoder_configs(config_encoder, config_decoder)

# Encoder-to-decoder instance:
MedICapEncoderDecoderModel.register_for_auto_class("AutoModel")
encoder_decoder = MedICapEncoderDecoderModel(config=config)

In [64]:
state_dict = {k:v for k,v in state_dict.items() if not 'masked_bias' in k}
state_dict = {k:v for k,v in state_dict.items() if not '.attn.bias' in k}

In [65]:
# Load renamed state dict:
encoder_decoder.load_state_dict(state_dict)

<All keys matched successfully>

In [66]:
# Load tokenizer:
tokenizer = transformers.GPT2TokenizerFast.from_pretrained(
    os.path.join(ckpt_zoo_dir, decoder_ckpt_name),
    local_files_only=True,
)
tokenizer.add_special_tokens({'bos_token': '[BOS]', 'pad_token': '[PAD]'})

2

In [67]:
# Image processor:
image_processor = transformers.AutoFeatureExtractor.from_pretrained(os.path.join(ckpt_zoo_dir, encoder_ckpt_name))

In [68]:
# Save model:
save_path = '/scratch2/nic261/checkpoints/medicap'
encoder_decoder.save_pretrained(save_path)
config.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
image_processor.save_pretrained(save_path)

['/scratch2/nic261/checkpoints/medicap/preprocessor_config.json']

In [69]:
# Hub login:
from huggingface_hub import login

with open('/home/nic261/hf_token.txt', 'r') as f:
    token = f.readline()
login(token=token[:-1])
del token

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/nic261/.cache/huggingface/token
Login successful


In [70]:
# Push to hub:
from huggingface_hub import HfApi

api = HfApi()
api.upload_folder(
    folder_path=save_path,
    repo_id=hub_ckpt_name,
    repo_type='model',
)

'https://huggingface.co/aehrc/medicap/tree/main/'