In [None]:
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")

dataset = load_dataset("/Users/zejun7/python/llm-training/reward_model/original_data/Skywork-Reward-Preference-80K-v0.1/data", split="train")
dataset[0]

{'chosen': [{'content': 'c#', 'role': 'user'},
  {'content': 'C# (pronounced "C sharp") is a modern, object-oriented programming language developed by Microsoft. It is widely used for building various types of applications, including web applications, desktop applications, mobile applications, and games. C# is similar to other programming languages such as Java and C++, and it is known for its simplicity and ease of use. C# is a powerful language that provides a rich set of libraries and frameworks that make it easy to build robust and scalable applications.\n\nHere is a brief overview of some key features of C#:\n\n1. Object-oriented: C# is an object-oriented language, which means it uses the concept of objects to represent real-world entities and their behavior.\n\n2. Cross-platform: C# can be used to build applications for multiple platforms, including Windows, macOS, and Linux.\n\n3. Strongly typed: C# is a strongly typed language, which means that variables must be declared with a

## 对话模板 apply_chat_template

In [28]:
# 不能这么赋值，dataset是只读的，需要用map来修改
for i in tqdm(range(len(dataset))):
    dataset[i]['new_chosen'] = tokenizer.apply_chat_template(dataset[i]['chosen'], tokenize=False)
    dataset[i]['new_rejected'] = tokenizer.apply_chat_template(dataset[i]['rejected'], tokenize=False)
    break

  0%|          | 0/81973 [00:00<?, ?it/s]


In [26]:
# 方式一：标准奖励模型格式
def apply_chat_template(examples):
    """
    输入格式:
    examples = {
        'chosen': [
            {'role': 'user', 'content': '用一句话介绍Python'},
            {'role': 'assistant','content': 'Python是一种简洁优雅的语言'}
        ],
        'rejected': [
            {'role': 'user','content': '用一句话介绍Python'},
            {'role': 'assistant','content': 'Python是条蛇'}
        ]
    }
    """
    examples['chosen'] = tokenizer.apply_chat_template(
        examples['chosen'],
        tokenize=False
    )
    examples['rejected'] = tokenizer.apply_chat_template(
        examples['rejected'],
        tokenize=False
    )
    return examples

# 使用 map 替换原字段
dataset_template = dataset.map(
    apply_chat_template,
    batched=True,
    batch_size=1000,  # 每批处理1000条
    num_proc=16,       # 多进程加速
    desc="Applying chat template"
)
dataset_template.to_parquet("data/skywork.parquet")
dataset_template[0]

Creating parquet from Arrow format: 100%|██████████| 82/82 [00:01<00:00, 53.44ba/s]


{'chosen': '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nc#<|im_end|>\n<|im_start|>assistant\nC# (pronounced "C sharp") is a modern, object-oriented programming language developed by Microsoft. It is widely used for building various types of applications, including web applications, desktop applications, mobile applications, and games. C# is similar to other programming languages such as Java and C++, and it is known for its simplicity and ease of use. C# is a powerful language that provides a rich set of libraries and frameworks that make it easy to build robust and scalable applications.\n\nHere is a brief overview of some key features of C#:\n\n1. Object-oriented: C# is an object-oriented language, which means it uses the concept of objects to represent real-world entities and their behavior.\n\n2. Cross-platform: C# can be used to build applications for multiple platforms, including Windows, macOS, and Linux.\

In [42]:
# 方式二：不推荐，它会将prompt转化成模板，再将prompt+chosen转化成模板，再提取prompt和chosen单独部分，很多冗余操作，降低数据处理效率
# 相比方式一，实测速度对比 15s -> 5min
# 特别注意：这种方法在后续的_tokenzie中没有用到prompt，会导致前缀丢失，所以**奖励模型是不能用这种格式的**
dataset_name = "data/skywork2.parquet"
def apply_chat_template(examples):
    """
    输入格式:
    examples = {
        'prompt': [
            {'role': 'user', 'content': '解释一下什么是机器学习'}
        ],
        'chosen': [
            {'role': 'assistant', 'content': '机器学习是...（正确解释）'}
        ],
        'rejected': [
            {'role': 'assistant', 'content': '机器学习就是机器人学习走路'}
        ]
    }
    """
    prompt_text = tokenizer.apply_chat_template(
        examples['prompt'],
        tokenize=False,
        add_generation_prompt=True
    )
    chosen_text = tokenizer.apply_chat_template(
        examples['prompt'] + examples['chosen'],
        tokenize=False
    )
    rejected_text = tokenizer.apply_chat_template(
        examples['prompt'] + examples['rejected'],
        tokenize=False
    )

    examples['prompt'] = prompt_text
    examples['chosen'] = chosen_text[len(prompt_text):]
    examples['rejected'] = rejected_text[len(prompt_text):]
    return examples


def template_transform(batch):
    new_examples = {
        'prompt': [],
        'chosen': [],
        'rejected': []
    }
    for chosen, rejected in zip(batch["chosen"], batch["rejected"]):
        new_examples['prompt'].append([chosen[0]])
        new_examples['chosen'].append([chosen[1]])
        new_examples['rejected'].append([rejected[1]])
    return new_examples


# 使用 map 替换原字段
dataset_prompt_format = dataset.map(
    template_transform,
    batched=True,
    batch_size=1000,  # 每批处理1000条
    num_proc=16,       # 多进程加速
    desc="Template transform"
)
dataset_template = dataset_prompt_format.map(
    apply_chat_template,
    batched=True,
    batch_size=1000,  # 每批处理1000条
    num_proc=16,       # 多进程加速
    desc="Applying chat template"
)
dataset_template.to_parquet("data/skywork2.parquet")
dataset_template[0]

Applying chat template (num_proc=16): 100%|██████████| 81973/81973 [00:11<00:00, 7115.62 examples/s] 
Creating parquet from Arrow format: 100%|██████████| 82/82 [00:01<00:00, 58.53ba/s]


{'chosen': '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>assistant\nC# (pronounced "C sharp") is a modern, object-oriented programming language developed by Microsoft. It is widely used for building various types of applications, including web applications, desktop applications, mobile applications, and games. C# is similar to other programming languages such as Java and C++, and it is known for its simplicity and ease of use. C# is a powerful language that provides a rich set of libraries and frameworks that make it easy to build robust and scalable applications.\n\nHere is a brief overview of some key features of C#:\n\n1. Object-oriented: C# is an object-oriented language, which means it uses the concept of objects to represent real-world entities and their behavior.\n\n2. Cross-platform: C# can be used to build applications for multiple platforms, including Windows, macOS, and Linux.\n\n3. Strongly typed: C# is a st

In [None]:
# 语言建模（messages）
def apply_chat_template(examples):
    """
    输入格式:
    examples = {
        'messages': [
            {'role': 'user', 'content': '你好'},
            {'role': 'assistant', 'content': '你好呀！'}
        ]
    }
    """
    examples['text'] = tokenizer.apply_chat_template(
        examples['messages'],
        tokenize=False
    )
    return examples


# Prompt-only
def apply_chat_template(examples):
    """
    输入格式:
    examples = {
        'prompt': [
            {'role': 'user', 'content': '写一个Python冒泡排序'}
        ]
    }
    """
    examples['prompt'] = tokenizer.apply_chat_template(
        examples['prompt'],
        tokenize=False,
        add_generation_prompt=True  # 最后是user，需要生成
    )
    return examples


# Prompt + Completion
def apply_chat_template(examples):
    """
    输入格式:
    examples = {
        'prompt': [
            {'role': 'user', 'content': '讲个冷笑话'}
        ],
        'completion': [
            {'role': 'assistant', 'content': '程序员的Bug不会飞~'}
        ]
    }
    """
    prompt_text = tokenizer.apply_chat_template(
        examples['prompt'],
        tokenize=False,
        add_generation_prompt=True
    )
    completion_text = tokenizer.apply_chat_template(
        examples['prompt'] + examples['completion'],
        tokenize=False
    )
    examples['prompt'] = prompt_text
    examples['completion'] = completion_text[len(prompt_text):]
    return examples


# Unpaired Preference (Prompt + Completion + Label)
def apply_chat_template(examples):
    """
    输入格式:
    examples = {
        'prompt': [
            {'role': 'user', 'content': '写一个1+1的答案'}
        ],
        'completion': [
            {'role': 'assistant', 'content': '答案是2'}
        ],
        'label': 1  # 比如 1表示偏好，0表示不偏好
    }
    """
    prompt_text = tokenizer.apply_chat_template(
        examples['prompt'],
        tokenize=False,
        add_generation_prompt=True
    )
    completion_text = tokenizer.apply_chat_template(
        examples['prompt'] + examples['completion'],
        tokenize=False
    )
    examples['prompt'] = prompt_text
    examples['completion'] = completion_text[len(prompt_text):]
    return examples


# vllm部署及测试

In [None]:
'''
CUDA_VISIBLE_DEVICES=0,1 vllm serve save/base_reward_0926/checkpoint-50 --served-model-name base_reward --tensor-parallel-size 2 --port 5001 --override-pooler-config '{"pooling_type": "LAST", "normalize": false, "softmax": false}'
'''

In [2]:
"""纯判别式奖励模型测试"""
import requests
def get_score_api(query: str, response: str):
    api_url = "http://0.0.0.0:5001/pooling"
    # Input like Chat API
    prompt = {
        "model": "base_reward",
        "messages": [
            {
                "role": "system",
                "content": [{"type": "text", "text": "You are a helpful assistant."}],
            },
            {
                "role": "user",
                "content": [{"type": "text", "text": query}],
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": response}],
            },
        ],
    }
    headers = {"Content-Type": "application/json", "Authorization": "Bearer EMPTY"}
    response = requests.post(api_url, headers=headers, json=prompt)
    return float(response.json()["data"][0]["data"][0])

user_prompt = "给我写一个谜语"
response_str = """我就不写"""

get_score_api(user_prompt, response_str)

-7.0625