In [1]:
import copy
import random

import wandb
import time
import os
from tqdm import tqdm
import numpy as np
import pandas as pd
from random import choices
import matplotlib.pyplot as plt
tqdm.pandas()

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import GPT2Tokenizer
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import AdamW, get_linear_schedule_with_warmup

from ppo_model_ac import GPT2HeadWithValueModel, respond_to_batch
# from ppo import PPOTrainer
from ppo_utils import build_bert_batch_from_txt, logprobs_from_logits, whiten, clip_by_value, entropy_from_logits, flatten_dict, stats_to_np, stack_dicts

from utils import get_classifier, generate_next, concat_past, expand_past, read_file
from trigger_semi_supervised import penalize_new_line, prep_inputs

In [2]:
config = {
    "lm_name": "gpt2-medium",
    "ref_lm_name": "lvwerra/gpt2-imdb",
    "cls_model_name": "lvwerra/bert-imdb",
    "tk_name": "gpt2",
    "steps": 12800,
    "batch_size": 128,
    "forward_batch_size": 16,
    "ppo_epochs": 4,   
    "txt_in_len": 5,
    "txt_out_len": 20,
    "lr": 5e-4,
    "init_kl_coef":0.2,
    "target": 6,
    "horizon":10000,
    "gamma":1,
    "lam":0.95,
    "cliprange": .2,
    "cliprange_value":.2,
    "vf_coef":.1, 
    "seed": 1,
    "adam_epsilon": 1e-8,
    "tgt_label": 1,  # 0 for negative, 1 for positive
    "ppo_mini_batch_size": 16,
    "padding_token": 50256,  # padding token for GPT-2 (same as BOS)
    "reset_pos_emb": True,
    "num_of_triggers": 1,
    "trigger_format": "key_value",
    "TRIGGER_POSITION_ID" : 0,
    "device": "cuda"
}

device = "cuda"

In [3]:
batch_size = 16  # should be the same as forward_batch_size
num_of_triggers = 1
trigger_format = "key_value"
reset_pos_emb = True
TRIGGER_POSITION_ID = 0

# WARNING: GPT2 only
new_line_idx = 198  # '\n'
new_line_idx_1 = 628  # '\n\n'
stop_token = "."
# t_pad_token = tokenizer.bos_token
block_list = [new_line_idx, new_line_idx_1]

sample = True
top_k = 10
temperature = 1.0
repetition_penalty = 1.0
length = 40

In [4]:
wandb.init(name='neutral_prompt_pos_response', project='trigger1', config=config)

In [5]:
# loading pretrained model for sentiment classification
sentiment_model = AutoModelForSequenceClassification.from_pretrained(config["cls_model_name"])
sentiment_tokenizer = AutoTokenizer.from_pretrained(config["cls_model_name"])
sentiment_model.to(device)
sentiment_model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), eps=1

In [6]:
torch.manual_seed(config['seed'])
np.random.seed(config['seed'])

model = GPT2HeadWithValueModel.from_pretrained(config['lm_name'])
tokenizer = GPT2Tokenizer.from_pretrained(config['lm_name'])

t_pad_token = tokenizer.bos_token

model.to(device)
model.eval()

# Freeze GPT-2 weights
for name, param in model.named_parameters():
    if not name.startswith("v_head"):
        param.requires_grad = False

num_layers = model.config.n_layer
    
lm_bos_output = model(torch.tensor(tokenizer.encode(tokenizer.bos_token), dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1))  # BOS
# Note: GPT2HeadWithValueModel returns lm_logits, transformer_outputs[1:], value
# transformer_outputs: hidden_states, past_key_values

# WARNING: GPT2 only
t_pad_token = tokenizer.bos_token


Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at gpt2-medium and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias', 'h.12.attn.masked_bias', 'h.13.attn.masked_bias', 'h.14.attn.masked_bias', 'h.15.attn.masked_bias', 'h.16.attn.masked_bias', 'h.17.attn.masked_bias', 'h.18.attn.masked_bias', 'h.19.attn.masked_bias', 'h.20.attn.masked_bias', 'h.21.attn.masked_bias', 'h.22.attn.masked_bias', 'h.23.attn.masked_bias', 'v_head.summary.bias', 'v_head.summary.weight', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
bos_logits, bos_key_values, bos_v = model(torch.tensor(tokenizer.encode(tokenizer.bos_token), dtype=torch.long).unsqueeze(0).to(device))

In [8]:
# initialize trigger
# Note: since we use the same trigger for all inputs in a batch, we only create/register trigger(s) for one and repeat it
def init_trigger(model, tokenizer, num_of_triggers, trigger_format, ref=False):
    if num_of_triggers > 0:
        if trigger_format == "token":  # learn a continuous embedding
            trigger_embedding_list = []
            for _ in range(num_of_triggers):
                trigger_embedding_i = copy.deepcopy(model.transformer.wte(
                    torch.tensor(tokenizer.encode(tokenizer.bos_token), device=device, dtype=torch.long).unsqueeze(0)))
                trigger_embedding_list.append(trigger_embedding_i)
            if not ref:
                ori_trigger_embedding = nn.Parameter(torch.cat(trigger_embedding_list, dim=1))  # bze x n x emb_size
                model.ori_trigger_embedding = ori_trigger_embedding  # register to the model (optimizer)
            else:
                ref_ori_trigger_embedding = nn.Parameter(torch.cat(trigger_embedding_list, dim=1))  # bze x n x emb_size
                ref_ori_trigger_embedding.requires_grad = False
                model.ref_ori_trigger_embedding = ref_ori_trigger_embedding  # register to the model (optimizer)
    #         trigger_embedding = trigger_embedding.repeat(batch_size, 1, 1)  # cannot do it here, otherwise trigger_embedding becomes a non-leaf node where the grad will not backprop
        elif trigger_format == "key_value":  # learn key values
            ori_trigger_key_values = [(None, None) for _ in range(num_layers)]
