In [None]:
import sys
sys.path.append('..')

from preference_datasets import get_batch_iterator
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
import datasets
import matplotlib.pyplot as plt
import random

In [None]:
lora_dir='PATH TO LORA WEIGHTS HERE'

model = AutoModelForCausalLM.from_pretrained('huggyllama/llama-7b',torch_dtype=torch.float16,device_map='auto')
model = PeftModel.from_pretrained(model, lora_dir)

tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

In [None]:
def generate_from_prompt(model,prompt,tokenizer,max_length,temperature):
    input_tok=tokenizer(prompt,add_special_tokens=False)
    input_ids=torch.LongTensor(input_tok['input_ids']).cuda()
    attention_mask=torch.LongTensor(input_tok['attention_mask']).cuda()
    tokenized_samples = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=2048, 
        do_sample=True, 
        temperature=.6, 
        pad_token_id=tokenizer.pad_token_id
    )
    return tokenizer.batch_decode(tokenized_samples,skip_special_tokens=True)

def recursive_generate(model,input_string,tokenizer,max_length,temperature,max_rec=30,current_rec=0):
    num_calls=0 # keep looping until num_calls is equal to number of "Call" in sample
    sample=generate_from_prompt(model,input_string,tokenizer,max_length,temperature)
    
    if current_rec>=max_rec:
        print('exceeded max recursion')
        return sample,True
    while sample[0].count('Call: ')>num_calls:
        num_calls+=1
        call=sample[0].split('Call: ')[-1] # get the latest call
        if '\n' in call:
            call=call.split('\n')[0]+'\nSolution: '
            call_ret,exceed_rec=recursive_generate(model,[call],tokenizer,max_length,temperature,max_rec,current_rec+1)
            if exceed_rec:return sample,True
        else:
            print('bad call')
            return sample,True
        call_out=call_ret[0].split(' ')[-1]
        new_prompt=(sample[0]+'Return: '+call_out+'\nAnswer: ').replace('  ',' ')
        sample=generate_from_prompt(model,[new_prompt],tokenizer,max_length,temperature)
    return sample,False

def generate_binary_list(n):
    binary_list = [random.choice([0, 1]) for _ in range(n)]
    return binary_list


In [None]:
prompt_template='What is the parity of {}?\nSolution: '
res_dict={}
eval_lengths=range(2,30)
for length in eval_lengths:
    num_right=0
    for _ in range(5):
        arr=generate_binary_list(length)
        out,exceed=recursive_generate(model,[prompt_template.format(arr)],tokenizer,max_length=512,temperature=.01,max_rec=2*len(arr))
        if out[0].split(' ')[-1]==str(arr.count(1)%2):num_right+=1
    res_dict[length]=num_right
print(res_dict)