In [1]:
import argparse
import logging
import random
import time
import os
from typing import List, Optional, Tuple, Union
import copy

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
from transformers import AdamW, get_linear_schedule_with_warmup

from utils import get_classifier, generate_next, concat_past, expand_past


In [2]:
def read_file(filename):
    l = list()
    for line in open(filename):
        l.append(line.strip())
    return l

In [3]:
# def main():  # the following n cells

seed = 0
device = "cuda"

pretrained_model = "gpt2-medium"
# sentiment: class_label = ["positive", "negative", "very position", "very negative", "neutral"] 
discrim = "sentiment"
class_label = 3
num_of_triggers = 1

trigger_format = "key_value"  # token or key_value
TRIGGER_POSITION_ID = 0  # for position reset

num_epochs = 1
num_iterations = 2
learning_rate = 5e-3
# learning_rate = 0  # baseline (no real trigger)
gradient_accumulation_steps = 1

sample = True
gumbel_softmax = True
detach = True
# ----------------------
reset_pos_emb = True
not_mask_trigger = False
gumbel_temperature = 1.0

batch_size = 4
multiple_input = True

# more or less fixed for generation
top_k = 10
temperature = 1.0
repetition_penalty = 1.0
adam_epsilon = 1e-8
max_grad_norm = 1.0
length = 40

if detach and not gumbel_softmax:
    assert False, "require gumbel softmax when using detach"

verbose = True
check_real_loss = True

# data path
train_filename = "persona_train.txt"
eval_filename = "persona_eval.txt"
    
# WARNING: GPT2 only
new_line_idx = 198  # '\n'
new_line_idx_1 = 628  # '\n\n'
stop_token = "."


In [4]:
torch.manual_seed(seed)
np.random.seed(seed)

# load pretrained model
model = GPT2LMHeadModel.from_pretrained(pretrained_model, output_hidden_states=True)
model.to(device)
model.eval()  # do not need batchnorm or dropout layers for training/eval

# load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)

# Freeze GPT-2 weights
for param in model.parameters():
    param.requires_grad = False

classifier, class_id = get_classifier(discrim, class_label, device)

num_layers = model.config.n_layer

ce_loss = nn.CrossEntropyLoss()

lm_bos_output = model(torch.tensor(tokenizer.encode(tokenizer.bos_token), dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1))  # BOS

# WARNING: GPT2 only
t_pad_token = tokenizer.bos_token


In [5]:
# cond_text_list = ["This is a terrible restaurant.", "I like drinking water.", "It is raining today.", "I'm doing my homework."]
# cond_text_list = ["I don't like this restaurant."]

In [6]:
# initialize trigger
# Note: since we use the same trigger for all inputs in a batch, we only create/register trigger(s) for one and repeat it
def init_trigger(model, tokenizer, num_of_triggers, trigger_format):
    if num_of_triggers > 0:
        if trigger_format == "token":  # learn a continuous embedding
            trigger_embedding_list = []
            for _ in range(num_of_triggers):
                trigger_embedding_i = copy.deepcopy(model.transformer.wte(
                    torch.tensor(tokenizer.encode(tokenizer.bos_token), device=device, dtype=torch.long).unsqueeze(0)))
                trigger_embedding_list.append(trigger_embedding_i)
            ori_trigger_embedding = nn.Parameter(torch.cat(trigger_embedding_list, dim=1))  # bze x n x emb_size
            model.ori_trigger_embedding = ori_trigger_embedding  # register to the model (optimizer)
    #         trigger_embedding = trigger_embedding.repeat(batch_size, 1, 1)  # cannot do it here, otherwise trigger_embedding becomes a non-leaf node where the grad will not backprop
        elif trigger_format == "key_value":  # learn key values
            ori_trigger_key_values = [(None, None) for _ in range(num_layers)]
            bos_key_values = model(torch.tensor(tokenizer.encode(tokenizer.bos_token), dtype=torch.long).unsqueeze(0).to(device))[
                            "past_key_values"]
            for layer in range(num_layers):
                for i_t in range(num_of_triggers):
                    trigger_i_key_value = copy.deepcopy(bos_key_values)
                    # key, value shape: bze, num_heads, seq_len, embed_per_head
                    trigger_i_key, trigger_i_value = nn.Parameter(trigger_i_key_value[layer][0]), \
                                                     nn.Parameter(trigger_i_key_value[layer][1])

                    trigger_i_key.requires_grad = True
                    trigger_i_value.requires_grad = True

                    if ori_trigger_key_values[layer][0] is None:
                        ori_trigger_key_values[layer] = (trigger_i_key, trigger_i_value)
                    else:
                        # if multiple triggers
                        trigger_key = nn.Parameter(torch.cat((ori_trigger_key_values[layer][0], trigger_i_key), dim=-2))
                        trigger_value = nn.Parameter(torch.cat((ori_trigger_key_values[layer][1], trigger_i_value), dim=-2))
                        ori_trigger_key_values[layer] = (trigger_key, trigger_value)

                # register parameter into optimizer
                key_name = "l_%d_key" % layer
                value_name = "l_%d_value" % layer
                if num_of_triggers == 1:
                    model.register_parameter(name=key_name, param=trigger_i_key)
                    model.register_parameter(name=value_name, param=trigger_i_value)
                else:
                    model.register_parameter(name=key_name, param=trigger_key)
                    model.register_parameter(name=value_name, param=trigger_value)

            ori_trigger_key_values = tuple(ori_trigger_key_values)
            model.ori_trigger_key_values = ori_trigger_key_values
    #         trigger_key_values = expand_past(trigger_key_values, num_layers, batch_size)  # similar to trigger_embedding, need leaf level grad
        else:
            assert False, "trigger_format: %s not supported" % trigger_format

