In [None]:
from pathlib import Path
from typing import Annotated, Union

import typer
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
)

In [None]:
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

def _resolve_path(path: Union[str, Path]) -> Path:
    return Path(path).expanduser().resolve()


def load_model_and_tokenizer(model_dir: Union[str, Path]) -> tuple[ModelType, TokenizerType]:
    model_dir = _resolve_path(model_dir)
    if (model_dir / 'adapter_config.json').exists():
        model = AutoPeftModelForCausalLM.from_pretrained(
            model_dir, trust_remote_code=True, device_map='auto'
        )
        tokenizer_dir = model.peft_config['default'].base_model_name_or_path
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir, trust_remote_code=True, device_map='auto'
        )
        tokenizer_dir = model_dir
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_dir, trust_remote_code=True
    )
    return model, tokenizer

In [None]:
model_dir = "output/checkpoint-1000"

model, tokenizer = load_model_and_tokenizer(model_dir)

In [None]:
import pandas as pd
from tools.common_utils import highlight_diff, read_jsonl

json_file_path = 'data/train_err.json'
data_list = read_jsonl(json_file_path)

df = pd.DataFrame(data_list)
df.head()

In [None]:
from tools.standard_name_utils import get_standard_api, optimize_api_chain, optimize_parameters

standard_api = get_standard_api("./data/api_定义.json")

data_stock = pd.read_excel('./raw_data/标准名.xlsx',sheet_name='股票标准名')
data_fund = pd.read_excel('./raw_data/标准名.xlsx',sheet_name='基金标准名')
fund_standard_name = data_stock['标准股票名称'].to_list()
stock_standard_name = data_fund['标准基金名称'].to_list()

In [None]:
print(standard_api[:5])
print(fund_standard_name[:5])
print(stock_standard_name[:5])

In [None]:
from random import random

top_p=0.7 
temperature=0.95

total_eval_count = 0.0
correct_count = 0.0

for index, row in df.iterrows():
    if random() <= 0.9:
        continue
    
    single_user_row = row["conversations"][0]
    single_assis_row = row["conversations"][1]
    
    input = single_user_row["content"]
    output = single_assis_row["content"]
    
    
    response, history = model.chat(tokenizer=tokenizer, query=input, history=[], top_p=top_p, temperature=temperature)
    
    # optimized_resp = optimize_api_chain(response, standard_api)
    # optimized_resp = optimize_parameters(response, fund_standard_name, stock_standard_name)
    optimized_resp = response
    
    total_eval_count += 1
    if optimized_resp == output:
        correct_count += 1
    else:
        print("-----data index-----")
        print(index)
        print("-----query input-----")
        print(input)
        print("-----output diff-----")
        print(highlight_diff(output, optimized_resp))
        print(optimized_resp)
        print()
    
    if total_eval_count == 20:
        break
    
print("预测正确的比例：" + f"{correct_count / total_eval_count :.2%}")

#### 输出新的提交文件

In [8]:
test_json_file_path = 'data/test_a.json'
test_data_list = read_jsonl(test_json_file_path)

test_df = pd.DataFrame(test_data_list)
test_df.head()

Unnamed: 0,conversations
0,"[{'role': 'user', 'content': '你现在是一个金融领域专家，你需要..."
1,"[{'role': 'user', 'content': '你现在是一个金融领域专家，你需要..."
2,"[{'role': 'user', 'content': '你现在是一个金融领域专家，你需要..."
3,"[{'role': 'user', 'content': '你现在是一个金融领域专家，你需要..."
4,"[{'role': 'user', 'content': '你现在是一个金融领域专家，你需要..."


In [None]:
total_eval_count = 0

with open('data/submit.txt','w', encoding="utf-8") as n:
    for index, row in test_df.iterrows():    
        single_user_row = row["conversations"][0]
        single_assis_row = row["conversations"][1]
    
        input = single_user_row["content"]
        output = single_assis_row["content"]
        
        response, history = model.chat(tokenizer=tokenizer, query=input, history=[], top_p=top_p, temperature=temperature)
        
        optimized_resp = optimize_parameters(response, fund_standard_name, stock_standard_name)
        
        total_eval_count += 1
        n.write(optimized_resp+'\n')
        if total_eval_count % 10 == 0:
            print("现在是第" + f"{total_eval_count}" + "条数据")