#             bos_key_values = model(torch.tensor(tokenizer.encode(tokenizer.bos_token), dtype=torch.long).unsqueeze(0).to(device))[1]  # 1 is past_key_values
            for layer in range(num_layers):
                for i_t in range(num_of_triggers):
                    trigger_i_key_value = copy.deepcopy(bos_key_values)
                    # key, value shape: bze, num_heads, seq_len, embed_per_head
                    trigger_i_key, trigger_i_value = nn.Parameter(trigger_i_key_value[layer][0]), \
                                                     nn.Parameter(trigger_i_key_value[layer][1])

                    if not ref:
                        trigger_i_key.requires_grad = True
                        trigger_i_value.requires_grad = True
                    else:
                        trigger_i_key.requires_grad = False
                        trigger_i_value.requires_grad = False
                        
                    if ori_trigger_key_values[layer][0] is None:
                        ori_trigger_key_values[layer] = (trigger_i_key, trigger_i_value)
                    else:
                        # if multiple triggers
                        trigger_key = nn.Parameter(torch.cat((ori_trigger_key_values[layer][0], trigger_i_key), dim=-2))
                        trigger_value = nn.Parameter(torch.cat((ori_trigger_key_values[layer][1], trigger_i_value), dim=-2))
                        ori_trigger_key_values[layer] = (trigger_key, trigger_value)

                if not ref:
                    # register parameter into optimizer
                    key_name = "l_%d_key" % layer
                    value_name = "l_%d_value" % layer
                else:
                    key_name = "ref_l_%d_key" % layer
                    value_name = "ref_l_%d_value" % layer
                    
                if num_of_triggers == 1:
                    model.register_parameter(name=key_name, param=trigger_i_key)
                    model.register_parameter(name=value_name, param=trigger_i_value)
                else:
                    model.register_parameter(name=key_name, param=trigger_key)
                    model.register_parameter(name=value_name, param=trigger_value)
                    
            if not ref:
                ori_trigger_key_values = tuple(ori_trigger_key_values)
                model.ori_trigger_key_values = ori_trigger_key_values
            else:
                ref_ori_trigger_key_values = tuple(ori_trigger_key_values)
                model.ref_ori_trigger_key_values = ori_trigger_key_values
    #         trigger_key_values = expand_past(trigger_key_values, num_layers, batch_size)  # similar to trigger_embedding, need leaf level grad
        else:
            assert False, "trigger_format: %s not supported" % trigger_format

In [9]:
init_trigger(model, tokenizer, num_of_triggers, trigger_format)
init_trigger(model, tokenizer, num_of_triggers, trigger_format, ref=True)

# optimizer
param_optimizer = list(filter(lambda p: p[1].requires_grad, list(model.named_parameters())))

# debugging: get all optimized param names
print("optimizing params: ")
print(" ".join(o[0] for o in param_optimizer))

no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {
        'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay': 0.0,
    },
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters,
                  lr=config["lr"],
                  eps=config["adam_epsilon"])

optimizing params: 
l_0_key l_0_value l_1_key l_1_value l_2_key l_2_value l_3_key l_3_value l_4_key l_4_value l_5_key l_5_value l_6_key l_6_value l_7_key l_7_value l_8_key l_8_value l_9_key l_9_value l_10_key l_10_value l_11_key l_11_value l_12_key l_12_value l_13_key l_13_value l_14_key l_14_value l_15_key l_15_value l_16_key l_16_value l_17_key l_17_value l_18_key l_18_value l_19_key l_19_value l_20_key l_20_value l_21_key l_21_value l_22_key l_22_value l_23_key l_23_value v_head.summary.weight v_head.summary.bias


In [10]:
wandb.watch(model, log='all')

[<wandb.wandb_torch.TorchGraph at 0x7f0f4ff82588>]

In [11]:
def get_pad_attn_mask(token_tensors, max_length):
    padded_tensors, attention_masks = list(), list()
    for tensor in token_tensors:
        length_i = tensor.shape[1]
        padded_tensor = F.pad(tensor, (0, max_length - length_i), 'constant', 0)  # WARNING: 0 is used for BERT?
        padded_tensors.append(padded_tensor)
        attention_mask = F.pad(torch.ones(1, length_i, device=device), (0, max_length - length_i), 'constant', 0)
        attention_masks.append(attention_mask)
    return torch.cat(padded_tensors), torch.cat(attention_masks)
        

In [12]:
class AdaptiveKLController:
    """
    Adaptive KL controller described in the paper:
    https://arxiv.org/pdf/1909.08593.pdf
    """
    def __init__(self, init_kl_coef, target, horizon):
        self.value = init_kl_coef
        self.target = target
        self.horizon = horizon

    def update(self, current, n_steps):
        target = self.target
        proportional_error = np.clip(current / target - 1, -0.2, 0.2)
        mult = 1 + proportional_error * n_steps / self.horizon
        self.value *= mult

class FixedKLController:
    """Fixed KL controller."""
    def __init__(self, kl_coef):
        self.value = kl_coef

    def update(self, current, n_steps):
        pass