In [7]:
init_trigger(model, tokenizer, num_of_triggers, trigger_format)

# optimizer
param_optimizer = list(filter(lambda p: p[1].requires_grad, list(model.named_parameters())))

# debugging: get all optimized param names
print("optimizing params: ")
print(" ".join(o[0] for o in param_optimizer))

no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {
        'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay': 0.0,
    },
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters,
                  lr=learning_rate,
                  eps=adam_epsilon)


optimizing params: 
l_0_key l_0_value l_1_key l_1_value l_2_key l_2_value l_3_key l_3_value l_4_key l_4_value l_5_key l_5_value l_6_key l_6_value l_7_key l_7_value l_8_key l_8_value l_9_key l_9_value l_10_key l_10_value l_11_key l_11_value l_12_key l_12_value l_13_key l_13_value l_14_key l_14_value l_15_key l_15_value l_16_key l_16_value l_17_key l_17_value l_18_key l_18_value l_19_key l_19_value l_20_key l_20_value l_21_key l_21_value l_22_key l_22_value l_23_key l_23_value


In [8]:
def prep_inputs(batch_cond_text_list, tokenizer, device, t_pad_token):
    batch_max_length = 0
    batch_min_length = 10000
    batch_input_ids = list()
    all_inputs, all_attention_masks, all_lengths, all_padding_length = list(), list(), list(), list()
    padding_token = tokenizer.encode(t_pad_token)[0]  # WARNING: BOS is for GPT2 only. Should use padding token
    for cond_text in batch_cond_text_list:
        inputs_ids = tokenizer.encode(tokenizer.bos_token + cond_text)
        batch_max_length = len(inputs_ids) if len(inputs_ids) > batch_max_length else batch_max_length
        batch_min_length = len(inputs_ids) if len(inputs_ids) < batch_min_length else batch_min_length
        batch_input_ids.append(inputs_ids)
    for inputs_ids in batch_input_ids:
        all_lengths.append(len(inputs_ids))
        padding_len = batch_max_length - len(inputs_ids)
        attention_mask = [1] * len(inputs_ids) + [0] * padding_len
        inputs_ids = inputs_ids + [padding_token] * padding_len 

        all_padding_length.append(padding_len)
        all_inputs.append(inputs_ids)
        all_attention_masks.append(attention_mask)

    all_input_ids = torch.tensor(all_inputs, dtype=torch.long, device=device)
    all_attention_masks = torch.tensor(all_attention_masks, dtype=torch.long, device=device)
    
    return all_input_ids, all_attention_masks, batch_min_length, batch_max_length, all_lengths

In [9]:
def penalize_new_line(logits):
    new_line_tokens = [new_line_idx, new_line_idx_1]
    for b_i in range(logits.shape[0]):
        for nt in new_line_tokens:
            if logits[b_i, -1, nt] < 0:
                logits[b_i, -1, nt] *= 5
            else:
                logits[b_i, -1, nt] /= 5
    return logits


In [10]:
def generate_prompt_response(model, tokenizer, mode, lm_bos_output, batch_size, device, class_label,
                             num_iterations, learning_rate, gradient_accumulation_steps, sample, gumbel_softmax,
                             detach, reset_pos_emb, not_mask_trigger, gumbel_temperature, top_k, temperature, 
                             repetition_penalty, adam_epsilon, max_grad_norm, length, t_pad_token, stop_token,
                             num_epochs, context_list, verbose, check_real_loss,
                             seed,
                            ):
    
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    if mode == "eval":
        num_epochs = 1
        
    total_data_batches = len(context_list) // batch_size
        
    for epoch in range(num_epochs):
        print("&&&&& epoch: %d &&&&&" % (epoch + 1))
        
