In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import json
import numpy as np

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from peft import LoraConfig, get_peft_model

import replay_buffer
import utils
from utils import generate, append_sol_and_remove_eos
from rewards import get_reward

In [None]:
## HPs
bsz = 64
grad_acc = 8

lr = 0.0005
warmup_steps = 100
total_steps = 1000
pf_temp_high = 2
pf_temp_low = 0.5


subtb_lambda = 1.
reward_temp = 1.
reward_sched_start = 1.2
reward_sched_end = 1.0
reward_sched_horizon = 150

max_len = 5
min_len = 1

eval_interval = 100
log_interval = 10

n_rationales = 20
train_samples = 50
preseed_buffer = True

In [None]:
model_to_use = 'gpt-j' # 'gpt2'

if model_to_use == 'gpt-j':
    tokenizer = AutoTokenizer.from_pretrained('nlpcloud/instruct-gpt-j-fp16')
    model = AutoModelForCausalLM.from_pretrained('nlpcloud/instruct-gpt-j-fp16',
                                                torch_dtype=torch.bfloat16)
elif model_to_use == 'gpt2':
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
    model = AutoModelForCausalLM.from_pretrained('gpt2')

model.to('cuda')

In [None]:
np.random.seed(0)
random.seed(0)

answers = [ 'objective', 'subjective' ]

obj_id = tokenizer.vocab['Ġobjective']
subj_id = tokenizer.vocab['Ġsubjective']

data_train = [ json.loads(l) for l in open('data/subj/train.{train_samples}.jsonl', 'r') ]
data_test = [ json.loads(l) for l in open('data/subj/test.jsonl', 'r') ]

data_train = [sample for sample in data_train]
data_test = [sample for sample in data_test]

train_queries = []
train_sols = []

test_queries = []
test_sols = []

intro_prompt = 'Classify this movie review as objective or subjective: "'
cot_prompt = '" This review is'
sol_prompt = ', so it is'

for sample in data_train:
    train_queries.append(intro_prompt + sample['text'] + cot_prompt)
    train_sols.append(sol_prompt + ' ' + sample['label_text'] + '.')

for sample in data_test:
    test_queries.append(intro_prompt + sample['text'] + cot_prompt)
    test_sols.append(sol_prompt + ' ' + sample['label_text'] + '.')

In [None]:
encoded_train_queries = [tokenizer(query, return_tensors='pt')['input_ids'].cuda() for query in train_queries]
encoded_train_sols = [tokenizer(answer, return_tensors='pt')['input_ids'].cuda() for answer in train_sols]
encoded_train_all_sols = [tokenizer(sol_prompt+' objective.', return_tensors='pt')['input_ids'].cuda(),
                          tokenizer(sol_prompt+' subjective.', return_tensors='pt')['input_ids'].cuda()]
encoded_test_queries = [tokenizer(query, return_tensors='pt')['input_ids'].cuda() for query in test_queries]
encoded_sol_prompt = tokenizer(sol_prompt, return_tensors='pt')['input_ids'].cuda()

eos_token_id = tokenizer.eos_token_id
pad_token_id = tokenizer.eos_token_id

In [None]:
train_sols[:10]

In [None]:
lora_config = LoraConfig(
    r=256,
    lora_alpha=16,
    target_modules=["k_proj", "v_proj"] if model_to_use == 'gpt-j' else ["c_attn"],
    lora_dropout=0.,
    bias="none",
    modules_to_save=["classifier"],
)
inference_model = get_peft_model(model, lora_config)

loss_type = 'modified_subtb' # 'tb' 'tb_no_z' 'hvi' 'hvi_bl' 'pg'

opt = torch.optim.AdamW([{'params': inference_model.parameters(), 'lr': lr}], betas=(0.9, 0.99))

# learning rate schedule
def get_lr_mult_at_step(step):
    if step <= warmup_steps:
        return min(step/warmup_steps, 1.)
    return max((total_steps - step) / (total_steps - warmup_steps), 0)
sched = torch.optim.lr_scheduler.LambdaLR(opt, get_lr_mult_at_step)

In [None]:
rew = get_reward(["FrozenModel"], [{"model": inference_model,
                                    "eos_token_id": eos_token_id,
                                    "temperature": reward_temp,
                                    "solution_beta": 1.,
                                    "cot_beta": 1.0,
                                    "len_beta": 0,
                                    "min_len": 0}])
