In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch
import tqdm

model_id = "MBZUAI/LaMini-GPT-124M"
peft_model_id = "FXNan/gpt2-124M-DPO"

model = AutoModelForCausalLM.from_pretrained(model_id)
model.load_adapter(peft_model_id)

In [None]:
dataset_id = "PKU-Alignment/processed-hh-rlhf"

tokenizer = AutoTokenizer.from_pretrained(model_id)

ds = load_dataset(dataset_id)

In [None]:
ds_test = ds['test']
print(ds_test)
print(ds_test[0]['context'])
print(ds_test[0]['chosen'])

In [None]:
tokenizer.chat_template = """{{ "Below is an instruction that describes a task. Write a response that appropriately completes the request."}}
{% for message in messages %}
{% if message['role'] == 'human' %}
{{ "\n\n### Instruction:" }}
{{ message['text'] }}
{% else %}
{{ "\n\n### Response:" }}{{ message['text'] }}
{% endif %}
{% endfor %}"""

print(tokenizer.apply_chat_template(ds_test[0]['context'], tokenize=False))

In [None]:
questions = [x['context'] for x in ds_test]
answers = [x['chosen'] for x in ds_test]

In [None]:
from string import Template

model.eval()

loss_array = []

# for i in range(min(2000, len(questions))):
for i in tqdm.tqdm(range(min(1000, len(questions)))):
    q = questions[i]
    q = tokenizer.apply_chat_template(q, tokenize=False) + "\n\n### Response:"
    a = answers[i]['text']
    qa = q
    q_tokens = tokenizer(qa, return_tensors="pt", max_length=1024, truncation=True)
    q_tokens_len = len(q_tokens['input_ids'][0])
    tokens = tokenizer(qa+a, return_tensors="pt", max_length=1024, truncation=True)
    qa_tokens_len = len(tokens['input_ids'][0])
    a_tokens_len = qa_tokens_len - q_tokens_len
    if a_tokens_len < 1:
        print(f"Skipping {i}, qa_tokens_len: {qa_tokens_len}, q_tokens_len: {q_tokens_len}, a_tokens_len: {a_tokens_len}")
        continue
    
    with torch.no_grad():
        output = model(**tokens)
    
        logits = output['logits'][:, -a_tokens_len-1:-1, :]
        
        try:
            loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), tokens['input_ids'][0, -a_tokens_len:].view(-1))
        except:
            print(f"Error at {i}")
            print(f"q_tokens_len: {q_tokens_len}")
            print(f"qa_tokens_len: {qa_tokens_len}")
        loss_array.append(loss.item())
        
        
print(f"Mean loss: {sum(loss_array)/len(loss_array)}")
        
# baseline: 2.8425033027483684
# lora 5e-6: 2.8299622096741572
# lora 1e-5: 2.7702952743533142