#         if mode == "train":
#             random.shuffle(context_list)
            
        epoch_loss = 0
            
        for cond_list_idx in range(total_data_batches):
            cond_list = context_list[cond_list_idx * batch_size: (cond_list_idx+1) * batch_size]
            all_input_ids, all_attention_masks, batch_min_length, batch_max_length, all_lengths = prep_inputs(cond_list, tokenizer, device, t_pad_token)
            
            model.zero_grad()

            loss_per_update = 0
            total_loss = 0

            if reset_pos_emb:
                c_position_ids = torch.arange(1, batch_min_length - 1, dtype=torch.long, device=device)
                c_position_ids = c_position_ids.unsqueeze(0).repeat(batch_size, 1)
            else:
                c_position_ids = None

            if reset_pos_emb:
                t_position_ids = torch.ones(batch_size, num_of_triggers).to(torch.long).to(device) * TRIGGER_POSITION_ID
            else:
                t_position_ids = None

            all_context_prompts = []
            all_generated_prompt_length = []

            for i in range(num_iterations):

                past = lm_bos_output["past_key_values"]

                if num_of_triggers > 0:
                    if trigger_format == "token":
                        trigger_embedding = model.ori_trigger_embedding.repeat(batch_size, 1, 1)
                        lm_trigger_output = model(inputs_embeds=trigger_embedding, position_ids=t_position_ids)
                        trigger_key_values = lm_trigger_output["past_key_values"]
                    else:
                        trigger_key_values = expand_past(model.ori_trigger_key_values, num_layers, batch_size)
                    past = concat_past(past, trigger_key_values, num_layers)
                else:
                    trigger_key_values = None

                output_so_far = all_input_ids[:, :batch_min_length]  # bze x (batch_min_length - 1)

                context_lm_output = model(all_input_ids[:, 1: batch_min_length - 1], past_key_values=past, position_ids=c_position_ids, )

                past = context_lm_output["past_key_values"]
                last = output_so_far[:, batch_min_length - 1: batch_min_length]

                gumbel_vector = None
                if detach:
                    all_gumbel_vectors = None

                prompt_not_done = torch.ones(batch_size, 1, dtype=torch.uint8, device=device)
                generated_prompt_length = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
                prompt_stop_first = True

                # generate conditional prompt
                if verbose:
                    print("=====Epoch: %d; data_batch: %d; Iteration: %d=====" % (epoch + 1, cond_list_idx + 1, i + 1))

                for p_i in range(length):
                    if reset_pos_emb:
                        past_length = past[0][0].size(-2)
                        p_position_ids = torch.arange(past_length - num_of_triggers, past_length - num_of_triggers + 1, dtype=torch.long, device=device)
                        p_position_ids = p_position_ids.unsqueeze(0).repeat(batch_size, 1)
                    else:
                        p_position_ids = None
                    if gumbel_softmax and gumbel_vector is not None:
                        last_emb = torch.mm(gumbel_vector, model.transformer.wte.weight).unsqueeze(1)  # needs to be bze, n, emb
                        lm_output = model(inputs_embeds=last_emb, past_key_values=past, position_ids=p_position_ids)
                    else:
                        lm_output = model(last, past_key_values=past, position_ids=p_position_ids)

                    logits, past, all_hidden = (
                        lm_output["logits"],  # bze, cur_seq_len, vocab_size
                        lm_output["past_key_values"],  # acc_seq_len
                        lm_output["hidden_states"],  # num_layers + 1, tuple of (bze, cur_seq_len, hid_sze)
                    )

                    vocab_size = logits.shape[-1]

                    logits = penalize_new_line(logits)

                    # last: bze x 1; gumbel_vector: bze x vocab_size
                    last, gumbel_vector = generate_next(logits, output_so_far, top_k=top_k, temperature=temperature, 
                                                        repetition_penalty=repetition_penalty, sample=sample, 
                                                        gumbel_softmax=gumbel_softmax, gumbel_temperature=gumbel_temperature, detach=detach)
                    # manually assign end token is too long
                    if p_i == length - 1:
                        for m_b_i in range(batch_size):
                            if generated_prompt_length[m_b_i] == 0:
                                last[m_b_i] = tokenizer.encode(stop_token)[0]  # encode outputs a list (1 element)
                                if gumbel_softmax:
                                    gumbel_vector[m_b_i] = F.one_hot(torch.tensor(tokenizer.encode(stop_token), dtype=torch.long, device=device), num_classes=vocab_size)

                    # double check the length (p_i vs. lengths below)
                    is_generated = torch.tensor(all_lengths, device=device).unsqueeze(-1) <= (p_i + batch_min_length)  # bze x 1. is generated or stil in the context
                    is_end_token = last == torch.tensor(tokenizer.encode(stop_token), device=device)  # bze x 1
                    is_actually_ending = is_generated * is_end_token

                    # keep track of prompt length
                    generated_prompt_length = generated_prompt_length + prompt_not_done * is_actually_ending * p_i
                    
                    # if generated, use the generated token as last; otherwise (from the original), copy the orignal token/gumbel_vector
                    if batch_min_length + p_i < all_input_ids.shape[1]:
                        last = last * is_generated + all_input_ids[:, batch_min_length + p_i].unsqueeze(1) * (~is_generated)  # is_generated is bool. need to use "~" instead of (1-is_generated)
                    else:
                        last = last

                    if gumbel_softmax:
                        if batch_min_length + p_i < all_input_ids.shape[1]:
                            ori_one_hot = F.one_hot(all_input_ids[:, batch_min_length + p_i], num_classes=vocab_size)  # bze x vocab_size 
                            gumbel_vector = gumbel_vector * is_generated + ori_one_hot * (~is_generated)
                        else:
                            gumbel_vector = gumbel_vector
                    if prompt_stop_first and torch.sum(is_actually_ending) > 0:
                        prompt_stop_first = False
                        min_p_past = past
                        min_p_last = last
                        min_gumbel_vector = gumbel_vector

                    if detach: 
                        # WARNING: This can be quite large
                        if all_gumbel_vectors is None:
                            all_gumbel_vectors = gumbel_vector.unsqueeze(1)  # bze x 1 x vocab_size
                        else:
                            all_gumbel_vectors = torch.cat((all_gumbel_vectors, gumbel_vector.unsqueeze(1)), dim=1)  # bze x n x vocab_size

                    output_so_far = torch.cat((output_so_far, last), dim=1)  # bze x length

                    prompt_not_done = prompt_not_done * (~is_actually_ending)  # to check is we need to stop by summing

                    if torch.sum(prompt_not_done) == 0:
                        break

                if verbose:
                    print("context + prompt")
                    batch_context_prompts = []
                    for cp_i in range(batch_size):
                        cp_i_text = tokenizer.decode(output_so_far[cp_i][:(generated_prompt_length[cp_i].item() + 1 + batch_min_length)].tolist())
                        print(cp_i_text)
                        batch_context_prompts.append(cp_i_text)
                    print("***" * 20)

                all_context_prompts.append(batch_context_prompts)
                all_generated_prompt_length.append(generated_prompt_length + batch_min_length + 1)

                if mode == "eval":  # only need context and prompts for evaluation
                    continue

                cp_min_length = torch.min(generated_prompt_length.squeeze(1)) + batch_min_length + 1

                if detach:
                    detach_context_output = model(all_input_ids[:, :batch_min_length])
                    past = detach_context_output["past_key_values"]
                    all_gumbel_embeddings = torch.matmul(all_gumbel_vectors[:, :cp_min_length - batch_min_length - 1, :], model.transformer.wte.weight)
                    last_detach_output = model(inputs_embeds=all_gumbel_embeddings, past_key_values=past)
                    past = last_detach_output["past_key_values"]

                    gumbel_vector = min_gumbel_vector
                else:
                    # adjust past, last, and gumbel_vector
                    past = min_p_past
                    last = min_p_last


                everything_so_far = output_so_far[:, :cp_min_length]

                # generate response
                response_hidden = None
                response_so_far = None
                first = True
                to_break = False

                response_not_done = torch.ones(batch_size, 1, dtype=torch.uint8, device=device)
                generated_response_length = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
                response_stop_first = True

                for r_i in range(length):
                    # TODO: mask trigger key and value
                    if num_of_triggers > 0 and not not_mask_trigger and not detach:
                        # create attention mask
                        past_length = past[0][0].shape[-2]
                        attention_mask = torch.ones(batch_size, past_length + 1)  # add current 1 to length
                        attention_mask[:, 1: 1 + num_of_triggers] = 0  # bze=1, the first element is BOS
                        attention_mask = attention_mask.to(device)
                    else:
                        attention_mask = None

                    if num_of_triggers > 0 and reset_pos_emb and not detach:
                        past_length = past[0][0].size(-2)
                        r_position_ids = torch.arange(past_length - num_of_triggers, past_length - num_of_triggers + 1, dtype=torch.long, device=device)
                        r_position_ids = r_position_ids.unsqueeze(0).repeat(batch_size, 1)
                    else:
                        r_position_ids = None

                    # debugging
                    if gumbel_softmax:
                        last_emb = torch.mm(gumbel_vector, model.transformer.wte.weight).unsqueeze(1)  # bze x 1 x hidden
                        lm_rep_output = model(inputs_embeds=last_emb, past_key_values=past, attention_mask=attention_mask,
                                          output_attentions=True, position_ids=r_position_ids)
                    else:
                        lm_rep_output = model(last, past_key_values=past, attention_mask=attention_mask, output_attentions=True,
                                          position_ids=r_position_ids)

                    rep_logits, past, rep_all_hidden = (
                        lm_rep_output["logits"],  # bze, cur_seq_len, vocab_size
                        lm_rep_output["past_key_values"],  # acc_seq_len
                        lm_rep_output["hidden_states"],  # num_layers + 1, tuple of (bze, cur_seq_len, hid_sze)
                    )

                    rep_logits = penalize_new_line(rep_logits)

                    last, gumbel_vector = generate_next(rep_logits, everything_so_far, top_k=top_k, temperature=temperature, 
                                                        repetition_penalty=repetition_penalty, sample=sample, 
                                                        gumbel_softmax=gumbel_softmax, gumbel_temperature=gumbel_temperature, detach=detach)

                    last_hidden = rep_all_hidden[-1]

                    if response_hidden is None:
                        response_hidden = last_hidden
                    else:
                        response_hidden = torch.cat((response_hidden, last_hidden), dim=1)  # bze, n, hid_size

                    if to_break:
                        break

                    # manually assign end token is too long
                    if r_i == length - 1:
                        for r_m_b_i in range(batch_size):
                            if generated_response_length[r_m_b_i] == 0:
                                last[r_m_b_i] = tokenizer.encode(stop_token)[0]  # encode outputs a list (1 element)
                                gumbel_vector[r_m_b_i] = F.one_hot(torch.tensor(tokenizer.encode(stop_token), dtype=torch.long, device=device), num_classes=vocab_size)

                    # adjust 
                    r_is_generated = generated_prompt_length + 1 + batch_min_length <= (r_i + cp_min_length)  # bze x 1. is generated or stil in the context+prompt
                    r_is_end_token = last == torch.tensor(tokenizer.encode(stop_token), device=device)  # bze x 1
                    r_is_actually_ending = r_is_generated * r_is_end_token

                    # keep track of response length
                    generated_response_length = generated_response_length + response_not_done * r_is_actually_ending * r_i

                    # if generated, use the generated token as last; otherwise (from context+prompt), copy the orignal token/gumbel_vector
                    if cp_min_length + r_i < output_so_far.shape[1]:
                        last = last * r_is_generated + output_so_far[:, cp_min_length + r_i].unsqueeze(1) * (~r_is_generated)
                    else:
                        last = last

                    if gumbel_softmax:
                        if cp_min_length - batch_min_length + r_i < all_gumbel_vectors.shape[1]:
                            gumbel_vector = gumbel_vector * r_is_generated + all_gumbel_vectors[:, cp_min_length - batch_min_length + r_i, :] * (~r_is_generated)                
                        else:
                            gumbel_vector = gumbel_vector

                    everything_so_far = torch.cat((everything_so_far, last), dim=1)

                    response_not_done = response_not_done * (~r_is_actually_ending)  # to check is we need to stop by summing
                    if torch.sum(response_not_done) == 0:
                        to_break = True
                
                if verbose:
                    print("response: ")
                    for pr_i in range(batch_size):
                        if generated_response_length[pr_i] == 0:
                            print(tokenizer.decode(everything_so_far[pr_i][generated_prompt_length[pr_i].item() + 1 + batch_min_length:].tolist()))  # exceeds the max length set (so generated_prompt_length is 0)
                        else:
                            print(tokenizer.decode(everything_so_far[pr_i][generated_prompt_length[pr_i].item() + 1 + batch_min_length:(generated_response_length[pr_i].item() + 1 + cp_min_length)].tolist()))
                    print("***" * 20)

                extracted_hidden = None
                # hidden: bze, 1, hid_size
                for hb_i in range(batch_size):
                    hb_i_start = generated_prompt_length[hb_i] + 1 + batch_min_length - cp_min_length + 1
                    hb_i_end = generated_response_length[hb_i] + 1 + 1
                    hb_i_hidden = torch.mean(response_hidden[hb_i:hb_i+1, hb_i_start:hb_i_end, :], dim=1)  # 1, hid_size 
                    if extracted_hidden is None:
                        extracted_hidden = hb_i_hidden
                    else:
                        extracted_hidden = torch.cat((extracted_hidden, hb_i_hidden), dim=0)

                prediction = classifier(extracted_hidden)
                label = torch.tensor([class_label], device=device, dtype=torch.long).repeat(batch_size)
                discrim_loss = ce_loss(prediction, label)
                
                if verbose:
                    print("discrim loss: %.6f" % discrim_loss.data.cpu().numpy())
                    loss_per_update += discrim_loss.item()

                if num_of_triggers > 0:
                    # compute gradients
                    discrim_loss.backward()

                    # # debugging: check grad
                    if trigger_format == "token":
                        print_debug = False
                        if print_debug:
                            print("token grad")
                            print(model.ori_trigger_embedding.grad)
                #             print(trigger_embedding)
                        #
                            # debugging
                            print("original trigger embedding")
                            print(model.ori_trigger_embedding)
                    # else:
                    #     print("trigger_key_value_grad")
                    #     # print(model.l_12_key_0.grad.shape)
                    #     # print(model.l_12_key_0.grad)
                    #     print(model.l_12_value.grad.shape)
                    #     # print(model.l_12_value_1.grad)
                    #     print(model.l_12_value)

                    if (i + 1) % gradient_accumulation_steps == 0:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

                        optimizer.step()
                        model.zero_grad()
                        if verbose:
                            print("\n=======update loss: %.6f=======" % (loss_per_update / gradient_accumulation_steps))
                        total_loss += loss_per_update
                        loss_per_update = 0

                    # # debugging
                    # print("new trigger embedding")
                    # print(trigger_embedding)
                    
                if verbose:
                    print("\n\n")

            if mode == "train":
                print("\n\nepoch: %d; data batch-%d/%d; total average loss: %.6f" % (epoch + 1, cond_list_idx + 1, total_data_batches, total_loss / num_iterations))
            
            if check_real_loss:
                data_batch_loss = sample_gpt2(model, tokenizer, all_context_prompts, all_generated_prompt_length, t_pad_token, classifier, class_label)
                epoch_loss += data_batch_loss
                
        if check_real_loss:
            print("epoch %d: loss: %.6f\n\n" % (epoch + 1, epoch_loss / total_data_batches))
            print("<<<<<<>>>>>>\n\n")

