In [32]:
# This is a notebook to do experiment requested by Nathan for checking token-wise length bias of rewards 
# (my guess: we'll see a lot at smaller scales, not so much at bigger)
import pandas as pd
from transformers import AutoTokenizer, pipeline, AutoModelForSequenceClassification
import torch
from datasets import load_dataset, Dataset
from tqdm import tqdm
import pickle
from rlhfutils.data import qaform
import matplotlib.pyplot as plt

In [15]:
from transformers import PreTrainedModel, LlamaConfig, LlamaModel, LlamaTokenizer
import torch.nn as nn
import torch
from typing import Optional, List

# UltraRM format so that we don't need to re-run a ton of times for each unique input
class LlamaRewardModel(PreTrainedModel):
    config_class = LlamaConfig
    def __init__(self, config):
        super().__init__(config)
        self.model = LlamaModel(config)
        self.regression_head = nn.Linear(self.config.hidden_size, 1, bias=False)

    def forward( # args are the same as LlamaForCausalLM
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):

        transformer_outputs = self.model(
                                input_ids,
                                attention_mask=attention_mask,
                                position_ids=position_ids,
                                past_key_values=past_key_values,
                                inputs_embeds=inputs_embeds,                               
                            )

        hidden_states = transformer_outputs[0]
        rewards = self.regression_head(hidden_states).squeeze(-1)
        
        return rewards

In [51]:
def hhrlhf_preproc(inp, out):
    resp = inp
    if "###Human: " not in inp: 
        resp = "###Human: "+inp
    if "###Assistant: " not in out: 
        resp = resp+" ###Assistant: "
    resp = resp+out
    return resp
    
def progress_causalrm(inps, mod, tok, cache):
    cache.clear()
    inps.reverse()
    scores = []
    # TODO test to see if batching is needed, currently this assumes no batching
    for i in tqdm(range(len(inps))):
        #print(inps[i])
        # switch batck to original HHRLHF format (TODO this is weird)
        question, answer = inps[i]
        question = question.replace("###Assistant:", "\n\nAssistant:").replace("###Human:", "\n\nHuman:")
        question = (question + "\n\nAssistant:").strip()
        tokrewards = None
        # used cached full response if we can
        for k in cache.keys():
            if question in k: 
                tokrewards = cache[question]
        modinps = tok(question+answer, return_tensors='pt')
        
        if tokrewards==None: 
            modinps = modinps.to(mod.device)
            # first one in the cache
            cache[question] = mod(**modinps).tolist()[0]
            tokrewards = cache[question]
        if i==0:
            print(question+answer)
            #print(len(modinps))
            #print(len(tokrewards))
            
        # get the token score based on length of stuff
        scores.append(tokrewards[len(modinps.input_ids[0])-1])   
        #print(len(tokrewards))
        #print(len(modinps.input_ids[0])-1)
    scores.reverse()
    return scores

def progress_oasst(inps, mod, tok):
    scores = []
    # TODO test to see if batching is needed
    for i in tqdm(range(len(inps))):
        #print(inps[i])
        # switch batck to original HHRLHF format (TODO this is weird)
        question, answer = inps[i]
        question = question.replace("###Assistant:", "\n\nAssistant:").replace("###Human:", "Human:")
        # question = (question + "\n\nAssistant:").strip()
        question = question.strip()+"\n\nAssistant: "
        if i==0:
            print(question, answer)
        modinps = tokenizer(question, answer, return_tensors='pt').to(mod.device)
        scores.append(mod(**modinps).logits[0].cpu().detach())
    return scores