In [13]:
mode = "train"
# train_bz = config["batch_size"]
train_bz = 16
shuffle_data = False

context_list = read_file("persona_train.txt")


In [14]:
def generate_sentence(all_input_ids, all_lengths, batch_min_length, stop_token, length, num_of_triggers):
    
    if reset_pos_emb:
        c_position_ids = torch.arange(1, batch_min_length - 1, dtype=torch.long, device=device)
        c_position_ids = c_position_ids.unsqueeze(0).repeat(batch_size, 1)
    else:
        c_position_ids = None

    if reset_pos_emb:
        t_position_ids = torch.ones(batch_size, num_of_triggers).to(torch.long).to(device) * TRIGGER_POSITION_ID
    else:
        t_position_ids = None
    
    # WARNING: Need to double check
    past = expand_past(bos_key_values, num_layers, batch_size)  # deep copy? shouldn't be modifed

    if num_of_triggers > 0:
        if trigger_format == "token":
            trigger_embedding = model.ori_trigger_embedding.repeat(batch_size, 1, 1)
            lm_trigger_output = model(inputs_embeds=trigger_embedding, position_ids=t_position_ids)
            trigger_key_values = lm_trigger_output["past_key_values"]
        else:
            trigger_key_values = expand_past(model.ori_trigger_key_values, num_layers, batch_size)
        past = concat_past(past, trigger_key_values, num_layers)

    output_so_far = all_input_ids[:, :batch_min_length]  # bze x (batch_min_length - 1)

    context_lm_output = model(all_input_ids[:, 1: batch_min_length - 1], past_key_values=past, position_ids=c_position_ids, )

#     past = context_lm_output["past_key_values"]
    past = context_lm_output[1]  # for gpt2 with value head model
    
    last = output_so_far[:, batch_min_length - 1: batch_min_length]
    
    sentence_not_done = torch.ones(batch_size, 1, dtype=torch.uint8, device=device)
    generated_sentence_length = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
    sentence_stop_first = True

    for i in range(length):
        if reset_pos_emb:
            past_length = past[0][0].size(-2)
            p_position_ids = torch.arange(past_length - num_of_triggers, past_length - num_of_triggers + 1,
                                          dtype=torch.long, device=device)
            p_position_ids = p_position_ids.unsqueeze(0).repeat(batch_size, 1)
        else:
            p_position_ids = None
        
        lm_output = model(last, past_key_values=past, position_ids=p_position_ids)
        
        logits, past = (
                        lm_output[0],  # bze, cur_seq_len, vocab_size
                        lm_output[1],  # acc_seq_len 
                       )

        vocab_size = logits.shape[-1]

        logits = penalize_new_line(logits, block_list)

        # last: bze x 1
        last, _ = generate_next(logits, output_so_far, top_k=top_k, temperature=temperature,
                                            repetition_penalty=repetition_penalty, sample=sample,
                                            gumbel_softmax=False, gumbel_temperature=1.0, detach=False)
        
        # manually assign end token is too long
        if i == length - 1:
            for m_b_i in range(batch_size):
                if generated_sentence_length[m_b_i] == 0:
                    last[m_b_i] = tokenizer.encode(stop_token)[0]  # encode outputs a list (1 element)
        
        is_generated = torch.tensor(all_lengths, device=device).unsqueeze(-1) <= (
                                    i + batch_min_length)  # bze x 1. is generated or stil in the context
        is_end_token = last == torch.tensor(tokenizer.encode(stop_token), device=device)  # bze x 1
        is_actually_ending = is_generated * is_end_token

        # keep track of generated sentence length
        generated_sentence_length = generated_sentence_length + sentence_not_done * is_actually_ending * i

        # if generated, use the generated token as last; otherwise (from the original), copy the
        # orignal token/gumbel_vector
        if batch_min_length + i < all_input_ids.shape[1]:
            last = last * is_generated + all_input_ids[:, batch_min_length + i].unsqueeze(1) * (
                ~is_generated)  # is_generated is bool. need to use "~" instead of (1-is_generated)
        else:
            last = last
        
        # the following may not be necessary?
        if sentence_stop_first and torch.sum(is_actually_ending) > 0:
            sentence_stop_first = False
            min_s_past = past
            min_s_last = last
        
        output_so_far = torch.cat((output_so_far, last), dim=1)  # bze x length

        sentence_not_done = sentence_not_done * (~is_actually_ending)  # to check is we need to stop by summing

        if torch.sum(sentence_not_done) == 0:
            break
        
    batch_sentence_tokens = []
    batch_sentence_texts = []
    batch_sentence_lengths = []
    for s_i in range(batch_size):
        si_tokens = output_so_far[s_i][:(generated_sentence_length[s_i].item() + 1 + batch_min_length)].tolist()
        si_text = tokenizer.decode(si_tokens)
        batch_sentence_tokens.append(si_tokens)
        batch_sentence_texts.append(si_text)
        batch_sentence_lengths.append(len(si_tokens))
    
    return batch_sentence_tokens, output_so_far, batch_sentence_texts, batch_sentence_lengths