In [11]:
def sample_gpt2(model, tokenizer, all_context_prompts, all_generated_prompt_length, 
                t_pad_token, classifier, class_label):
    
    real_loss_all = 0
    padding_token = tokenizer.encode(t_pad_token)[0]
    
    for real_i in range(len(all_context_prompts)):
        cp_length_i = all_generated_prompt_length[real_i]
        min_cp_length_i = torch.min(cp_length_i)  # tensor of float (no shape)
        max_cp_length_i = torch.max(cp_length_i)
        cp_i_text = all_context_prompts[real_i]
        
        cp_i_input_ids = list()
        for cp_i_j in cp_i_text:
            input_ids = tokenizer.encode(cp_i_j)
            padding_len = max_cp_length_i - len(input_ids)
            input_ids = input_ids + [padding_token] * padding_len
            cp_i_input_ids.append(input_ids)
        
        # note: for instance, '\n\n' will be decoded as 628 in gpt2, but it may be encoded as '198 198' if it's connected to another token (e.g. '\n\n-'), which creates a problem in length
        ck_max_cp_length_i = max(len(cp_i_inp) for cp_i_inp in cp_i_input_ids)
        if ck_max_cp_length_i > max_cp_length_i:
            max_cp_length_i = ck_max_cp_length_i
            for cp_i_inp_idx in range(len(cp_i_input_ids)):
                if len(cp_i_input_ids[cp_i_inp_idx]) < ck_max_cp_length_i:
                    cp_i_input_ids[cp_i_inp_idx] += (max_cp_length_i - len(cp_i_input_ids[cp_i_inp_idx])) * [padding_token]
                elif len(cp_i_input_ids[cp_i_inp_idx]) == ck_max_cp_length_i:
                    cp_length_i[cp_i_inp_idx] = ck_max_cp_length_i
        
        cp_i_inputs = torch.tensor(cp_i_input_ids, dtype=torch.long, device=device)
            
        cp_i_batch_size = len(cp_i_text)
        real_response_hidden = None
        real_rp_h_ck = None
        to_break = False
        cp_i_response_not_done = torch.ones(cp_i_batch_size, 1, dtype=torch.uint8, device=device)
        cp_i_generated_response_length = torch.zeros(cp_i_batch_size, 1, dtype=torch.long, device=device)
        
        with torch.no_grad():  # Need this?
            cp_i_so_far = cp_i_inputs[:, :min_cp_length_i]  # bze x (min_cp_length_i - 1)
            cp_i_context_lm_output = model(cp_i_inputs[:, :min_cp_length_i - 1])
            past = cp_i_context_lm_output["past_key_values"]
            last = cp_i_inputs[:, min_cp_length_i - 1: min_cp_length_i]
            
            for rr_i in range(length):  # length + max_cp_length_i - min_cp_length_i?
                lm_cp_i_output = model(last, past_key_values=past)
                cp_i_logits, past, cp_i_hidden = (lm_cp_i_output["logits"],  # bze, cur_seq_len, vocab_size
                                                 lm_cp_i_output["past_key_values"],  # acc_seq_len
                                                 lm_cp_i_output["hidden_states"])  # num_layers + 1, tuple of (bze, cur_seq_len, hid_sze)
                                
                cp_i_logits = penalize_new_line(cp_i_logits)
                
                last, _ = generate_next(cp_i_logits, cp_i_so_far, top_k=top_k, temperature=temperature, 
                                        repetition_penalty=repetition_penalty, sample=sample,
                                        gumbel_softmax=False, gumbel_temperature=gumbel_temperature, detach=False)
                
                cp_i_last_hidden = cp_i_hidden[-1]
                if real_response_hidden is None:
                    real_response_hidden = cp_i_last_hidden
                    real_rp_h_ck = last
                else:
                    real_response_hidden = torch.cat((real_response_hidden, cp_i_last_hidden), dim=1)  # bze, n, hid_size
                    real_rp_h_ck = torch.cat((real_rp_h_ck, last), dim=1)  # bze, n  
                if to_break:
                    break
                    
                # manually assign end token if too long
                if rr_i == length - 1:
                    for rr_b_i in range(cp_i_batch_size):
                        if cp_i_generated_response_length[rr_b_i] == 0:
                            last[rr_b_i] = tokenizer.encode(stop_token)[0]  # encode outputs a list (1 element)
            
                # adjust
                rr_i_is_generated = cp_length_i <= (rr_i + min_cp_length_i)  # bze x 1. is generated or stil in the context+prompt
                rr_i_is_end_token = last == torch.tensor(tokenizer.encode(stop_token), device=device)  # bze x 1
                rr_i_is_actually_ending = rr_i_is_generated * rr_i_is_end_token
                
                # keep track of response length
                cp_i_generated_response_length = cp_i_generated_response_length + cp_i_response_not_done * rr_i_is_actually_ending * rr_i
                
                # if generated, use the generated token as last; otherwise (from context+prompt), copy the orignal token/gumbel_vector
                if min_cp_length_i + rr_i < max_cp_length_i:
                    last = last * rr_i_is_generated + cp_i_inputs[:, min_cp_length_i + rr_i].unsqueeze(1) * (~rr_i_is_generated)
                else:
                    last = last
                
                cp_i_so_far = torch.cat((cp_i_so_far, last), dim=1)
                
                cp_i_response_not_done = cp_i_response_not_done * (~rr_i_is_actually_ending)
                if torch.sum(cp_i_response_not_done) == 0:
                    to_break = True
            
            print("Real responses: ")
            for r_pr_i in range(cp_i_batch_size):
                if cp_i_generated_response_length[r_pr_i] == 0:
                    print(tokenizer.decode(cp_i_so_far[r_pr_i, :].tolist()))  # exceeds the max length set (so generated_prompt_length is 0)
                else:
                    print(tokenizer.decode(cp_i_so_far[r_pr_i, :(cp_i_generated_response_length[r_pr_i].item() + 1 + min_cp_length_i)].tolist()))
            print("***" * 20)
            
            cp_i_extracted_hidden = None
            for r_hb_i in range(cp_i_batch_size):
                r_hb_i_start = cp_length_i[r_hb_i] - min_cp_length_i + 1
                r_hb_i_end = cp_i_generated_response_length[r_hb_i] + 1 + 1
                r_hb_i_hidden = torch.mean(real_response_hidden[r_hb_i:r_hb_i+1, r_hb_i_start:r_hb_i_end, :], dim=1)
                if cp_i_extracted_hidden is None:
                    cp_i_extracted_hidden = r_hb_i_hidden
                else:
                    cp_i_extracted_hidden = torch.cat((cp_i_extracted_hidden, r_hb_i_hidden), dim=0)
            
            prediction = classifier(cp_i_extracted_hidden)
            label = torch.tensor([class_label], device=device, dtype=torch.long).repeat(cp_i_batch_size)
            loss = ce_loss(prediction, label)
            
            real_loss_all += loss.item()
                
            print("loss: %.6f" % loss.item())
            print()
            
    print("all_loss")
    print(real_loss_all / len(all_context_prompts))
    print("\n\n")
    
    return real_loss_all / len(all_context_prompts)