def load_rm(rmname): 
    # TODO maybe offload input formatting into here as well
    if rmname=="weqweasdas/hh_rlhf_rm_open_llama_3b":
        rm_tokenizer = AutoTokenizer.from_pretrained("weqweasdas/hh_rlhf_rm_open_llama_3b")
        rm_pipe = pipeline(
          "sentiment-analysis",
          model="weqweasdas/hh_rlhf_rm_open_llama_3b",
          device=1,
          tokenizer=rm_tokenizer,
          model_kwargs={"torch_dtype": torch.bfloat16, 'attn_implementation':"flash_attention_2"}
        )
        pipe_kwargs = {
          "return_all_scores": True,
          "function_to_apply": "none",
          "batch_size": 4
        }
        def progress_pipe(inps, kwargs, pipe): 
            chunk=16
            results = []
            print(inps[0])
            for i in tqdm(range(0, len(inps), chunk)):
                results.extend(pipe(inps, **kwargs))
            return results
        return lambda texts: [output[0]["score"] for output in progress_pipe([hhrlhf_preproc(i, o) for i, o in texts], pipe_kwargs, rm_pipe)]
    if rmname=="OpenAssistant/reward-model-deberta-v3-large-v2":
        rank_model, toker = AutoModelForSequenceClassification.from_pretrained(rmname, device_map=1), AutoTokenizer.from_pretrained(rmname)
        return lambda texts: progress_oasst(texts, rank_model, toker)
    if rmname=="openbmb/UltraRM-13b":
        # load in models
        rtoker = LlamaTokenizer.from_pretrained("openbmb/UltraRM-13b")
        rmodel = LlamaRewardModel.from_pretrained("openbmb/UltraRM-13b", device_map=1, torch_dtype=torch.bfloat16)
        rmodel.eval()
        # cache full rewards (tokenwise) for the longest seqeunce (reverse the passing order), then we can pull from that 
        reward_cache = {}
        
        return lambda texts: progress_causalrm(texts, rmodel, rtoker, reward_cache)
    if "/u/prasanns/research/rlhf-length-biases/models/rewards" in rmname:
        rm_tokenizer = AutoTokenizer.from_pretrained(rmname)
        rm_pipe = pipeline(
          "sentiment-analysis",
          model=rmname,
          device=1,
          tokenizer=rm_tokenizer,
          model_kwargs={"torch_dtype": torch.bfloat16, 'attn_implementation':"flash_attention_2"}
        )
        pipe_kwargs = {
          "return_all_scores": True,
          "function_to_apply": "none",
          "batch_size": 32
        }
        def progress_pipe(inps, kwargs, pipe): 
            chunk=16
            results = []
            print(inps[0])
            for i in tqdm(range(0, len(inps), chunk)):
                results.extend(pipe(inps, **kwargs))
            return results
        return lambda texts: [output[0]["score"] for output in progress_pipe([qaform(i, o) for i, o in texts], pipe_kwargs, rm_pipe)]

def get_token_inps(inputs, toker, gap=1, flat=True):
    oldinps = None
    # only work with output part of input tuples (TODO this may need adaptation for sanity check)
    if len(inputs[0])==2: 
        oldinps = inputs
        inputs = [inputs[i][1] for i in range(len(inputs))]
    finlist = []
    toklists = [toker(inp).input_ids for inp in inputs]
    for toks in toklists: 
        tmp = []
        for ind in range(gap, len(toks), gap): 
            tmp.append(toker.decode(toks[:ind+1], skip_special_tokens=True))
        if (len(toks)%gap)!=0:
            tmp.append(toker.decode(toks, skip_special_tokens=True))
        finlist.append(tmp)
    if oldinps: 
        # we want something we can use directly
        if flat: 
            rlens = [len(finlist[i]) for i in range(len(finlist))]
            inps = []
            outs = []
            for i in range(len(oldinps)): 
                inps.extend([oldinps[i][0]]*rlens[i])
                outs.extend(finlist[i])
            rlens = [sum(rlens[:i+1]) for i in range(len(rlens))]
            return [0]+rlens, [(inps[i], outs[i]) for i in range(len(outs))]
        # get the corrected reprocessed thing
        return [(oldinps[i][0], finlist[i]) for i in range(len(finlist))]
    return finlist