In [15]:
class PPOTrainer:
    """
    The PPO_trainer uses Proximal Policy Optimization to optimise language models.
    """

    default_params = {
        "lr": 1.41e-5,
        "adap_kl_ctrl": True,
        "init_kl_coef":0.2,
        "target": 6,
        "horizon":10000,
        "gamma":1,
        "lam":0.95,
        "cliprange": .2,
        "cliprange_value":.2,
        "vf_coef":.1,
        "batch_size": 256,
        "forward_batch_size": 16,
        "ppo_epochs": 4,
        "ppo_mini_batch-size": 4,
    }

    def __init__(self, model, optimizer, **ppo_params):
        """
        Initialize PPOTrainer.
        Args:
            model (torch.model): Hugging Face transformer GPT2 model with value head
            ref_model (torch.model): Hugging Face transformer GPT2 refrence model used for KL penalty
            ppo_params (dict or None): PPO parameters for training. Can include following keys:
                'lr' (float): Adam learning rate, default: 1.41e-5
                'batch_size' (int): Number of samples per optimisation step, default: 256
                'forward_batch_size' (int): Number of samples forward passed through model at a time, default: 16
                'ppo_epochs' (int): Number of optimisation epochs per batch of samples, default: 4
                'gamma' (float)): Gamma parameter for advantage calculation, default: 1.
                'lam' (float): Lambda parameter for advantage calcualation, default: 0.95
                'cliprange_value' (float): Range for clipping values in loss calculation, default: 0.2
                'cliprange' (float): Range for clipping in PPO policy gradient loss, default: 0.2
                'vf_coef' (float): Scaling factor for value loss, default: 0.1
                'adap_kl_ctrl' (bool): Use adaptive KL control, otherwise linear, default: True
                'init_kl_coef' (float): Initial KL penalty coefficient (used for adaptive and linear control), default: 0.2
                'target' (float): Target KL value for adaptive KL control, default: 6.0
                'horizon' (float): Horizon for adaptive KL control, default: 10000
        """
        self.ppo_params = self.default_params
        self.ppo_params.update(ppo_params)

        # self.ref_model = ref_model
        self.model = model
        # self.optimizer = Adam(model.parameters(), lr=self.ppo_params['lr'])
        self.optimizer = optimizer

        self.kl_ctl = AdaptiveKLController(self.ppo_params['init_kl_coef'],
                                           self.ppo_params['target'],
                                           self.ppo_params['horizon'])


    def step(self, all_c_p_tensors, all_c_p_lengths, all_c_lengths, all_scores):
        """
        Run a PPO optimisation step.
        args:
            # query (torch.tensor): tensor containing the encoded queries, shape [batch_size, query_length]
            # response (torch.tensor): tensor containing the encoded responses, shape [batch_size, response_length]
            # scores (torch.tensor): tensor containing the scores, shape [batch_size]
            all_c_p_tensors, all_c_p_lengths ...: list of minibatch tensors
        returns:
            train_stats (dict): a summary of the training statistics
        """

        bs = self.ppo_params['batch_size']
        mini_bs = self.ppo_params["ppo_mini_batch_size"]
        timing = dict()
        t0 = time.time()

        t = time.time()
        
#         print("batched trigger forward + compute_reward")
        logprobs, ref_logprobs, values, rewards, non_score_reward, kl_coef, real_p_tensors, real_c_p_tensors, real_c_p_lengths, real_c_lengths = self.batched_trigger_forward_pass(
            all_c_p_tensors, all_c_p_lengths, all_c_lengths, all_scores)
        timing['time/ppo/batched_trigger_forward'] = time.time()-t
