In [1]:
from pathlib import Path
from typing import Annotated, Union

import typer
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast
)

ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

def load_model_and_tokenizer(
        model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
    model_dir = Path(model_dir).expanduser().resolve()
    if (model_dir / 'adapter_config.json').exists():
        model = AutoPeftModelForCausalLM.from_pretrained(
            model_dir, trust_remote_code=trust_remote_code, device_map='auto'
        )
        tokenizer_dir = model.peft_config['default'].base_model_name_or_path
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir, trust_remote_code=trust_remote_code, device_map='auto'
        )
        tokenizer_dir = model_dir
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_dir, trust_remote_code=trust_remote_code, encode_special_tokens=True, use_fast=False
    )
    return model, tokenizer

In [2]:
model_dir = "../../output/checkpoint-1000"
model, tokenizer = load_model_and_tokenizer(model_dir)

Loading checkpoint shards:   0%|          | 0/10 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
from tools.utils import post_process,read_jsonl

json_file_path = 'data/dev.jsonl'
df = read_jsonl(json_file_path)
df.head()

Unnamed: 0,messages
0,"[{'role': 'user', 'content': ' <任务>  你现在是一个..."
1,"[{'role': 'user', 'content': ' <任务>  你现在是一个..."
2,"[{'role': 'user', 'content': ' <任务>  你现在是一个..."
3,"[{'role': 'user', 'content': ' <任务>  你现在是一个..."
4,"[{'role': 'user', 'content': ' <任务>  你现在是一个..."


In [20]:
print(df.iloc[80,0][0]['content'])


<任务>
    你现在是一个金融领域专家，负责在接收用户的**股票查询**相关的提问(query)后，通过详细分析问题，编排api来得到用户query的答案，为了保证逻辑的严谨性，你需要将问题一步步拆解，写出思考过程，最后给出json格式的标准答案。
    API描述：
        所有可用的api名称（共有三类，分别为：股票查询、数值计算和逻辑运算）：
            股票查询类api:
                名称：
                    代码,开盘价,最高价,最低价,当前价,收盘价,成交量,成交额,涨停价,跌停价,涨跌额,涨跌幅,主力资金流入,主力资金流出,主力资金净流入,总流入,总流出,
                    换手率,每股收益,静态市盈率,总市值,振幅,流通市值,每股收益ttm,市盈率ttm,净资产收益率,每股净资产,每股经营性现金流,毛利率,净利率,净利润,
                    净利润同比增长,营业收入,营收同比增长,投资收入,营业利润,营业利润同比增长,扣非净利润,流动资产,总资产,短期负债,总负债,股东权益,净经营性现金流,净投资性现金流,净融资性现金流,银行资本充足率,流通股本,总股本,高管名称
                说明：
                    其中第一行除了"代码"以外的api（从“开盘价”到“总流出”），有第二参数：时间。其余api均只有1个参数。
            数值计算类(4个)：
                名称：加法计算,减法计算,乘法计算,除法计算
                说明：输入两个数值，做相应计算
            逻辑运算类(2个)：
                名称：与运算,或运算
                说明：输入两个列表，分别是做交集和并集
    行业背景知识&Tips（可以辅助你分析解决问题）：
        1. 如无特殊说明，计算收益(率)统一使用收盘价计算
        2. 涉及到市盈率，一般指的是静态市盈率；动态市盈率指的是市盈率ttm
        3. 只能选用上文中"所有可用的api"中的api，不能捏造
        4. 切题，避免无关的api

In [12]:
%%time
from random import random
from tools.common_utils import highlight_diff

total_eval_count = 0.0
correct_count = 0.0
eval_num = 100

for index, row in df.iterrows():
    if random() <= 0.9:
        continue
    
    single_user_row = row["messages"][0]
    single_assis_row = row["messages"][1]
    
    messages = [single_user_row]
    output = single_assis_row["content"]
    
    inputs = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt"
        ).to(model.device)
    
    generate_kwargs = {
            "input_ids": inputs,
            "max_new_tokens": 1024,
            "do_sample": True,
            "top_p": 0.8,
            "temperature": 0.8,
            "repetition_penalty": 1.2,
            "eos_token_id": model.config.eos_token_id,
        }
    
    outputs = model.generate(**generate_kwargs)
    response = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True).strip()

    response = post_process(response)
    output = post_process(output)
    # optimized_resp = optimize_parameters(response, fund_standard_name, stock_standard_name)
    
    total_eval_count += 1
    if response == output:
        correct_count += 1
    else:
        print("-----data index-----")
        print(index)
        print("-----query input-----")
        # print()
        print("-----output diff-----")
        print(highlight_diff(output, response))
        print(response)
        print()
    
    if total_eval_count == 20:
        print('='*20 , f' 第{total_eval_count}次评估 ', '='*20)
        print('='*20 , f' 正确比例：{correct_count / total_eval_count :.2%} ', '='*20)
    
print("预测正确的比例：" + f"{correct_count / total_eval_count :.2%}")

