In [None]:
import json
import os
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, HfArgumentParser, pipeline
from accelerate import Accelerator

tqdm.pandas()



# Input the evaluation file 
ds_popo = load_dataset("json", data_files="../data/alphca_eval_popo.json", split="train")
ds_xpo = load_dataset("json", data_files="../data/alphca_eval_xpo.json", split="train")
ds_dpo = load_dataset("json", data_files="../data/alphca_eval_dpo.json", split="train")

baseline = load_dataset("json", data_files="../data/alphca_eval_LLaMA3_SFT.json", split="train")





In [None]:
accelerator = Accelerator()
device = accelerator.device
pipe_kwargs = {
    "return_all_scores": True,
    "function_to_apply": "none",
    "batch_size": 4,
}
reward_model = "Ray2333/GRM-Llama3-8B-rewardmodel-ft"
rm_tokenizer = AutoTokenizer.from_pretrained(reward_model)
rm_pipe = pipeline(
    "sentiment-analysis",
    model=reward_model,
    device=device,
    tokenizer=rm_tokenizer,
    model_kwargs={"torch_dtype": torch.bfloat16},
    truncation=True,
)

def get_reward(test_texts):
    pipe_outputs = rm_pipe(test_texts, **pipe_kwargs)
    rewards = [output[0]["score"] for output in pipe_outputs]
    return rewards


def change_of_format(prom, resp):
    message = [{"role": "user", "content": prom}] + [{"role": "assistant", "content": resp}]
    return rm_tokenizer.apply_chat_template(message, tokenize=False).replace(rm_tokenizer.bos_token, "")

In [None]:

data_dpo = []
with torch.no_grad():
    for sample in tqdm(ds_dpo):
        # The VLLM may not generate responses for some prompts because it is too long, we skip them
        test_texts = [change_of_format(sample['instruction'], sample['output'])]
        
        rewards = get_reward(test_texts)
        data_dpo.append({"prompt": sample["instruction"], "responses": sample["output"], "rewards": rewards})
data_xpo = []
with torch.no_grad():
    for sample in tqdm(ds_xpo):
        # The VLLM may not generate responses for some prompts because it is too long, we skip them
        test_texts = [change_of_format(sample['instruction'], sample['output'])]
        
        rewards = get_reward(test_texts)
        data_xpo.append({"prompt": sample["instruction"], "responses": sample["output"], "rewards": rewards})
data_popo = []
with torch.no_grad():
    for sample in tqdm(ds_popo):
        # The VLLM may not generate responses for some prompts because it is too long, we skip them
        test_texts = [change_of_format(sample['instruction'], sample['output'])]
        
        rewards = get_reward(test_texts)
        data_popo.append({"prompt": sample["instruction"], "responses": sample["output"], "rewards": rewards})

In [None]:
data_bl = []
with torch.no_grad():
    for sample in tqdm(baseline):
        # The VLLM may not generate responses for some prompts because it is too long, we skip them
        test_texts = [change_of_format(sample['instruction'], sample['output'])]
        
        rewards = get_reward(test_texts)
        data_bl.append({"prompt": sample["instruction"], "responses": sample["output"], "rewards": rewards})

In [None]:
winrate_dpo = [(data_dpo[i]["rewards"][0]>data_bl[i]["rewards"][0]) for i in range(len(data_bl))]+ [(data_dpo[i]["rewards"][0]>=data_bl[i]["rewards"][0]) for i in range(len(data_bl))]
winrate_xpo = [(data_xpo[i]["rewards"][0]>data_bl[i]["rewards"][0]) for i in range(len(data_bl))]+ [(data_xpo[i]["rewards"][0]>=data_bl[i]["rewards"][0]) for i in range(len(data_bl))]
winrate_popo = [(data_popo[i]["rewards"][0]>data_bl[i]["rewards"][0]) for i in range(len(data_bl))]+ [(data_popo[i]["rewards"][0]>=data_bl[i]["rewards"][0]) for i in range(len(data_bl))]


print("WR:")
print("DPO:", np.mean(winrate_dpo))
print("XPO:",np.mean(winrate_xpo))
print("POPO:", np.mean(winrate_wpo))

In [None]:
print("AvgV:")
print("DPO:", np.mean([data_dpo[i]["rewards"][0] for i in range(len(data_dpo))]))
print("XPO:",np.mean([data_xpo[i]["rewards"][0] for i in range(len(data_xpo))]))
print("POPO:",np.mean([data_popo[i]["rewards"][0] for i in range(len(data_popo))]))
