In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import json
import numpy as np

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from peft import LoraConfig, get_peft_model

from utils import score_fast, append_sol_and_remove_eos, remove_eos_and_pad_left

In [None]:
bsz = 32
grad_acc = 8

lr = 0.001
warmup_steps = 20
total_steps = 100

train_samples = 20
log_interval = 10

rng_seed = 3

In [None]:
np.random.seed(rng_seed)
random.seed(rng_seed)
torch.manual_seed(rng_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
model_to_use = 'instruct-gpt-j-fp16' # 'gpt2'

if model_to_use == 'instruct-gpt-j-fp16':
    tokenizer = AutoTokenizer.from_pretrained('nlpcloud/instruct-gpt-j-fp16')
    model = AutoModelForCausalLM.from_pretrained('nlpcloud/instruct-gpt-j-fp16',
                                                torch_dtype=torch.bfloat16)
elif model_to_use == 'gpt2':
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
    model = AutoModelForCausalLM.from_pretrained('gpt2')

model.to('cuda')

In [None]:
answers = [ 'objective', 'subjective' ]

obj_id = tokenizer.vocab['Ġobjective']
subj_id = tokenizer.vocab['Ġsubjective']

data_train = [ json.loads(l) for l in open(f'data/subj/train.{train_samples}.jsonl', 'r') ]
data_test = [ json.loads(l) for l in open('data/subj/test.jsonl', 'r') ]

data_train = [sample for sample in data_train if len(sample['text'].split()) < 25]
data_test = [sample for sample in data_test]

train_queries = []
train_sols = []

test_queries = []
test_sols = []

intro_prompt = 'Classify this movie review as objective or subjective: "'
cot_prompt = '" It is'

for sample in data_train:
    train_queries.append(intro_prompt + sample['text'] + cot_prompt)
    train_sols.append(' ' + sample['label_text'])

for sample in data_test:
    test_queries.append(intro_prompt + sample['text'] + cot_prompt)
    test_sols.append(' ' + sample['label_text'])

In [None]:
encoded_train_queries = [tokenizer(query, return_tensors='pt')['input_ids'].cuda() for query in train_queries]
encoded_train_sols = [tokenizer(answer, return_tensors='pt')['input_ids'].cuda() for answer in train_sols]
encoded_train_all_sols = [tokenizer(' objective.', return_tensors='pt')['input_ids'].cuda(),
                          tokenizer(' subjective.', return_tensors='pt')['input_ids'].cuda()]
encoded_test_queries = [tokenizer(query, return_tensors='pt')['input_ids'].cuda() for query in test_queries]

eos_token_id = tokenizer.eos_token_id
pad_token_id = tokenizer.eos_token_id

In [None]:
train_sols[:10]

In [None]:
lora_config = LoraConfig(
    r=256,
    lora_alpha=16,
    target_modules=["k_proj", "v_proj"] if model_to_use == 'instruct-gpt-j-fp16' else ["c_attn"],
    lora_dropout=0.,
    bias="none",
    modules_to_save=["classifier"],
)
knowledge_model = get_peft_model(model, lora_config)

In [None]:
opt = torch.optim.AdamW([{'params': knowledge_model.parameters(), 'lr': lr}], betas=(0.9, 0.99))

# learning rate schedule
def get_lr_mult_at_step(step):
    if step <= warmup_steps:
        return min(step/warmup_steps, 1.)
    return max((total_steps - step) / (total_steps - warmup_steps), 0)
sched = torch.optim.lr_scheduler.LambdaLR(opt, get_lr_mult_at_step)

get_lr_at_step = lambda x : min(x/warmup_steps*lr, lr)

In [None]:
for step in range(total_steps):
    opt.zero_grad()
    loss = 0.
    for _ in range(grad_acc):
        # build a batch
        batch_input = []
        batch_labels = []
        for _ in range(bsz):
            # select an example
            query_ind = np.random.choice(np.arange(len(encoded_train_queries)))
            encoded_input = encoded_train_queries[query_ind]
            batch_input.append(encoded_input[0]) # reverse to prepare for left-padding
            if 'objective' in train_sols[query_ind]:
                batch_labels.append(True)
            elif 'subjective' in train_sols[query_ind]:
                batch_labels.append(False)
        #batch_input = torch.nn.utils.rnn.pad_sequence(batch_input, batch_first=True, padding_value=eos_token_id).flip(-1)
        batch_input, position_ids, _ = \
            remove_eos_and_pad_left(batch_input, eos_token_id=eos_token_id, pad_token_id=eos_token_id)
        position_ids = position_ids.cuda()
        batch_labels = torch.tensor(batch_labels, device='cuda', dtype=torch.bool)

        last_logprob = knowledge_model(batch_input,
                                       attention_mask=batch_input!=eos_token_id,
                                       position_ids=position_ids)['logits'][:, -1].log_softmax(dim=-1)
        obj_logprob = last_logprob[:, obj_id]
        subj_logprob = last_logprob[:, subj_id]
        partition_fn = torch.logsumexp(torch.stack([obj_logprob, subj_logprob], dim=-1), dim=-1)
        loss = torch.where(batch_labels, -(obj_logprob - partition_fn), -(subj_logprob - partition_fn))
        loss.mean().backward()
        
    opt.step()
    sched.step()
    if step % log_interval == 0:
        print(f'loss: {loss.mean().item()}')

In [None]:
def get_preds(model, encoded_queries, top_n = 999999, bsz = 1):
    preds = []
    encoded_obj = tokenizer(' objective',
                                return_tensors='pt').to('cuda')['input_ids'][0]
    encoded_sub = tokenizer(' subjective',
                                return_tensors='pt').to('cuda')['input_ids'][0]
    encoded_results = torch.nn.utils.rnn.pad_sequence([encoded_obj, encoded_sub], batch_first=True, padding_value=eos_token_id)
    encoded_queries_to_use = encoded_queries[:top_n]
    for i in range(len(encoded_queries_to_use) // bsz):
        batch_input = torch.nn.utils.rnn.pad_sequence([x[0] for x in encoded_queries_to_use[i*bsz:(i+1)*bsz]],
                                                      batch_first=True,
                                                      padding_value=eos_token_id)
        with torch.no_grad():
            mean_reward = score_fast(model,
                            append_sol_and_remove_eos(batch_input.repeat_interleave(2, dim=0),
                                                      encoded_results.repeat(bsz, 1), eos_token_id, pad_token_id),
                            eos_token_id=eos_token_id)
        pred = mean_reward.reshape(bsz, 2)
        preds += (pred[:, 0] > pred[:, 1]).tolist()
    return preds

In [None]:
true_preds_train = torch.tensor([True if 'objective' in sol else False for sol in train_sols])
true_preds = torch.tensor([True if 'objective' in sol else False for sol in test_sols])

knowledge_model.eval()
train_preds = get_preds(knowledge_model, encoded_train_queries, bsz = 10)
print(f'Train Acc : {(torch.tensor(train_preds) == true_preds_train).sum() / len(true_preds_train)}')
test_preds = get_preds(knowledge_model, encoded_test_queries, bsz = 100)
print(f'Test Acc : {(torch.tensor(test_preds) == true_preds).sum() / len(true_preds)}')