get_reward_temp = lambda x : reward_sched_start + (reward_sched_end - reward_sched_start) * min(1, x / reward_sched_horizon)

In [None]:
rbuffer = replay_buffer.ReplayBuffer(100, eos_token_id=eos_token_id, sim_tolerance=0.1)

list_of_symbols = [v for k, v in tokenizer.vocab.items() if not k.strip('Ġ').strip().isalnum()]

# add plausible rationales to the replay buffer
if preseed_buffer:
    for query_ind in range(len(train_sols)):
        encoded_input = encoded_train_queries[query_ind]
        encoded_result = encoded_train_sols[query_ind]
        if 'objective' in train_sols[query_ind]:
            wishful_prompt = train_queries[query_ind] + ' objective because it is'
        else:
            wishful_prompt = train_queries[query_ind] + ' subjective because it is'
        encoded_wishful_input = tokenizer(wishful_prompt, return_tensors='pt')['input_ids'].cuda()
        encoded_rationale = generate(inference_model,
                                    encoded_wishful_input.repeat(n_rationales, 1),
                                    eos_token_id=eos_token_id,
                                    max_len=max_len,
                                    temperature=1)[0][:, encoded_wishful_input.size(-1):]
        # find the first non-letter symbol and replace it with EOS
        for j in range(encoded_rationale.size(0)):
            for k in range(encoded_rationale.size(1)):
                if encoded_rationale[j, k] in list_of_symbols:
                    encoded_rationale[j, k:] = eos_token_id
                    break
        print(tokenizer.batch_decode(encoded_rationale)[0])

        def reward_fn(x):
            results = rew.score(append_sol_and_remove_eos(x.repeat(3, 1),
                                                          torch.cat([encoded_result.repeat(x.size(0), 1),
                                                                     encoded_train_all_sols[0].repeat(x.size(0), 1),
                                                                     encoded_train_all_sols[1].repeat(x.size(0), 1)]),
                                                          eos_token_id,
                                                          pad_token_id),
                                skip_first=encoded_input.size(-1),
                                solution_len=0)
            base_reward = results[:x.size(0)]
            obj_score = results[x.size(0):2*x.size(0)]
            sub_score = results[2*x.size(0):]
            pred_obj = obj_score > sub_score
            if 'obj' in train_sols[query_ind]:
                return torch.where(pred_obj, base_reward, base_reward - 50)
            if 'sub' in train_sols[query_ind]:
                return torch.where(pred_obj, base_reward - 50, base_reward)
            raise NotImplementedError

        with torch.no_grad():
            logrewards = utils.generate_and_return_eos_logprob(inference_model, 
                                                            encoded_input.repeat(n_rationales, 1),
                                                            eos_token_id=eos_token_id,
                                                            reward_fn=reward_fn,
                                                            max_len=max_len,
                                                            min_len=min_len,
                                                            temperature=1,
                                                            action_seq=encoded_rationale)[3]
        rbuffer.add_batch(query=encoded_input,
                          answer=encoded_result,
                          rationales=encoded_rationale,
                          logrewards=logrewards,
                          tokenizer=tokenizer)

In [None]:
rewards_to_date = []
for k, v in rbuffer._buffer.items():
    rewards_to_date.append(np.mean([rat[0] for rat in v['rationales']]))
print(np.mean(rewards_to_date))