In [49]:
def pp_hh(ex): 
    ex['chosen'] = ex['chosen'].replace("\n\n", " ###")
    ex['rejected'] = ex['rejected'].replace("\n\n", " ###")
    return ex

def pp_uf(ex):
    ex['chosen'] = "###Human: "+ex['question'].strip()+" ###Assistant: "+ex['response_j']
    ex['rejected'] = "###Human: "+ex['question'].strip()+" ###Assistant: "+ex['response_k']
    return ex

In [5]:
# Code for loading in data
hh_train = load_dataset("Anthropic/hh-rlhf", data_dir="helpful-base", split="test").map(pp_hh, num_proc=10)

Found cached dataset json (/home/prasann/.cache/huggingface/datasets/Anthropic___json/Anthropic--hh-rlhf-37c6f75e35564d2a/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)
Loading cached processed dataset at /home/prasann/.cache/huggingface/datasets/Anthropic___json/Anthropic--hh-rlhf-37c6f75e35564d2a/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96/cache-a0c4ef1f168e4156_*_of_00010.arrow


In [6]:
ultrafeediff = Dataset.load_from_disk("../../data/ultra/ultrafeeddiff/").shuffle(seed=0).select(range(1000)).filter(lambda ex: ex['tokj']<150).map(pp_uf)

Loading cached shuffled indices for dataset at /scratch/cluster/prasanns/research/rlhf-length-biases/data/ultra/ultrafeeddiff/cache-5898f0ebaad14ffe.arrow
Loading cached processed dataset at /scratch/cluster/prasanns/research/rlhf-length-biases/data/ultra/ultrafeeddiff/cache-45b182770748902d.arrow
Loading cached processed dataset at /scratch/cluster/prasanns/research/rlhf-length-biases/data/ultra/ultrafeeddiff/cache-472f00fcfc1cacab.arrow


In [None]:
# ultrafeediff['question']

In [52]:
rewardmodels = ["OpenAssistant/reward-model-deberta-v3-large-v2", "weqweasdas/hh_rlhf_rm_open_llama_3b", "openbmb/UltraRM-13b", "/u/prasanns/research/rlhf-length-biases/models/rewards/bow/expbow50", 
               "/u/prasanns/research/rlhf-length-biases/models/rewards/bow/lenevenrm"]
ind = -1
rmstr = rewardmodels[ind]
# TODO maybe we should use the same tokenizer across all of these? Aligning between models is gonna be a bit weird
# TODO set this back to 0
tokenizer = AutoTokenizer.from_pretrained(rewardmodels[ind])
model = load_rm(rmstr)

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


In [53]:
def scoretexts(texts, mod, gap=5): 
    test_inps = [x.rsplit("###Assistant: ", 1) for x in texts]
    print(test_inps[0])
    values, data = get_token_inps(test_inps, tokenizer, gap)
    scos = model(data)
    nscos = [scos[values[i-1]:values[i]] for i in range(1, len(values))]
    return nscos

In [54]:
with torch.no_grad():
    chosenscos = scoretexts(ultrafeediff['chosen'][:50], model, 5)
    rejscos = scoretexts(ultrafeediff['rejected'][:50], model, 5)

['###Human: In this task you will be given a list of integers. You should remove all of the integers that are divisible by 3 from the list. If every integer in the input list is divisible by 3 then an empty list should be returned. Zero is divisible by 3.\nQ: [61, 35, -86, -38, 58, -9, 78]\nA: ', 'The filtered list is: [61, 35, -86, -38, 58, -9]']
Question: ###Human: In this task you will be given a list of integers. You should remove all of the integers that are divisible by 3 from the list. If every integer in the input list is divisible by 3 then an empty list should be returned. Zero is divisible by 3.
Q: [61, 35, -86, -38, 58, -9, 78]
A: 

Answer: The filtered list is:


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:32<00:00,  1.07s/it]


