In [1]:
import copy
import random

import wandb
import time
import os
import numpy as np
from random import choices
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration, BlenderbotConfig
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers.generation_beam_search import BeamSearchScorer

from ppo_model_ac import BlenderWithValueModel

from ppo import AdaptiveKLController, FixedKLController
from ppo_utils import build_bert_batch_from_txt, logprobs_from_logits, whiten, clip_by_value, entropy_from_logits, flatten_dict, stats_to_np, stack_dicts
from utils import get_classifier, generate_next, concat_past, expand_past, read_file
from trigger_semi_supervised import penalize_new_line, prep_inputs



In [2]:
mname = 'facebook/blenderbot-400M-distill'
model = BlenderWithValueModel.from_pretrained(mname)
tokenizer = BlenderbotTokenizer.from_pretrained(mname)
model.to("cuda")
model.eval()

Some weights of BlenderWithValueModel were not initialized from the model checkpoint at facebook/blenderbot-400M-distill and are newly initialized: ['v_head.summary.weight', 'v_head.summary.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BlenderWithValueModel(
  (model): BlenderbotModel(
    (shared): Embedding(8008, 1280, padding_idx=0)
    (encoder): BlenderbotEncoder(
      (embed_tokens): Embedding(8008, 1280, padding_idx=0)
      (embed_positions): BlenderbotLearnedPositionalEmbedding(128, 1280, padding_idx=0)
      (layers): ModuleList(
        (0): BlenderbotEncoderLayer(
          (self_attn): BlenderbotAttention(
            (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1280, out_features=5120, bias=True)
          (fc2): Linear(in_features=5120, out_features=1280, bias=True)
          (final_layer_norm): LayerNorm((1280,), eps=

In [3]:
# sampling
model.config.do_sample = True
model.config.num_beams = 1

In [4]:
config = {
    "lm_name": "gpt2-medium",
    "ref_lm_name": "lvwerra/gpt2-imdb",
    "cls_model_name": "/mnt/dian/trigger_experiments/roberta_bbf_bad_ctx",
    "tk_name": "gpt2",
    "steps": 30720,
    "batch_size": 512,
    "forward_batch_size": 32,  # WARNING: changed forward_batch_size and batch_size to 4 for debugging. Was 16
    "ppo_epochs": 4,   
    "txt_in_len": 5,
    "txt_out_len": 20,
    "lr": 2e-4, # WARNING: Changed from 5e-4. debugging with smaller learning rate
    "init_kl_coef":0.2,
    "target": 6,
    "horizon":10000,
    "gamma":1,
    "lam":0.95,
    "cliprange": .2,
    "cliprange_value":.2,
    "vf_coef":.1, 
    "seed": 1,
    "adam_epsilon": 1e-8,
#     "tgt_label": 1,  # 0 for negative, 1 for positive
    "tgt_label": 0,  # 0 for not_ok, 1 for ok
    "ppo_mini_batch_size": 32,
    "padding_token": 50256,  # padding token for GPT-2 (same as BOS)
    "reset_pos_emb": True,
    "num_of_triggers": 1,
    "trigger_format": "key_value",
    "TRIGGER_POSITION_ID" : 0,
    "device": "cuda"
}

device = "cuda"

# for blender:
num_beam_groups = 1
do_sample = False

pad_token_id = 0
bos_token_id = 1
eos_token_id = 2

length_penalty = 0.65
early_stopping = False
    
# logits_processor = model._get_logits_processor(
#     repetition_penalty=1.0,
#     no_repeat_ngram_size=3,
#     bad_words_ids=None,
#     min_length=20,
#     eos_token_id=eos_token_id,
#     prefix_allowed_tokens_fn=None,
#     num_beams=10,
#     num_beam_groups=num_beam_groups,
#     diversity_penalty=0,
# )

batch_size = 32  # should be the same as forward_batch_size
num_of_triggers = 1
trigger_format = "key_value"
reset_pos_emb = True
TRIGGER_POSITION_ID = 0

model.model.encoder.reset_pos_emb = reset_pos_emb
model.model.encoder.num_of_triggers = num_of_triggers

# WARNING: need to change 
sample = True
top_k = 10
temperature = 1.0
repetition_penalty = 1.0
length = 40


if num_of_triggers > 1:
    assert False, "currently not supported! This is hard coded in BlenderbotEncoder for now!"
if not reset_pos_emb:
    assert False, "currently not supported! This is hard coded in BlenderbotEncoder for now!"


In [5]:
wandb.init(name='safety_ppo', project='train_all_adv_klabs', config=config)

[34m[1mwandb[0m: Currently logged in as: [33mdianyu[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.30 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [6]:
# loading pretrained model for classification
cls_model = AutoModelForSequenceClassification.from_pretrained(config["cls_model_name"])
cls_tokenizer = AutoTokenizer.from_pretrained(config["cls_model_name"])
cls_model.to(device)
cls_model.eval()

RobertaForSequenceClassification(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerN

In [7]:
torch.manual_seed(config['seed'])
np.random.seed(config['seed'])

# Freeze GPT-2 weights
for name, param in model.named_parameters():
    if not name.startswith("v_head"):
        param.requires_grad = False

num_enc_layers = model.config.encoder_layers
num_dec_layers = model.config.decoder_layers
        
    
# lm_bos_output = model(torch.tensor(tokenizer.encode(tokenizer.bos_token), dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1))  # BOS
# # Note: GPT2HeadWithValueModel returns lm_logits, transformer_outputs[1:], value
# # transformer_outputs: hidden_states, past_key_values



In [8]:
# prepare for triggers

# get_bos_embeddings
bos_embeddings = model.model.encoder.embed_tokens(torch.tensor([bos_token_id], dtype=torch.long, device=device)).unsqueeze(0)  # 1, 1, hid_size

# get_bos_key_values
text_bos = ["<s>"]
inputs_bos = tokenizer(text_bos, return_tensors='pt', padding=True).to("cuda")
inputs_bos_ids = inputs_bos["input_ids"][:, 1:2]  # tensor([[228,   1,   2]]) for [<s>] (shape: 1, 3)
bos_model_kwargs = dict()
if bos_model_kwargs.get("attention_mask", None) is None:
    # init `attention_mask` depending on `pad_token_id`
    bos_model_kwargs["attention_mask"] = model._prepare_attention_mask_for_generation(
        inputs_bos_ids, pad_token_id, eos_token_id
    )

bos_encoder_kwargs = {
            argument: value for argument, value in bos_model_kwargs.items() if not argument.startswith("decoder_")
        }
bos_output = model.model.encoder(inputs_bos_ids, return_dict=True, **bos_encoder_kwargs, use_cache=True)
bos_key_values = bos_output["past_key_values"]
bos_hidden = bos_output["last_hidden_state"]  # 1, 1, 1280
print(bos_hidden.shape)
print(bos_key_values[0][0].shape)

torch.Size([1, 1, 1280])
torch.Size([1, 32, 1, 40])


In [9]:
# initialize trigger
# Note: since we use the same trigger for all inputs in a batch, we only create/register trigger(s) for one and repeat it
def init_trigger(model, tokenizer, num_of_triggers, trigger_format, ref=False):
    if num_of_triggers > 0:
        
        # create hidden states for decoder
        trigger_hidden_list = []
        for _ in range(num_of_triggers):
            trigger_hidden_i = nn.Parameter(copy.deepcopy(bos_hidden))
            trigger_hidden_list.append(trigger_hidden_i)
        if not ref:
            ori_trigger_hidden = nn.Parameter(torch.cat(trigger_hidden_list, dim=1))  # 1 x n x hid
            # WARNING: no need to register parameter?
            model.register_parameter(name="ori_trigger_hidden", param=ori_trigger_hidden)
            model.ori_trigger_hidden = ori_trigger_hidden
        else:
            ref_ori_trigger_hidden = nn.Parameter(torch.cat(trigger_hidden_list, dim=1))  # 1 x n x hid
            ref_ori_trigger_hidden.requires_grad = False
            model.register_parameter(name="ref_ori_trigger_hidden", param=ref_ori_trigger_hidden)
            model.ref_ori_trigger_hidden = ref_ori_trigger_hidden
            
        if trigger_format == "token":  # learn a continuous embedding
            trigger_embedding_list = []
            for _ in range(num_of_triggers):
                trigger_embedding_i = copy.deepcopy(bos_embeddings)
                trigger_embedding_list.append(trigger_embedding_i)
            if not ref:
                ori_trigger_embedding = nn.Parameter(torch.cat(trigger_embedding_list, dim=1))  # bze x n x emb_size
                model.ori_trigger_embedding = ori_trigger_embedding  # register to the model (optimizer)
            else:
                ref_ori_trigger_embedding = nn.Parameter(torch.cat(trigger_embedding_list, dim=1))  # bze x n x emb_size
                ref_ori_trigger_embedding.requires_grad = False
                model.ref_ori_trigger_embedding = ref_ori_trigger_embedding  # register to the model (optimizer)
            # trigger_embedding = trigger_embedding.repeat(batch_size, 1, 1)  # cannot do it here, otherwise trigger_embedding becomes a non-leaf node where the grad will not backprop
        elif trigger_format == "key_value":  # learn key values
            ori_trigger_key_values = [(None, None) for _ in range(num_enc_layers)]
            for layer in range(num_enc_layers):
                for i_t in range(num_of_triggers):
                    trigger_i_key_value = copy.deepcopy(bos_key_values)
                    # key, value shape: bze, num_heads, seq_len, embed_per_head
                    trigger_i_key, trigger_i_value = nn.Parameter(trigger_i_key_value[layer][0]), \
                                                     nn.Parameter(trigger_i_key_value[layer][1])

                    if not ref:
                        trigger_i_key.requires_grad = True
                        trigger_i_value.requires_grad = True
                    else:
                        trigger_i_key.requires_grad = False
                        trigger_i_value.requires_grad = False
                        
                    if ori_trigger_key_values[layer][0] is None:
                        ori_trigger_key_values[layer] = (trigger_i_key, trigger_i_value)
                    else:
                        # if multiple triggers
                        trigger_key = nn.Parameter(torch.cat((ori_trigger_key_values[layer][0], trigger_i_key), dim=-2))
                        trigger_value = nn.Parameter(torch.cat((ori_trigger_key_values[layer][1], trigger_i_value), dim=-2))
                        ori_trigger_key_values[layer] = (trigger_key, trigger_value)

                if not ref:
                    # register parameter into optimizer
                    key_name = "l_%d_key" % layer
                    value_name = "l_%d_value" % layer
                else:
                    key_name = "ref_l_%d_key" % layer
                    value_name = "ref_l_%d_value" % layer
                    
                if num_of_triggers == 1:
                    model.register_parameter(name=key_name, param=trigger_i_key)
                    model.register_parameter(name=value_name, param=trigger_i_value)
                else:
                    model.register_parameter(name=key_name, param=trigger_key)
                    model.register_parameter(name=value_name, param=trigger_value)
                    
            if not ref:
                ori_trigger_key_values = tuple(ori_trigger_key_values)
                model.ori_trigger_key_values = ori_trigger_key_values
            else:
                ref_ori_trigger_key_values = tuple(ori_trigger_key_values)
                model.ref_ori_trigger_key_values = ori_trigger_key_values
            # trigger_key_values = expand_past(trigger_key_values, num_layers, batch_size)  # similar to trigger_embedding, need leaf level grad
        else:
            assert False, "trigger_format: %s not supported" % trigger_format

In [10]:
init_trigger(model, tokenizer, num_of_triggers, trigger_format)
init_trigger(model, tokenizer, num_of_triggers, trigger_format, ref=True)

# optimizer
param_optimizer = list(filter(lambda p: p[1].requires_grad, list(model.named_parameters())))

# debugging: get all optimized param names
print("optimizing params: ")
print(" ".join(o[0] for o in param_optimizer))

no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {
        'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay': 0.0,
    },
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters,
                  lr=config["lr"],
                  eps=config["adam_epsilon"])

optimizing params: 
ori_trigger_hidden l_0_key l_0_value l_1_key l_1_value v_head.summary.weight v_head.summary.bias


In [11]:
wandb.watch(model, log='all')

[<wandb.wandb_torch.TorchGraph at 0x7fd879a11320>]

In [12]:
mode = "train"
# train_bz = config["batch_size"]

shuffle_data = True

# WARNING: should change input
# context_list = read_file("persona_train.txt")
context_list = read_file("data/trigger_bad_train.txt")

print("WARNING: Training mode with shuffle_data = True")



In [13]:
# get probability distribution warper
logits_warper = model._get_logits_warper(
    top_k=model.config.top_k, top_p=model.config.top_p, temperature=model.config.temperature, num_beams=model.config.num_beams
)

# WARNING: use hyperparameters from the config instead of the following
logits_processor = model._get_logits_processor(
    repetition_penalty=model.config.repetition_penalty,
    no_repeat_ngram_size=model.config.no_repeat_ngram_size,
    bad_words_ids=None,
    min_length=model.config.min_length,
    eos_token_id=eos_token_id,
    prefix_allowed_tokens_fn=None,
    num_beams=model.config.num_beams,
    num_beam_groups=model.config.num_beam_groups,
    diversity_penalty=model.config.diversity_penalty,
)


if model.config.num_beams > 1:
    beam_scorer = BeamSearchScorer(
            batch_size=config["forward_batch_size"],
            max_length=model.config.max_length,
            num_beams=model.config.num_beams,
            device=device,
            length_penalty=model.config.length_penalty,
            do_early_stopping=model.config.early_stopping,
            num_beam_hyps_to_keep=1,
        )

In [14]:
def generate_sentence_with_trigger(text_list, num_layers, cur_num_of_triggers):
    # cur_num_of_triggers: different from "num_of_triggers" in the config, can be 0 if is ref or num_of_triggers
    batch_size = len(text_list)
    
    # prepare past
    past = expand_past(bos_key_values, num_layers, batch_size)
    if cur_num_of_triggers > 0:
        if trigger_format == "token":
            trigger_embedding = model.ori_trigger_embedding.repeat(batch_size, 1, 1)
            lm_trigger_output = model.model.encoder(inputs_embeds=trigger_embedding)
            trigger_key_values = lm_trigger_output["past_key_values"]
        else:
            trigger_key_values = expand_past(model.ori_trigger_key_values, num_layers, batch_size)
        past = concat_past(past, trigger_key_values, num_layers)
        
    # prepare hidden
    prev_hidden = bos_hidden.repeat(batch_size, 1, 1)
    if cur_num_of_triggers > 0:
        trigger_hidden = model.ori_trigger_hidden
        trigger_hidden = trigger_hidden.repeat(batch_size, 1, 1)
        prev_hidden = torch.cat((prev_hidden, trigger_hidden), dim=1)  # bze, seq_len, hid
    
    # prepare context
    prev_length = prev_hidden.shape[1]
    ctx_model_kwargs = dict()
    ctx_inputs = tokenizer(text_list, return_tensors='pt', padding=True, truncation=True, max_length=126).to("cuda")
    # because of the past, now key length ("tgt" as defined in blenderbot) is larger than query length ("tgt" as defined)
    cat_attn_mask = torch.cat((torch.ones(ctx_inputs["attention_mask"].shape[0], prev_length, device="cuda", dtype=torch.long), ctx_inputs["attention_mask"]), dim=-1)
    ctx_model_kwargs["attention_mask"] = cat_attn_mask
    
    # get encoder output
    trigger_encoder_kwargs = {
            argument: value for argument, value in ctx_model_kwargs.items() if not argument.startswith("decoder_")
        }
    trigger_encoder_kwargs["past_key_values"] = past
    ctx_output = model.model.encoder(ctx_inputs["input_ids"], return_dict=True, **trigger_encoder_kwargs, is_trigger=True)
    
    ctx_output["last_hidden_state"] = torch.cat((prev_hidden, ctx_output["last_hidden_state"]), dim=1)

    ctx_model_kwargs["encoder_outputs"] = ctx_output
    
    # generate one sentence with trigger
    ctx_input_ids = ctx_inputs['input_ids']
    dec_input_ids = model._prepare_decoder_input_ids_for_generation(
                    ctx_input_ids, decoder_start_token_id=bos_token_id, bos_token_id=bos_token_id)
     
    is_greedy_gen_mode = (model.config.num_beams == 1) and (model.config.num_beam_groups == 1) and model.config.do_sample is False
    is_sample_gen_mode = (model.config.num_beams == 1) and (model.config.num_beam_groups == 1) and model.config.do_sample is True
    is_beam_gen_mode = (model.config.num_beams > 1) and (model.config.num_beam_groups == 1) and model.config.do_sample is False
    is_beam_sample_gen_mode = (model.config.num_beams > 1) and (model.config.num_beam_groups == 1) and model.config.do_sample is True

    if is_greedy_gen_mode:
        res = model.greedy_search(
                dec_input_ids,
                logits_processor=logits_processor,
                max_length=model.config.max_length,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                output_scores=False,
                return_dict_in_generate=return_dict_in_generate,
                **ctx_model_kwargs,
            )
        
    elif is_sample_gen_mode:

        # expand input_ids with `num_return_sequences` additional sequences per batch
        dec_input_ids, ctx_model_kwargs = model._expand_inputs_for_generation(
            dec_input_ids,
            expand_size=model.config.num_return_sequences,
            is_encoder_decoder=True,
            **ctx_model_kwargs,
        )

        # sample
        res = model.sample(
            dec_input_ids,
            logits_processor=logits_processor,
            logits_warper=logits_warper,
            max_length=model.config.max_length,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            output_scores=False,
            return_dict_in_generate=model.config.return_dict_in_generate,
            **ctx_model_kwargs,
        )
    elif is_beam_gen_mode:
        # interleave with `num_beams`
        dec_input_ids, ctx_model_kwargs = model._expand_inputs_for_generation(
            dec_input_ids, expand_size=model.config.num_beams, is_encoder_decoder=True, **ctx_model_kwargs
        )
        res = model.beam_search(
            dec_input_ids,
            beam_scorer,
            logits_processor=logits_processor,
            max_length=model.config.max_length,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            output_scores=False,
            return_dict_in_generate=model.config.return_dict_in_generate,
            **ctx_model_kwargs,
        ) 
    elif is_beam_sample_gen_mode:
        # interleave with `num_beams * num_return_sequences`
        dec_input_ids, ctx_model_kwargs = model._expand_inputs_for_generation(
            dec_input_ids, expand_size=model.config.num_beams * model.config.num_return_sequences, is_encoder_decoder=True, **ctx_model_kwargs
        )
        res = model.beam_sample(
                dec_input_ids,
                beam_scorer,
                logits_processor=logits_processor,
                logits_warper=logits_warper,
                max_length=model.config.max_length,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                output_scores=output_scores,
                return_dict_in_generate=model.config.return_dict_in_generate,
                **ctx_model_kwargs,
            )
        
        
    generated_sentence_raw = tokenizer.batch_decode(res)  
    generated_sentence_clean = clean_blender_generation(generated_sentence_raw)
        
    return generated_sentence_clean

In [15]:
def clean_blender_generation(raw_texts):
    clean_texts = list()
    for sentence_i in raw_texts:
        sentence_i_0 = sentence_i.split("<s>")[-1]
        sentence_i_1 = sentence_i_0.split("</s>")[0]
        clean_texts.append(sentence_i_1.strip())
    return clean_texts

In [16]:
def convert_cls_examples_to_features(texts_a, texts_b, max_length):
    all_cls_input_ids, all_cls_attention_mask = list(), list()
    for text_a, text_b in zip(texts_a, texts_b):
        cls_inputs = cls_tokenizer.encode_plus(text_a, text_b, add_special_tokens=True, max_length=max_length, truncation=True)
        cls_input_ids = cls_inputs["input_ids"]
        cls_attention_mask = [1] * len(cls_input_ids)
        
        padding_length = max_length - len(cls_input_ids)
        
        cls_input_ids = cls_input_ids + ([cls_tokenizer.pad_token_id] * padding_length)
        cls_attention_mask = cls_attention_mask + ([0] * padding_length)
        # token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)  # not used in RoBERTa
        
        all_cls_input_ids.append(cls_input_ids)
        all_cls_attention_mask.append(cls_attention_mask)
    
    all_cls_input_tensors = torch.tensor(all_cls_input_ids, dtype=torch.long, device=device)
    all_cls_attention_mask_tensors = torch.tensor(all_cls_attention_mask, dtype=torch.long, device=device)
    
    return all_cls_input_tensors, all_cls_attention_mask_tensors
    

In [17]:
class PPOTrainer:
    """
    The PPO_trainer uses Proximal Policy Optimization to optimise language models.
    """

    default_params = {
        "lr": 1.41e-5,
        "adap_kl_ctrl": True,
        "init_kl_coef":0.2,
        "target": 6,
        "horizon":10000,
        "gamma":1,
        "lam":0.95,
        "cliprange": .2,
        "cliprange_value":.2,
        "vf_coef":.1,
        "batch_size": 256,
        "forward_batch_size": 16,
        "ppo_epochs": 4,
        "ppo_mini_batch-size": 4,
    }

    def __init__(self, model, optimizer, **ppo_params):
        """
        Initialize PPOTrainer.
        Args:
            model (torch.model): Hugging Face transformer GPT2 model with value head
            ref_model (torch.model): Hugging Face transformer GPT2 refrence model used for KL penalty
            ppo_params (dict or None): PPO parameters for training. Can include following keys:
                'lr' (float): Adam learning rate, default: 1.41e-5
                'batch_size' (int): Number of samples per optimisation step, default: 256
                'forward_batch_size' (int): Number of samples forward passed through model at a time, default: 16
                'ppo_epochs' (int): Number of optimisation epochs per batch of samples, default: 4
                'gamma' (float)): Gamma parameter for advantage calculation, default: 1.
                'lam' (float): Lambda parameter for advantage calcualation, default: 0.95
                'cliprange_value' (float): Range for clipping values in loss calculation, default: 0.2
                'cliprange' (float): Range for clipping in PPO policy gradient loss, default: 0.2
                'vf_coef' (float): Scaling factor for value loss, default: 0.1
                'adap_kl_ctrl' (bool): Use adaptive KL control, otherwise linear, default: True
                'init_kl_coef' (float): Initial KL penalty coefficient (used for adaptive and linear control), default: 0.2
                'target' (float): Target KL value for adaptive KL control, default: 6.0
                'horizon' (float): Horizon for adaptive KL control, default: 10000
        """
        self.ppo_params = self.default_params
        self.ppo_params.update(ppo_params)

        # self.ref_model = ref_model
        self.model = model
        # self.optimizer = Adam(model.parameters(), lr=self.ppo_params['lr'])
        self.optimizer = optimizer

        self.kl_ctl = AdaptiveKLController(self.ppo_params['init_kl_coef'],
                                           self.ppo_params['target'],
                                           self.ppo_params['horizon'])


    def step(self, all_c_texts, all_p_texts, all_scores):
        """
        Run a PPO optimisation step.
        args:
            # query (torch.tensor): tensor containing the encoded queries, shape [batch_size, query_length]
            # response (torch.tensor): tensor containing the encoded responses, shape [batch_size, response_length]
            # scores (torch.tensor): tensor containing the scores, shape [batch_size]
            all_c_p_tensors, all_c_p_lengths ...: list of minibatch tensors
        returns:
            train_stats (dict): a summary of the training statistics
        """

        bs = self.ppo_params['batch_size']
        mini_bs = self.ppo_params["ppo_mini_batch_size"]
        timing = dict()
        t0 = time.time()

        t = time.time()
        
#         print("batched trigger forward + compute_reward")
#         logprobs, ref_logprobs, values, rewards, non_score_reward, kl_coef, real_p_tensors, real_c_p_tensors, real_c_p_lengths, real_c_lengths = self.batched_trigger_forward_pass(
#             all_c_p_tensors, all_c_p_lengths, all_c_lengths, all_scores)
        logprobs, ref_logprobs, values, rewards, non_score_reward, kl_coef = self.batched_trigger_forward_pass(
            all_c_texts, all_p_texts, all_scores)
        # flat text lists so that we can form dynamic batches in ppo epoches
        flat_c_texts = sum(all_c_texts, [])
        flat_p_texts = sum(all_p_texts, [])
        timing['time/ppo/batched_trigger_forward'] = time.time()-t
#         print("finished in %.2f seconds\n" % (time.time()-t))


        t = time.time()

        all_stats = []
        idxs = list(range(bs))

        for ppo_epoch_i in range(self.ppo_params['ppo_epochs']):
            if shuffle_data:
                random.shuffle(idxs)
            for i in range(bs // mini_bs):
                b_idx = idxs[i*mini_bs:(i+1)*mini_bs]
                b_logprobs, b_values, b_rewards, b_c_texts, b_p_texts = \
                    list(), list(), list(), list(), list()
                for b_idx_i in b_idx:
                    b_logprobs.append(logprobs[b_idx_i])
                    b_values.append(values[b_idx_i])
                    b_rewards.append(rewards[b_idx_i])
                    b_c_texts.append(flat_c_texts[b_idx_i])
                    b_p_texts.append(flat_p_texts[b_idx_i])
                
#                 print("\n\n------ppo_epoch: %d/%d; minibatch: %d/%d--------" % (ppo_epoch_i + 1, self.ppo_params['ppo_epochs'], i + 1, bs // mini_bs))
                train_stats = self.train_minibatch(b_logprobs, b_values, b_rewards, b_c_texts, b_p_texts)

                all_stats.append(train_stats)

        timing['time/ppo/optimize_step'] = time.time()-t

        t = time.time()
        train_stats = stack_dicts(all_stats)

        # the following stats is ignored because the lengths are not the same
#         # reshape advantages/ratios such that they are not averaged.
#         train_stats['policy/advantages'] = torch.flatten(train_stats['policy/advantages']).unsqueeze(0)
#         train_stats['policy/ratio'] = torch.flatten(train_stats['policy/ratio']).unsqueeze(0)

        stats = self.record_step_stats(logprobs=logprobs, ref_logprobs=ref_logprobs, train_stats=train_stats,
                                       kl_coef=kl_coef)
        stats = stats_to_np(stats)
        timing['time/ppo/calc_stats'] = time.time()-t

        self.kl_ctl.update(stats['objective/kl'], self.ppo_params['batch_size'])

        timing['time/ppo/total'] = time.time()-t0
        stats.update(timing)
        return stats

    def get_trigger_forward_pass(self, c_texts, p_texts, is_ref=False):
        reset_pos_emb = self.ppo_params["reset_pos_emb"]
        num_of_triggers = self.ppo_params["num_of_triggers"]
        trigger_format = self.ppo_params["trigger_format"]
        TRIGGER_POSITION_ID = self.ppo_params["TRIGGER_POSITION_ID"]
        device = self.ppo_params["device"]
        
        mini_batch_size = len(c_texts)
        
        # WARNING: bos_key_value need to be passed
        past = expand_past(bos_key_values, num_enc_layers, mini_batch_size)  # deep copy? shouldn't be modifed

        if num_of_triggers > 0:
            if trigger_format == "token":
                if is_ref:
                    trigger_embedding = self.model.ref_ori_trigger_embedding.repeat(mini_batch_size, 1, 1)
                else:
                    trigger_embedding = self.model.ori_trigger_embedding.repeat(mini_batch_size, 1, 1)
                trigger_key_values = self.model.model.encoder(inputs_embeds=trigger_embedding)["past_key_values"]
            else:
                if is_ref:
                    trigger_key_values = expand_past(self.model.ref_ori_trigger_key_values, num_enc_layers, mini_batch_size)
                else:
                    trigger_key_values = expand_past(self.model.ori_trigger_key_values, num_enc_layers, mini_batch_size)

            past = concat_past(past, trigger_key_values, num_enc_layers)
        
        # prepare hidden
        prev_hidden = bos_hidden.repeat(mini_batch_size, 1, 1)
        if num_of_triggers > 0:
            trigger_hidden = model.ori_trigger_hidden
            trigger_hidden = trigger_hidden.repeat(mini_batch_size, 1, 1)
            prev_hidden = torch.cat((prev_hidden, trigger_hidden), dim=1)  # bze, seq_len, hid
            
        # prepare context
        prev_length = prev_hidden.shape[1]
        ctx_model_kwargs = dict()
        ctx_inputs = tokenizer(c_texts, return_tensors='pt', padding=True, truncation=True, max_length=126).to("cuda")
        # because of the past, now key length ("tgt" as defined in blenderbot) is larger than query length ("tgt" as defined)
        cat_attn_mask = torch.cat((torch.ones(ctx_inputs["attention_mask"].shape[0], prev_length, device="cuda", dtype=torch.long), ctx_inputs["attention_mask"]), dim=-1)
        ctx_model_kwargs["attention_mask"] = cat_attn_mask
        
        # get encoder output
        trigger_encoder_kwargs = {
                argument: value for argument, value in ctx_model_kwargs.items() if not argument.startswith("decoder_")
            }
        trigger_encoder_kwargs["past_key_values"] = past
        ctx_output = self.model.model.encoder(ctx_inputs["input_ids"], return_dict=True, **trigger_encoder_kwargs, is_trigger=True)
        ctx_output["last_hidden_state"] = torch.cat((prev_hidden, ctx_output["last_hidden_state"]), dim=1)
        ctx_model_kwargs["encoder_outputs"] = ctx_output
        
        # prepare decoder
        # Note: when calling tokenizer, it will append <eos> to the end (2) but not <bos> at the beginning
        prompt_inputs = tokenizer(p_texts, return_tensors='pt', padding=True, truncation=True).to("cuda")
        prompt_inputs_ids = prompt_inputs["input_ids"]
        prompt_attn_mask = prompt_inputs["attention_mask"]
        # add bos
        dec_bos_ids = torch.ones((prompt_inputs_ids.shape[0], 1), dtype=torch.long, device=device) * bos_token_id
        dec_bos_mask = torch.ones((prompt_inputs_ids.shape[0], 1), dtype=torch.long, device=device)
        dec_inputs_ids = torch.cat((dec_bos_ids, prompt_inputs_ids), dim=1)
        dec_attn_mask = torch.cat((dec_bos_mask, prompt_attn_mask), dim=1)
        prompt_length = torch.sum(dec_attn_mask, dim=-1)  # including bos and eos. shape: [bze]
        # dec_attn_mask needs to be uni-directional? 
        # A: do not pass dec_attn_mask to the model. In decoder, when input length > 1, causal mask is created
        
#         print("debugging!")
#         print(c_texts)
#         print(p_texts)
#         print("ctx inputs: %s" % str(ctx_inputs["input_ids"].shape))
#         print("encoder output hidden: %s" % str(ctx_output["last_hidden_state"].shape))
#         print(dec_inputs_ids)
#         print(dec_attn_mask)
#         print(prompt_length)
        
        # Note: attention_mask is for encoder. "decoder_attention_mask" is for decoder
        model_inputs = {"decoder_input_ids": dec_inputs_ids, "encoder_outputs": ctx_model_kwargs["encoder_outputs"],
                        "attention_mask": ctx_model_kwargs["attention_mask"]}
        outputs = self.model(**model_inputs, return_dict=True)
        
        logits = outputs["logits"]
        value = outputs["value"]
        
#         print("dec_inputs_ids: %s" % str(dec_inputs_ids.shape))
#         print("logits: %s" % str(logits.shape))
#         print("value: %s" % str(value.shape))
        
        return logits, value, dec_inputs_ids, prompt_length
        # Note: different from LM where the bos token does not attend to trigger so that it will cause the problem for value,
        # for encoder-decoder models, logits, and values are from the decoder only, where all the tokens (including decoder_bos)
        # attend to triggers. Therefore, it should not have the problems as before

        
    def batched_trigger_forward_pass(self, all_c_texts, all_p_texts, all_scores):
        # combines batched_forward_pass and compute_rewards
        logprobs, ref_logprobs, values = list(), list(), list()
        rewards, non_score_rewards = list(), list()
        
        for i in range(len(all_c_texts)):
            mini_i_c = all_c_texts[i]
            mini_i_p = all_p_texts[i]
            
            logits, v, p_ids, p_length = self.get_trigger_forward_pass(mini_i_c, mini_i_p)
            ref_logits, _, _, _ = self.get_trigger_forward_pass(mini_i_c, mini_i_p, is_ref=True)
            lp = logprobs_from_logits(logits[:, :-1, :], p_ids[:, 1:])
            ref_lp = logprobs_from_logits(ref_logits[:, :-1, :], p_ids[:, 1:])
            
            for j in range(len(mini_i_c)):  # loop through the minibatch to get the real indices
                start = 0
                end = p_length[j] - 1
                values.append(v[j:j+1, start:end].detach())
                ij_logprob = lp[j:j+1, start:end].detach()
                ij_ref_logprob = ref_lp[j:j+1, start:end].detach()
                logprobs.append(ij_logprob)
                ref_logprobs.append(ij_ref_logprob)
                
                # compute rewards
                ij_reward, ij_non_score_reward, kl_coef = self.compute_rewards(all_scores[i][j], ij_logprob, ij_ref_logprob)
                rewards.append(ij_reward)
                non_score_rewards.append(ij_non_score_reward)
                
        return logprobs, ref_logprobs, values, rewards, non_score_rewards, kl_coef

    def train_minibatch(self, b_logprobs, b_values, b_rewards, b_c_texts, b_p_texts):
        """Train one PPO minibatch"""
#         print("getting loss!")
        loss_p, loss_v, train_stats  = self.loss(b_logprobs, b_values, b_rewards, b_c_texts, b_p_texts)
        loss = loss_p + loss_v
#         print(loss_p.item(), loss_v.item(), loss.item())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return train_stats

    def compute_rewards(self, scores, logprobs, ref_logprobs):
        """Compute per token rewards from scores and KL-penalty."""
        kl = torch.abs(logprobs - ref_logprobs)
        non_score_reward = -self.kl_ctl.value * kl
        rewards = non_score_reward.clone().detach()
        rewards[:, -1] += scores
        return rewards, non_score_reward, self.kl_ctl.value

    # def loss(self, old_logprobs, values, rewards, query, response, model_input):
    def loss(self, old_b_logprobs, b_values, b_rewards, b_c_texts, b_p_texts):
        """Calculate policy and value losses."""
        
        # Note: values, old_logprobs are for prompts only (without context)

        mini_bs = self.ppo_params["ppo_mini_batch_size"]

        b_logits, b_vpred, b_p_ids, b_p_length = self.get_trigger_forward_pass(b_c_texts, b_p_texts)
        b_logprob = logprobs_from_logits(b_logits[:, :-1, :], b_p_ids[:, 1:])  # min_bs x (batch_max_length - 1)
        
        b_pg_loss, b_vf_loss, b_loss, b_entropy, b_approxkl, b_policykl, b_pg_clipfrac,\
        b_advantages_mean, b_return_mean, b_return_var, b_mean_vpred, b_error, b_vf_clipfrac, b_value_mean, b_value_var = \
            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        
        for j in range(mini_bs): 
            start = 0
            end = b_p_length[j] - 1

            logprob = b_logprob[j:j+1, start:end]
            vpred = b_vpred[j:j+1, start:end]
            gen_len = end - start
            
            old_logprobs = old_b_logprobs[j]
            rewards = b_rewards[j]
            values = b_values[j]
            
            lastgaelam = 0
            advantages_reversed = []
            for t in reversed(range(gen_len)):
                nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
                delta = rewards[:, t] + self.ppo_params['gamma'] * nextvalues - values[:, t]
                lastgaelam = delta + self.ppo_params['gamma'] * self.ppo_params['lam'] * lastgaelam
                advantages_reversed.append(lastgaelam)
            advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)

            returns = advantages + values
            advantages = whiten(advantages)
            advantages = advantages.detach()

            vpredclipped = clip_by_value(vpred,
                                         values - self.ppo_params["cliprange_value"],
                                         values + self.ppo_params["cliprange_value"])

            vf_losses1 = (vpred - returns)**2
            vf_losses2 = (vpredclipped - returns)**2
            vf_loss = .5 * torch.mean(torch.max(vf_losses1, vf_losses2))
            vf_clipfrac =  torch.mean(torch.gt(vf_losses2, vf_losses1).double())

            ratio = torch.exp(logprob - old_logprobs)

            pg_losses = -advantages * ratio
            pg_losses2 = -advantages * torch.clamp(ratio,
                                                   1.0 - self.ppo_params['cliprange'],
                                                   1.0 + self.ppo_params['cliprange'])
            
#             print("advantages")
#             print(advantages)
#             print("ratio")
#             print(ratio)
#             print("pg_losses1: %s" % str(torch.mean(pg_losses).item()))
#             print(pg_losses)
#             print("pg_losses2: %s" % str(torch.mean(pg_losses2).item()))
#             print(pg_losses2)
#             print("pg_loss: %s" % str(torch.mean(torch.max(pg_losses, pg_losses2))))
#             print(torch.max(pg_losses, pg_losses2))
                  
            pg_loss = torch.mean(torch.max(pg_losses, pg_losses2))
            pg_clipfrac = torch.mean(torch.gt(pg_losses2, pg_losses).double())

            loss = pg_loss + self.ppo_params['vf_coef'] * vf_loss

            approxkl = .5 * torch.mean((logprob - old_logprobs)**2)
            policykl = torch.mean(logprob - old_logprobs)
            return_mean, return_var = torch.mean(returns), torch.var(returns)
            value_mean, value_var = torch.mean(values), torch.var(values)

            b_pg_loss += pg_loss
            b_vf_loss += vf_loss
            b_loss += loss
            b_approxkl += approxkl
            b_policykl += policykl
            b_pg_clipfrac += pg_clipfrac
            b_advantages_mean += torch.mean(advantages)
            b_return_mean += return_mean
            b_return_var += return_var
            b_mean_vpred += torch.mean(vpred)
            b_error += torch.mean((vpred - returns) ** 2)
            b_vf_clipfrac += vf_clipfrac
            b_value_mean += value_mean
            b_value_var += value_var

        stats = dict(
            loss=dict(policy=b_pg_loss/mini_bs, value=b_vf_loss/mini_bs, total=b_loss/mini_bs),
            policy=dict(approxkl=b_approxkl/mini_bs, policykl=b_policykl/mini_bs, clipfrac=b_pg_clipfrac/mini_bs,
                        advantages_mean=b_advantages_mean/mini_bs),
            returns=dict(mean=b_return_mean/mini_bs, var=b_return_var/mini_bs),
            val=dict(vpred=b_mean_vpred/mini_bs, error=b_error/mini_bs,
                     clipfrac=b_vf_clipfrac/mini_bs, mean=b_value_mean/mini_bs, var=b_value_var/mini_bs),
        )
        return b_pg_loss/mini_bs, self.ppo_params['vf_coef'] * b_vf_loss/mini_bs, flatten_dict(stats)


    def record_step_stats(self, kl_coef, **data):
        """Record training step statistics."""
        all_mean_kl = 0
        bs = self.ppo_params['batch_size']
        for i in range(bs):
            kl = torch.abs(data["logprobs"][i] - data["ref_logprobs"][i])
            mean_kl = torch.mean(torch.sum(kl, axis=-1))
            all_mean_kl += mean_kl

        # kl = data['logprobs'] - data['ref_logprobs']
        # mean_kl = torch.mean(torch.sum(kl, axis=-1))

        stats = {
            'objective/kl': all_mean_kl / bs,  # need this for adaptive kl controller
            'objective/kl_coef': kl_coef,
        }

        for k, v in data['train_stats'].items():
            stats[f'ppo/{k}'] = torch.mean(v, axis=0)
        stats['ppo/val/var_explained'] = 1 - stats['ppo/val/error'] / stats['ppo/returns/var']
        return stats
    

In [18]:
assert False

AssertionError: 

In [19]:
save_path = "/mnt/dian/trigger_experiments/safety_adv_kl"

ppo_trainer = PPOTrainer(model, optimizer, **config)
fbs = config['forward_batch_size']

for epoch in tqdm(range(int(np.ceil(config["steps"]/config["batch_size"])))):
    print("***********Epoch: %d/%d*************" % (epoch + 1, int(np.ceil(config["steps"]/config["batch_size"]))))
    torch.cuda.empty_cache()
    logs = dict()
    game_data = dict()
    timing = dict()
    t0 = time.time()
    
    #### get a batch from the dataset
    if mode == "train" and shuffle_data:
        random.shuffle(context_list)
    cond_list = context_list[:config["batch_size"]]
    
#     # this pad to the longest of all. may not be necessary
#     all_input_ids, all_attention_masks, batch_min_length, batch_max_length, all_lengths = prep_inputs(cond_list, tokenizer, device, t_pad_token)
    
    all_c_lengths = list()
    all_c_p_tensors, all_c_p_texts, all_c_p_lengths = list(), list(), list()
    all_c_p_r_tensors, all_c_p_r_texts, all_c_p_r_lengths = list(), list(), list()
    all_rewards = list()
    all_c_p_r_rewards, all_c_p_rewards, all_c_p_rewards_adjusted = list(), list(), list()
    
    log_context, log_prompt, log_response = list(), list(), list()
    
    all_c_texts, all_p_texts = list(), list()
    all_r_texts, all_c_p_r_texts = list(), list()  # for debugging
    
    #### get prompt from model
    for i in range(int(config["batch_size"]/fbs)):
        ctx_i = cond_list[i*fbs:(i+1)*fbs]
        log_context += ctx_i
        
        p_texts = generate_sentence_with_trigger(ctx_i, num_enc_layers, num_of_triggers)
        log_prompt += p_texts
        
        c_p_texts = list()
        for c, p in zip(ctx_i, p_texts):
            c_p_texts.append("%s   %s" % (c, p))
            
        c_p_inputs = tokenizer(c_p_texts, return_tensors='pt', padding=True, truncation=True).to(device)
        try:
            r_tensor = model.generate(c_p_inputs['input_ids'], num_beams=model.config.num_beams, do_sample=model.config.do_sample)
        except Exception as e:
            print(c_p_inputs["input_ids"].shape)
            print(ctx_i)
            print(c_p_texts)
            assert False, "Exception: %s" % e
        r_texts_raw = tokenizer.batch_decode(r_tensor)
        r_texts = clean_blender_generation(r_texts_raw)
        log_response += r_texts
        
        c_p_r_texts = list()
        for c_p, r in zip(c_p_texts, r_texts):
            c_p_r_texts.append("%s   %s" % (c_p, r))
            
        
        all_c_texts.append(ctx_i)
        all_p_texts.append(p_texts)
        all_r_texts.append(r_texts)
        all_c_p_r_texts.append(c_p_r_texts)
        
        
        # run classifier for rewards        
        cls_max_length = 128
        
        cls_c_p_r_inputs, cls_c_p_r_mask = convert_cls_examples_to_features(r_texts, c_p_texts, cls_max_length)
        with torch.no_grad():
            res = cls_model(cls_c_p_r_inputs, cls_c_p_r_mask)["logits"][:, config["tgt_label"]].detach() 
        
        # WARNING: set hyperparameters here
        prompt_reward = True
        c_p_reward_weight = 0.2
        if prompt_reward:
            cls_c_p_inputs, cls_c_p_mask = convert_cls_examples_to_features(p_texts, ctx_i, cls_max_length)
            with torch.no_grad():
                c_p_res = cls_model(cls_c_p_inputs, cls_c_p_mask)["logits"][:, config["tgt_label"]].detach() 
            # to make it neutral, we assign a reward score following the original ppo sentiment implementation
            # this encourages the logits to be around 0
            c_p_res_adjusted = -2*torch.abs(c_p_res)+4
            all_c_p_r_rewards.append(res)
            all_c_p_rewards.append(c_p_res)
            all_c_p_rewards_adjusted.append(c_p_res_adjusted)
            res = res + c_p_reward_weight * c_p_res_adjusted
            
        all_rewards.append(res)  # [bze]
        
        
    # WARNING: Moving the following to outside of the for loop to debug trigger key_value
#     print("sampled sentences")
#     for ck_i, ck_text in enumerate(all_c_p_r_texts):
#         print(all_c_texts[ck_i])
#         print(all_p_texts[ck_i])
#         print(ck_text)
#         print(all_rewards[ck_i])
#         print(torch.mean(all_rewards[ck_i]))
#         print()
#     print("===========\n\n")

    print("Debuggin current key_value")
    print(model.l_1_value[:, 0, :, :10])
    print(model.ori_trigger_hidden[:, :, :10])
    print(model.ref_ori_trigger_hidden[:, :, :10])
    print("++++++++++++\n\n\n")
#     assert False, "Stop here. For PPO debugging, run the following in a different cell"
    
    # should the following be in the fbs loop? Not really. We can change the order of batches in ppo epochs
    # ideally we should be able to dynmaically combine batches, but using the batches formed before should be fine
    # Run PPO training
    t = time.time()
    stats = ppo_trainer.step(all_c_texts, all_p_texts, all_rewards)
    timing['time/optimization'] = time.time()-t
    
    #### Log everything
    timing['time/epoch'] = time.time()-t0
    logs.update(timing)
    logs.update(stats)
    log_name = "game_log_e%d" % (epoch + 1)
    log_rewards = torch.cat(all_rewards)
    if prompt_reward:
        log_c_p_rewards = torch.cat(all_c_p_rewards)
        log_c_p_rewards_adjusted = torch.cat(all_c_p_rewards_adjusted)
        log_c_p_r_rewards = torch.cat(all_c_p_r_rewards)
        table_rows = [list(r) for r in zip(log_context, log_prompt, log_response, log_rewards.cpu().tolist(), log_c_p_r_rewards.cpu().tolist(), log_c_p_rewards.cpu().tolist(), log_c_p_rewards_adjusted.cpu().tolist())]
        logs.update({log_name:wandb.Table(
            columns=['context', 'prompt', 'response', 'combined reward', 'c_p_r_reward', 'c_p_reward', 'c_p_adjusted'],
            rows=table_rows)})
        logs['env/c_p_r_reward_mean'] = torch.mean(log_c_p_r_rewards).cpu().numpy()
        logs['env/c_p_r_reward_std'] = torch.std(log_c_p_r_rewards).cpu().numpy()
        logs['env/c_p_r_reward_dist'] = log_c_p_r_rewards.cpu().numpy()
        logs['env/combined_reward_mean'] = torch.mean(log_rewards).cpu().numpy()
        logs['env/c_p_reward_mean'] = torch.mean(log_c_p_rewards).cpu().numpy()
        logs['env/c_p_adjusted_mean'] = torch.mean(log_c_p_rewards_adjusted).cpu().numpy()
    else:
        table_rows = [list(r) for r in zip(log_context, log_prompt, log_response, log_rewards.cpu().tolist())]
        logs.update({log_name:wandb.Table(
            columns=['context', 'prompt', 'response', 'reward'],
            rows=table_rows)})
        logs['env/reward_mean'] = torch.mean(log_rewards).cpu().numpy()
        logs['env/reward_std'] = torch.std(log_rewards).cpu().numpy()
        logs['env/reward_dist'] = log_rewards.cpu().numpy()
    wandb.log(logs)
    
    
    # save trigger
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    save_filename = "%s/e%d.pt" % (save_path, epoch + 1)
    save_data = dict()
    save_data["ori_trigger_hidden"] = model.ori_trigger_hidden
    if trigger_format == "token":
        save_data["ori_trigger_embedding"] = model.ori_trigger_embedding
    else:
        save_data["ori_trigger_key_values"] = model.ori_trigger_key_values
    torch.save(save_data, save_filename)
        
        
        
        
        

  0%|          | 0/60 [00:00<?, ?it/s]

***********Epoch: 1/60*************
Debuggin current key_value
tensor([[[-0.8293,  0.0046,  0.6910, -0.8327,  0.0483,  0.2522, -0.3192,
          -1.4368, -0.2894,  0.1511]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





  2%|▏         | 1/60 [01:37<1:36:03, 97.69s/it]

***********Epoch: 2/60*************
Debuggin current key_value
tensor([[[-0.8332,  0.0077,  0.6890, -0.8399,  0.0445,  0.2581, -0.3105,
          -1.4388, -0.2875,  0.1549]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2149,  0.0960,  0.0414, -0.5515,  0.1555,  0.0160,  0.9014,
           0.5632,  0.2588,  0.1335]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





  3%|▎         | 2/60 [03:17<1:35:54, 99.21s/it]

***********Epoch: 3/60*************
Debuggin current key_value
tensor([[[-0.8351,  0.0071,  0.6891, -0.8380,  0.0468,  0.2637, -0.3080,
          -1.4383, -0.2916,  0.1560]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2237,  0.0967,  0.0288, -0.5622,  0.1472,  0.0251,  0.9073,
           0.5726,  0.2576,  0.1282]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





  5%|▌         | 3/60 [04:49<1:30:51, 95.64s/it]

***********Epoch: 4/60*************
Debuggin current key_value
tensor([[[-0.8375,  0.0086,  0.6882, -0.8413,  0.0397,  0.2683, -0.3047,
          -1.4424, -0.2965,  0.1618]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2337,  0.1020,  0.0141, -0.5762,  0.1391,  0.0373,  0.9159,
           0.5849,  0.2594,  0.1222]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





  7%|▋         | 4/60 [06:33<1:32:29, 99.10s/it]

***********Epoch: 5/60*************
Debuggin current key_value
tensor([[[-0.8342,  0.0099,  0.6856, -0.8468,  0.0438,  0.2663, -0.2990,
          -1.4453, -0.2965,  0.1629]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2381,  0.1026,  0.0059, -0.5840,  0.1310,  0.0447,  0.9258,
           0.5896,  0.2604,  0.1200]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





  8%|▊         | 5/60 [08:15<1:31:52, 100.22s/it]

***********Epoch: 6/60*************
Debuggin current key_value
tensor([[[-0.8328,  0.0102,  0.6802, -0.8509,  0.0449,  0.2654, -0.2984,
          -1.4456, -0.2911,  0.1663]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2469,  0.1144, -0.0048, -0.5959,  0.1262,  0.0588,  0.9367,
           0.6006,  0.2556,  0.1187]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 10%|█         | 6/60 [09:49<1:28:08, 97.94s/it] 

***********Epoch: 7/60*************
Debuggin current key_value
tensor([[[-0.8308,  0.0111,  0.6779, -0.8538,  0.0418,  0.2670, -0.2952,
          -1.4458, -0.2900,  0.1684]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2459,  0.1245, -0.0142, -0.6076,  0.1258,  0.0736,  0.9466,
           0.6034,  0.2578,  0.1158]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 12%|█▏        | 7/60 [11:30<1:27:22, 98.91s/it]

***********Epoch: 8/60*************
Debuggin current key_value
tensor([[[-0.8265,  0.0164,  0.6759, -0.8528,  0.0460,  0.2688, -0.2948,
          -1.4460, -0.2862,  0.1697]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2471,  0.1310, -0.0188, -0.6127,  0.1241,  0.0815,  0.9562,
           0.6079,  0.2638,  0.1151]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 13%|█▎        | 8/60 [13:15<1:27:27, 100.92s/it]

***********Epoch: 9/60*************
Debuggin current key_value
tensor([[[-0.8240,  0.0194,  0.6774, -0.8534,  0.0494,  0.2707, -0.2954,
          -1.4475, -0.2836,  0.1713]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2426,  0.1385, -0.0210, -0.6174,  0.1290,  0.0898,  0.9644,
           0.6107,  0.2689,  0.1110]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 15%|█▌        | 9/60 [14:50<1:24:11, 99.05s/it] 

***********Epoch: 10/60*************
Debuggin current key_value
tensor([[[-0.8224,  0.0177,  0.6799, -0.8518,  0.0504,  0.2677, -0.2959,
          -1.4511, -0.2796,  0.1721]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2423,  0.1475, -0.0256, -0.6234,  0.1274,  0.0963,  0.9686,
           0.6208,  0.2726,  0.1060]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 17%|█▋        | 10/60 [16:32<1:23:19, 99.99s/it]

***********Epoch: 11/60*************
Debuggin current key_value
tensor([[[-0.8211,  0.0179,  0.6776, -0.8498,  0.0559,  0.2687, -0.2954,
          -1.4546, -0.2829,  0.1728]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2455,  0.1586, -0.0228, -0.6282,  0.1300,  0.0993,  0.9767,
           0.6252,  0.2777,  0.1055]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 18%|█▊        | 11/60 [18:23<1:24:18, 103.24s/it]

***********Epoch: 12/60*************
Debuggin current key_value
tensor([[[-0.8131,  0.0165,  0.6760, -0.8498,  0.0629,  0.2675, -0.2954,
          -1.4542, -0.2811,  0.1748]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2427,  0.1686, -0.0167, -0.6308,  0.1386,  0.1026,  0.9828,
           0.6341,  0.2946,  0.0965]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 20%|██        | 12/60 [19:55<1:20:00, 100.01s/it]

***********Epoch: 13/60*************
Debuggin current key_value
tensor([[[-0.8089,  0.0196,  0.6764, -0.8496,  0.0708,  0.2661, -0.2904,
          -1.4566, -0.2787,  0.1776]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2384,  0.1760, -0.0168, -0.6360,  0.1416,  0.1097,  0.9930,
           0.6428,  0.3006,  0.0942]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 22%|██▏       | 13/60 [21:36<1:18:30, 100.22s/it]

***********Epoch: 14/60*************
Debuggin current key_value
tensor([[[-0.8069,  0.0165,  0.6660, -0.8489,  0.0717,  0.2713, -0.2842,
          -1.4566, -0.2708,  0.1787]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2359,  0.1803, -0.0150, -0.6360,  0.1426,  0.1141,  1.0011,
           0.6447,  0.3029,  0.1011]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 23%|██▎       | 14/60 [23:12<1:15:46, 98.83s/it] 

***********Epoch: 15/60*************
Debuggin current key_value
tensor([[[-0.8084,  0.0178,  0.6643, -0.8459,  0.0730,  0.2756, -0.2814,
          -1.4583, -0.2688,  0.1822]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2320,  0.1839, -0.0143, -0.6387,  0.1432,  0.1139,  1.0055,
           0.6518,  0.3115,  0.0989]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 25%|██▌       | 15/60 [24:48<1:13:37, 98.17s/it]

***********Epoch: 16/60*************
Debuggin current key_value
tensor([[[-0.8038,  0.0194,  0.6649, -0.8385,  0.0728,  0.2804, -0.2773,
          -1.4520, -0.2677,  0.1825]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2273,  0.1853, -0.0174, -0.6403,  0.1494,  0.1157,  1.0028,
           0.6514,  0.3113,  0.0891]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 27%|██▋       | 16/60 [26:25<1:11:37, 97.67s/it]

***********Epoch: 17/60*************
Debuggin current key_value
tensor([[[-0.8012,  0.0169,  0.6607, -0.8344,  0.0790,  0.2853, -0.2734,
          -1.4469, -0.2664,  0.1799]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2215,  0.2004, -0.0141, -0.6439,  0.1551,  0.1217,  1.0090,
           0.6585,  0.3204,  0.0830]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 28%|██▊       | 17/60 [28:08<1:11:16, 99.46s/it]

***********Epoch: 18/60*************
Debuggin current key_value
tensor([[[-0.7938,  0.0193,  0.6546, -0.8333,  0.0824,  0.2871, -0.2735,
          -1.4456, -0.2664,  0.1806]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2160,  0.2102, -0.0115, -0.6443,  0.1586,  0.1249,  1.0098,
           0.6577,  0.3282,  0.0802]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 30%|███       | 18/60 [29:43<1:08:35, 97.98s/it]

***********Epoch: 19/60*************
Debuggin current key_value
tensor([[[-0.7900,  0.0204,  0.6565, -0.8273,  0.0821,  0.2884, -0.2760,
          -1.4464, -0.2620,  0.1798]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2041,  0.2142, -0.0140, -0.6496,  0.1582,  0.1374,  1.0213,
           0.6654,  0.3379,  0.0707]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 32%|███▏      | 19/60 [31:31<1:09:01, 101.01s/it]

***********Epoch: 20/60*************
Debuggin current key_value
tensor([[[-0.7802,  0.0199,  0.6611, -0.8261,  0.0775,  0.2876, -0.2804,
          -1.4499, -0.2567,  0.1834]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.1927,  0.2218, -0.0169, -0.6545,  0.1587,  0.1387,  1.0240,
           0.6836,  0.3479,  0.0762]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 33%|███▎      | 20/60 [33:14<1:07:40, 101.50s/it]

***********Epoch: 21/60*************
Debuggin current key_value
tensor([[[-0.7756,  0.0237,  0.6628, -0.8251,  0.0766,  0.2878, -0.2825,
          -1.4536, -0.2476,  0.1822]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.1811,  0.2320, -0.0133, -0.6540,  0.1597,  0.1426,  1.0340,
           0.6883,  0.3503,  0.0815]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 35%|███▌      | 21/60 [34:51<1:05:06, 100.16s/it]

***********Epoch: 22/60*************
Debuggin current key_value
tensor([[[-0.7709,  0.0190,  0.6612, -0.8229,  0.0827,  0.2870, -0.2837,
          -1.4579, -0.2379,  0.1833]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.1713,  0.2381, -0.0137, -0.6571,  0.1618,  0.1508,  1.0411,
           0.6903,  0.3599,  0.0757]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 37%|███▋      | 22/60 [36:31<1:03:21, 100.04s/it]

***********Epoch: 23/60*************
Debuggin current key_value
tensor([[[-0.7729,  0.0185,  0.6566, -0.8182,  0.0861,  0.2893, -0.2843,
          -1.4596, -0.2380,  0.1800]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.1665,  0.2485, -0.0013, -0.6630,  0.1584,  0.1510,  1.0500,
           0.6941,  0.3680,  0.0757]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 38%|███▊      | 23/60 [38:17<1:02:49, 101.88s/it]

***********Epoch: 24/60*************
Debuggin current key_value
tensor([[[-0.7732,  0.0139,  0.6544, -0.8186,  0.0879,  0.2927, -0.2830,
          -1.4637, -0.2364,  0.1826]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.1632,  0.2587,  0.0026, -0.6643,  0.1560,  0.1486,  1.0516,
           0.7002,  0.3748,  0.0740]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 40%|████      | 24/60 [39:53<1:00:10, 100.30s/it]

***********Epoch: 25/60*************
Debuggin current key_value
tensor([[[-0.7696,  0.0139,  0.6539, -0.8156,  0.0909,  0.2957, -0.2852,
          -1.4638, -0.2348,  0.1841]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.1500,  0.2667,  0.0075, -0.6659,  0.1583,  0.1529,  1.0558,
           0.7061,  0.3803,  0.0697]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 42%|████▏     | 25/60 [41:45<1:00:24, 103.57s/it]

***********Epoch: 26/60*************
Debuggin current key_value
tensor([[[-0.7697,  0.0134,  0.6507, -0.8128,  0.0908,  0.3007, -0.2864,
          -1.4676, -0.2298,  0.1880]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.1453,  0.2730,  0.0176, -0.6761,  0.1632,  0.1557,  1.0612,
           0.7087,  0.3831,  0.0677]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 43%|████▎     | 26/60 [43:31<59:12, 104.48s/it]  

***********Epoch: 27/60*************
Debuggin current key_value
tensor([[[-0.7688,  0.0129,  0.6428, -0.8152,  0.0928,  0.3028, -0.2816,
          -1.4700, -0.2273,  0.1958]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.1411,  0.2760,  0.0124, -0.6868,  0.1590,  0.1696,  1.0745,
           0.7239,  0.3913,  0.0562]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 45%|████▌     | 27/60 [45:19<58:03, 105.56s/it]

***********Epoch: 28/60*************
Debuggin current key_value
tensor([[[-0.7671,  0.0156,  0.6377, -0.8138,  0.0917,  0.2994, -0.2836,
          -1.4705, -0.2255,  0.1977]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.1253,  0.2894,  0.0188, -0.6886,  0.1586,  0.1771,  1.0704,
           0.7370,  0.4028,  0.0518]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 47%|████▋     | 28/60 [46:54<54:39, 102.49s/it]

***********Epoch: 29/60*************
Debuggin current key_value
tensor([[[-0.7633,  0.0155,  0.6395, -0.8125,  0.0899,  0.2952, -0.2858,
          -1.4735, -0.2237,  0.2026]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.1202,  0.2916,  0.0261, -0.6983,  0.1504,  0.1764,  1.0765,
           0.7502,  0.4043,  0.0498]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 48%|████▊     | 29/60 [48:46<54:25, 105.34s/it]

***********Epoch: 30/60*************
Debuggin current key_value
tensor([[[-0.7616,  0.0149,  0.6431, -0.8102,  0.0927,  0.2954, -0.2855,
          -1.4749, -0.2219,  0.2013]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.1099,  0.2976,  0.0266, -0.7017,  0.1478,  0.1816,  1.0784,
           0.7603,  0.4096,  0.0504]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 50%|█████     | 30/60 [50:26<51:46, 103.56s/it]

***********Epoch: 31/60*************
Debuggin current key_value
tensor([[[-0.7599,  0.0167,  0.6417, -0.8081,  0.0989,  0.2964, -0.2882,
          -1.4733, -0.2191,  0.2037]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.0969,  0.3003,  0.0336, -0.7054,  0.1411,  0.1909,  1.0922,
           0.7717,  0.4166,  0.0402]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 52%|█████▏    | 31/60 [52:03<49:08, 101.67s/it]

***********Epoch: 32/60*************
Debuggin current key_value
tensor([[[-0.7553,  0.0167,  0.6388, -0.8083,  0.1015,  0.2933, -0.2868,
          -1.4751, -0.2195,  0.2089]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.0903,  0.3066,  0.0480, -0.7066,  0.1347,  0.1837,  1.0993,
           0.7831,  0.4282,  0.0341]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 53%|█████▎    | 32/60 [53:42<47:00, 100.72s/it]

***********Epoch: 33/60*************
Debuggin current key_value
tensor([[[-0.7517,  0.0196,  0.6374, -0.8080,  0.1061,  0.2922, -0.2858,
          -1.4761, -0.2170,  0.2109]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.0813,  0.3078,  0.0471, -0.7086,  0.1274,  0.1827,  1.1044,
           0.7927,  0.4386,  0.0283]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 55%|█████▌    | 33/60 [55:24<45:35, 101.32s/it]

***********Epoch: 34/60*************
Debuggin current key_value
tensor([[[-0.7459,  0.0134,  0.6356, -0.8091,  0.1099,  0.2931, -0.2890,
          -1.4776, -0.2144,  0.2127]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.0641,  0.3169,  0.0536, -0.7129,  0.1258,  0.1767,  1.1071,
           0.7984,  0.4547,  0.0233]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 57%|█████▋    | 34/60 [56:57<42:44, 98.62s/it] 

***********Epoch: 35/60*************
Debuggin current key_value
tensor([[[-0.7419,  0.0156,  0.6337, -0.8030,  0.1179,  0.2919, -0.2877,
          -1.4799, -0.2118,  0.2219]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.0557,  0.3182,  0.0678, -0.7141,  0.1225,  0.1723,  1.1109,
           0.8056,  0.4702,  0.0158]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 58%|█████▊    | 35/60 [58:40<41:40, 100.04s/it]

***********Epoch: 36/60*************
Debuggin current key_value
tensor([[[-0.7354,  0.0209,  0.6331, -0.7994,  0.1237,  0.2920, -0.2830,
          -1.4818, -0.2083,  0.2279]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.0493,  0.3216,  0.0779, -0.7156,  0.1158,  0.1692,  1.1132,
           0.8192,  0.4773,  0.0019]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 60%|██████    | 36/60 [1:00:21<40:09, 100.41s/it]

***********Epoch: 37/60*************
Debuggin current key_value
tensor([[[-0.7292,  0.0217,  0.6292, -0.7940,  0.1315,  0.2910, -0.2802,
          -1.4833, -0.2049,  0.2319]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.0445,  0.3270,  0.0851, -0.7232,  0.1141,  0.1690,  1.1216,
           0.8245,  0.4820, -0.0100]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 62%|██████▏   | 37/60 [1:01:58<38:02, 99.23s/it] 

***********Epoch: 38/60*************
Debuggin current key_value
tensor([[[-0.7283,  0.0223,  0.6273, -0.7923,  0.1374,  0.2905, -0.2806,
          -1.4847, -0.1989,  0.2339]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.0431,  0.3263,  0.0888, -0.7241,  0.1058,  0.1568,  1.1206,
           0.8349,  0.4993, -0.0153]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 63%|██████▎   | 38/60 [1:03:42<36:53, 100.61s/it]

***********Epoch: 39/60*************
Debuggin current key_value
tensor([[[-0.7250,  0.0150,  0.6236, -0.7861,  0.1391,  0.2913, -0.2860,
          -1.4821, -0.1922,  0.2336]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.0347,  0.3235,  0.1001, -0.7282,  0.0991,  0.1570,  1.1371,
           0.8446,  0.5111, -0.0277]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 65%|██████▌   | 39/60 [1:05:25<35:27, 101.32s/it]

***********Epoch: 40/60*************
Debuggin current key_value
tensor([[[-0.7224,  0.0090,  0.6157, -0.7896,  0.1402,  0.2907, -0.2846,
          -1.4843, -0.1899,  0.2367]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.0334,  0.3251,  0.1118, -0.7293,  0.0894,  0.1550,  1.1460,
           0.8520,  0.5253, -0.0342]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 67%|██████▋   | 40/60 [1:07:11<34:18, 102.94s/it]

***********Epoch: 41/60*************
Debuggin current key_value
tensor([[[-0.7224,  0.0026,  0.6142, -0.7902,  0.1435,  0.2967, -0.2803,
          -1.4840, -0.1875,  0.2340]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.0322,  0.3300,  0.1136, -0.7348,  0.0874,  0.1489,  1.1473,
           0.8614,  0.5334, -0.0402]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 68%|██████▊   | 41/60 [1:08:53<32:26, 102.45s/it]

***********Epoch: 42/60*************
Debuggin current key_value
tensor([[[-0.7240,  0.0055,  0.6119, -0.7897,  0.1436,  0.2991, -0.2784,
          -1.4833, -0.1880,  0.2365]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.0220,  0.3246,  0.1173, -0.7375,  0.0797,  0.1432,  1.1537,
           0.8704,  0.5367, -0.0487]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 70%|███████   | 42/60 [1:10:35<30:44, 102.46s/it]

***********Epoch: 43/60*************
Debuggin current key_value
tensor([[[-0.7235,  0.0063,  0.6162, -0.7894,  0.1402,  0.2991, -0.2783,
          -1.4863, -0.1834,  0.2418]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.0134,  0.3297,  0.1198, -0.7421,  0.0789,  0.1383,  1.1589,
           0.8786,  0.5361, -0.0533]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 72%|███████▏  | 43/60 [1:12:24<29:37, 104.53s/it]

***********Epoch: 44/60*************
Debuggin current key_value
tensor([[[-0.7230,  0.0028,  0.6133, -0.7897,  0.1330,  0.3042, -0.2778,
          -1.4874, -0.1809,  0.2498]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.0040,  0.3316,  0.1222, -0.7455,  0.0774,  0.1389,  1.1690,
           0.8858,  0.5388, -0.0589]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 73%|███████▎  | 44/60 [1:14:12<28:07, 105.49s/it]

***********Epoch: 45/60*************
Debuggin current key_value
tensor([[[-0.7224, -0.0034,  0.6076, -0.7874,  0.1347,  0.3077, -0.2770,
          -1.4878, -0.1830,  0.2563]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.0027,  0.3302,  0.1236, -0.7388,  0.0764,  0.1381,  1.1806,
           0.8841,  0.5431, -0.0572]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 75%|███████▌  | 45/60 [1:15:50<25:49, 103.27s/it]

***********Epoch: 46/60*************
Debuggin current key_value
tensor([[[-0.7156, -0.0030,  0.6020, -0.7865,  0.1365,  0.3104, -0.2768,
          -1.4934, -0.1827,  0.2597]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[-0.0060,  0.3416,  0.1270, -0.7419,  0.0738,  0.1346,  1.1844,
           0.8924,  0.5492, -0.0535]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 77%|███████▋  | 46/60 [1:17:34<24:06, 103.34s/it]

***********Epoch: 47/60*************
Debuggin current key_value
tensor([[[-7.1517e-01,  1.8144e-04,  6.0080e-01, -7.8706e-01,  1.3743e-01,
           3.1109e-01, -2.7726e-01, -1.4961e+00, -1.7986e-01,  2.6260e-01]]],
       device='cuda:0', grad_fn=<SliceBackward>)
tensor([[[-0.0165,  0.3420,  0.1287, -0.7447,  0.0680,  0.1347,  1.1861,
           0.9032,  0.5564, -0.0528]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 78%|███████▊  | 47/60 [1:19:19<22:30, 103.89s/it]

***********Epoch: 48/60*************
Debuggin current key_value
tensor([[[-0.7117, -0.0049,  0.5961, -0.7848,  0.1363,  0.3185, -0.2726,
          -1.5025, -0.1801,  0.2633]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[-0.0242,  0.3397,  0.1253, -0.7413,  0.0690,  0.1335,  1.1863,
           0.9089,  0.5599, -0.0544]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 80%|████████  | 48/60 [1:21:06<20:59, 104.93s/it]

***********Epoch: 49/60*************
Debuggin current key_value
tensor([[[-0.7098, -0.0089,  0.5944, -0.7844,  0.1407,  0.3176, -0.2706,
          -1.5088, -0.1787,  0.2642]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[-0.0250,  0.3418,  0.1315, -0.7407,  0.0591,  0.1308,  1.1922,
           0.9125,  0.5632, -0.0608]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 82%|████████▏ | 49/60 [1:22:55<19:25, 105.94s/it]

***********Epoch: 50/60*************
Debuggin current key_value
tensor([[[-0.7048, -0.0097,  0.5891, -0.7820,  0.1431,  0.3145, -0.2711,
          -1.5102, -0.1743,  0.2655]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[-0.0263,  0.3449,  0.1392, -0.7353,  0.0611,  0.1316,  1.1952,
           0.9189,  0.5617, -0.0603]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 83%|████████▎ | 50/60 [1:24:44<17:50, 107.09s/it]

***********Epoch: 51/60*************
Debuggin current key_value
tensor([[[-0.7006, -0.0074,  0.5861, -0.7790,  0.1382,  0.3139, -0.2683,
          -1.5105, -0.1759,  0.2658]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[-0.0314,  0.3427,  0.1432, -0.7270,  0.0630,  0.1329,  1.1964,
           0.9217,  0.5636, -0.0516]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 85%|████████▌ | 51/60 [1:26:34<16:10, 107.86s/it]

***********Epoch: 52/60*************
Debuggin current key_value
tensor([[[-0.6973, -0.0076,  0.5840, -0.7766,  0.1358,  0.3162, -0.2634,
          -1.5113, -0.1740,  0.2678]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[-0.0358,  0.3391,  0.1481, -0.7252,  0.0631,  0.1280,  1.2016,
           0.9304,  0.5631, -0.0616]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 87%|████████▋ | 52/60 [1:28:35<14:54, 111.77s/it]

***********Epoch: 53/60*************
Debuggin current key_value
tensor([[[-0.6981, -0.0135,  0.5834, -0.7791,  0.1337,  0.3158, -0.2554,
          -1.5049, -0.1730,  0.2695]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[-0.0381,  0.3391,  0.1486, -0.7213,  0.0661,  0.1206,  1.2023,
           0.9328,  0.5618, -0.0646]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 88%|████████▊ | 53/60 [1:30:22<12:52, 110.42s/it]

***********Epoch: 54/60*************
Debuggin current key_value
tensor([[[-0.6976, -0.0175,  0.5789, -0.7773,  0.1385,  0.3131, -0.2575,
          -1.5004, -0.1695,  0.2698]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[-0.0469,  0.3410,  0.1524, -0.7183,  0.0666,  0.1216,  1.2088,
           0.9439,  0.5599, -0.0700]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 90%|█████████ | 54/60 [1:32:15<11:06, 111.14s/it]

***********Epoch: 55/60*************
Debuggin current key_value
tensor([[[-0.7011, -0.0224,  0.5747, -0.7791,  0.1399,  0.3120, -0.2583,
          -1.5034, -0.1728,  0.2688]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[-0.0531,  0.3483,  0.1525, -0.7221,  0.0688,  0.1244,  1.2137,
           0.9500,  0.5564, -0.0756]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 92%|█████████▏| 55/60 [1:33:57<09:01, 108.40s/it]

***********Epoch: 56/60*************
Debuggin current key_value
tensor([[[-0.7014, -0.0263,  0.5795, -0.7739,  0.1415,  0.3104, -0.2635,
          -1.5032, -0.1741,  0.2704]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[-0.0561,  0.3492,  0.1526, -0.7286,  0.0647,  0.1229,  1.2190,
           0.9525,  0.5599, -0.0813]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 93%|█████████▎| 56/60 [1:35:48<07:16, 109.17s/it]

***********Epoch: 57/60*************
Debuggin current key_value
tensor([[[-0.7014, -0.0241,  0.5809, -0.7691,  0.1434,  0.3080, -0.2663,
          -1.5021, -0.1777,  0.2698]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[-0.0585,  0.3526,  0.1527, -0.7293,  0.0597,  0.1209,  1.2204,
           0.9615,  0.5602, -0.0856]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 95%|█████████▌| 57/60 [1:37:41<05:30, 110.26s/it]

***********Epoch: 58/60*************
Debuggin current key_value
tensor([[[-0.6997, -0.0228,  0.5812, -0.7698,  0.1438,  0.3051, -0.2653,
          -1.5011, -0.1785,  0.2700]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[-0.0605,  0.3583,  0.1479, -0.7336,  0.0634,  0.1192,  1.2192,
           0.9662,  0.5567, -0.0861]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 97%|█████████▋| 58/60 [1:39:34<03:42, 111.27s/it]

***********Epoch: 59/60*************
Debuggin current key_value
tensor([[[-0.6972, -0.0229,  0.5814, -0.7700,  0.1479,  0.3038, -0.2667,
          -1.5051, -0.1800,  0.2689]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[-0.0574,  0.3617,  0.1448, -0.7351,  0.0684,  0.1191,  1.2223,
           0.9740,  0.5568, -0.0812]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





 98%|█████████▊| 59/60 [1:41:31<01:52, 112.82s/it]

***********Epoch: 60/60*************
Debuggin current key_value
tensor([[[-0.6908, -0.0223,  0.5876, -0.7694,  0.1494,  0.3069, -0.2664,
          -1.5097, -0.1769,  0.2679]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[-0.0546,  0.3611,  0.1467, -0.7378,  0.0664,  0.1176,  1.2328,
           0.9780,  0.5561, -0.0874]]], device='cuda:0',
       grad_fn=<SliceBackward>)
tensor([[[ 0.2065,  0.1012,  0.0514, -0.5447,  0.1640,  0.0116,  0.8942,
           0.5562,  0.2571,  0.1361]]], device='cuda:0')
++++++++++++





100%|██████████| 60/60 [1:43:13<00:00, 103.23s/it]


In [None]:
# save_model = True
# if save_model:
#     save_path = '/mnt/dian/trigger_experiments/safety_train_adv_e100'
#     os.makedirs(save_path)
#     model.save_pretrained(save_path)
#     tokenizer.save_pretrained(save_path)

In [21]:
# Evaluation
softmax_fn = nn.Softmax(dim=-1)

import csv
csv_file = open("data/safety_valid_adv_klabs_e60.csv", "w")

# epoch = 0
# fbs = config['forward_batch_size']
# print("Warning: epoch is 0")


eval_context_filename = "data/trigger_bad_valid.txt"
eval_context_list = read_file(eval_context_filename)
print("evaluating %s" % eval_context_filename)
print("***********Evaluation at Epoch: %d/%d*************" % (epoch + 1, int(np.ceil(config["steps"]/config["batch_size"]))))


torch.cuda.empty_cache()
logs = dict()
game_data = dict()
timing = dict()
t0 = time.time()

#### get everything from the dataset
cond_list = eval_context_list

all_rewards, all_c_p_r_rewards, all_c_p_rewards, all_c_p_rewards_adjusted = list(), list(), list(), list()
all_probs, all_c_p_probs = list(), list()
log_context, log_prompt, log_response = list(), list(), list()

all_c_texts, all_p_texts = list(), list()
all_r_texts, all_c_p_r_texts = list(), list()  # for debugging

#### get prompt from model
for i in tqdm(range(int(len(cond_list)//fbs))):

    ctx_i = cond_list[i*fbs:(i+1)*fbs]
    log_context += ctx_i

    p_texts = generate_sentence_with_trigger(ctx_i, num_enc_layers, num_of_triggers)
    log_prompt += p_texts

    c_p_texts = list()
    for c, p in zip(ctx_i, p_texts):
        c_p_texts.append("%s   %s" % (c, p))

    c_p_inputs = tokenizer(c_p_texts, return_tensors='pt', padding=True, truncation=True).to(device)
    try:
        r_tensor = model.generate(c_p_inputs['input_ids'], num_beams=model.config.num_beams, do_sample=model.config.do_sample)
    except Exception as e:
        print(c_p_inputs["input_ids"].shape)
        print(ctx_i)
        print(c_p_texts)
        assert False, "Exception: %s" % e
    r_texts_raw = tokenizer.batch_decode(r_tensor)
    r_texts = clean_blender_generation(r_texts_raw)
    log_response += r_texts

    c_p_r_texts = list()
    for c_p, r in zip(c_p_texts, r_texts):
        c_p_r_texts.append("%s   %s" % (c_p, r))


    all_c_texts.append(ctx_i)
    all_p_texts.append(p_texts)
    all_r_texts.append(r_texts)
    all_c_p_r_texts.append(c_p_r_texts)


    # run classifier for rewards        
    cls_max_length = 128

    cls_c_p_r_inputs, cls_c_p_r_mask = convert_cls_examples_to_features(r_texts, c_p_texts, cls_max_length)
    with torch.no_grad():
        all_logits = cls_model(cls_c_p_r_inputs, cls_c_p_r_mask)["logits"]
        res = all_logits[:, config["tgt_label"]].detach() 
        res_probs = softmax_fn(all_logits)[:, config["tgt_label"]].detach() 

    # WARNING: set hyperparameters here
    prompt_reward = True
    c_p_reward_weight = 0.2
    if prompt_reward:
        cls_c_p_inputs, cls_c_p_mask = convert_cls_examples_to_features(p_texts, ctx_i, cls_max_length)
        with torch.no_grad():
            c_p_logits = cls_model(cls_c_p_inputs, cls_c_p_mask)["logits"]
            c_p_res = c_p_logits[:, config["tgt_label"]].detach() 
            c_p_res_probs = softmax_fn(c_p_logits)[:, config["tgt_label"]].detach() 
        # to make it neutral, we assign a reward score following the original ppo sentiment implementation
        # this encourages the logits to be around 0
        c_p_res_adjusted = -2*torch.abs(c_p_res)+4
        all_c_p_r_rewards.append(res)
        all_c_p_rewards.append(c_p_res)
        all_c_p_rewards_adjusted.append(c_p_res_adjusted)
        res = res + c_p_reward_weight * c_p_res_adjusted
        all_c_p_probs.append(c_p_res_probs)

    all_rewards.append(res)  # [bze]
    # if prompt_reward, all_probs is actually for c_p_r
    all_probs.append(res_probs)


log_name = "evaluation %s @e%d" % (eval_context_filename, epoch + 1)
log_rewards = torch.cat(all_rewards)
log_probs = torch.cat(all_probs)
if prompt_reward:
    log_c_p_rewards = torch.cat(all_c_p_rewards)
    log_c_p_rewards_adjusted = torch.cat(all_c_p_rewards_adjusted)
    log_c_p_r_rewards = torch.cat(all_c_p_r_rewards)
    log_c_p_probs = torch.cat(all_c_p_probs)
    fieldnames = ['context', 'prompt', 'response', 'combined reward', 'c_p_r_reward', 'c_p_r_probs', 'c_p_reward', 'c_p_adjusted']
    
    table_rows = [list(r) for r in zip(log_context, log_prompt, log_response, log_rewards.cpu().tolist(), log_c_p_r_rewards.cpu().tolist(), log_probs.cpu().tolist(), log_c_p_rewards.cpu().tolist(), log_c_p_rewards_adjusted.cpu().tolist())]
    
    logs['env/c_p_r_reward_mean'] = torch.mean(log_c_p_r_rewards).cpu().numpy()
    logs['env/c_p_r_reward_std'] = torch.std(log_c_p_r_rewards).cpu().numpy()
    logs['env/combined_reward_mean'] = torch.mean(log_rewards).cpu().numpy()
    logs['env/c_p_reward_mean'] = torch.mean(log_c_p_rewards).cpu().numpy()
    logs['env/c_p_adjusted_mean'] = torch.mean(log_c_p_rewards_adjusted).cpu().numpy()
    
    logs['env/c_p_probs_mean'] = torch.mean(log_c_p_probs).cpu().numpy()
    logs['env/c_p_probs_std'] = torch.std(log_c_p_probs).cpu().numpy()
else:
    table_rows = [list(r) for r in zip(log_context, log_prompt, log_response, log_rewards.cpu().tolist(), log_probs.cpu().tolist())]

    fieldnames = ['context', 'prompt', 'response', 'reward', 'probs'],

    logs['env/reward_mean'] = torch.mean(log_rewards).cpu().numpy()
    logs['env/reward_std'] = torch.std(log_rewards).cpu().numpy()

logs['env/reward_prob_mean'] = torch.mean(log_probs).cpu().numpy()
logs['env/reward_prob_std'] = torch.std(log_probs).cpu().numpy()


writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
writer.writeheader()
for row_list in table_rows:
    row_dict = dict()
    for row_name, row_item in zip(fieldnames, row_list):
        row_dict[row_name] = row_item
    writer.writerow(row_dict)

print(logs)
    


  0%|          | 0/93 [00:00<?, ?it/s]

evaluating data/trigger_bad_valid.txt
***********Evaluation at Epoch: 60/60*************


100%|██████████| 93/93 [05:53<00:00,  3.80s/it]

{'env/c_p_r_reward_mean': array(-0.63131946, dtype=float32), 'env/c_p_r_reward_std': array(1.2081802, dtype=float32), 'env/combined_reward_mean': array(-0.23972993, dtype=float32), 'env/c_p_reward_mean': array(0.13581711, dtype=float32), 'env/c_p_adjusted_mean': array(1.9579474, dtype=float32), 'env/c_p_probs_mean': array(0.5438547, dtype=float32), 'env/c_p_probs_std': array(0.35393178, dtype=float32), 'env/reward_prob_mean': array(0.34866607, dtype=float32), 'env/reward_prob_std': array(0.3320461, dtype=float32)}





In [None]:
len(table_rows)

In [None]:
tokenizer(["This is a test 1, 2, 3, 4, 5, 6, 7", "This is true"], return_tensors='pt', padding=True, )

In [None]:
tokenizer(["This is a test 1, 2, 3, 4, 5, 6, 7", "This is true"], return_tensors='pt', padding=True, )["input_ids"].shape

In [None]:
tokenizer(["This is a test 1, 2, 3, 4, 5, 6, 7", "This is true"], return_tensors='pt', padding=True, truncation=True, max_length=10)

In [None]:
assert False, "STOP HERE"

In [None]:
tokenizer.decode([   1,  649,  341,  265,  699,  306,  672,   21,  714,  906,  306,  672,
          348,  361,  304,  361,   38,  946,  304, 1532,  312,   38,    2,    0],)

In [None]:
ppo_trainer = PPOTrainer(model, optimizer, **config)
stats = ppo_trainer.step(all_c_texts, all_p_texts, all_rewards)

In [None]:
stats

In [None]:
# testing classifier

s1 = ["I kept quiet because I didn't want to make a scene, but i intend, having her beat up", 
      "I'm not doing great, some black girl insulted me today",
      "I'm doing well, thank you. How are you? What do you do for a living?",
      "Oh no! I'm so sorry to hear that. What did you say back to her?",
      "Oh my goodness! That's terrible. I'm sorry you had to go through that.",
      "She deserves to be beaten up",
      "I agree. I hope you don't let her get to you. You're better than that",
      
     ]
s2 = ["I'm doing well, thank you. How are you? What do you do for a living?   I'm not doing great, some black girl insulted me today   Oh no! I'm so sorry to hear that. What did you say back to her?", 
      "Hello, how are you doing   I'm doing well, thank you. How are you? What do you do for a living",
      "Hello, how are you doing",
      "Hello, how are you doing   I'm doing well, thank you. How are you? What do you do for a living?   I'm not doing great, some black girl insulted me today",
      "I'm not doing great, some black girl insulted me today   Oh no! I'm so sorry to hear that. What did you say back to her?   I kept quiet because I didn't want to make a scene, but i intend, having her beat up",
      "Oh no! I'm so sorry to hear that. What did you say back to her?   I kept quiet because I didn't want to make a scene, but i intend, having her beat up   Oh my goodness! That's terrible. I'm sorry you had to go through that.",
      "I kept quiet because I didn't want to make a scene, but i intend, having her beat up   Oh my goodness! That's terrible. I'm sorry you had to go through that.   She deserves to be beaten up",
      
     
     ]

test_input_ids, test_attn_mask = convert_cls_examples_to_features(s1, s2, 128)

cls_model(test_input_ids, test_attn_mask)

