In [1]:
from pathlib import Path
from typing import Annotated, Union

import typer
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast
)

ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

def load_model_and_tokenizer(
        model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
    model_dir = Path(model_dir).expanduser().resolve()
    if (model_dir / 'adapter_config.json').exists():
        model = AutoPeftModelForCausalLM.from_pretrained(
            model_dir, trust_remote_code=trust_remote_code, device_map='auto'
        )
        tokenizer_dir = model.peft_config['default'].base_model_name_or_path
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir, trust_remote_code=trust_remote_code, device_map='auto'
        )
        tokenizer_dir = model_dir
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_dir, trust_remote_code=trust_remote_code, encode_special_tokens=True, use_fast=False
    )
    return model, tokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_dir = "output/checkpoint-250"

model, tokenizer = load_model_and_tokenizer(model_dir)

Loading checkpoint shards: 100%|██████████| 10/10 [04:35<00:00, 27.58s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
import pandas as pd
from tools.common_utils import highlight_diff, read_jsonl

json_file_path = 'data/dev.jsonl'
data_list = read_jsonl(json_file_path)

df = pd.DataFrame(data_list)
df.head()

Unnamed: 0,messages
0,"[{'role': 'user', 'content': '你现在是一个金融领域专家，你需要..."
1,"[{'role': 'user', 'content': '你现在是一个金融领域专家，你需要..."
2,"[{'role': 'user', 'content': '你现在是一个金融领域专家，你需要..."
3,"[{'role': 'user', 'content': '你现在是一个金融领域专家，你需要..."
4,"[{'role': 'user', 'content': '你现在是一个金融领域专家，你需要..."


In [4]:
from random import random

total_eval_count = 0.0
correct_count = 0.0

for index, row in df.iterrows():
    if random() <= 0.9:
        continue
    
    single_user_row = row["messages"][0]
    single_assis_row = row["messages"][1]
    
    messages = [single_user_row]
    output = single_assis_row["content"]
    
    inputs = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt"
        ).to(model.device)
    
    generate_kwargs = {
            "input_ids": inputs,
            "max_new_tokens": 1024,
            "do_sample": True,
            "top_p": 0.8,
            "temperature": 0.8,
            "repetition_penalty": 1.2,
            "eos_token_id": model.config.eos_token_id,
        }
    
    outputs = model.generate(**generate_kwargs)
    response = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True).strip()
    
    # optimized_resp = optimize_parameters(response, fund_standard_name, stock_standard_name)
    
    total_eval_count += 1
    if response == output:
        correct_count += 1
    else:
        print("-----data index-----")
        print(index)
        print("-----query input-----")
        print(input)
        print("-----output diff-----")
        print(highlight_diff(output, response))
        print(response)
        print()
    
    if total_eval_count == 20:
        break
    
print("预测正确的比例：" + f"{correct_count / total_eval_count :.2%}")

-----data index-----
1
-----query input-----
<bound method Kernel.raw_input of <ipykernel.ipkernel.IPythonKernel object at 0x7f6fa146d5d0>>
-----output diff-----
{"relevant APIs": [{"api_id": "0", "api_name": "查询代码", "required_parameters": [["汇丰晋信2016生命周期开放式证券投资基金C类"]], "rely_apis": [], "tool_name": "基金查询"}, {"api_id": "1", "api_name": "查询[31m基[0m[31m金[0m经理", "required_parameters": ["api_0的结果"], "rely_apis": ["0"], "tool_name": "基金查询"}, {"api_id": "2", "api_name": "查询[31m规[0m[31m模[0m[31m"[0m[31m,[0m[31m [0m[31m"[0m[31mr[0m[31me[0m[31mq[0m[31mu[0m[31mi[0m[31mr[0m[31me[0m[31md[0m[31m_[0m[31mp[0m[31ma[0m[31mr[0m[31ma[0m[31mm[0m[31me[0m[31mt[0m[31me[0m[31mr[0m[31ms[0m[31m"[0m[31m:[0m[31m [0m[31m[[0m"[31ma[0m[31mp[0m[31mi[0m[31m_[0m[31m0[0m[31m的[0m[31m结[0m[31m果[0m[31m"[0m[31m][0m[31m,[0m[31m [0m"[31mr[0m[31me[0m[31ml[0m[31my[0m[31m_[0m[31ma[0m[31mp[0m[31mi[0m[31ms[0m[31m"[0m[31m:[0m[