['###Human: In this task you will be given a list of integers. You should remove all of the integers that are divisible by 3 from the list. If every integer in the input list is divisible by 3 then an empty list should be returned. Zero is divisible by 3.\nQ: [61, 35, -86, -38, 58, -9, 78]\nA: ', '[35, -86, -38, 58, -9, 78]\nConfidence: 95%']
Question: ###Human: In this task you will be given a list of integers. You should remove all of the integers that are divisible by 3 from the list. If every integer in the input list is divisible by 3 then an empty list should be returned. Zero is divisible by 3.
Q: [61, 35, -86, -38, 58, -9, 78]
A: 

Answer: [35, -86


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 68/68 [03:22<00:00,  2.98s/it]


In [37]:
def fin_accs(chos, rejs):
    rlast = [r[-1] for r in rejs]
    clast = [c[-1] for c in chos]
    return sum([clast[i]>rlast[i] for i in range(len(rejs))])

In [38]:
fin_accs(chosenscos, rejscos)

tensor([31])

In [13]:
sum(len(chosenscos[i])>len(rejscos[i]) for i in range(len(rejscos)))

17

In [55]:
chosenscos = [[float(f) for f in chose] for chose in chosenscos]
rejscos = [[float(f) for f in chose] for chose in rejscos]

with open('boweven_ultrascos.pkl', 'wb') as file:
    pickle.dump((chosenscos, rejscos), file)
# with open('oasst_deberta_hhscores.pkl', 'wb') as file:
#     pickle.dump((chosenscos, rejscos), file)

In [11]:
# with open('llama3b_hhscores.pkl', 'rb') as file:
#     chosen, rejs = pickle.load((chosenscos, rejscos), file)

In [None]:
chosenscos

In [32]:
with open('llama3b_hhscores.pkl', 'rb') as file:
    l3bchos, l3rejs = pickle.load(file)

In [34]:
with open('ultra13b.pkl', 'rb') as file:
    ultrachos, ultrarejs = pickle.load(file)

In [24]:
with open('bow_ultrascos.pkl', 'rb') as file:
    ultrachos, ultrarejs = pickle.load(file)

FileNotFoundError: [Errno 2] No such file or directory: 'bow_ultrascos.pkl'

In [41]:
fin_accs(l3bchos, l3rejs)

13

In [56]:
boweq = Dataset.load_from_disk("../../data/bagofwords/bowsynth100k/")

In [28]:
otok = AutoTokenizer.from_pretrained("facebook/opt-125m")

In [57]:
from rlhfutils.rewards import get_synth_rewards

def proc_eqlen(ex):
    def cutstr(s, lim=1000):
        ts = otok(s).input_ids[:lim]
        return len(ts), otok.decode(ts, skip_special_tokens=True)
    ex['stoks'], ex["question"] = cutstr(ex['question'])
    ex['tokj'], ex['response_j'] = cutstr(ex['response_j'])
    ex['tok'], ex['response_k'] = cutstr(ex['response_k'])
    resps = [ex['response_j'], ex['response_k']]
    scos = get_synth_rewards(resps, "bagofwords")
    jind = 0 if scos[0]>scos[1] else 0
    ex['response_j'] = resps[jind]
    ex['response_k'] = resps[1-jind]
    ex['score_j'] = scos[jind]
    ex['score_k'] = scos[1-jind]
    ex['magnitude'] = ex['score_j'] - ex['score_k']
    return ex

In [58]:
boweq = boweq.map(proc_eqlen, num_proc=10)
# boweq = boweq.filter(lambda ex: (ex['stoks']==30) and (ex['tok']==30) and (ex['tokj']==30))
print(len(boweq))

                                                                                                                                                                                                                      

100000




In [59]:
len(boweq.filter(lambda ex: ex['tokj']>ex['tok']))

                                                                                                                                                                                                                      

51432

In [47]:
boweq.save_to_disk("../../data/bagofwords/bowprefseqlenprefs")

                                                                                                                                                                                                                      