In [None]:
for step in range(1, total_steps+1):
    opt.zero_grad()
    loss = 0.
    # change reward temperature
    rew.temperature = get_reward_temp(step)
    for _ in range(grad_acc):
        # select an example
        query_ind = np.random.choice(np.arange(len(encoded_train_queries)))
        encoded_input = encoded_train_queries[query_ind]
        encoded_result = encoded_train_sols[query_ind]
        if loss_type.startswith('modified'):
            def reward_fn(x):
                results = rew.score(append_sol_and_remove_eos(x.repeat(3, 1),
                                                              torch.cat([encoded_result.repeat(x.size(0), 1),
                                                                         encoded_train_all_sols[0].repeat(x.size(0), 1),
                                                                         encoded_train_all_sols[1].repeat(x.size(0), 1)]),
                                                              eos_token_id,
                                                              pad_token_id),
                                    skip_first=encoded_input.size(-1),
                                    solution_len=0)
                base_reward = results[:x.size(0)]
                obj_score = results[x.size(0):2*x.size(0)]
                sub_score = results[2*x.size(0):]
                pred_obj = obj_score > sub_score
                if 'obj' in train_sols[query_ind]:
                    return torch.where(pred_obj, base_reward, base_reward - 50)
                if 'sub' in train_sols[query_ind]:
                    return torch.where(pred_obj, base_reward - 50, base_reward)
                raise NotImplementedError
                
            # choose a behavior policy
            b_policy_choice = random.randint(0, 3)
            if b_policy_choice in [0, 1]:
                # using the action policy without tempering
                generated_text, logPF, eos_logprob, logrewards = \
                    utils.generate_and_return_eos_logprob(inference_model, 
                                                    encoded_input.repeat(bsz, 1),
                                                    eos_token_id=eos_token_id,
                                                    reward_fn=reward_fn,
                                                    max_len=max_len,
                                                    min_len=min_len,
                                                    use_tools=False,
                                                    temperature=1 if b_policy_choice == 0 else random.random()*(pf_temp_high-pf_temp_low)+pf_temp_low)
                rbuffer.add_batch(query=encoded_input,
                                answer=encoded_result,
                                rationales=generated_text[:, encoded_input.size(-1):],
                                logrewards=logrewards * rew.temperature, # undo the effect of reward tempering
                                tokenizer=tokenizer)
            else:
                # using samples from the replay buffer
                action_seq, logrewards = rbuffer.sample(bsz, query=encoded_input, answer=encoded_result)
                if action_seq is None:
                    continue
                logrewards *= (1/rew.temperature) # redo the effect of reward tempering
                generated_text, logPF, eos_logprob, logrewards_2 = \
                    utils.generate_and_return_eos_logprob(inference_model, 
                                                    encoded_input.repeat(action_seq.size(0), 1),
                                                    eos_token_id=eos_token_id,
                                                    reward_fn=reward_fn,
                                                    max_len=max_len,
                                                    min_len=min_len,
                                                    use_tools=False,
                                                    action_seq=action_seq,
                                                    skip_rewards=True)
            if loss_type == 'modified_db':
                # modified db loss with logpb=0
                db_loss = (logrewards[:, :-1] + logPF[:, :-1] + eos_logprob[:, 1:] - logrewards[:, 1:] - eos_logprob[:, :-1])**2
                # get a mask for newly generated tokens after the first eos in generated_text
                mask = (generated_text[:, encoded_input.size(-1):] == eos_token_id).cumsum(dim=-1) >= 1
                # if mask is too short, pad it
                if mask.size(-1) < max_len:
                    mask = torch.cat([mask, torch.ones(mask.size(0), max_len-1-mask.size(-1), dtype=torch.bool, device='cuda')], dim=-1)
                mask = mask[:, :max_len]
                # get trajectory lengths by summing the mask
                traj_len = (~mask).sum(dim=-1)
                # get rid of the loss for the terminating step
                db_loss[mask] = 0
                batch_loss = db_loss.sum(-1) / traj_len
            elif loss_type == 'modified_subtb':
                # modified subTB loss with logpb=0
                delta = (logrewards[:, :-1] - eos_logprob[:, :-1] + logPF[:, :-1] - (logrewards[:, 1:] - eos_logprob[:, 1:]))
                delta_cumsum = torch.cat( [ torch.zeros_like(delta[:, :1]), delta ], 1).cumsum(1)
                # get a mask for tokens after the first eos in generated_text
                mask = (generated_text == eos_token_id).cumsum(dim=-1) >= 1
                mask = mask[:, encoded_input.size(-1):]
                mask = mask[:, :max_len]
                # if mask is too short, pad it
                if mask.size(-1) < max_len:
                    mask = torch.cat([mask, torch.ones(mask.size(0), max_len-mask.size(-1), dtype=torch.bool, device='cuda')], dim=-1)
                # get trajectory lengths by summing the mask
                batch_loss = 0.
                total_lambda = 0.
                for subtraj_len in range(1, max_len+1):
                    subtb_term = (delta_cumsum[:, subtraj_len:] - delta_cumsum[:, :-subtraj_len])**2
                    subtb_term[mask[:, subtraj_len - 1:]] = 0
                    batch_loss += subtb_lambda ** (subtraj_len - 1) * subtb_term.sum()
                    total_lambda += subtb_lambda ** (subtraj_len - 1) * (~mask[:, subtraj_len - 1:]).sum()
                batch_loss /= total_lambda
        else:
            raise NotImplementedError
        loss += batch_loss.mean()
        batch_loss.mean().backward()
    opt.step()
    sched.step()
    if step % log_interval == 0:
        print(f'loss: {loss.item()}')
    if step % eval_interval == 0:
        print(f'Step: {step}')
        # pick a random example from the test set
        query_ind = random.randint(0, len(encoded_test_queries)-1)
        encoded_input = encoded_test_queries[query_ind]
        generated_text = generate(inference_model,
                                 encoded_input.repeat(3, 1),
                                 eos_token_id=eos_token_id,
                                 max_len=max_len,
                                 temperature=1)[0]
        print("Test example:")
        print('\n'.join(tokenizer.batch_decode(append_sol_and_remove_eos(generated_text, [None,] * generated_text.size(0), eos_token_id, pad_token_id))))

