In [1]:
import sys
import os
from argparse import ArgumentParser

import transformers
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
from tqdm import tqdm
import torch
from sfm.models.progpt.progpt import ProGPTModel
from sfm.models.progpt.progpt_config import ProGPTConfig
from sfm.models.pfm.pfm_config import PFMConfig
from sfm.data.sci_data.SFMDecTokenizer import SFMDecTokenizer
from sfm.utils import arg_utils


  from .autonotebook import tqdm as notebook_tqdm


[2024-04-14 19:44:45,725] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[[32m2024-04-14 19:44:46.625[0m][[36mINFO[0m]: flash_attn not installed, use default attn


In [2]:
def get_args_and_tokenizer():
    parser = ArgumentParser()
    cfg_classes = [PFMConfig, ProGPTConfig]
    parser = arg_utils.add_dataclass_to_parser(cfg_classes, parser)
    args = parser.parse_args(args=[])
    args.llm_model_name_or_path = "/fastdata/peiran/scigpt/ckpt/stageB.prot/global_step224655"
    args.tokenizer_path = "/fastdata/peiran/scigpt"
    args.save_dir = '/fastdata/peiran/nlm/checkpoints/stageB/global_step12386/'
    args.load_ckpt = False
    args.strategy = "DDP"
    args.encoder_layers = 33
    args.encoder_embed_dim = 1280
    args.encoder_ffn_embed_dim = 5120
    args.encoder_attention_heads = 20

    tokenizer = SFMDecTokenizer.from_pretrained(
        args.llm_model_name_or_path,
        prot_spm_path=os.path.join(args.tokenizer_path, "ur50bpe/bpe"),
        dna_spm_path=os.path.join(args.tokenizer_path, "dnabpe/bpe"),
        rna_spm_path=os.path.join(args.tokenizer_path, "rnabpe/bpe"),
    )
    args.vocab_size = len(tokenizer)  # now we have new tokens
    args.pad_token_id = tokenizer.pad_token_id

    return args, tokenizer

In [3]:
args, tokenizer = get_args_and_tokenizer()

with init_empty_weights():
    model = ProGPTModel(args, len(tokenizer))
# print(model.state_dict().keys())


[[32m2024-04-14 19:44:50.453[0m][[36mINFO[0m]: Loading protein sentencepiece model from /fastdata/peiran/scigpt/ur50bpe/bpe.model and /fastdata/peiran/scigpt/ur50bpe/bpe.vocab
[[32m2024-04-14 19:44:50.456[0m][[36mINFO[0m]: Loading DNA sentencepiece model from /fastdata/peiran/scigpt/dnabpe/bpe.model and /fastdata/peiran/scigpt/dnabpe/bpe.vocab
[[32m2024-04-14 19:44:50.456[0m][[36mINFO[0m]: Loading RNA sentencepiece model from /fastdata/peiran/scigpt/rnabpe/bpe.model and /fastdata/peiran/scigpt/rnabpe/bpe.vocab
[[32m2024-04-14 19:44:50.689[0m][[36mINFO[0m]: Tokenizer has 40014 tokens
[[32m2024-04-14 19:44:50.758[0m][[36mINFO[0m]: Trainer args: Namespace(num_classes=1, encoder_attention_heads=20, encoder_ffn_embed_dim=5120, encoder_embed_dim=1280, encoder_layers=33, num_pred_attn_layer=4, num_3d_bias_kernel=128, max_length=1024, pbc_expanded_token_cutoff=512, pbc_expanded_num_cell_per_direction=10, multi_hop_max_dist=20, droppath_prob=0.0, act_dropout=0.0, attn_dropou

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LlamaTokenizer'. 
The class this function is called from is 'SFMDecTokenizer'.
You are using the default legacy behaviour of the <class 'sfm.data.sci_data.SFMDecTokenizer.SFMDecTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [4]:
ckpt_dict = {}

model_dict = model.state_dict()

layer0 = torch.load(os.path.join(args.save_dir, "layer_00-model_states.pt"), map_location=torch.device("cpu"))
for k, v in layer0.items():
    new_k = "pfm_encoder." + k
    ckpt_dict[new_k] = v

layer1 = torch.load(os.path.join(args.save_dir, "layer_01-model_states.pt"), map_location=torch.device("cpu"))
ckpt_dict['decoder.model.embed_tokens.weight'] = layer1['embed_tokens.weight']

layer2 = torch.load(os.path.join(args.save_dir, "layer_02-model_states.pt"), map_location=torch.device("cpu"))
for k, v in layer2.items():
    ckpt_dict[k] = v

for l in range(0, 32):
    l_index = str(l + 3).zfill(2)
    layer = torch.load(os.path.join(args.save_dir, f"layer_{l_index}-model_states.pt"), map_location=torch.device("cpu"))
    for k in layer:
        if "dummy" in k or 'rotary_emb' in k:
            continue
        ckpt_dict[f"decoder.model.layers.{l}.{k}"] = layer[k]
    del layer

layer = torch.load(os.path.join(args.save_dir, "layer_35-model_states.pt"), map_location=torch.device("cpu"))
ckpt_dict["decoder.model.norm.weight"] = layer["norm.weight"]

layer = torch.load(os.path.join(args.save_dir, "layer_36-model_states.pt"), map_location=torch.device("cpu"))
ckpt_dict["decoder.lm_head.weight"] = layer["lm_head.weight"]

# model_dict.update(ckpt_dict)
model.load_state_dict(model_dict)


<All keys matched successfully>