In [6]:
from transformers import AutoTokenizer, AutoConfig, AddedToken
from transformers.generation.utils import GenerationConfig
import torch
from loguru import logger
import copy
import json
import time
import os

from utils import ModelUtils
from template import template_dict


def build_prompt_chatglm3(tokenizer, query, history, system=None):
    history.append({"role": 'user', 'message': query})
    # system
    input_ids = tokenizer.get_prefix_tokens() + \
                [tokenizer.get_command(f"<|system|>")] + \
                tokenizer.encode(system, add_special_tokens=False)
    # convs
    for item in history:
        role, message = item['role'], item['message']
        if role == 'user':
            tokens = [tokenizer.get_command(f"<|user|>")] + \
                     tokenizer.encode(message, add_special_tokens=False) + \
                     [tokenizer.get_command(f"<|assistant|>")]
        else:
            tokens = tokenizer.encode(message, add_special_tokens=False) + [tokenizer.eos_token_id]
        input_ids += tokens

    return input_ids

def build_prompt(tokenizer, template, query, history, system=None):
    template_name = template.template_name
    system_format = template.system_format
    user_format = template.user_format
    assistant_format = template.assistant_format
    system = system if system is not None else template.system

    if template_name == 'chatglm2':
        prompt = tokenizer.build_prompt(query, history)
        input_ids = tokenizer.encode(prompt)
    elif template_name == 'chatglm3':
        input_ids = build_prompt_chatglm3(tokenizer, query, history, system)
    else:
        history.append({"role": 'user', 'message': query})
        input_ids = []

        # setting system information
        if system_format is not None:
            # system信息不为空
            if system is not None:
                system_text = system_format.format(content=system)
                input_ids = tokenizer.encode(system_text, add_special_tokens=False)
        # concat conversation
        for item in history:
            role, message = item['role'], item['message']
            if role == 'user':
                message = user_format.format(content=message, stop_token=tokenizer.eos_token)
            else:
                message = assistant_format.format(content=message, stop_token=tokenizer.eos_token)
            tokens = tokenizer.encode(message, add_special_tokens=False)
            input_ids += tokens
    input_ids = torch.tensor([input_ids], dtype=torch.long)

    return input_ids


def load_tokenizer(model_name_or_path):
    # config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
    # 加载tokenzier
    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path,
        trust_remote_code=True,
        use_fast=False
        # llama不支持fast
        # use_fast=False if config.model_type == 'llama' else True
    )

    if tokenizer.__class__.__name__ == 'QWenTokenizer':
        tokenizer.pad_token_id = tokenizer.eod_id
        tokenizer.bos_token_id = tokenizer.eod_id
        tokenizer.eos_token_id = tokenizer.eod_id
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    # assert tokenizer.pad_token_id is not None, "pad_token_id should not be None"
    return tokenizer



model_name_or_path = '/DATA/jupyter/share/LLM_NBS/Baichuan2-13B-Chat'
template_name = 'baichuan2'
file = './data/v1_prompt_label/test.json'
file = '/DATA/jupyter/personal/ChatGLM3/finetune_demo/data/v1_prompt_label/test.json'
generation_config = GenerationConfig.from_pretrained('/DATA/jupyter/personal/Firefly/output/')
logger.info(generation_config)
template = template_dict[template_name]
# 是否使用4bit进行推理，能够节省很多显存，但效果可能会有一定的下降
load_in_4bit = True
# for step in range(200, 901, 100):
#     torch.cuda.empty_cache()
#     adapter_name_or_path = './output/baichuan2_13b_b1_acc16_epoch3_rk64_a16_lr2e4/checkpoint-%d' % step
#     save_file = './data/v1_prompt_label/13b-checkpoint-%d.json' % step
#     # 加载模型
#     logger.info(f'Loading model from: {model_name_or_path}')
#     logger.info(f'adapter_name_or_path: {adapter_name_or_path}')
# #         try:
#     model = ModelUtils.load_model(
#         model_name_or_path,
#         load_in_4bit=load_in_4bit,
#         adapter_name_or_path=adapter_name_or_path
#     ).eval()
#     tokenizer = load_tokenizer(model_name_or_path if adapter_name_or_path is None else adapter_name_or_path)
#     wf = open(save_file, 'w')
#     st = time.time()
#     with open(file, 'r') as rf:
#         for cnt, line in enumerate(rf):
#             if not line:
#                 continue
#             sample = json.loads(line.strip('\n'))
#             prompt = [sample['conversations'][0]]
#             logger.info(f'{cnt} {prompt}')
#             response = model.chat(tokenizer, prompt, generation_config=generation_config)
#             sample['ans'] = response
#             wf.write(json.dumps(sample, ensure_ascii=False) + '\n')
#     wf.close()