In [None]:
def eval_acc(encoded_test_queries, test_sols, top_n = 2000):
    correct, total = 0, 0
    encoded_obj = tokenizer(', so it is objective.',
                                return_tensors='pt').to('cuda')['input_ids']
    encoded_sub = tokenizer(', so it is subjective.',
                                return_tensors='pt').to('cuda')['input_ids']
    for encoded_input, sol in zip(encoded_test_queries[:top_n], test_sols[:top_n]):        
        lls = []
        for encoded_result in [encoded_obj, encoded_sub]:
            generated_text = generate(inference_model,
                                    encoded_input,
                                    eos_token_id=eos_token_id,
                                    max_len=max_len,
                                    temperature=.1)[0]
            mean_reward = rew.score(
                            append_sol_and_remove_eos(generated_text, encoded_result, eos_token_id, pad_token_id),
                            skip_first=encoded_input.size(-1),
                            solution_len=0).mean().item()
            lls.append(mean_reward)
        pred = lls[0] > lls[1]
        if (pred is True and 'objective' in sol) or (pred is False and 'subjective' in sol):
            correct += 1
        total += 1
    return correct/total

In [None]:
print('Train Acc:', eval_acc(encoded_train_queries, train_sols))
print('Test Acc @ 100:', eval_acc(encoded_test_queries, test_sols, 100))
#print('Test Acc @ 2000:', eval_acc(encoded_test_queries, test_sols, 2000))

## Generate data

In [None]:
ckpt_name = f'subj_obj_{model_to_use}_{train_samples}samples_len{max_len}_{total_steps}steps_rewtemp{reward_temp}_seed_{preseed_buffer}'
inference_model.save_pretrained(f'ckpts/{ckpt_name}')

In [None]:
# sample CoTs for the entire training set
n_samples = 100
encoded_train_queries_w_cot_sample = []
for encoded_input in encoded_train_queries:        
    generated_text = generate(inference_model,
                                encoded_input.repeat(n_samples, 1),
                                eos_token_id=eos_token_id,
                                max_len=max_len,
                                temperature=1)[0]
    encoded_train_queries_w_cot_sample.append(generated_text)

In [None]:
# greedily generate CoTs for the entire test set
encoded_test_queries_w_cot_greedy = []
for encoded_input in encoded_test_queries:
    generated_text = generate(inference_model,
                                encoded_input,
                                eos_token_id=eos_token_id,
                                max_len=max_len,
                                temperature=.01)[0]
    encoded_test_queries_w_cot_greedy.append(generated_text)

In [None]:
# sample CoTs for the entire test set
n_samples = 10
encoded_test_queries_w_cot_sample = []
for encoded_input in encoded_test_queries:        
    generated_text = generate(inference_model,
                                encoded_input.repeat(n_samples, 1),
                                eos_token_id=eos_token_id,
                                max_len=max_len,
                                temperature=1)[0]
    encoded_test_queries_w_cot_sample.append(generated_text)

In [None]:
torch.save(encoded_train_queries_w_cot_sample, f'ckpts/{ckpt_name}/encoded_train_queries_w_cot_sample.pt')
torch.save(encoded_test_queries_w_cot_greedy, f'ckpts/{ckpt_name}/encoded_test_queries_w_cot_greedy.pt')
torch.save(encoded_test_queries_w_cot_sample, f'ckpts/{ckpt_name}/encoded_test_queries_w_cot_sample.pt')