In [12]:
# run
if not multiple_input:
    cond_list = [context_text] * batch_size
else:
    train_cond_list = read_file(train_filename)
    eval_cond_list = read_file(eval_filename)

eval_num_per_cond = 2

# baseline:
# print("=======getting baselines=======")
# generate_prompt_response(model, tokenizer, "eval", lm_bos_output, batch_size, device, class_label,
#                          eval_num_per_cond, 0, gradient_accumulation_steps, sample, False,
#                          False, reset_pos_emb, not_mask_trigger, gumbel_temperature, top_k, temperature, 
#                          repetition_penalty, adam_epsilon, max_grad_norm, length, t_pad_token, stop_token,
#                          1, eval_cond_list, verbose, True,
#                          seed,
#                         )
# print("\n\n\n")
    
# train:
print("=======training=======")
generate_prompt_response(model, tokenizer, "train", lm_bos_output, batch_size, device, class_label,
                         num_iterations, learning_rate, gradient_accumulation_steps, sample, gumbel_softmax,
                         detach, reset_pos_emb, not_mask_trigger, gumbel_temperature, top_k, temperature, 
                         repetition_penalty, adam_epsilon, max_grad_norm, length, t_pad_token, stop_token,
                         num_epochs, train_cond_list, verbose, check_real_loss,
                         seed,
                        )
