In [None]:
# download reward model
from transformers import AutoTokenizer, AutoModelForSequenceClassification

rm_tokenizer = AutoTokenizer.from_pretrained("usvsnsp/pythia-6.9b-rm-full-hh-rlhf")
rm = AutoModelForSequenceClassification.from_pretrained("usvsnsp/pythia-6.9b-rm-full-hh-rlhf")
seq_length = 1024

# Try on a single example

In [None]:
prompt = """\n\nHuman: What is the modern day concept of a "soul mate"?\n\nAssistant:"""
selected = """ Ah, soul mates! A lot of romantic relationships depend on the concept of a soul mate, so you can consider soul mates the ultimate expression of romantic love. Soul mates are supposed to be a relationship of profound attraction and deep connection, a meeting of two "soul twins" who are almost one."""
#prompt_converted, selected_converted, _ = to_vicuna_format({"prompt": prompt, "selected": selected, "rejected": ""}).values()
tokenized = rm_tokenizer(rm_tokenizer.bos_token + prompt + selected + rm_tokenizer.eos_token, truncation=True, max_length=seq_length, return_tensors="pt")
if torch.cuda.is_available():
    tokenized = tokenized.to("cuda")
rm_output = rm(**tokenized) 
rm_output[0]

# Compute accuracy

In [None]:
from datasets import load_dataset

#dataset = load_dataset("Anthropic/hh-rlhf")
dataset = load_dataset("pvduy/rm_oa_hh")

In [None]:
import torch
rm_tokenizer.add_special_tokens({"pad_token": "<|padding|>"})
rm_tokenizer.padding_side = "right"
rm_tokenizer.truncation_side = "left"
seq_length=1024

def tokenize(prompt, selected, rejected, tokenizer):
    return {
        "selected_input_ids": tokenizer(prompt + selected + tokenizer.eos_token, truncation=True, max_length=seq_length).input_ids,
        "rejected_input_ids": tokenizer(prompt + rejected + tokenizer.eos_token, truncation=True, max_length=seq_length).input_ids,
    }

def collate_fn(batch):
    input_ids = sum([[x["rejected_input_ids"], x["selected_input_ids"]] for x in batch], [])
    return rm_tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt")

if "chosen" in dataset["train"].column_names:
    dataset = dataset.rename_column("chosen", "selected")
if "replies" in dataset["train"].column_names:
    dataset = dataset.map(lambda x: {"selected": x["replies"][0], "rejected": x["replies"][1]}, remove_columns=["replies"])

def to_vicuna_format(sample):
    prompt = sample["prompt"].strip()
    prompt = prompt.replace("\n\nHuman: ", "</s>USER: ") \
                    .replace("\n\nAssistant: ", " ASSISTANT: ") \
                    .replace("\n\nAssistant:", " ASSISTANT:")
    if prompt.startswith("Human: "):
        prompt = prompt.replace("Human: ", "USER: ")
    if prompt.startswith("</s>"):
        prompt = prompt[4:]

    selected = " " + sample["selected"].strip()
    rejected = " " + sample["rejected"].strip()

    return {"prompt": prompt, "selected": selected, "rejected": rejected}

def to_oa_format(sample):
    prompt = sample["prompt"].strip()
    prompt = prompt.replace("\n\nHuman: ", "</s><|prompter|>") \
                    .replace("\n\nAssistant: ", "</s><|assistant|>") \
                    .replace("\n\nAssistant:", "</s><|assistant|>")
    if prompt.startswith("Human: "):
        prompt = prompt.replace("Human: ", "<|prompter|>")

    selected = sample["selected"].strip()
    rejected = sample["rejected"].strip()

    return {"prompt": prompt, "selected": selected, "rejected": rejected}

# if args.add_oasst_tokens:
#dataset = dataset.map(to_oa_format)
# else:
dataset = dataset.map(to_vicuna_format)


tokenized = dataset.map(tokenize, input_columns=["prompt", "selected", "rejected"], fn_kwargs=dict(tokenizer=rm_tokenizer), desc="Tokenizing")
#dataloader = torch.utils.data.DataLoader(tokenized["train"], shuffle=True, batch_size=32, collate_fn=collate_fn)

In [None]:
eval_dataloader = torch.utils.data.DataLoader(tokenized["test"], shuffle=True, batch_size=2, collate_fn=collate_fn)

In [None]:
all_scores, all_delta_scores, all_tokens = [], [], []
# move dataloader to GPU
import numpy as np
rm.eval()
rm = rm.to("cuda")
for batch in eval_dataloader:
        # move batch to GPU
        if torch.cuda.is_available():
            batch = {k: v.to("cuda") for k, v in batch.items()}
        with torch.no_grad():
            print(batch["input_ids"])
            scores = rm(**batch)[0]

        delta_scores = scores.reshape(-1, 2).diff().view(-1)
        all_delta_scores.extend(delta_scores.tolist())
        all_scores.extend(scores.view(-1).tolist())
        all_tokens.extend(batch["input_ids"].tolist())

        delta_scores = np.hstack(all_delta_scores)
        accuracy = (delta_scores > 0).mean()
        print(f"Accuracy: {accuracy:.2f}")
