In [2]:
from pathlib import Path
from typing import Annotated, Union

import typer
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast
)

ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

def load_model_and_tokenizer(
        model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
    model_dir = Path(model_dir).expanduser().resolve()
    if (model_dir / 'adapter_config.json').exists():
        model = AutoPeftModelForCausalLM.from_pretrained(
            model_dir, trust_remote_code=trust_remote_code, device_map='auto'
        )
        tokenizer_dir = model.peft_config['default'].base_model_name_or_path
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir, trust_remote_code=trust_remote_code, device_map='auto'
        )
        tokenizer_dir = model_dir
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_dir, trust_remote_code=trust_remote_code, encode_special_tokens=True, use_fast=False
    )
    return model, tokenizer

In [None]:
model_dir = "../../train-result/2024-06-29/checkpoint-100"
model, tokenizer = load_model_and_tokenizer(model_dir)

In [None]:
json_file_path = 'data/dev.jsonl'
data_list = read_jsonl(json_file_path)

df = pd.DataFrame(data_list)
df.head()

In [None]:
from random import random

total_eval_count = 0.0
correct_count = 0.0

for index, row in df.iterrows():
    if random() <= 0.9:
        continue
    
    single_user_row = row["messages"][0]
    single_assis_row = row["messages"][1]
    
    messages = [single_user_row]
    output = single_assis_row["content"]
    
    inputs = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt"
        ).to(model.device)
    
    generate_kwargs = {
            "input_ids": inputs,
            "max_new_tokens": 1024,
            "do_sample": True,
            "top_p": 0.8,
            "temperature": 0.8,
            "repetition_penalty": 1.2,
            "eos_token_id": model.config.eos_token_id,
        }
    
    outputs = model.generate(**generate_kwargs)
    response = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True).strip()
    
    # optimized_resp = optimize_parameters(response, fund_standard_name, stock_standard_name)
    
    total_eval_count += 1
    if response == output:
        correct_count += 1
    else:
        print("-----data index-----")
        print(index)
        print("-----query input-----")
        print(input)
        print("-----output diff-----")
        print(highlight_diff(output, response))
        print(response)
        print()
    
    if total_eval_count == 20:
        break
    
print("预测正确的比例：" + f"{correct_count / total_eval_count :.2%}")