## Imports

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
from tqdm import tqdm
from datasets import load_dataset
from torchmetrics.text.rouge import ROUGEScore
from torchmetrics import SacreBLEUScore
import nltk
nltk.download("punkt")

import pandas as pd

from collections import defaultdict
import gc

## Define metrics

In [None]:
rouge = ROUGEScore()
bleu = SacreBLEUScore()

## Define models

In [None]:
if torch.cuda.is_available:
    device = torch.device("cuda")
else:
    device = "cpu"

In [None]:
tokenizer = AutoTokenizer.from_pretrained("Myashka/125M_GPTneo_reward_gen")
tokenizer.pad_token = tokenizer.eos_token

reward_model = AutoModelForSequenceClassification.from_pretrained("Myashka/125M_GPTneo_reward_gen").to(device)
sft_model = AutoModelForCausalLM.from_pretrained("Myashka/125M_GPTneo_sft_tuned").to(device)

reward_model = reward_model.eval()
sft_model = sft_model.eval()

reward_model = torch.compile(reward_model)
sft_model = torch.compile(sft_model)

## Config

In [None]:
data_config = {'data_file_path': '/content/1.0-data-div-ans-sep-api-usage.json',
               "padding": False,
               "max_length_promt": 256,
               "truncate_promt": True,
               }

generation_kwargs = {
    "min_length": -1,
    "top_k": 1,
    'num_return_sequences': 10,
    "top_p": 1.0,
    "do_sample": True,
    "max_new_tokens": 256,
}

## Data

In [None]:
def build_dataset(
    tokenizer,
    data_config,
    splits,
):

    def promt_tokenize(examples):
        if data_config['truncate_promt']:
            q_toks = tokenizer.encode(examples['Question'])
            q_toks = q_toks[:data_config['max_length_promt']-7]
            tmp = tokenizer.decode(q_toks).strip()
        else:
            tmp = examples['Question']

        sample = 'Question: ' + tmp + "\nAnswer:"

        tokenized_dict = tokenizer(
            [sample], padding=data_config['padding'], max_length=data_config['max_length_promt'], truncation=True)
        
        tokenized_dict['Question_promt'] = sample
        tokenized_dict['Original_answer'] = examples['Answer']

        return tokenized_dict

    datasets = []
    for split in splits:
        dataset = load_dataset(
            "json", data_files=f"{data_config['data_file_path']}", field=f'{split}')['train']
        dataset = dataset.map(promt_tokenize)
        dataset.set_format(type="torch", columns=["input_ids", "Question_promt", 'Original_answer'])
        datasets.append(dataset)
    return datasets

In [None]:
val_dataset = build_dataset(tokenizer, data_config, 'val')[0]

## Generate samples to eval

In [None]:
generation_kwargs = {
    "min_length": -1,
    "top_k": 50,
    'num_return_sequences': 10,
    "top_p": 0.9,
    "do_sample": True,
    "max_new_tokens": 256,
}

In [None]:
val_dict = defaultdict(list)
for i, sample in tqdm(enumerate(val_dataset)):
    generated_samples = sft_model.generate(sample["input_ids"].to(device), **generation_kwargs)

    val_dict['Question'].extend([sample['Question_promt']]*len(generated_samples))
    val_dict['Answer_orig'].extend([sample['Original_answer']]*len(generated_samples))
    val_dict['Q_Id'].extend([i]*len(generated_samples))

    val_dict["Answer_gen"].extend([tokenizer.decode(r.squeeze()[len(query_idx):], skip_special_tokens=True) for r, query_idx in zip(generated_samples, sample["input_ids"].repeat(10, 1))])

    del sample
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
val_rouge1, val_rouge2, val_rougeL, val_bleu = [], [], [], []
val_rewards = []

for i in range(len(val_dict["Question"])):
    generated_answer = val_dict["Answer_gen"][i]
    original_answer = val_dict["Answer_orig"][i]
    
    # calculate Rouge and BLEU scores
    scores = rouge.compute(predictions=generated_answer, references=original_answer)
    val_rouge1.append(scores['rouge1'].item())
    val_rouge2.append(scores['rouge2'].item())
    val_rougeL.append(scores['rougeL'].item())
    val_bleu.append(bleu.compute(predictions=[generated_answer], references=[[original_answer]]).item())

    # calculate reward score
    reward = reward_model(original_answer, generated_answer)
    val_rewards.append(reward)