In [5]:
import sys
import os

import transformers
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
from tqdm import tqdm
import torch

from sfm.models.pfm.pfm_config import PFMConfig
from sfm.models.pfm.pfmmodel import PFMModel
from sfm.pipeline.accelerator.dataclasses import DistributedTrainConfig
from sfm.criterions.mae3d import ProteinMAE3dCriterions

from sfm.utils.cli_utils import cli


In [7]:
@cli(DistributedTrainConfig, PFMConfig)
def infer(args) -> None:
    args.sample_mode = True
    args.layers=12
    args.num_pred_attn_layer=2
    args.hidden_size=512
    args.ffn_size=2048
    args.num_head=32
    args.num_3d_bias_kernel=8

    model = PFMModel(args, loss_fn=ProteinMAE3dCriterions)

infer()



In [None]:
from accelerate import load_checkpoint_and_dispatch

device_map = {"graphormer_encoder": 0, "decoder.model.embed_tokens": 0, "adaptor": 0}
for i in range(8):
    for j in range(i * 10, i * 10 + 10):
        device_map[f'decoder.model.layers.{j}'] = i
device_map["decoder.model.norm"] = 7
device_map["decoder.lm_head"] = 0

model = load_checkpoint_and_dispatch(
    model, "/mnt/shiyu/models/converted/ft_100MMFM_70Bllama2_full_mix1/global_step2000/", device_map=device_map, no_split_module_classes=["LlamaDecoderLayer"]
)