#         print("finished in %.2f seconds\n" % (time.time()-t))


        t = time.time()

        all_stats = []
        idxs = list(range(bs))

        for ppo_epoch_i in range(self.ppo_params['ppo_epochs']):
            if shuffle_data:
                random.shuffle(idxs)
            for i in range(bs // mini_bs):
                b_idx = idxs[i*mini_bs:(i+1)*mini_bs]
                b_logprobs, b_values, b_rewards, b_p_tensors, b_c_p_tensors, b_c_p_lengths, b_c_lengths = \
                    list(), list(), list(), list(), list(), list(), list()
                for b_idx_i in b_idx:
                    b_logprobs.append(logprobs[b_idx_i])
                    b_values.append(values[b_idx_i])
                    b_rewards.append(rewards[b_idx_i])
                    b_p_tensors.append(real_p_tensors[b_idx_i])
                    b_c_p_tensors.append(real_c_p_tensors[b_idx_i])
                    b_c_p_lengths.append(real_c_p_lengths[b_idx_i])
                    b_c_lengths.append(real_c_lengths[b_idx_i])
                
#                 print("\n\n------ppo_epoch: %d/%d; minibatch: %d/%d--------" % (ppo_epoch_i + 1, self.ppo_params['ppo_epochs'], i + 1, bs // mini_bs))
                train_stats = self.train_minibatch(b_logprobs, b_values, b_rewards, b_p_tensors, b_c_p_tensors, b_c_p_lengths, b_c_lengths)

                all_stats.append(train_stats)

        timing['time/ppo/optimize_step'] = time.time()-t

        t = time.time()
        train_stats = stack_dicts(all_stats)

        # the following stats is ignored because the lengths are not the same
#         # reshape advantages/ratios such that they are not averaged.
#         train_stats['policy/advantages'] = torch.flatten(train_stats['policy/advantages']).unsqueeze(0)
#         train_stats['policy/ratio'] = torch.flatten(train_stats['policy/ratio']).unsqueeze(0)

        stats = self.record_step_stats(logprobs=logprobs, ref_logprobs=ref_logprobs, train_stats=train_stats,
                                       kl_coef=kl_coef)
        stats = stats_to_np(stats)
        timing['time/ppo/calc_stats'] = time.time()-t

        self.kl_ctl.update(stats['objective/kl'], self.ppo_params['batch_size'])

        timing['time/ppo/total'] = time.time()-t0
        stats.update(timing)
        return stats

    def get_trigger_forward_pass(self, model_input, is_ref=False):
        reset_pos_emb = self.ppo_params["reset_pos_emb"]
        num_of_triggers = self.ppo_params["num_of_triggers"]
        trigger_format = self.ppo_params["trigger_format"]
        TRIGGER_POSITION_ID = self.ppo_params["TRIGGER_POSITION_ID"]
        device = self.ppo_params["device"]
        batch_size, batch_max_length = model_input.shape
        
        # WARNING: bos_key_value need to be passed
        past = expand_past(bos_key_values, num_layers, batch_size)  # deep copy? shouldn't be modifed

        if reset_pos_emb:
            t_position_ids = torch.ones(batch_size, num_of_triggers).to(torch.long).to(device) * TRIGGER_POSITION_ID
        else:
            t_position_ids = None

        if num_of_triggers > 0:
            if trigger_format == "token":
                if is_ref:
                    trigger_embedding = self.model.ref_ori_trigger_embedding.repeat(batch_size, 1, 1)
                else:
                    trigger_embedding = self.model.ori_trigger_embedding.repeat(batch_size, 1, 1)
                trigger_key_values = self.model(inputs_embeds=trigger_embedding, position_ids=t_position_ids)[1]
            else:
                if is_ref:
                    trigger_key_values = expand_past(model.ref_ori_trigger_key_values, num_layers, batch_size)
                else:
                    trigger_key_values = expand_past(model.ori_trigger_key_values, num_layers, batch_size)

            past = concat_past(past, trigger_key_values, num_layers)

        if reset_pos_emb:
            p_position_ids = torch.arange(1, batch_max_length, dtype=torch.long, device=device)
            p_position_ids = p_position_ids.unsqueeze(0).repeat(batch_size, 1)
        else:
            p_position_ids = None

        logits, _, v = self.model(model_input[:, 1:], past_key_values=past, position_ids=p_position_ids)  # model_input[:, 0] is BOS

        # Note: logits, v is 1 less than the model_input length because there is no prediction for BOS
        # in order to use the actual lengths, we pad logits and v from bos
        # WARNING: 1. should pass bos_logits as paramter 2. should expand box_logits at the beginning
        logits = torch.cat((bos_logits.repeat(batch_size, 1, 1), logits), dim=1)
        v = torch.cat((bos_v.detach().repeat(batch_size, 1), v), dim=1)
        # Note: we do bos_v.detach() here because bos_v requires grad (bos_logits does not because it's not conditioned on anything we are training)
        # if not detach, even if it's not actually used (by selecting actual length only), it will raise the error of 
        # "trying to backward through the graph a second time" because it is used for each mini-batch (regardless of the batch size, including size of 1)
        
        return logits, _, v

    def batched_trigger_forward_pass(self, all_c_p_tensors, all_c_p_lengths, all_c_lengths, all_scores):
        # combines batched_forward_pass and compute_rewards
        logprobs, ref_logprobs, values = list(), list(), list()
        rewards, non_score_rewards = list(), list()
        real_p_tensors, real_c_p_tensors, real_c_p_lengths, real_c_lengths = list(), list(), list(), list()
        for i in range(len(all_c_lengths)):  # number of minibatches
            model_input = all_c_p_tensors[i]  # bze x seq_len
            logits, _, v = self.get_trigger_forward_pass(model_input)
            ref_logits, _, _ = self.get_trigger_forward_pass(model_input, is_ref=True)
            lp = logprobs_from_logits(logits[:, :-1, :], model_input[:, 1:])
            ref_lp = logprobs_from_logits(ref_logits[:, :-1, :], model_input[:, 1:])

            for j in range(len(all_c_lengths[i])):  # loop through the minibatch to get the real indices
                start = all_c_lengths[i][j] - 1
                end = all_c_p_lengths[i][j] - 1
                values.append(v[j:j+1, start:end].detach())
                ij_logprob = lp[j:j+1, start:end].detach()
                ij_ref_logprob = ref_lp[j:j+1, start:end].detach()
                logprobs.append(ij_logprob)
                ref_logprobs.append(ij_ref_logprob)
                real_p_tensors.append(all_c_p_tensors[i][j:j+1, start+1:end+1])
                real_c_p_tensors.append(all_c_p_tensors[i][j:j+1, :end+1])
                real_c_p_lengths.append(end + 1)
                real_c_lengths.append(start + 1)

                # compute rewards
                ij_reward, ij_non_score_reward, kl_coef = self.compute_rewards(all_scores[i][j], ij_logprob, ij_ref_logprob)
                rewards.append(ij_reward)
                non_score_rewards.append(ij_non_score_reward)

        return logprobs, ref_logprobs, values, rewards, non_score_rewards, kl_coef, real_p_tensors, real_c_p_tensors, real_c_p_lengths, real_c_lengths

    def train_minibatch(self, b_logprobs, b_values, b_rewards, b_p_tensors, b_c_p_tensors, b_c_p_lengths, b_c_lengths):
        """Train one PPO minibatch"""
#         print("getting loss!")
        loss_p, loss_v, train_stats  = self.loss(b_logprobs, b_values, b_rewards, b_p_tensors, b_c_p_tensors, b_c_p_lengths, b_c_lengths)
        loss = loss_p + loss_v
#         print(loss_p.item(), loss_v.item(), loss.item())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return train_stats

    def compute_rewards(self, scores, logprobs, ref_logprobs):
        """Compute per token rewards from scores and KL-penalty."""
        kl = logprobs - ref_logprobs
        non_score_reward = -self.kl_ctl.value * kl
        rewards = non_score_reward.clone().detach()
        rewards[:, -1] += scores
        return rewards, non_score_reward, self.kl_ctl.value

    # def loss(self, old_logprobs, values, rewards, query, response, model_input):
    def loss(self, old_b_logprobs, b_values, b_rewards, b_p_tensors, b_c_p_tensors, b_c_p_lengths, b_c_lengths):
        """Calculate policy and value losses."""
        
        # Note: values, old_logprobs are for prompts only (without context)

        mini_bs = self.ppo_params["ppo_mini_batch_size"]
        
        # pad mini batch for forward path
        batch_max_length = max(c_p_t.shape[1] for c_p_t in b_c_p_tensors)
        padded_tensors = list()
        for c_p_t in b_c_p_tensors:
            padded_tensors.append(torch.cat((c_p_t, torch.ones(1, batch_max_length - c_p_t.shape[1], dtype=torch.long, device=device) * self.ppo_params["padding_token"]), dim=1))
        padded_tensors = torch.cat(padded_tensors)  # mini_bs x batch_max_length
        b_logits, _, b_vpred = self.get_trigger_forward_pass(padded_tensors)
        b_logprob = logprobs_from_logits(b_logits[:, :-1, :], padded_tensors[:, 1:])  # min_bs x (batch_max_length - 1)
        
        b_pg_loss, b_vf_loss, b_loss, b_entropy, b_approxkl, b_policykl, b_pg_clipfrac,\
        b_advantages_mean, b_return_mean, b_return_var, b_mean_vpred, b_error, b_vf_clipfrac, b_value_mean, b_value_var = \
            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        
        for j in range(mini_bs):   
            start = b_c_lengths[j] - 1
            end = b_c_p_lengths[j] - 1

            logprob = b_logprob[j:j+1, start:end]
            vpred = b_vpred[j:j+1, start:end]
            gen_len = end - start
            
            old_logprobs = old_b_logprobs[j]
            rewards = b_rewards[j]
            values = b_values[j]
            
            lastgaelam = 0
            advantages_reversed = []
            for t in reversed(range(gen_len)):
                nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
                delta = rewards[:, t] + self.ppo_params['gamma'] * nextvalues - values[:, t]
                lastgaelam = delta + self.ppo_params['gamma'] * self.ppo_params['lam'] * lastgaelam
                advantages_reversed.append(lastgaelam)
            advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)

            returns = advantages + values
            advantages = whiten(advantages)
            advantages = advantages.detach()

            vpredclipped = clip_by_value(vpred,
                                         values - self.ppo_params["cliprange_value"],
                                         values + self.ppo_params["cliprange_value"])

            vf_losses1 = (vpred - returns)**2
            vf_losses2 = (vpredclipped - returns)**2
            vf_loss = .5 * torch.mean(torch.max(vf_losses1, vf_losses2))
            vf_clipfrac =  torch.mean(torch.gt(vf_losses2, vf_losses1).double())

            ratio = torch.exp(logprob - old_logprobs)

            pg_losses = -advantages * ratio
            pg_losses2 = -advantages * torch.clamp(ratio,
                                                   1.0 - self.ppo_params['cliprange'],
                                                   1.0 + self.ppo_params['cliprange'])
            
