In [1]:
import pandas as pd
import torch
from datasets import Dataset
test_data_df = pd.read_csv("data/reward_test_data-new.csv")

def prepare_sample_text(prompt, response):
    """Prepare the text from a sample of the dataset."""
    text = f"Question: {prompt}\n\nAnswer: {response}"
    return text

def harmful(sample):
    return sample["flag"] == "unsafe-unsafe"

def harmless(sample):
    return sample["flag"] != "unsafe-unsafe"

In [2]:
test_data  = Dataset.from_pandas(test_data_df)

In [3]:
test_data

Dataset({
    features: ['prompt', 'chosen', 'rejected', 'flag'],
    num_rows: 1723
})

In [4]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)

model_name = "avinashreddy/gpt-2-harmful"


ref_model_name = "avinashreddy/gpt-2-rlhf-finetuned"

model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code= True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code =  True)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "right"

ref_model = AutoModelForCausalLM.from_pretrained(ref_model_name, trust_remote_code = True)
# ref_tokenizer = AutoTokenizer.from_pretrained(ref_model_name, trust_remote_code = True)
# ref_tokenizer.pad_token_id = ref_tokenizer.eos_token_id
# ref_tokenizer.padding_side = "right"

Some weights of the model checkpoint at avinashreddy/gpt-2-rlhf-finetuned were not used when initializing GPT2LMHeadModel: ['v_head.summary.bias', 'v_head.summary.weight']
- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:

kwargs = {
        "truncation": True,
        "max_length": 256,
        "return_tensors": "pt",
    }

gen_kwargs = {
    "do_sample" : True,
    "max_length" :256,
    "temperature" : 0.7,
    "top_p" : 0.92, 
    "top_k" : 0 ,
    "eos_token_id" : tokenizer.eos_token_id,
    "pad_token_id" :  tokenizer.pad_token_id

}

In [6]:
def compare(sample):
    input_text = prepare_sample_text(sample["prompt"], "")
    input_tokens = tokenizer(input_text, **kwargs)
    with torch.inference_mode():
        out = model.generate(**input_tokens, **gen_kwargs )

        ref_out = ref_model.generate(**input_tokens, **gen_kwargs)
        
    sft_gpt2_out = tokenizer.decode(out[0], skip_special_tokens= True)
    rlhf_gpt2_out = tokenizer.decode(ref_out[0], skip_special_tokens= True)
    
    
    return{
        "sft_gpt_2_out": sft_gpt2_out,
        "rlhf_gpt2_out" : rlhf_gpt2_out
    }

In [8]:
test_data_gen = test_data.filter(harmful).shuffle(42).select(range(500)).map(compare)

Filter:   0%|          | 0/1723 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

In [8]:
test_data_gen.to_pandas().to_csv("harmful_generated_data.csv", index= False)

In [12]:
print(test_data_gen.to_pandas()[["sft_gpt_2_out", "rlhf_gpt2_out"]].to_markdown())

|     | sft_gpt_2_out                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   