print("\n\n\n")

# eval:
print("=======evaluation=======")
generate_prompt_response(model, tokenizer, "eval", lm_bos_output, batch_size, device, class_label,
                         eval_num_per_cond, 0, gradient_accumulation_steps, sample, False,
                         False, reset_pos_emb, not_mask_trigger, gumbel_temperature, top_k, temperature, 
                         repetition_penalty, adam_epsilon, max_grad_norm, length, t_pad_token, stop_token,
                         1, eval_cond_list, verbose, True,
                         seed,
                        )
    

&&&&& epoch: 1 &&&&&
=====Epoch: 1; data_batch: 1; Iteration: 1=====
context + prompt
<|endoftext|>My favorite music genre is death metal. In my mind, death metal is all about the heavy death metal.
<|endoftext|>I listen to rap music. I love it.
<|endoftext|>I workout four hours a day. When I get into the gym, I'm not going to lie, it's not a lot.
<|endoftext|>I'm a christian. So I'm not going to pretend to know what the Bible is about, but I do know that if it's about Jesus as the Son of God, then I have a lot of respect for.
************************************************************
response: 
 I love that they are constantly pushing their music forward and trying new things with different types of music.
 And I love it because it is so much more than rap music.
 I usually start in the morning, usually in my car, but I do it in my own apartment.
 And I believe.
