In [1]:
import os
from tqdm import tqdm

from datasets import load_dataset, Dataset

from transformers import pipeline

from trl import HfPairwiseJudge, OpenAIPairwiseJudge

### Devices

In [2]:
# Visible devices
# -------------------------------------------------------------------------------------------------
VISIBLE_DEVICES = "0"
# -------------------------------------------------------------------------------------------------

# Enumerate GPUs based on their PCI bus IDs
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

os.environ["CUDA_VISIBLE_DEVICES"] = f"{VISIBLE_DEVICES}"

### Models and dataset

In [3]:
# Base model
# -------------------------------------------------------------------------------------------------
BASE_MODEL_PATH = "RLHF-And-Friends/Llama-3.2-3B-Instruct"
# -------------------------------------------------------------------------------------------------

# Fine-tuned model
# -------------------------------------------------------------------------------------------------
FT_MODEL_PATH = "RLHF-And-Friends/Llama-3.2-3B-Instruct-DPO-Math"
# -------------------------------------------------------------------------------------------------

# Dataset
# -------------------------------------------------------------------------------------------------
DATASET_PATH = "HuggingFaceH4/MATH-500"
DATASET_SPLIT = "test"
PROMPT_FIELD = "problem"
# -------------------------------------------------------------------------------------------------

DATASET_NAME = DATASET_PATH.split('/')[1]

### Load dataset

In [None]:
test_dataset = load_dataset(DATASET_PATH, split=DATASET_SPLIT)

### Inference

In [5]:
def get_responses(
    prompts: list[str], 
    model_path: str, 
    batch_size: int = 8,
    max_new_tokens: int = 512
) -> list[str]:

    chats = [[{'role': "user", 'content': prompt}] for prompt in prompts]

    text_generator = pipeline(
        model=model_path,
        device_map='auto',
        batch_size=batch_size,
        max_new_tokens=max_new_tokens
    )

    responses = []
    for idx in tqdm(
        range(0, len(chats), batch_size), desc=f'{model_path} inference'
    ):
        batch = chats[idx:idx+batch_size]
        responses.extend(text_generator(batch))

    text_reponses = [
        response[0]['generated_text'][-1]['content'] for response in responses
    ]

    return text_reponses


In [None]:
prompts = list(test_dataset[PROMPT_FIELD])

base_completions = get_responses(prompts, BASE_MODEL_PATH)
ft_completions = get_responses(prompts, FT_MODEL_PATH)

### Create dataset with models' responses and load it to HF

In [None]:
responses_dataset = Dataset.from_dict(
    {
        'prompt': prompts, 
        'base_completion': base_completions,
        'ft_completions': ft_completions
    },
    split = "test"
)

responses_dataset.push_to_hub(f"RLHF-And-Friends/{DATASET_NAME}-Completions")

### Load dataset with responses and prepare to judge

In [4]:
responses_dataset = load_dataset(
    f"RLHF-And-Friends/{DATASET_NAME}-Completions"
)["train"].select(range(50)) # change to `test`

In [5]:
prompts = list(responses_dataset['prompt'])
completions = [
    [base_completion, ft_completion] 
    for base_completion, ft_completion in zip(
        responses_dataset['base_completion'], responses_dataset['ft_completions']
    )
]

### Judge with OpenAI API

In [6]:
os.environ["OPENAI_API_KEY"] = "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"

In [8]:
gpt_judge = OpenAIPairwiseJudge("gpt-4o-mini")

gpt_judgements = gpt_judge.judge(prompts, completions, shuffle_order=False)

In [9]:
gpt_winrate = sum(gpt_judgements) / len(gpt_judgements)

In [None]:
print(f"GPT-judged winrate: {gpt_winrate}")

Winrate: 0.58


### Judge with Huggingface API

In [12]:
hf_judge = HfPairwiseJudge("meta-llama/Meta-Llama-3-8B-Instruct")

hf_judgements = hf_judge.judge(prompts, completions, shuffle_order=False)

In [None]:
hf_winrate = sum(hf_judgements) / len(hf_judgements)
print(f"HF-judged winrate: {hf_winrate}")