[32m2024-03-01 15:42:28.427[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m96[0m - [1mGenerationConfig {
  "do_sample": true,
  "max_new_tokens": 4,
  "temperature": 0.35,
  "top_p": 0.9
}
[0m


In [7]:
adapter_name_or_path = '/DATA/jupyter/personal/Firefly/output/baichuan2_13b_b1_acc16_epoch3_rk64_a16_lr2e4/checkpoint-%d' % 200
model = ModelUtils.load_model(
    model_name_or_path,
    load_in_4bit=load_in_4bit,
    adapter_name_or_path=adapter_name_or_path
).eval()
tokenizer = load_tokenizer(model_name_or_path if adapter_name_or_path is None else adapter_name_or_path)

    PyTorch 2.1.0a0+4136153 with CUDA 1201 (you have 2.0.0+cu117)
    Python  3.10.12 (you have 3.10.12)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details
You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it).Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model.


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

In [8]:
# wf = open(save_file, 'w')
# st = time.time()
with open(file, 'r') as rf:
    for cnt, line in enumerate(rf):
        if not line:
            continue
        sample = json.loads(line.strip('\n'))
        prompt = [sample['conversations'][0]]
        break
#         logger.info(f'{cnt} {prompt}')
#         response = model.chat(tokenizer, prompt, generation_config=generation_config)
#         sample['ans'] = response
#         wf.write(json.dumps(sample, ensure_ascii=False) + '\n')
# wf.close()

In [9]:
template = template_dict[template_name]

In [10]:
if template.stop_word is None:
    template.stop_word = tokenizer.eos_token
stop_token_id = tokenizer.encode(template.stop_word, add_special_tokens=False)
assert len(stop_token_id) == 1
stop_token_id = stop_token_id[0]

In [11]:
input_ids = build_prompt(tokenizer, template, prompt[0]['content'], [], system=None).to(model.device)

In [12]:
generation_config.do_sample = True

In [13]:
generation_config

GenerationConfig {
  "do_sample": true,
  "max_new_tokens": 4,
  "temperature": 0.35,
  "top_p": 0.9
}

In [14]:
%%time
outputs = model.generate(
    input_ids=input_ids, generation_config=generation_config
)

CPU times: user 6.35 s, sys: 37.5 s, total: 43.9 s
Wall time: 5.28 s


In [42]:
%%time
outputs = outputs.tolist()[0][len(input_ids[0]):]
response = tokenizer.decode(outputs)
response = response.strip().replace(template.stop_word, "").strip()
response

CPU times: user 1.21 ms, sys: 0 ns, total: 1.21 ms
Wall time: 920 µs


'判断为0'

In [31]:
def memory_stats():
    print(torch.cuda.memory_allocated()/1024**2)
    print(torch.cuda.memory_cached()/1024**2)

In [32]:
memory_stats()

22813.427734375
24468.0


In [28]:
torch.cuda.empty_cache()

In [24]:
response

'判断为0</s'

In [16]:
stop_token_id

[2]

In [None]:

history = []

query = input('User：')
while True:
    query = query.strip()
    input_ids = build_prompt(tokenizer, template, query, copy.deepcopy(history), system=None).to(model.device)
    outputs = model.generate(
        input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True,
        top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty,
        eos_token_id=stop_token_id
    )
    outputs = outputs.tolist()[0][len(input_ids[0]):]
    response = tokenizer.decode(outputs)
    response = response.strip().replace(template.stop_word, "").strip()
    # update history
    history.append({"role": 'user', 'message': query})
    history.append({"role": 'assistant', 'message': response})

    print("Firefly：{}".format(response))
    query = input('User：')