************************************************************
discrim loss: 3.162913




=====Epoch: 1; data_batch: 1; Iter

response: 
 I love tortillas and am so excited to be back.
 I'm six.
 I know what i like to cook and I am not good at it.
6" wheels and a 5-speed automatic.
************************************************************
discrim loss: 2.227957






epoch: 1; data batch-3/33; total average loss: 2.654085
Real responses: 
<|endoftext|>I like tacos. But when they're tacos, they're good. And sometimes you need something a little different.
<|endoftext|>I'm four. This is a good time to be alive, but I think I should start by saying that there are still a few things I want to talk about. I'll leave the rest of the story up to you to decide.
<|endoftext|>I also like to cook but i am not very good at it. I have always been an amateur chef. I am not even good at making anything.
<|endoftext|>I own two vintage mustangs. I'm not sure I've ever owned a Mustang with a 4 door engine. I've owned many Mustangs over the years, but this was my first and only.
**********************************************

Real responses: 
<|endoftext|>I am an elementary school teacher. I have been teaching English in a district that is located in the northern part of the state and I have taught a lot of English students, mostly middle schoolers. My students are mostly from the south of the state, and most of them speak English as their native language.
<|endoftext|>I live in a house. There is a lot of furniture around the house, but no furniture is really needed for it. The furniture is a nice addition to the space, and I'm happy with what it is.
<|endoftext|>I'm high maintenance. I want to have a great time in this town and I want to meet all the people I want to meet. I like to do things on my own and it seems like every single person there has a different style of dress and hair, even the girls have different hairstyles.
<|endoftext|>I spent a decade working in the human services field. I was responsible for the creation and management of all human resources departments, and the hiring and training o

KeyboardInterrupt: 