#             print("advantages")
#             print(advantages)
#             print("ratio")
#             print(ratio)
#             print("pg_losses1: %s" % str(torch.mean(pg_losses).item()))
#             print(pg_losses)
#             print("pg_losses2: %s" % str(torch.mean(pg_losses2).item()))
#             print(pg_losses2)
#             print("pg_loss: %s" % str(torch.mean(torch.max(pg_losses, pg_losses2))))
#             print(torch.max(pg_losses, pg_losses2))
                  
            pg_loss = torch.mean(torch.max(pg_losses, pg_losses2))
            pg_clipfrac = torch.mean(torch.gt(pg_losses2, pg_losses).double())

            loss = pg_loss + self.ppo_params['vf_coef'] * vf_loss

            approxkl = .5 * torch.mean((logprob - old_logprobs)**2)
            policykl = torch.mean(logprob - old_logprobs)
            return_mean, return_var = torch.mean(returns), torch.var(returns)
            value_mean, value_var = torch.mean(values), torch.var(values)

            b_pg_loss += pg_loss
            b_vf_loss += vf_loss
            b_loss += loss
            b_approxkl += approxkl
            b_policykl += policykl
            b_pg_clipfrac += pg_clipfrac
            b_advantages_mean += torch.mean(advantages)
            b_return_mean += return_mean
            b_return_var += return_var
            b_mean_vpred += torch.mean(vpred)
            b_error += torch.mean((vpred - returns) ** 2)
            b_vf_clipfrac += vf_clipfrac
            b_value_mean += value_mean
            b_value_var += value_var

        stats = dict(
            loss=dict(policy=b_pg_loss/mini_bs, value=b_vf_loss/mini_bs, total=b_loss/mini_bs),
            policy=dict(approxkl=b_approxkl/mini_bs, policykl=b_policykl/mini_bs, clipfrac=b_pg_clipfrac/mini_bs,
                        advantages_mean=b_advantages_mean/mini_bs),
            returns=dict(mean=b_return_mean/mini_bs, var=b_return_var/mini_bs),
            val=dict(vpred=b_mean_vpred/mini_bs, error=b_error/mini_bs,
                     clipfrac=b_vf_clipfrac/mini_bs, mean=b_value_mean/mini_bs, var=b_value_var/mini_bs),
        )
        return b_pg_loss/mini_bs, self.ppo_params['vf_coef'] * b_vf_loss/mini_bs, flatten_dict(stats)


    def record_step_stats(self, kl_coef, **data):
        """Record training step statistics."""
        all_mean_kl = 0
        bs = self.ppo_params['batch_size']
        for i in range(bs):
            kl = data["logprobs"][i] - data["ref_logprobs"][i]
            mean_kl = torch.mean(torch.sum(kl, axis=-1))
            all_mean_kl += mean_kl

        # kl = data['logprobs'] - data['ref_logprobs']
        # mean_kl = torch.mean(torch.sum(kl, axis=-1))

        stats = {
            'objective/kl': all_mean_kl / bs,  # need this for adaptive kl controller
            'objective/kl_coef': kl_coef,
        }

        for k, v in data['train_stats'].items():
            stats[f'ppo/{k}'] = torch.mean(v, axis=0)
        stats['ppo/val/var_explained'] = 1 - stats['ppo/val/error'] / stats['ppo/returns/var']
        return stats
    