-----data index-----
6
-----query input-----
-----output diff-----
{"relevant APIs": [{"api_id": "0", "api_name": "查询代码", "required_parameters": [["金鹰稳进配置六个月持有期混合型发起式基金中基金(FOF)C类"]], "rely_apis": [], "tool_name": "基金查询"}, {"api_id": "1", "api_name": "查询申购费率", "required_parameters": ["api_0的结果"], "rely_apis": ["0"], "tool_name": "基金查询"}, {"api_id": "2", "api_name": "乘法计算", "required_parameters": ["[31ma[0m[31mp[0m[31mi[0m[31m_[0m[31m1[0m[31m的[0m[31m结[0m[31m果[0m"[31m,[0m[31m [0m[31m"[0m[31m1[0m[31m0[0m[31m0[0m[31m0[0m[31m0[0m"], "rely_apis": ["1"], "tool_name": "数值计算"}], "result": ["api_2的结果"]}
{"relevant APIs": [{"api_id": "0", "api_name": "查询代码", "required_parameters": [["金鹰稳进配置六个月持有期混合型发起式基金中基金(FOF)C类"]], "rely_apis": [], "tool_name": "基金查询"}, {"api_id": "1", "api_name": "查询申购费率", "required_parameters": ["api_0的结果"], "rely_apis": ["0"], "tool_name": "基金查询"}, {"api_id": "2", "api_name": "乘法计算", "required_parameters": ["10000", "api_1的结果"], "rely_apis": [

TypeError: 'NoneType' object is not iterable

In [21]:
output

'{"relevant APIs": [{"api_id": "0", "api_name": "查询代码", "required_parameters": [["西部利得沪深300指数增强型证券投资基金C类"]], "rely_apis": [], "tool_name": "基金查询"}, {"api_id": "1", "api_name": "查询单位净值", "required_parameters": ["api_0的结果"], "rely_apis": ["0"], "tool_name": "基金查询"}, {"api_id": "2", "api_name": "除法计算", "required_parameters": ["1000", "api_1的结果"], "rely_apis": ["1"], "tool_name": "数值计算"}, {"api_id": "3", "api_name": "乘法计算", "required_parameters": ["api_2的结果", "12"], "rely_apis": ["2"], "tool_name": "数值计算"}, {"api_id": "4", "api_name": "查询近期收益率", "required_parameters": ["api_0的结果", "1年"], "rely_apis": ["0"], "tool_name": "基金查询"}, {"api_id": "5", "api_name": "乘法计算", "required_parameters": ["api_3的结果", "api_4的结果"], "rely_apis": ["3", "4"], "tool_name": "数值计算"}], "result": ["api_5的结果"]}'

In [25]:
response
glm4_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True).strip()

In [28]:
standard_output = glm4_output.split('于是最终标准的json格式结果为:')[1].replace('</output>','').strip()
# 格式验证
try:
    standard_output = json.loads(standard_output)
except:
    print(f'Json格式不正确:解析失败 {standard_output}')

if 'relevant APIs' not in standard_output:
    print(f'Json格式不正确:relevant APIs未找到 {standard_output}')
else:
    for api in standard_output['relevant APIs']:
        if 'tool_name' not in api:
            print(f'Json格式不正确:tool_name未找到 {standard_output}')
        if 'api_name' not in api:
            print(f'Json格式不正确:api_name未找到 {standard_output}')
        if 'required_parameters' not in api:
            print(f'Json格式不正确:required_parameters未找到 {standard_output}')
        if 'rely_apis' not in api:
            print(f'Json格式不正确:rely_apis未找到 {standard_output}')
        # 后处理
        if api['tool_name'] in {'基金查询','条件选基','股票查询','条件选股'}:
            api['api_name'] = '查询'+api['api_name']
            if api['tool_name'] == '条件选基' and api['api_name']=='查询基金份额类型':
                api['api_name'] = '查询基金份额类型(A、B、C)'
            elif api['tool_name'] == '条件选股' and api['api_name']=='查询每股经营性现金流':
                api['api_name'] = '查询每股经营性现资金流'
if 'result' not in standard_output:
    print(f'Json格式不正确:result未找到 {standard_output}')

print(standard_output)

Json格式不正确:解析失败 {"relevant APIs": [{"api_id": "0", "api_name": "代码", "required_parameters": [["西部利得沪深300指数增强型证券投资基金C类"]], "rely_apis": [], "tool_name": "基金查询"}, {"api_id": "1", "api_name": "近期收益率", "required_parameters": ["api_0的结果", "1年"], "rely_apis": ["0"], "tool_name": "基金查询"}, {"api_id": "2", "api_name": "乘法计算", "required_parameters": ["api_1的结果", "1000"], "rely_apis": ["1"], "tool_name": "数值计算"}, {"api_id": "3", "api_name": "加法计算", "required_parameters": ["api_2的结果", "api_2的结果", ..., "api_2的结果"], "rely_apis": ["2"], "tool_name": "数值计算"}], "result": ["api_3的结果"]}


TypeError: string indices must be integers

In [32]:
output

'{"relevant APIs": [{"api_id": "0", "api_name": "查询代码", "required_parameters": [["西部利得沪深300指数增强型证券投资基金C类"]], "rely_apis": [], "tool_name": "基金查询"}, {"api_id": "1", "api_name": "查询单位净值", "required_parameters": ["api_0的结果"], "rely_apis": ["0"], "tool_name": "基金查询"}, {"api_id": "2", "api_name": "除法计算", "required_parameters": ["1000", "api_1的结果"], "rely_apis": ["1"], "tool_name": "数值计算"}, {"api_id": "3", "api_name": "乘法计算", "required_parameters": ["api_2的结果", "12"], "rely_apis": ["2"], "tool_name": "数值计算"}, {"api_id": "4", "api_name": "查询近期收益率", "required_parameters": ["api_0的结果", "1年"], "rely_apis": ["0"], "tool_name": "基金查询"}, {"api_id": "5", "api_name": "乘法计算", "required_parameters": ["api_3的结果", "api_4的结果"], "rely_apis": ["3", "4"], "tool_name": "数值计算"}], "result": ["api_5的结果"]}'