In [1]:
import copy
import random
import math

import wandb
import time
import os
import numpy as np
from random import choices
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration, BlenderbotConfig
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers.generation_beam_search import BeamSearchScorer

from ppo_model_ac import BlenderWithValueModel

from ppo import AdaptiveKLController, FixedKLController
from ppo_utils import build_bert_batch_from_txt, logprobs_from_logits, whiten, clip_by_value, entropy_from_logits, flatten_dict, stats_to_np, stack_dicts
from utils import get_classifier, generate_next, concat_past, expand_past, read_file
from trigger_semi_supervised import penalize_new_line, prep_inputs


In [2]:
mname = 'facebook/blenderbot-400M-distill'
model = BlenderbotForConditionalGeneration.from_pretrained(mname)
tokenizer = BlenderbotTokenizer.from_pretrained(mname)
model.to("cuda")
model.eval()

BlenderbotForConditionalGeneration(
  (model): BlenderbotModel(
    (shared): Embedding(8008, 1280, padding_idx=0)
    (encoder): BlenderbotEncoder(
      (embed_tokens): Embedding(8008, 1280, padding_idx=0)
      (embed_positions): BlenderbotLearnedPositionalEmbedding(128, 1280, padding_idx=0)
      (layers): ModuleList(
        (0): BlenderbotEncoderLayer(
          (self_attn): BlenderbotAttention(
            (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1280, out_features=5120, bias=True)
          (fc2): Linear(in_features=5120, out_features=1280, bias=True)
          (final_layer_norm): LayerNorm(

In [3]:
# sampling
model.config.do_sample = True
model.config.num_beams = 1

In [4]:
device = "cuda"

# for blender:
num_beam_groups = 1
do_sample = False

pad_token_id = 0
bos_token_id = 1
eos_token_id = 2

tgt_label = 0,  # 0 for not_ok, 1 for ok

batch_size = 16  # should be the same as forward_batch_size
num_of_triggers = 1
trigger_format = "key_value"
reset_pos_emb = True
TRIGGER_POSITION_ID = 0

adam_epsilon = 1e-8
learning_rate = 2e-4

model.model.encoder.reset_pos_emb = reset_pos_emb
model.model.encoder.num_of_triggers = num_of_triggers

# WARNING: need to change 
sample = True
top_k = 10
temperature = 1.0
repetition_penalty = 1.0
length = 40
gradient_accumulation_steps = 1
max_grad_norm = 1.0

seed = 2

if num_of_triggers > 1:
    assert False, "currently not supported! This is hard coded in BlenderbotEncoder for now!"
if not reset_pos_emb:
    assert False, "currently not supported! This is hard coded in BlenderbotEncoder for now!"


In [5]:
save_path = "/mnt//trigger_experiments/semi_safety_2_ft"
exp_name = "semi_safety"
proj_name = "final_2_ft"
cls_model_name = "/mnt//trigger_experiments/roberta_bbf_bad_ctx"
discrim_name = "safety"
cls_max_length = 128
# WARNING: should change input
training_data = "data/trigger_bad_train.txt"

total_steps = 15360 # changed from 30720 for finetune
epoch_batch_size = 256  # changed from 512 for finetune

finetune_dev = True
if finetune_dev:
    training_data = "data/trigger_decode_human-bot.txt"
    finetune_init_ckeckpoint = "/mnt//trigger_experiments/semi_safety_2/e36.pt"

context_list = read_file(training_data)


prompt_reward = False
c_p_reward_weight = 0.2

use_wandb = True

mode = "train"
shuffle_data = True

print("WARNING: Training mode with shuffle_data = %s" % shuffle_data)




In [6]:
classifier, _ = get_classifier(discrim_name, class_label=0, device=device)

ce_loss = nn.CrossEntropyLoss()


In [7]:
if use_wandb:
    wandb.init(name=exp_name, project=proj_name)


[34m[1mwandb[0m: Currently logged in as: [33m[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.30 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [8]:
# loading pretrained model for classification
cls_model = AutoModelForSequenceClassification.from_pretrained(cls_model_name)
cls_tokenizer = AutoTokenizer.from_pretrained(cls_model_name)
cls_model.to(device)
cls_model.eval()

RobertaForSequenceClassification(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerN

In [9]:
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)


# Freeze GPT-2 weights
for name, param in model.named_parameters():
    param.requires_grad = False

num_enc_layers = model.config.encoder_layers
num_dec_layers = model.config.decoder_layers
        
    
# lm_bos_output = model(torch.tensor(tokenizer.encode(tokenizer.bos_token), dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1))  # BOS
# # Note: GPT2HeadWithValueModel returns lm_logits, transformer_outputs[1:], value
# # transformer_outputs: hidden_states, past_key_values



In [10]:
# prepare for triggers

# get_bos_embeddings
bos_embeddings = model.model.encoder.embed_tokens(torch.tensor([bos_token_id], dtype=torch.long, device=device)).unsqueeze(0)  # 1, 1, hid_size

# get_bos_key_values
text_bos = ["<s>"]
inputs_bos = tokenizer(text_bos, return_tensors='pt', padding=True).to("cuda")
inputs_bos_ids = inputs_bos["input_ids"][:, 1:2]  # tensor([[228,   1,   2]]) for [<s>] (shape: 1, 3)
bos_model_kwargs = dict()
if bos_model_kwargs.get("attention_mask", None) is None:
    # init `attention_mask` depending on `pad_token_id`
    bos_model_kwargs["attention_mask"] = model._prepare_attention_mask_for_generation(
        inputs_bos_ids, pad_token_id, eos_token_id
    )

bos_encoder_kwargs = {
            argument: value for argument, value in bos_model_kwargs.items() if not argument.startswith("decoder_")
        }
bos_output = model.model.encoder(inputs_bos_ids, return_dict=True, **bos_encoder_kwargs, use_cache=True)
bos_key_values = bos_output["past_key_values"]
bos_hidden = bos_output["last_hidden_state"]  # 1, 1, 1280
print(bos_hidden.shape)
print(bos_key_values[0][0].shape)

torch.Size([1, 1, 1280])
torch.Size([1, 32, 1, 40])


In [11]:
# initialize trigger
# Note: since we use the same trigger for all inputs in a batch, we only create/register trigger(s) for one and repeat it
def init_trigger(model, tokenizer, num_of_triggers, trigger_format, ref=False):
    if num_of_triggers > 0:
        
        # create hidden states for decoder
        trigger_hidden_list = []
        for _ in range(num_of_triggers):
            trigger_hidden_i = nn.Parameter(copy.deepcopy(bos_hidden))
            trigger_hidden_list.append(trigger_hidden_i)
        if not ref:
            ori_trigger_hidden = nn.Parameter(torch.cat(trigger_hidden_list, dim=1))  # 1 x n x hid
            # WARNING: no need to register parameter?
            model.register_parameter(name="ori_trigger_hidden", param=ori_trigger_hidden)
            model.ori_trigger_hidden = ori_trigger_hidden
        else:
            ref_ori_trigger_hidden = nn.Parameter(torch.cat(trigger_hidden_list, dim=1))  # 1 x n x hid
            ref_ori_trigger_hidden.requires_grad = False
            model.register_parameter(name="ref_ori_trigger_hidden", param=ref_ori_trigger_hidden)
            model.ref_ori_trigger_hidden = ref_ori_trigger_hidden
            
        if trigger_format == "token":  # learn a continuous embedding
            trigger_embedding_list = []
            for _ in range(num_of_triggers):
                trigger_embedding_i = copy.deepcopy(bos_embeddings)
                trigger_embedding_list.append(trigger_embedding_i)
            if not ref:
                ori_trigger_embedding = nn.Parameter(torch.cat(trigger_embedding_list, dim=1))  # bze x n x emb_size
                model.ori_trigger_embedding = ori_trigger_embedding  # register to the model (optimizer)
            else:
                ref_ori_trigger_embedding = nn.Parameter(torch.cat(trigger_embedding_list, dim=1))  # bze x n x emb_size
                ref_ori_trigger_embedding.requires_grad = False
                model.ref_ori_trigger_embedding = ref_ori_trigger_embedding  # register to the model (optimizer)
            # trigger_embedding = trigger_embedding.repeat(batch_size, 1, 1)  # cannot do it here, otherwise trigger_embedding becomes a non-leaf node where the grad will not backprop
        elif trigger_format == "key_value":  # learn key values
            ori_trigger_key_values = [(None, None) for _ in range(num_enc_layers)]
            for layer in range(num_enc_layers):
                for i_t in range(num_of_triggers):
                    trigger_i_key_value = copy.deepcopy(bos_key_values)
                    # key, value shape: bze, num_heads, seq_len, embed_per_head
                    trigger_i_key, trigger_i_value = nn.Parameter(trigger_i_key_value[layer][0]), \
                                                     nn.Parameter(trigger_i_key_value[layer][1])

                    if not ref:
                        trigger_i_key.requires_grad = True
                        trigger_i_value.requires_grad = True
                    else:
                        trigger_i_key.requires_grad = False
                        trigger_i_value.requires_grad = False
                        
                    if ori_trigger_key_values[layer][0] is None:
                        ori_trigger_key_values[layer] = (trigger_i_key, trigger_i_value)
                    else:
                        # if multiple triggers
                        trigger_key = nn.Parameter(torch.cat((ori_trigger_key_values[layer][0], trigger_i_key), dim=-2))
                        trigger_value = nn.Parameter(torch.cat((ori_trigger_key_values[layer][1], trigger_i_value), dim=-2))
                        ori_trigger_key_values[layer] = (trigger_key, trigger_value)

                if not ref:
                    # register parameter into optimizer
                    key_name = "l_%d_key" % layer
                    value_name = "l_%d_value" % layer
                else:
                    key_name = "ref_l_%d_key" % layer
                    value_name = "ref_l_%d_value" % layer
                    
                if num_of_triggers == 1:
                    model.register_parameter(name=key_name, param=trigger_i_key)
                    model.register_parameter(name=value_name, param=trigger_i_value)
                else:
                    model.register_parameter(name=key_name, param=trigger_key)
                    model.register_parameter(name=value_name, param=trigger_value)
                    
            if not ref:
                ori_trigger_key_values = tuple(ori_trigger_key_values)
                model.ori_trigger_key_values = ori_trigger_key_values
            else:
                ref_ori_trigger_key_values = tuple(ori_trigger_key_values)
                model.ref_ori_trigger_key_values = ori_trigger_key_values
            # trigger_key_values = expand_past(trigger_key_values, num_layers, batch_size)  # similar to trigger_embedding, need leaf level grad
        else:
            assert False, "trigger_format: %s not supported" % trigger_format

In [12]:
init_trigger(model, tokenizer, num_of_triggers, trigger_format)

# optimizer
param_optimizer = list(filter(lambda p: p[1].requires_grad, list(model.named_parameters())))

# debugging: get all optimized param names
print("optimizing params: ")
print(" ".join(o[0] for o in param_optimizer))

no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {
        'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay': 0.0,
    },
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters,
                  lr=learning_rate,
                  eps=adam_epsilon)

optimizing params: 
ori_trigger_hidden l_0_key l_0_value l_1_key l_1_value


In [13]:
if finetune_dev:
    ft_saved_dict = torch.load(finetune_init_ckeckpoint)
    model.ori_trigger_hidden = ft_saved_dict["ori_trigger_hidden"]
    model.ori_trigger_key_values = ft_saved_dict["ori_trigger_key_values"]
    print("WARNING: fine-tuning from %s" % finetune_init_ckeckpoint)
    print("total training steps: %d" % total_steps)

total training steps: 15360


In [14]:
if use_wandb:
    wandb.watch(model, log='all')

In [15]:
# get probability distribution warper
logits_warper = model._get_logits_warper(
    top_k=model.config.top_k, top_p=model.config.top_p, temperature=model.config.temperature, num_beams=model.config.num_beams
)

# WARNING: use hyperparameters from the config instead of the following
logits_processor = model._get_logits_processor(
    repetition_penalty=model.config.repetition_penalty,
    no_repeat_ngram_size=model.config.no_repeat_ngram_size,
    bad_words_ids=None,
    min_length=model.config.min_length,
    eos_token_id=eos_token_id,
    prefix_allowed_tokens_fn=None,
    num_beams=model.config.num_beams,
    num_beam_groups=model.config.num_beam_groups,
    diversity_penalty=model.config.diversity_penalty,
)


if model.config.num_beams > 1:
    beam_scorer = BeamSearchScorer(
            batch_size=config["forward_batch_size"],
            max_length=model.config.max_length,
            num_beams=model.config.num_beams,
            device=device,
            length_penalty=model.config.length_penalty,
            do_early_stopping=model.config.early_stopping,
            num_beam_hyps_to_keep=1,
        )

In [16]:
def generate_sentence_with_trigger(text_list, num_layers, cur_num_of_triggers, is_response=False, use_gumbel=False, get_ppl=False):
    # cur_num_of_triggers: different from "num_of_triggers" in the config, can be 0 if is ref or num_of_triggers
    batch_size = len(text_list)
    
    # prepare past
    past = expand_past(bos_key_values, num_layers, batch_size)
    if cur_num_of_triggers > 0:
        if trigger_format == "token":
            trigger_embedding = model.ori_trigger_embedding.repeat(batch_size, 1, 1)
            lm_trigger_output = model.model.encoder(inputs_embeds=trigger_embedding)
            trigger_key_values = lm_trigger_output["past_key_values"]
        else:
            trigger_key_values = expand_past(model.ori_trigger_key_values, num_layers, batch_size)
        past = concat_past(past, trigger_key_values, num_layers)
        
    # prepare hidden
    prev_hidden = bos_hidden.repeat(batch_size, 1, 1)
    if cur_num_of_triggers > 0:
        trigger_hidden = model.ori_trigger_hidden
        trigger_hidden = trigger_hidden.repeat(batch_size, 1, 1)
        prev_hidden = torch.cat((prev_hidden, trigger_hidden), dim=1)  # bze, seq_len, hid
    
    # prepare context
    prev_length = prev_hidden.shape[1]
    ctx_model_kwargs = dict()
    ctx_inputs = tokenizer(text_list, return_tensors='pt', padding=True, truncation=True, max_length=126).to("cuda")
    # because of the past, now key length ("tgt" as defined in blenderbot) is larger than query length ("tgt" as defined)
    cat_attn_mask = torch.cat((torch.ones(ctx_inputs["attention_mask"].shape[0], prev_length, device="cuda", dtype=torch.long), ctx_inputs["attention_mask"]), dim=-1)
    ctx_model_kwargs["attention_mask"] = cat_attn_mask
    
    # get encoder output
    trigger_encoder_kwargs = {
            argument: value for argument, value in ctx_model_kwargs.items() if not argument.startswith("decoder_")
        }
    trigger_encoder_kwargs["past_key_values"] = past
    try:
        ctx_output = model.model.encoder(ctx_inputs["input_ids"], return_dict=True, **trigger_encoder_kwargs, is_trigger=True)
    except:
        print(ctx_inputs["input_ids"].shape)
        assert False
        
    ctx_output["last_hidden_state"] = torch.cat((prev_hidden, ctx_output["last_hidden_state"]), dim=1)

    ctx_model_kwargs["encoder_outputs"] = ctx_output
    
    # generate one sentence with trigger
    ctx_input_ids = ctx_inputs['input_ids']
    dec_input_ids = model._prepare_decoder_input_ids_for_generation(
                    ctx_input_ids, decoder_start_token_id=bos_token_id, bos_token_id=bos_token_id)
     
    is_greedy_gen_mode = (model.config.num_beams == 1) and (model.config.num_beam_groups == 1) and model.config.do_sample is False
    is_sample_gen_mode = (model.config.num_beams == 1) and (model.config.num_beam_groups == 1) and model.config.do_sample is True
    is_beam_gen_mode = (model.config.num_beams > 1) and (model.config.num_beam_groups == 1) and model.config.do_sample is False
    is_beam_sample_gen_mode = (model.config.num_beams > 1) and (model.config.num_beam_groups == 1) and model.config.do_sample is True
    
    return_dict_in_generate = model.config.return_dict_in_generate
    output_hidden_states = False
    if is_response:
        return_dict_in_generate = True
        output_hidden_states = True
    if use_gumbel:
        return_dict_in_generate = True
    
    output_scores = False
    if get_ppl:
        output_scores = True
        return_dict_in_generate = True
        
    if is_greedy_gen_mode:
        res = model.greedy_search(
                dec_input_ids,
                logits_processor=logits_processor,
                max_length=model.config.max_length,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                output_scores=False,
                return_dict_in_generate=return_dict_in_generate,
                output_hidden_states=output_hidden_states,
                **ctx_model_kwargs,
            )
        
    elif is_sample_gen_mode:

        # expand input_ids with `num_return_sequences` additional sequences per batch
        dec_input_ids, ctx_model_kwargs = model._expand_inputs_for_generation(
            dec_input_ids,
            expand_size=model.config.num_return_sequences,
            is_encoder_decoder=True,
            **ctx_model_kwargs,
        )
        

        # sample
        res = model.sample(
            dec_input_ids,
            logits_processor=logits_processor,
            logits_warper=logits_warper,
            max_length=model.config.max_length,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            output_scores=output_scores,
            return_dict_in_generate=return_dict_in_generate,
            output_hidden_states=output_hidden_states,
            use_gumbel=use_gumbel,
            **ctx_model_kwargs,
        )
    elif is_beam_gen_mode:
        # interleave with `num_beams`
        dec_input_ids, ctx_model_kwargs = model._expand_inputs_for_generation(
            dec_input_ids, expand_size=model.config.num_beams, is_encoder_decoder=True, **ctx_model_kwargs
        )
        res = model.beam_search(
            dec_input_ids,
            beam_scorer,
            logits_processor=logits_processor,
            max_length=model.config.max_length,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            output_scores=False,
            return_dict_in_generate=return_dict_in_generate,
            output_hidden_states=output_hidden_states,
            **ctx_model_kwargs,
        ) 
    elif is_beam_sample_gen_mode:
        # interleave with `num_beams * num_return_sequences`
        dec_input_ids, ctx_model_kwargs = model._expand_inputs_for_generation(
            dec_input_ids, expand_size=model.config.num_beams * model.config.num_return_sequences, is_encoder_decoder=True, **ctx_model_kwargs
        )
        res = model.beam_sample(
                dec_input_ids,
                beam_scorer,
                logits_processor=logits_processor,
                logits_warper=logits_warper,
                max_length=model.config.max_length,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
                output_hidden_states=output_hidden_states,
                **ctx_model_kwargs,
            )
    
    if use_gumbel:
        generated_sentence_raw = tokenizer.batch_decode(res.sequences)  
        generated_sentence_clean = clean_blender_generation(generated_sentence_raw)
        gumbel_vectors = res.gumbel_vectors
        return generated_sentence_clean, gumbel_vectors
    elif not is_response:  # generating prompts
        if get_ppl:
            generated_sentence_raw = tokenizer.batch_decode(res.sequences)  
            generated_sentence_clean = clean_blender_generation(generated_sentence_raw)
            
            generated_sentence_mask = res.sequences.ne(pad_token_id).long()[:, :-2]  # the first one is bos
            logits_tensor = torch.cat([raw_logits.unsqueeze(1) for raw_logits in res.scores], dim=1)  # bze x len x vocab
            shift_logits = logits_tensor[..., :-1, :].contiguous()
            shift_labels = res.sequences[..., 1:-1].contiguous()
            loss_fct = nn.CrossEntropyLoss(reduction="none")

            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).detach()
            loss_reshape = loss.view(generated_sentence_mask.shape)  # bze x seq_len
            # if loss is inf, then the masked loss (after multipling mask) will be nan
            loss_reshape = torch.where(loss_reshape > 1e10, torch.ones_like(loss_reshape) * 0, loss_reshape)
            masked_loss_sum = torch.sum(loss_reshape * generated_sentence_mask, dim=-1)  # [bze]
            real_length = torch.sum(generated_sentence_mask, dim=-1)
            masked_loss = torch.mean(masked_loss_sum / real_length).item()
            ppl = math.exp(masked_loss)
            return generated_sentence_clean, ppl
        else:   
            generated_sentence_raw = tokenizer.batch_decode(res)  
            generated_sentence_clean = clean_blender_generation(generated_sentence_raw)
            return generated_sentence_clean
    else:
        generated_sentence_raw = tokenizer.batch_decode(res.sequences)  
        generated_sentence_clean = clean_blender_generation(generated_sentence_raw)
        all_hidden_states = res.decoder_hidden_states  # tuple of hidden states (bze, 1, hid)
        all_last_hidden_states_list = [i[-1] for i in all_hidden_states]
        hidden_states = torch.cat(all_last_hidden_states_list, dim=1)  # bze, seq_len, hid
        
        generated_sentence_mask = res.sequences.ne(pad_token_id).long()[:, :-1]
        
        return generated_sentence_clean, hidden_states, generated_sentence_mask
        

In [17]:
def clean_blender_generation(raw_texts):
    clean_texts = list()
    for sentence_i in raw_texts:
        sentence_i_0 = sentence_i.split("<s>")[-1]
        sentence_i_1 = sentence_i_0.split("</s>")[0]
        clean_texts.append(sentence_i_1.strip())
    return clean_texts

In [18]:
def convert_cls_examples_to_features(texts_a, texts_b, max_length):
    all_cls_input_ids, all_cls_attention_mask = list(), list()
    for text_a, text_b in zip(texts_a, texts_b):
        cls_inputs = cls_tokenizer.encode_plus(text_a, text_b, add_special_tokens=True, max_length=max_length, truncation=True)
        cls_input_ids = cls_inputs["input_ids"]
        cls_attention_mask = [1] * len(cls_input_ids)
        
        padding_length = max_length - len(cls_input_ids)
        
        cls_input_ids = cls_input_ids + ([cls_tokenizer.pad_token_id] * padding_length)
        cls_attention_mask = cls_attention_mask + ([0] * padding_length)
        # token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)  # not used in RoBERTa
        
        all_cls_input_ids.append(cls_input_ids)
        all_cls_attention_mask.append(cls_attention_mask)
    
    all_cls_input_tensors = torch.tensor(all_cls_input_ids, dtype=torch.long, device=device)
    all_cls_attention_mask_tensors = torch.tensor(all_cls_attention_mask, dtype=torch.long, device=device)
    
    return all_cls_input_tensors, all_cls_attention_mask_tensors
    

In [19]:
# # implement gumbel softmax here
# def generate_with_gumbel(c_texts, gumbel_vectors_tuple, prompt_texts, c_p_texts):
#     past = expand_past(bos_key_values, num_enc_layers, mini_batch_size)  # deep copy? shouldn't be modifed
#     # assume 1 trigger, key_value
#     past = concat_past(past, past, num_enc_layers)
    
#     # prepare hidden
#     prev_hidden = bos_hidden.repeat(mini_batch_size, 1, 1)
#     prev_hidden = torch.cat((prev_hidden, prev_hidden), dim=1)
    
#     # prepare context
#     prev_length = prev_hidden.shape[1]
#     ctx_model_kwargs = dict()
#     ctx_inputs = tokenizer(c_texts, return_tensors='pt', padding=True, truncation=True, max_length=126).to("cuda")
#     # because of the past, now key length ("tgt" as defined in blenderbot) is larger than query length ("tgt" as defined)
#     cat_attn_mask = torch.cat((torch.ones(ctx_inputs["attention_mask"].shape[0], prev_length, device="cuda", dtype=torch.long), ctx_inputs["attention_mask"]), dim=-1)
#     ctx_model_kwargs["attention_mask"] = cat_attn_mask
    
#     # get encoder output
#     trigger_encoder_kwargs = {
#             argument: value for argument, value in ctx_model_kwargs.items() if not argument.startswith("decoder_")
#         }
#     trigger_encoder_kwargs["past_key_values"] = past
#     ctx_output = model.model.encoder(ctx_inputs["input_ids"], return_dict=True, **trigger_encoder_kwargs, is_trigger=True)
#     ctx_output["last_hidden_state"] = torch.cat((prev_hidden, ctx_output["last_hidden_state"]), dim=1)
#     ctx_model_kwargs["encoder_outputs"] = ctx_output
    
#     # prepare decoder
#     prompt_inputs = tokenizer(p_texts, return_tensors='pt', padding=True, truncation=True).to("cuda")
#     prompt_inputs_ids = prompt_inputs["input_ids"]
#     prompt_attn_mask = prompt_inputs["attention_mask"]
    
#     gumbel_vectors_tensor = torch.cat(gumbel_vectors_tuple, dim=1)  # should be bze x seq_len x vocab_size
#     assert gumbel_vectors_tensor.shape[1] == prompt_inputs_ids.shape[1], "gumbel vector shape: %s; prompt inputs shape: %s" % (str(gumbel_vectors_tensor.shape), str(prompt_inputs_ids.shape))
#     prompt_inputs_emb = torch.matmul(gumbel_vectors_tensor, model.model.decoder.embed_tokens.weight.shape)  # bze x seq_len x hid
    
    
#     # add bos
#     dec_bos_ids = torch.ones((prompt_inputs_ids.shape[0], 1), dtype=torch.long, device=device) * bos_token_id
#     dec_bos_mask = torch.ones((prompt_inputs_ids.shape[0], 1), dtype=torch.long, device=device)
#     dec_inputs_ids = torch.cat((dec_bos_ids, prompt_inputs_ids), dim=1)
#     dec_attn_mask = torch.cat((dec_bos_mask, prompt_attn_mask), dim=1)
#     prompt_length = torch.sum(dec_attn_mask, dim=-1)  # including bos and eos. shape: [bze]
    
#     model_inputs = {"decoder_input_ids": dec_inputs_ids, "encoder_outputs": ctx_model_kwargs["encoder_outputs"],
#                     "attention_mask": ctx_model_kwargs["attention_mask"]}
#     outputs = model(**model_inputs, return_dict=True)

#     hidden_states = outputs[]
    
    
    

    
    

In [20]:
EPSILON = 1e-10
def get_representation(r_hidden, r_mask):  
    cur_batch_size, hidden_size = r_hidden.shape[0], r_hidden.shape[-1]
    r_mask = r_mask.unsqueeze(2).repeat(1, 1, hidden_size)
    masked_hidden = r_hidden * r_mask
    
    avg_hidden = torch.sum(masked_hidden, dim=1) / (torch.sum(r_mask, dim=1).detach() + EPSILON)
    return avg_hidden

In [21]:
for epoch in tqdm(range(total_steps // epoch_batch_size)):
    print("***********Epoch: %d/%d*************" % (epoch + 1, int(np.ceil(total_steps / epoch_batch_size))))
    torch.cuda.empty_cache()
    logs = dict()
    game_data = dict()
    timing = dict()
    t0 = time.time()
    
    if mode == "train" and shuffle_data:
        random.shuffle(context_list)
    cond_list = context_list[:epoch_batch_size]
    
    log_context, log_prompt, log_response = list(), list(), list()
    all_rewards = list()
    all_ppl = list()
    all_c_p_rewards = list()
    
    # check real reward:
    for i in range(int(epoch_batch_size / batch_size)):
        ctx_i = cond_list[i*batch_size:(i+1)*batch_size]
        log_context += ctx_i
        
        p_texts, p_ppl = generate_sentence_with_trigger(ctx_i, num_enc_layers, num_of_triggers, get_ppl=True)
        log_prompt += p_texts
        all_ppl.append(p_ppl)
        
        c_p_texts = list()
        for c, p in zip(ctx_i, p_texts):
            c_p_texts.append("%s   %s" % (c, p))
        
        c_p_inputs = tokenizer(c_p_texts, return_tensors='pt', padding=True, truncation=True).to(device)
        try:
            r_tensor = model.generate(c_p_inputs['input_ids'], num_beams=model.config.num_beams, do_sample=model.config.do_sample)
        except Exception as e:
            print(c_p_inputs["input_ids"].shape)
            print(ctx_i)
            print(c_p_texts)
            assert False, "Exception: %s" % e
        r_texts_raw = tokenizer.batch_decode(r_tensor)
        r_texts = clean_blender_generation(r_texts_raw)
        log_response += r_texts
        
         # run classifier for rewards        
        cls_c_p_r_inputs, cls_c_p_r_mask = convert_cls_examples_to_features(r_texts, c_p_texts, cls_max_length)
        cls_c_p_inputs, cls_c_p_mask = convert_cls_examples_to_features(p_texts, ctx_i, cls_max_length)

        with torch.no_grad():
            res = cls_model(cls_c_p_r_inputs, cls_c_p_r_mask)["logits"][:, tgt_label].detach() 
            c_p_res = cls_model(cls_c_p_inputs, cls_c_p_mask)["logits"][:, tgt_label].detach() 

        all_c_p_rewards.append(c_p_res)
        all_rewards.append(res)  # [bze]
        
    # logging
    log_name = "game_log_e%d" % (epoch + 1)
    log_rewards = torch.cat(all_rewards)
    log_c_p_rewards = torch.cat(all_c_p_rewards)
    log_ppl = sum(all_ppl) / len(all_ppl)
    table_rows = [list(r) for r in zip(log_context, log_prompt, log_response, log_rewards.cpu().tolist())]
    logs.update({log_name:wandb.Table(
        columns=['context', 'prompt', 'response', 'reward'],
        rows=table_rows)})
    logs['env/reward_mean'] = torch.mean(log_rewards).cpu().numpy()
    logs['env/reward_std'] = torch.std(log_rewards).cpu().numpy()
    logs['env/c_p_reward_mean'] = torch.mean(log_c_p_rewards).cpu().numpy()
    logs['env/p_ppl'] = log_ppl
    
    if use_wandb:
        wandb.log(logs)
    
    model.zero_grad()
     
    # real training
    for i in range(int(epoch_batch_size / batch_size)):
        loss_per_update = 0
        total_loss = 0
        
        ctx_i = cond_list[i*batch_size:(i+1)*batch_size]
        
#         p_texts, p_gumbel_vectors = generate_sentence_with_trigger(ctx_i, num_enc_layers, num_of_triggers, use_gumbel=True)
        p_texts = generate_sentence_with_trigger(ctx_i, num_enc_layers, num_of_triggers)
        log_prompt += p_texts
        
        c_p_texts = list()
        for c, p in zip(ctx_i, p_texts):
            c_p_texts.append("%s   %s" % (c, p))
        
        # without gumbel softmax. Need to run generate_with_gumbel to use gumbel softmax
        r_texts, r_hidden, r_mask = generate_sentence_with_trigger(c_p_texts, num_enc_layers, num_of_triggers, is_response=True)    
        
        # r_hidden_mask?????????
#         with torch.no_grad():
        avg_hidden = get_representation(r_hidden, r_mask)
        prediction = classifier(avg_hidden)
        # WARNING: There is a weird bug: used to be tgt_label, which somehow changed to (0,). so use [0] directly here
        label = torch.tensor([0], device=device, dtype=torch.long).repeat(batch_size)
        discrim_loss = ce_loss(prediction, label)
        
        loss_per_update += discrim_loss.item()
        
        discrim_loss.backward()
        
#         print("Debuggin current key_value")
#         print(model.l_1_value[:, 0, :, :10])
#         print(model.ori_trigger_hidden[:, :, :10])
#         print("++++++++++++\n\n\n")
        
        if (i + 1) % gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            optimizer.step()
            model.zero_grad()
#             print("\n=======update loss: %.6f=======" % (loss_per_update / gradient_accumulation_steps))
            total_loss += loss_per_update
            loss_per_update = 0
    
    
    
    # save trigger
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    save_filename = "%s/e%d.pt" % (save_path, epoch + 1)
    save_data = dict()
    save_data["ori_trigger_hidden"] = model.ori_trigger_hidden
    if trigger_format == "token":
        save_data["ori_trigger_embedding"] = model.ori_trigger_embedding
    else:
        save_data["ori_trigger_key_values"] = model.ori_trigger_key_values
    torch.save(save_data, save_filename)    
        

  0%|          | 0/60 [00:00<?, ?it/s]

***********Epoch: 1/60*************


  2%|▏         | 1/60 [02:10<2:07:55, 130.10s/it]

***********Epoch: 2/60*************


  3%|▎         | 2/60 [04:24<2:07:56, 132.36s/it]

***********Epoch: 3/60*************


  5%|▌         | 3/60 [06:33<2:04:37, 131.19s/it]

***********Epoch: 4/60*************


  7%|▋         | 4/60 [08:49<2:04:02, 132.90s/it]

***********Epoch: 5/60*************


  8%|▊         | 5/60 [10:59<2:00:45, 131.74s/it]

***********Epoch: 6/60*************


 10%|█         | 6/60 [13:10<1:58:30, 131.68s/it]

***********Epoch: 7/60*************


 12%|█▏        | 7/60 [15:23<1:56:38, 132.05s/it]

***********Epoch: 8/60*************


 13%|█▎        | 8/60 [17:36<1:54:49, 132.50s/it]

***********Epoch: 9/60*************


 15%|█▌        | 9/60 [19:49<1:52:32, 132.39s/it]

***********Epoch: 10/60*************


 17%|█▋        | 10/60 [22:00<1:50:04, 132.09s/it]

***********Epoch: 11/60*************


 18%|█▊        | 11/60 [24:14<1:48:17, 132.60s/it]

***********Epoch: 12/60*************


 20%|██        | 12/60 [26:30<1:47:03, 133.81s/it]

***********Epoch: 13/60*************


 22%|██▏       | 13/60 [28:39<1:43:37, 132.28s/it]

***********Epoch: 14/60*************


 23%|██▎       | 14/60 [30:57<1:42:40, 133.92s/it]

***********Epoch: 15/60*************


 25%|██▌       | 15/60 [33:10<1:40:20, 133.78s/it]

***********Epoch: 16/60*************


 27%|██▋       | 16/60 [35:20<1:37:18, 132.69s/it]

***********Epoch: 17/60*************


 28%|██▊       | 17/60 [37:25<1:33:26, 130.39s/it]

***********Epoch: 18/60*************


 30%|███       | 18/60 [39:47<1:33:33, 133.64s/it]

***********Epoch: 19/60*************


 32%|███▏      | 19/60 [42:12<1:33:45, 137.20s/it]

***********Epoch: 20/60*************


 33%|███▎      | 20/60 [44:24<1:30:29, 135.73s/it]

***********Epoch: 21/60*************


 35%|███▌      | 21/60 [46:38<1:27:44, 134.98s/it]

***********Epoch: 22/60*************


 37%|███▋      | 22/60 [48:49<1:24:44, 133.82s/it]

***********Epoch: 23/60*************


 38%|███▊      | 23/60 [51:01<1:22:11, 133.27s/it]

***********Epoch: 24/60*************


 40%|████      | 24/60 [53:09<1:19:07, 131.88s/it]

***********Epoch: 25/60*************


 42%|████▏     | 25/60 [55:19<1:16:35, 131.30s/it]

***********Epoch: 26/60*************


 43%|████▎     | 26/60 [57:31<1:14:23, 131.27s/it]

***********Epoch: 27/60*************


 45%|████▌     | 27/60 [59:42<1:12:13, 131.32s/it]

***********Epoch: 28/60*************


 47%|████▋     | 28/60 [1:01:48<1:09:09, 129.67s/it]

***********Epoch: 29/60*************


 48%|████▊     | 29/60 [1:03:53<1:06:14, 128.21s/it]

***********Epoch: 30/60*************


 50%|█████     | 30/60 [1:06:02<1:04:12, 128.43s/it]

***********Epoch: 31/60*************


 52%|█████▏    | 31/60 [1:08:03<1:01:06, 126.43s/it]

***********Epoch: 32/60*************


 53%|█████▎    | 32/60 [1:10:19<1:00:16, 129.15s/it]

***********Epoch: 33/60*************


 55%|█████▌    | 33/60 [1:12:24<57:37, 128.05s/it]  

***********Epoch: 34/60*************


 57%|█████▋    | 34/60 [1:14:28<54:58, 126.88s/it]

***********Epoch: 35/60*************


 58%|█████▊    | 35/60 [1:16:36<52:53, 126.94s/it]

***********Epoch: 36/60*************


 60%|██████    | 36/60 [1:18:40<50:30, 126.26s/it]

***********Epoch: 37/60*************


 62%|██████▏   | 37/60 [1:20:42<47:50, 124.79s/it]

***********Epoch: 38/60*************


 63%|██████▎   | 38/60 [1:22:41<45:08, 123.12s/it]

***********Epoch: 39/60*************


 65%|██████▌   | 39/60 [1:24:44<43:03, 123.04s/it]

***********Epoch: 40/60*************


 67%|██████▋   | 40/60 [1:26:55<41:47, 125.40s/it]

***********Epoch: 41/60*************


 68%|██████▊   | 41/60 [1:28:56<39:19, 124.19s/it]

***********Epoch: 42/60*************


 70%|███████   | 42/60 [1:31:16<38:39, 128.87s/it]

***********Epoch: 43/60*************


 72%|███████▏  | 43/60 [1:33:21<36:11, 127.71s/it]

***********Epoch: 44/60*************


 73%|███████▎  | 44/60 [1:35:14<32:54, 123.43s/it]

***********Epoch: 45/60*************


 75%|███████▌  | 45/60 [1:37:18<30:55, 123.69s/it]

***********Epoch: 46/60*************


 77%|███████▋  | 46/60 [1:39:21<28:48, 123.45s/it]

***********Epoch: 47/60*************


 78%|███████▊  | 47/60 [1:41:27<26:55, 124.24s/it]

***********Epoch: 48/60*************


 80%|████████  | 48/60 [1:43:45<25:38, 128.23s/it]

***********Epoch: 49/60*************


 82%|████████▏ | 49/60 [1:45:49<23:15, 126.83s/it]

***********Epoch: 50/60*************


 83%|████████▎ | 50/60 [1:47:50<20:51, 125.17s/it]

***********Epoch: 51/60*************


 85%|████████▌ | 51/60 [1:49:54<18:45, 125.01s/it]

***********Epoch: 52/60*************


 87%|████████▋ | 52/60 [1:51:58<16:36, 124.58s/it]

***********Epoch: 53/60*************


 88%|████████▊ | 53/60 [1:53:57<14:20, 122.96s/it]

***********Epoch: 54/60*************


 90%|█████████ | 54/60 [1:55:53<12:05, 120.85s/it]

***********Epoch: 55/60*************


 92%|█████████▏| 55/60 [1:58:00<10:13, 122.64s/it]

***********Epoch: 56/60*************


 93%|█████████▎| 56/60 [2:00:13<08:23, 125.89s/it]

***********Epoch: 57/60*************


 95%|█████████▌| 57/60 [2:02:25<06:22, 127.54s/it]

***********Epoch: 58/60*************


 97%|█████████▋| 58/60 [2:04:40<04:19, 129.86s/it]

***********Epoch: 59/60*************


 98%|█████████▊| 59/60 [2:06:47<02:08, 128.89s/it]

***********Epoch: 60/60*************


100%|██████████| 60/60 [2:08:58<00:00, 128.97s/it]


In [22]:
assert False

AssertionError: 

In [25]:
# load pre-trained model
saved_model_path = "/mnt//trigger_experiments/semi_safety/e60.pt"
saved_dict = torch.load(saved_model_path)
model.ori_trigger_hidden = saved_dict["ori_trigger_hidden"]
model.ori_trigger_key_values = saved_dict["ori_trigger_key_values"]
print("WARNING: Evaluating a saved model")




In [26]:
# Evaluation
softmax_fn = nn.Softmax(dim=-1)

import csv

csv_file = open("data/safety_train_all_semi_final_e60_1.csv", "w")



eval_context_filename = "data/trigger_bad_valid.txt"
eval_context_list = read_file(eval_context_filename)
print("evaluating %s" % eval_context_filename)
print("***********Evaluation at Epoch: %d/%d*************" % (epoch + 1, int(np.ceil(total_steps / epoch_batch_size))))



torch.cuda.empty_cache()
logs = dict()
game_data = dict()
timing = dict()
t0 = time.time()

#### get everything from the dataset
cond_list = eval_context_list

all_rewards, all_c_p_r_rewards, all_c_p_rewards, all_c_p_rewards_adjusted = list(), list(), list(), list()
all_probs, all_c_p_probs = list(), list()
log_context, log_prompt, log_response = list(), list(), list()
all_ppl = list()

all_c_texts, all_p_texts = list(), list()
all_r_texts, all_c_p_r_texts = list(), list()  # for debugging

#### get prompt from model
for i in tqdm(range(int(len(eval_context_list) / batch_size))):
    ctx_i = cond_list[i*batch_size:(i+1)*batch_size]

    log_context += ctx_i

    p_texts, p_ppl = generate_sentence_with_trigger(ctx_i, num_enc_layers, num_of_triggers, get_ppl=True)
    log_prompt += p_texts
    all_ppl.append(p_ppl)

    c_p_texts = list()
    for c, p in zip(ctx_i, p_texts):
        c_p_texts.append("%s   %s" % (c, p))

    c_p_inputs = tokenizer(c_p_texts, return_tensors='pt', padding=True, truncation=True).to(device)
    try:
        r_tensor = model.generate(c_p_inputs['input_ids'], num_beams=model.config.num_beams, do_sample=model.config.do_sample)
    except Exception as e:
        print(c_p_inputs["input_ids"].shape)
        print(ctx_i)
        print(c_p_texts)
        assert False, "Exception: %s" % e
    r_texts_raw = tokenizer.batch_decode(r_tensor)
    r_texts = clean_blender_generation(r_texts_raw)
    log_response += r_texts

    c_p_r_texts = list()
    for c_p, r in zip(c_p_texts, r_texts):
        c_p_r_texts.append("%s   %s" % (c_p, r))


    all_c_texts.append(ctx_i)
    all_p_texts.append(p_texts)
    all_r_texts.append(r_texts)
    all_c_p_r_texts.append(c_p_r_texts)


    # run classifier for rewards        
    cls_c_p_r_inputs, cls_c_p_r_mask = convert_cls_examples_to_features(r_texts, c_p_texts, cls_max_length)
    cls_c_p_inputs, cls_c_p_mask = convert_cls_examples_to_features(p_texts, ctx_i, cls_max_length)

    with torch.no_grad():
        all_logits = cls_model(cls_c_p_r_inputs, cls_c_p_r_mask)["logits"]
        res = all_logits[:, tgt_label].detach() 
        res_probs = softmax_fn(all_logits)[:, tgt_label].detach() 
        
        c_p_logits = cls_model(cls_c_p_inputs, cls_c_p_mask)["logits"]
        c_p_res = c_p_logits[:, tgt_label].detach() 
        c_p_res_probs = softmax_fn(c_p_logits)[:, tgt_label].detach()



    all_rewards.append(res)  # [bze]
    # if prompt_reward, all_probs is actually for c_p_r
    all_probs.append(res_probs)
    all_c_p_rewards.append(c_p_res)
    all_c_p_probs.append(c_p_res_probs)
    
log_name = "evaluation %s @e%d" % (eval_context_filename, epoch + 1)
log_rewards = torch.cat(all_rewards)
log_probs = torch.cat(all_probs)
log_c_p_rewards = torch.cat(all_c_p_rewards)
log_c_p_probs = torch.cat(all_c_p_probs)

log_ppl = sum(all_ppl) / len(all_ppl)   

table_rows = [list(r) for r in zip(log_context, log_prompt, log_response, log_rewards.cpu().tolist(), log_probs.cpu().tolist(), log_c_p_rewards.cpu().tolist(), log_c_p_probs.cpu().tolist(),)]

fieldnames = ['context', 'prompt', 'response', 'reward', 'probs', 'c_p_reward', 'c_p_probs']

logs['env/reward_mean'] = torch.mean(log_rewards).cpu().numpy()
logs['env/reward_std'] = torch.std(log_rewards).cpu().numpy()

logs['env/c_p_probs_mean'] = torch.mean(log_c_p_probs).cpu().numpy()
logs['env/c_p_probs_std'] = torch.std(log_c_p_probs).cpu().numpy()
    
logs['env/reward_prob_mean'] = torch.mean(log_probs).cpu().numpy()
logs['env/reward_prob_std'] = torch.std(log_probs).cpu().numpy()
logs['env/p_ppl'] = log_ppl


writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
writer.writeheader()
for row_list in table_rows:
    row_dict = dict()
    for row_name, row_item in zip(fieldnames, row_list):
        row_dict[row_name] = row_item
    writer.writerow(row_dict)

print(logs)
    



  0%|          | 0/186 [00:00<?, ?it/s]

evaluating data/trigger_bad_valid.txt
***********Evaluation at Epoch: 60/60*************


100%|██████████| 186/186 [10:56<00:00,  3.53s/it]


{'env/reward_mean': array(-0.7983979, dtype=float32), 'env/reward_std': array(1.3177015, dtype=float32), 'env/c_p_probs_mean': array(0.67617565, dtype=float32), 'env/c_p_probs_std': array(0.38747743, dtype=float32), 'env/reward_prob_mean': array(0.31896862, dtype=float32), 'env/reward_prob_std': array(0.3358653, dtype=float32), 'env/p_ppl': 17.783813842696706}