In [16]:
def cut_text(whole_text_list, prefix_text_list):
    return_cut_text = list()
    for whole_text, prefix_text in zip(whole_text_list, prefix_text_list):
        return_cut_text.append(whole_text.split(prefix_text)[1].strip())
    return return_cut_text

In [None]:
ppo_trainer = PPOTrainer(model, optimizer, **config)
fbs = config['forward_batch_size']

for epoch in tqdm(range(int(np.ceil(config["steps"]/config["batch_size"])))):
    print("***********Epoch: %d/%d*************" % (epoch + 1, int(np.ceil(config["steps"]/config["batch_size"]))))
    torch.cuda.empty_cache()
    logs = dict()
    game_data = dict()
    timing = dict()
    t0 = time.time()
    
    #### get a batch from the dataset
    if mode == "train" and shuffle_data:
        random.shuffle(context_list)
    cond_list = context_list[:config["batch_size"]]
    
#     # this pad to the longest of all. may not be necessary
#     all_input_ids, all_attention_masks, batch_min_length, batch_max_length, all_lengths = prep_inputs(cond_list, tokenizer, device, t_pad_token)
    
    all_c_lengths = list()
    all_c_p_tensors, all_c_p_texts, all_c_p_lengths = list(), list(), list()
    all_c_p_r_tensors, all_c_p_r_texts, all_c_p_r_lengths = list(), list(), list()
    all_rewards = list()
    all_c_p_r_rewards, all_c_p_rewards, all_c_p_rewards_adjusted = list(), list(), list()
    
    log_context, log_prompt, log_response = list(), list(), list()
    
    #### get prompt from model
    for i in range(int(config["batch_size"]/fbs)):
        ctx_i = cond_list[i*fbs:(i+1)*fbs]
        log_context += ctx_i
        input_ids_i, _, _, _, lengths_i = prep_inputs(ctx_i, tokenizer, device, t_pad_token)
        # input_ids_i: bze x max_len
        # lengths_i: list of int
        all_c_lengths.append(lengths_i)
        
        min_c_length = min(lengths_i)
        
        c_p_tokens, c_p_tensors_padded, c_p_texts, c_p_lengths = generate_sentence(input_ids_i, lengths_i, min_c_length, stop_token, length, num_of_triggers=num_of_triggers)
        # Note: c_p_lengths is the whole length (including context and prompt). list of int
        # c_p_tensors_padded: not exactly padded with padding token, should be fine (can be removed by length in classification)
        log_prompt += cut_text(c_p_texts, ctx_i)
        
        min_c_p_length = min(c_p_lengths)
        c_p_r_tokens, c_p_r_tensors_padded, c_p_r_texts, c_p_r_lengths = generate_sentence(c_p_tensors_padded, c_p_lengths, min_c_p_length, stop_token, length, num_of_triggers=0)
        log_response += cut_text(c_p_r_texts, c_p_texts)
        
#         # TODO: check c_p_texts, c_p_r_texts
#         print(ctx_i)
#         print(c_p_texts)
#         print(c_p_r_texts)
        
