In [1]:
import sys
import os
import numpy as np
import tqdm

import mmap
import re
from argparse import ArgumentParser

import transformers
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import torch
from sfm.models.scigpt.scigpt import ScigptModel
from sfm.models.scigpt.config import ScigptConfig
from sfm.utils import arg_utils
from argparse import ArgumentParser

import multiprocessing as mp
from sfm.utils.science_tokens import SCIENCE_TAG_TOKENS, SCIENCE_TOKENS

from sfm.logging import logger

import struct
from multiprocessing import Lock


[2024-06-06 05:51:21,973] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[[32m2024-06-06 05:51:23.494[0m][[36mINFO[0m]: apex is installed, using FusedAdam with fp16 optimizer states


In [4]:

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"

def get_args_and_tokenizer(use_llama=False):
    parser = ArgumentParser()
    cfg_classes = [ScigptConfig]
    parser = arg_utils.add_dataclass_to_parser(cfg_classes, parser)
    args = parser.parse_args(args=[])
    args.load_ckpt = False
    args.strategy = "DDP"
    args.encoder_layers = 33
    args.encoder_embed_dim = 1280
    args.encoder_ffn_embed_dim = 5120
    args.encoder_attention_heads = 20
    args.infer = True
    args.bf16 = True

    tokenizer = AutoTokenizer.from_pretrained('/home/v-zekunguo/sfm/llama/Meta-Llama-3-8B/original')
    # args.save_dir = "/home/v-zekunguo/hai1data/nlm/output/llama3_stageB/global_step1600/"
    args.save_dir = '/home/v-zekunguo/nlm/peiran/output/llama3_stageB_G256/global_step8572/'
    args.llm_model_name_or_path = '/home/v-zekunguo/sfm/llama/Meta-Llama-3-8B/original'

    special_tokens_dict = dict()
    if tokenizer.pad_token is None:
        special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
    if tokenizer.eos_token is None:
        special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
    if tokenizer.bos_token is None:
        special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
    if tokenizer.unk_token is None:
        special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN

    # special_tokens_dict["additional_special_tokens"] = SCIENCE_TAG_TOKENS
    tokenizer.add_special_tokens(special_tokens_dict)
    tokenizer.tag_re = re.compile(f'{"|".join(SCIENCE_TAG_TOKENS)}')
    tokenizer.smiles_re = re.compile(
        "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
    )

    tokenizer.add_special_tokens(
        {
            "pad_token": "[PAD]",
            "unk_token":"<unk>",
        },

    )

    tokenizer.add_tokens(SCIENCE_TAG_TOKENS)
    tokenizer.add_tokens(SCIENCE_TOKENS)
    extra_tokens = []
    # protein
    for i in range(26):
        extra_tokens.append(f"<a>{chr(65 + i)}")

    # DNA, RNA, including ambiguous bases
    for c in "ACTGURYSWKMBDHVN":
        extra_tokens.append(f"<d>{c}")
        extra_tokens.append(f"<r>{c}")

    # materials, non-elements
    for c in "0123456789()+-":
        extra_tokens.append(f"<i>{c}")
    for i in range(26):
        extra_tokens.append(f"<i>{chr(65 + i)}")
        extra_tokens.append(f"<i>{chr(97 + i)}")

    tokenizer.add_tokens(extra_tokens)
    tokenizer.split_special_tokens = True  # Ensure _tokenize() can access special tokens

    logger.info(f"Tokenizer has {len(tokenizer)} tokens")

    args.vocab_size=len(tokenizer)

    return args, tokenizer

args, tokenizer = get_args_and_tokenizer()
print(type(tokenizer))

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


<class 'transformers.tokenization_utils_fast.PreTrainedTokenizerFast'>


In [6]:
args.vocab_size=130304

In [24]:
tokenizer.encode("<protein>ASD</protien>")

[128264, 1950, 35, 524, 4490, 3675, 29]

In [7]:
# Loading the extended trained model
ckpt_dict = {}

model = ScigptModel(args)
model.decoder.resize_token_embeddings(args.vocab_size)
model_dict = model.state_dict()
print(f"model_dict: {model_dict.keys()}")
print(model_dict['decoder.model.layers.0.mlp.gate_proj.weight'].shape)
print(model_dict['decoder.model.layers.0.mlp.up_proj.weight'].shape)
weight1_size=model_dict['decoder.model.layers.0.mlp.gate_proj.weight'].size(0)
weight2_size=model_dict['decoder.model.layers.0.mlp.up_proj.weight'].size(0)
layer0 = torch.load(os.path.join(args.save_dir, "layer_00-model_00-model_states.pt"), map_location=torch.device("cpu"))
for k, v in layer0.items():
    if k=='word_embeddings.weight':
        ckpt_dict['decoder.model.embed_tokens.weight'] = v

for l in range(0, 32):
    l_index = str(l + 1).zfill(2)
    layer = torch.load(os.path.join(args.save_dir, f"layer_{l_index}-model_00-model_states.pt"), map_location=torch.device("cpu"))
    for k in layer:
        if "dummy" in k or 'rotary_emb' in k:
            continue
        if k=="self_attention.layernorm_qkv.layer_norm_weight":
            ckpt_dict[f"decoder.model.layers.{l}.input_layernorm.weight"] = layer[k]
        elif k=='self_attention.layernorm_qkv.query_weight':
            ckpt_dict[f"decoder.model.layers.{l}.self_attn.q_proj.weight"] = layer[k]
        elif k=='self_attention.layernorm_qkv.key_weight':
            ckpt_dict[f"decoder.model.layers.{l}.self_attn.k_proj.weight"] = layer[k]
        elif k=='self_attention.layernorm_qkv.value_weight':
            ckpt_dict[f"decoder.model.layers.{l}.self_attn.v_proj.weight"] = layer[k]
        elif k=='self_attention.proj.weight':
            ckpt_dict[f"decoder.model.layers.{l}.self_attn.o_proj.weight"] = layer[k]
        elif k=='layernorm_mlp.layer_norm_weight':
            ckpt_dict[f"decoder.model.layers.{l}.post_attention_layernorm.weight"] = layer[k]
        elif k=='layernorm_mlp.fc1_weight':
            weight1,weight2=torch.split(layer[k], [weight1_size, weight2_size], dim=0)
            ckpt_dict[f"decoder.model.layers.{l}.mlp.gate_proj.weight"] = weight1
            ckpt_dict[f"decoder.model.layers.{l}.mlp.up_proj.weight"] = weight2
        elif k=='layernorm_mlp.fc2_weight':
            ckpt_dict[f"decoder.model.layers.{l}.mlp.down_proj.weight"] = layer[k]
    del layer

layer = torch.load(os.path.join(args.save_dir, "layer_33-model_00-model_states.pt"), map_location=torch.device("cpu"))
ckpt_dict["decoder.model.norm.weight"] = layer["norm.weight"]

layer = torch.load(os.path.join(args.save_dir, "layer_34-model_00-model_states.pt"), map_location=torch.device("cpu"))
ckpt_dict["decoder.lm_head.weight"] = layer["lm_head.weight"]

print(f"ckpt_dict: {ckpt_dict.keys()}")
model_dict.update(ckpt_dict)
model.load_state_dict(model_dict)



model_dict: odict_keys(['decoder.model.embed_tokens.weight', 'decoder.model.layers.0.self_attn.q_proj.weight', 'decoder.model.layers.0.self_attn.k_proj.weight', 'decoder.model.layers.0.self_attn.v_proj.weight', 'decoder.model.layers.0.self_attn.o_proj.weight', 'decoder.model.layers.0.mlp.gate_proj.weight', 'decoder.model.layers.0.mlp.up_proj.weight', 'decoder.model.layers.0.mlp.down_proj.weight', 'decoder.model.layers.0.input_layernorm.weight', 'decoder.model.layers.0.post_attention_layernorm.weight', 'decoder.model.layers.1.self_attn.q_proj.weight', 'decoder.model.layers.1.self_attn.k_proj.weight', 'decoder.model.layers.1.self_attn.v_proj.weight', 'decoder.model.layers.1.self_attn.o_proj.weight', 'decoder.model.layers.1.mlp.gate_proj.weight', 'decoder.model.layers.1.mlp.up_proj.weight', 'decoder.model.layers.1.mlp.down_proj.weight', 'decoder.model.layers.1.input_layernorm.weight', 'decoder.model.layers.1.post_attention_layernorm.weight', 'decoder.model.layers.2.self_attn.q_proj.weight

<All keys matched successfully>

In [None]:
ckpt_dict = {}
# Load the original llama3 model
model = ScigptModel(args)

model_dict = model.state_dict()
print(f"model_dict: {model_dict.keys()}")

layer0 = torch.load(os.path.join(args.save_dir, "layer_00-model_states.pt"), map_location=torch.device("cpu"))
for k, v in layer0.items():
    new_k = "decoder.model." + k
    ckpt_dict[new_k] = v

for l in range(0, 32):
    l_index = str(l + 1).zfill(2)
    layer = torch.load(os.path.join(args.save_dir, f"layer_{l_index}-model_states.pt"), map_location=torch.device("cpu"))
    for k in layer:
        if "dummy" in k or 'rotary_emb' in k:
            continue
        ckpt_dict[f"decoder.model.layers.{l}.{k}"] = layer[k]
    del layer

layer = torch.load(os.path.join(args.save_dir, "layer_33-model_states.pt"), map_location=torch.device("cpu"))
ckpt_dict["decoder.model.norm.weight"] = layer["norm.weight"]

layer = torch.load(os.path.join(args.save_dir, "layer_34-model_states.pt"), map_location=torch.device("cpu"))
ckpt_dict["decoder.lm_head.weight"] = layer["lm_head.weight"]

print(f"ckpt_dict: {ckpt_dict.keys()}")
model_dict.update(ckpt_dict)
model.load_state_dict(model_dict)



In [8]:
device = torch.device("cuda")

model = model.to(torch.bfloat16).to(device)

model.eval()


ScigptModel(
  (loss): AutoregressiveCriterion(
    (cross_entropy): CrossEntropyLoss()
  )
  (decoder): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(130304, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaSdpaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          

In [16]:
import lmdb
from sfm.data.prot_data.util import bstr2obj
# load data
file_path='/home/v-zekunguo/sfm/nlm/valid_lmdb/new_valid.patent.v2.txt.lmdb'
env = lmdb.open(
    file_path, subdir=True, readonly=True, lock=False, readahead=False
)
txn = env.begin(write=False)

print(env.stat())
count=0
metadata = bstr2obj(txn.get("metadata".encode()))
cur_len, cur_keys = metadata["size"], metadata["keys"]
print(cur_len)

{'psize': 4096, 'depth': 2, 'branch_pages': 1, 'leaf_pages': 38, 'overflow_pages': 33698, 'entries': 3745}
3744


In [19]:
key=cur_keys[4]
value = txn.get(str(key).encode())
input_ids = np.frombuffer(value, dtype=np.uint32)
tokenizer.decode(input_ids)

' treatment time for step (iii) is from about 30 minutes to 18 hours, particularly from about 1 to 6 hours, more particularly from about 2 to 3 hours. Typically, the fabric is washed between each of the steps (i) to (iii) of the method of the present invention. For example, the fabric may be washed with water, for example with distilled water. The washing step substantially removes residual reagents present from the previous reaction step(s). Typically, after step the fibrous catalyst is dried before use. The catalyst may be dried using any conventional means, for example at temperatures up to 105° C. Any fabric comprising PAN fibres may be used in the present invention. The references herein to a fabric may refer simply to an arrangement of one or more PAN fibres. In one aspect of the invention, the fabric that comprises PAN fibres is a knitted fabric, such as a fibrous knitted mesh. Thus, in this aspect, the PAN fibres/yarn must be capable of being knitted. The knitted fabric may be 

In [None]:
import numpy as np
# Calculate loss
print(metadata.keys())
loss_list=[]
print(metadata['processed_seq_len'])
for key in cur_keys:
    value = txn.get(str(key).encode())
    input_ids = np.frombuffer(value, dtype=np.uint32)
    input_tensor = torch.from_numpy(input_ids.astype(np.int64)).unsqueeze(0).to(device)
    labels = input_tensor.clone()
    out = model.decoder(input_tensor, labels=labels)
    input_tensor.to("cpu")
    labels.to("cpu")
    print(out.loss.cpu().item())
    loss_list.append(out.loss.cpu().item())
    out = None
    torch.cuda.empty_cache()

    del out
print(sum(loss_list) / len(loss_list))

In [None]:
tokenizer.decode(input_ids)

In [None]:
tokenizer.decode(input_tensor[0])

In [None]:
tokenizer.encode("Football is a ", return_tensors="pt")
input_tensor

In [None]:
output = model.decoder.generate(
    input_ids=torch.tensor(input_tensor).to(device),
    num_beams=5,
    max_new_tokens=512,
    num_return_sequences=1,
    return_dict_in_generate=True,
    output_scores=True,
    do_sample=True,
    top_p=0.95,
    repetition_penalty=1.5,
)
res = tokenizer.decode(output.sequences[0], skip_special_tokens=False)
print(res)

In [None]:
input_ids=torch.tensor(tokenizer.encode("Football is a ", return_tensors="pt")).to(device)
labels = input_ids.clone()
out = model.decoder(input_ids,labels=labels)

In [23]:
output = model.decoder.generate(
    input_ids=torch.tensor(tokenizer.encode("The capital of China is", return_tensors="pt")).to(device),
    num_beams=5,
    max_new_tokens=512,
    num_return_sequences=1,
    return_dict_in_generate=True,
    output_scores=True,
    do_sample=True,
    top_p=0.95,
    repetition_penalty=1.5,
)
res = tokenizer.decode(output.sequences[0], skip_special_tokens=False)
print(res)

  input_ids=torch.tensor(tokenizer.encode("The capital of China is", return_tensors="pt")).to(device),
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


The capital of China is located in the east of the country. The city is one of the largest cities in the world and it is also one of the most popular tourist destinations in the world. There are many things to see and do in Beijing. You can visit the Great Wall of China, the Forbidden City, the Temple of Heaven, the Summer Palace, the Old Summer Palace, the Drum Tower, the Bell Tower, the National Museum of China, the National Art Museum of China, the National Gallery of China, the National Library of China, the National Theatre of China, the National Astronomical Observatory of China, the National Zoological Park of China, the Beijing National Stadium, the Beijing National Aquatics Center, the Beijing National Stadium, the Beijing National Stadium, the Beijing National Stadium, the Beijing National Stadium, the Beijing National Stadium, the Beijing National Stadium, the Beijing National Stadium, the Beijing National Stadium, the Beijing National Stadium, the Beijing National Stadium, 

In [None]:
out

In [25]:
a='''0.20717960596084595
0.29078763723373413
0.4763331115245819
0.1211240291595459
0.2743508517742157
0.18720833957195282
0.11175647377967834
0.3087017238140106
0.044937893748283386
0.11803882569074631
0.2398451566696167
0.15379402041435242
0.07513220608234406
0.21486173570156097
0.2845916152000427
0.26695480942726135
0.5750867128372192
0.2512124478816986
0.273259699344635
0.06262169778347015
0.2809830605983734
0.12392335385084152
0.5005030632019043
0.6178805232048035
0.15205508470535278
1.3323503732681274
1.335503339767456
0.20509064197540283
0.13405446708202362
0.341273695230484
0.21496140956878662
0.396593302488327
0.08487134426832199
0.281646728515625
1.3463149070739746
0.9981168508529663
0.04263443499803543
1.2692070007324219
0.5956265330314636
0.19295325875282288
0.5995656847953796
0.21511414647102356
1.4270107746124268
1.320778489112854
0.15251511335372925
0.5639299154281616
1.3715708255767822
1.453153371810913
0.4727802872657776
1.2407070398330688
0.5323099493980408
0.1223621666431427
0.5076682567596436
1.0528452396392822
1.0112085342407227
0.14629696309566498
0.1310725212097168
0.07658973336219788
0.08177613466978073
1.1829582452774048
0.8916467428207397
0.11947863548994064
0.423902690410614
0.14153778553009033
0.047447025775909424
0.20104463398456573
0.12043419480323792
0.0807940736413002
0.18193954229354858
0.47314414381980896
0.0679519921541214
0.12400541454553604
1.3700897693634033
0.9072701930999756
0.12065614014863968
1.7163435220718384
0.08381613343954086
0.0326957032084465
0.40456461906433105
0.692696213722229
0.3141028583049774
0.2531353235244751
0.2273268848657608
0.13474857807159424
0.563109278678894
0.156013622879982
1.658959150314331
0.09070580452680588
1.3941181898117065
0.4297913610935211
0.5922026038169861
0.22258101403713226
0.12181457132101059
0.16902956366539001
0.22528114914894104
0.9651983976364136
0.10283296555280685
0.28349676728248596
0.2935939133167267
0.22655513882637024
0.18640880286693573
0.2510978579521179
0.8008484840393066
1.3282228708267212
0.21326211094856262
0.14087621867656708
0.10235189646482468
0.11559632420539856
0.11227940768003464
0.2077503502368927
1.358059287071228
0.7917405962944031
0.22339721024036407
1.3450703620910645
0.630079984664917
0.13973328471183777
0.5261379480361938
0.06727326661348343
0.23248539865016937
0.31089815497398376
1.3583070039749146
0.07996442168951035
1.3741847276687622
0.5316010117530823
0.34544873237609863
0.07100202143192291
0.12925738096237183
1.365248680114746
0.18263107538223267
0.36207520961761475
1.3990753889083862
0.30359530448913574
0.09579917788505554
0.030381273478269577
0.19169947504997253
0.03946821019053459
0.24850626289844513
0.17042821645736694
0.7291426062583923
0.22262592613697052
0.06581363081932068
0.7285424470901489
0.392994225025177
0.18774549663066864
0.186307892203331
0.5569478869438171
0.6422935724258423
0.31806761026382446
0.10028895735740662
1.0810362100601196
0.20113618671894073
0.5040784478187561
1.3916536569595337
0.6414272785186768
0.1204301044344902
1.2862085103988647
0.24327492713928223
0.21865741908550262
0.6439153552055359
0.24383366107940674
0.31891173124313354
0.15836231410503387
0.10932537913322449
0.14307968318462372
0.12938232719898224
0.09723549336194992
0.5819553136825562
0.8949821591377258
1.3371659517288208
0.8995532393455505
0.18025264143943787
0.20069092512130737
0.18684826791286469
0.2974279522895813
0.9680314660072327
0.19020363688468933
0.5359116196632385
0.09533888846635818
0.12332974374294281
0.3095950484275818
0.25255653262138367
0.15088067948818207
0.18418022990226746
0.25024962425231934
1.6245007514953613
0.25705811381340027
1.2602794170379639
0.2825852930545807
0.4596610963344574
0.27140161395072937
1.3548699617385864
1.6391781568527222
0.21397431194782257
0.32775214314460754
0.08903037756681442
0.15756107866764069
1.3055980205535889
0.03688358888030052
1.3175513744354248
0.08926389366388321
0.15858736634254456
0.31540006399154663
0.28574520349502563
0.36452287435531616
0.10388889163732529
0.23763996362686157
0.09452733397483826
0.23180389404296875
0.23069925606250763
0.6979408860206604
0.10550756007432938
0.4723520874977112
0.14598198235034943
0.1557895392179489
0.6543782949447632
0.4490549862384796
0.103681780397892
0.19329038262367249
1.2767599821090698
0.8332330584526062
0.10007733106613159
0.15582333505153656
1.249241828918457
0.517615020275116
0.15385372936725616
1.3441174030303955
0.9407714009284973
0.12658777832984924
0.14014221727848053
1.3540621995925903
0.2357904464006424
0.8998426795005798
0.10065589100122452
0.41297203302383423
0.28044211864471436
1.3492896556854248
0.043217238038778305
0.21052215993404388
0.08350750803947449
1.1624751091003418
1.2785097360610962
0.20286978781223297
0.6564274430274963
0.6118311285972595
0.1830492466688156
1.2313441038131714
0.2561013996601105
0.0793098732829094
0.4527606666088104
0.30490294098854065
0.08304128795862198
0.26137775182724
0.5580241680145264
0.6373624205589294
0.25001731514930725
0.1880124807357788
0.24547462165355682
0.7612487077713013
0.3183324933052063
0.11884171515703201
0.07513443380594254
0.0670829489827156
0.38079485297203064
0.34268122911453247
0.2640489935874939
0.2765796482563019
1.6809512376785278
0.17735382914543152
0.11657076328992844
0.31248295307159424
0.3016645312309265
0.5279101729393005
0.11424760520458221
0.6875072121620178
1.610180139541626
1.3809417486190796
0.3671647310256958
1.3125965595245361
0.3456448018550873
0.13240839540958405
0.17399924993515015
0.07215277850627899
1.142114520072937
1.7568718194961548
0.09248217195272446
0.28186503052711487
0.069955013692379
0.38316670060157776
0.19083179533481598
0.5279322266578674
0.4683115482330322
1.3078995943069458
0.42920079827308655
0.1557501256465912
0.09976639598608017
0.17055709660053253'''
out = a.split('\n')
result=[]
for item in out:
    result.append(float(item))
print(len(result))
print(sum(result)/len(result))

296
0.4630786909444912


: 