In [1]:
import json
import torch
import jsonlines
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
path = "MiniCPM-2B-sft-fp32"
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(
    path, torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=True
)

In [3]:
res, history = model.chat(tokenizer, query="<user>Answer the question according to the following content \n Question: Rita Coolidge sang the title song for which Bond film?\nFollowing is the content.\nRITA COOLIDGE  ALL TIME HIGH James Bond 007 OCTOPUSSY The val doonican show 1983 - YouTube\nRITA COOLIDGE  ALL TIME HIGH James Bond 007 OCTOPUSSY The val doonican show 1983\nWant to watch this again later?\nSign in to add this video to a playlist.\nNeed to report the video?\nSign in to report inappropriate content.\nRating is available when the video has been rented.\nThis feature is not available right now. Please try again later.\nPublished on Sep 17, 2012\nClip from THE VAL DOONICAN MUSIC SHOW 1983 Featuring Rita Coolidge Performing The title track to the JAMES BOND film OCTOPUSSY.\nCategory\n<assistant>", max_length=1024, top_p=0.5)
res, history

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.


('The title song for the Bond film Octopussy was sung by Rita Coolidge.',
 [{'role': 'user',
   'content': '<user>Answer the question according to the following content \n Question: Rita Coolidge sang the title song for which Bond film?\nFollowing is the content.\nRITA COOLIDGE  ALL TIME HIGH James Bond 007 OCTOPUSSY The val doonican show 1983 - YouTube\nRITA COOLIDGE  ALL TIME HIGH James Bond 007 OCTOPUSSY The val doonican show 1983\nWant to watch this again later?\nSign in to add this video to a playlist.\nNeed to report the video?\nSign in to report inappropriate content.\nRating is available when the video has been rented.\nThis feature is not available right now. Please try again later.\nPublished on Sep 17, 2012\nClip from THE VAL DOONICAN MUSIC SHOW 1983 Featuring Rita Coolidge Performing The title track to the JAMES BOND film OCTOPUSSY.\nCategory\n<assistant>'},
  {'role': 'assistant',
   'content': 'The title song for the Bond film Octopussy was sung by Rita Coolidge.'}])

In [4]:
test_sample_list = []
with jsonlines.open("test_data/triviaqa-rc_web-ver_test.jsonl", 'r') as reader:
    for i, obj in enumerate(reader):
        if i >= 2000:
            break
        test_sample_list.append(obj)

In [5]:
max_input_length = 8000  # 设置输入的最大长度
output_dict = {}
for sample in tqdm(test_sample_list):
    index = sample["index"]
    user_message = sample["messages"][0]["content"]
    
    # 将用户消息限制在最大输入长度内
    user_message = user_message[:max_input_length]
    
    # 使用 model.chat() 方法进行响应，同时限制生成文本的最大长度
    res, history = model.chat(tokenizer, query=f"<user>{user_message}<assistant>", max_length=max_input_length, top_p=0.5, temperature=0.8)
    output_dict[index] = res

  0%|          | 0/410 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  0%|          | 1/410 [00:00<02:57,  2.30it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  0%|          | 2/410 [00:00<03:21,  2.03it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  1%|          | 3/410 [00:01<03:23,  2.00it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  1%|          | 4/410 [00:01<03:09,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  1%|          | 5/410 [00:02<04:22,  1.54it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  1%|▏         | 6/410 [00:03<04:08,  1.62it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  2%|▏         | 7/410 [00:03<03:57,  1.69it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  2%|▏         | 8/410 [00:04<03:44,  1.79it/s]Setting `pad_token_id` to `eos_token_id`:

In [6]:
output_file = 'result_data/baseline_web.json'
# 将 output_list 写入 JSON 文件
with open(output_file, 'w', encoding='utf-8') as f:
    json.dump(output_dict, f, ensure_ascii=False, indent=4)

In [8]:
test_sample_list = []
with jsonlines.open("test_data/triviaqa-rc_wiki-ver_test.jsonl", 'r') as reader:
    for i, obj in enumerate(reader):
        if i >= 2000:
            break
        test_sample_list.append(obj)
print(len(test_sample_list))

637


In [9]:
max_input_length = 8000  # 设置输入的最大长度
output_dict = {}
for sample in tqdm(test_sample_list):
    index = sample["index"]
    qid = index.split('--')[0]
    user_message = sample["messages"][0]["content"]
    
    # 将用户消息限制在最大输入长度内
    user_message = user_message[:max_input_length]
    
    # 使用 model.chat() 方法进行响应，同时限制生成文本的最大长度
    res, history = model.chat(tokenizer, query=f"<user>{user_message}<assistant>", max_length=max_input_length, top_p=0.5, temperature=0.8)
    output_dict[qid] = res

  0%|          | 0/637 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  0%|          | 1/637 [00:00<07:20,  1.44it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  0%|          | 2/637 [00:01<08:21,  1.27it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  0%|          | 3/637 [00:03<16:14,  1.54s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  1%|          | 4/637 [00:04<11:32,  1.09s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  1%|          | 5/637 [00:05<10:58,  1.04s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  1%|          | 6/637 [00:05<09:20,  1.13it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  1%|          | 7/637 [00:06<07:55,  1.32it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  1%|▏         | 8/637 [00:07<07:48,  1.34it/s]Setting `pad_token_id` to `eos_token_id`:

In [10]:
output_file = 'result_data/baseline_wiki.json'
# 将 output_list 写入 JSON 文件
with open(output_file, 'w', encoding='utf-8') as f:
    json.dump(output_dict, f, ensure_ascii=False, indent=4)