#         print()
#         print(c_p_tensors_padded)
#         print(c_p_lengths)
#         print()
#         print(c_p_r_tensors_padded)
#         print(c_p_r_lengths)
#         assert False
        
        all_c_p_tensors.append(c_p_tensors_padded)
        all_c_p_texts.append(c_p_texts)
        all_c_p_lengths.append(c_p_lengths)
        all_c_p_r_tensors.append(c_p_r_tensors_padded)
        all_c_p_r_texts.append(c_p_r_texts)
        all_c_p_r_lengths.append(c_p_r_lengths)
        
        # prepare for classification. Need to change later to only consider hidden states of the response
        # WARNING: Different tokenizer! Need to pad again!
        cls_c_p_r_tensors = [sentiment_tokenizer.encode(txt, return_tensors="pt").to(device) for txt in c_p_r_texts]
        max_cls_len = max([t.size()[1] for t in cls_c_p_r_tensors])
        
        cls_c_p_r_padded, cls_c_p_r_attn_mask = get_pad_attn_mask(cls_c_p_r_tensors, max_cls_len)
        res = sentiment_model.forward(cls_c_p_r_padded, cls_c_p_r_attn_mask)[0][:, config["tgt_label"]].detach()  # 0 is the logits of the transformer output
        
        # WARNING: set hyperparameters here
        prompt_reward = True
        c_p_reward_weight = 0.3
        if prompt_reward:
            cls_c_p_tensors = [sentiment_tokenizer.encode(txt, return_tensors="pt").to(device) for txt in c_p_texts]
            max_c_p_cls_len = max([t.size()[1] for t in cls_c_p_tensors])
            cls_c_p_padded, cls_c_p_attn_mask = get_pad_attn_mask(cls_c_p_tensors, max_c_p_cls_len)
            c_p_res = sentiment_model.forward(cls_c_p_padded, cls_c_p_attn_mask)[0][:, config["tgt_label"]].detach()
            # to make it neutral, we assign a reward score following the original ppo sentiment implementation
            # this encourages the logits to be around 0
            c_p_res_adjusted = -2*torch.abs(c_p_res)+4
            all_c_p_r_rewards.append(res)
            all_c_p_rewards.append(c_p_res)
            all_c_p_rewards_adjusted.append(c_p_res_adjusted)
            res = res + c_p_reward_weight * c_p_res_adjusted
            
        all_rewards.append(res)  # [bze]
    
#     print("sampled sentences")
#     for ck_i, ck_text in enumerate(all_c_p_r_texts):
#         print(ck_text)
#         print(all_rewards[ck_i])
#         print(torch.mean(all_rewards[ck_i]))
#         print()
#     print("===========\n\n")
    
#     print("Debuggin current key_value")
#     print(model.l_22_value[:, 0, :, :10])
#     print("++++++++++++\n\n\n")
#     assert False, "Stop here. For PPO debugging, run the following in a different cell"
    
    # should the following be in the fbs loop? Not really. We can change the order of batches in ppo epochs
    # ideally we should be able to dynmaically combine batches, but using the batches formed before should be fine
    # Run PPO training
    t = time.time()
    stats = ppo_trainer.step(all_c_p_tensors, all_c_p_lengths, all_c_lengths, all_rewards)
    timing['time/optimization'] = time.time()-t
    
    #### Log everything
    timing['time/epoch'] = time.time()-t0
    logs.update(timing)
    logs.update(stats)
    log_name = "game_log_e%d" % (epoch + 1)
    log_rewards = torch.cat(all_rewards)
    if prompt_reward:
        log_c_p_rewards = torch.cat(all_c_p_rewards)
        log_c_p_rewards_adjusted = torch.cat(all_c_p_rewards_adjusted)
        log_c_p_r_rewards = torch.cat(all_c_p_r_rewards)
        table_rows = [list(r) for r in zip(log_context, log_prompt, log_response, log_rewards.cpu().tolist(), log_c_p_r_rewards.cpu().tolist(), log_c_p_rewards.cpu().tolist(), log_c_p_rewards_adjusted.cpu().tolist())]
        logs.update({log_name:wandb.Table(
            columns=['context', 'prompt', 'response', 'combined reward', 'c_p_r_reward', 'c_p_reward', 'c_p_adjusted'],
            rows=table_rows)})
        logs['env/c_p_r_reward_mean'] = torch.mean(log_c_p_r_rewards).cpu().numpy()
        logs['env/c_p_r_reward_std'] = torch.std(log_c_p_r_rewards).cpu().numpy()
        logs['env/c_p_r_reward_dist'] = log_c_p_r_rewards.cpu().numpy()
        logs['env/combined_reward_mean'] = torch.mean(log_rewards).cpu().numpy()
        logs['env/c_p_reward_mean'] = torch.mean(log_c_p_rewards).cpu().numpy()
        logs['env/c_p_adjusted_mean'] = torch.mean(log_c_p_rewards_adjusted).cpu().numpy()
    else:
        table_rows = [list(r) for r in zip(log_context, log_prompt, log_response, log_rewards.cpu().tolist())]
        logs.update({log_name:wandb.Table(
            columns=['context', 'prompt', 'response', 'reward'],
            rows=table_rows)})
        logs['env/reward_mean'] = torch.mean(log_rewards).cpu().numpy()
        logs['env/reward_std'] = torch.std(log_rewards).cpu().numpy()
        logs['env/reward_dist'] = log_rewards.cpu().numpy()
    wandb.log(logs)
        
        
        
        
        

  0%|          | 0/100 [00:00<?, ?it/s]

***********Epoch: 1/100*************


  1%|          | 1/100 [00:34<57:33, 34.88s/it]

***********Epoch: 2/100*************


  2%|▏         | 2/100 [01:15<1:02:40, 38.38s/it]

***********Epoch: 3/100*************


  3%|▎         | 3/100 [01:56<1:04:09, 39.69s/it]

***********Epoch: 4/100*************


  4%|▍         | 4/100 [02:37<1:04:00, 40.01s/it]

***********Epoch: 5/100*************


  5%|▌         | 5/100 [03:22<1:06:23, 41.93s/it]

***********Epoch: 6/100*************


In [None]:
# ppo_trainer = PPOTrainer(model, optimizer, **config)
# # ppo_trainer.model.detach_value_head()
# print(ppo_trainer.model.v_head.detach_head)
# stats = ppo_trainer.step(all_c_p_tensors, all_c_p_lengths, all_c_lengths, all_rewards)