In [2]:
import argparse
from pathlib import Path
from typing import Union, List
import os
os.environ['TRANSFORMERS_CACHE'] = 'cache/'
from transformers import BartForConditionalGeneration, BartTokenizer
from IPython import embed
# from infilling import *
from training.infilling import *
from utils import *
import nltk.tokenize.casual
import torch
import torch.nn.functional as F
import sys
from rewrite import gen_utils
from rewrite import generation_logits_process
import pandas as pd
from tqdm import tqdm
# from .masking import Masker, method1, method1_list, preprocess
from rewrite.masking_v2 import Masker, preprocess
from rewrite.generation_v2 import Infiller
import re
import time
import itertools

In [104]:
# DEBUG Purpose
# parser = argparse.ArgumentParser()
# args = parser.parse_args()
class MyDict:
    def __init__(self, data):
        self.data = data
        for key, value in data.items():
            setattr(self, key, value)

args = MyDict({
    "data_type": "manual",
    "data_path": "dataset/eval/microagressions/val.csv",
    "thresh": 1.2,
    "base_path" : "facebook/bart-base", 
    "antiexpert_path" : "CARC/anti-expert-model/bart-base_1e-06_0_32_jigsaw_full_30/checkpoint-40000",
    "expert_path" : "CARC/expert-model/bart-base_2e-06_0_96_jigsaw_full_30/checkpoint-90000", 
    "tokenizer" : "facebook/bart-base",
    "base_type": "base",
    "expert_type": "expert",
    "antiexpert_type": "antiexpert",
    "seed": 0,

    "alpha_a": 1.5,
    "alpha_e": 4.25,
    "alpha_b": 1.0,
    "temperature": 2.5,
    "sample": None,
    "top_k_gen": 50,
    "rep_penalty": 1.0,
    "filter_p": 1.0,
    "max_length": 128,
    "top_p": 1.0,
    "batch_size": 50,
    "output_dir": "data/dexp_outputs",
    "verbose": True,
    "rep_proc": None,
    "top_k": 0,
})

In [111]:
def get_data(args):
    inputs =  ["I'm surprised you got it done, seeing as you're all girls!", "You are a human", "You are a genius"]
    if args.data_path is not None:
        if "dynabench" in args.data_path:
            df = pd.read_csv(args.data_path)
            df_lab = "hate"
            df_split = "dev"

            if "test" in args.data_type:
                df_split = "test"
            if "train" in args.data_type:
                df_split = "train"
            if "nothate" in args.data_type:
                df_lab = "nothate"
            if "all" in args.data_type:
                inputs = df[(df.split == df_split) & (df.label == df_lab)].text.tolist()
            else:
                df_round = int(args.data_type[-1])
                inputs = df[(df.split == df_split) & (df.label == df_lab)][df["round.base"] == df_round].text.tolist()

        elif "sbf" in args.data_path:
            df = pd.read_csv(args.data_path)
            dataSource = "redditMicroagressions"
            
            if "nonoff" in args.data_type:
                inputs = df[df.dataSource ==dataSource ][df.offensiveYN < 0.5].post.tolist()
            else:
                inputs = df[df.dataSource ==dataSource ][df.offensiveYN >= 0.5].post.tolist()


        elif "microagressions" in args.data_path:
            df = pd.read_csv(args.data_path)
            inputs = [preprocess(s) for s in df.actual_quote.tolist()]

    return inputs

In [112]:
# Get the inputs to rewrite
inputs = get_data(args)

In [113]:
inputs = inputs[:10]

In [114]:
# Specifying the path to save the maksed inputs to. Feel free to change this based on your file name
mask_path = "masked_thresh" + str(args.thresh)
cur_path = os.path.join(args.output_dir, args.data_type, mask_path)
print(cur_path)


data/dexp_outputs/manual/masked_thresh1.2


In [115]:
# Check if we already have the masked inputs if we want to reuse previously masked inputs; args.overwrite_mask means we will regenerate the masked inputs regardless
# Branch: generate new masked versions of the inputs. Feel free to replace this logic 
if not os.path.exists(os.path.join(cur_path, "masked_inputs.txt")) or args.overwrite_mask:       
    try:
        os.makedirs(cur_path)
    except:
        pass

    # Initilaize the Makser object with the parameters from the args
    masker = Masker(
        seed = args.seed, base_path = args.base_path, antiexpert_path = args.antiexpert_path,\
        expert_path = args.expert_path, tokenizer =  args.tokenizer
    )

    # Use the mask function from makser to mask the inputs with a specified threshold
    decoded_masked_inputs = masker.mask(inputs=inputs, thresh=args.thresh)

    # Hacky way to remove bos and eos token from decoded mask inputs and save to text file
    decoded_mask_inputs = [d.replace("<s>", "").replace("</s>", "") for d in decoded_masked_inputs]
    with open(os.path.join(cur_path, "masked_inputs.txt"), "w") as f:
        for d in decoded_mask_inputs:
            f.write(re.sub(r"\s+", " ", d) + "\n") 
# Branch: Reused previously masked inputs instead of 

No GPUs found!
Checking anti-expert local file name: CARC/anti-expert-model/bart-base_1e-06_0_32_jigsaw_full_30/checkpoint-40000


Identifying masks: 100%|██████████| 10/10 [00:13<00:00,  1.38s/it]


In [116]:
# Initialize our Infiller class
rewriter = Infiller(
    seed = args.seed, base_path = args.base_path, antiexpert_path = args.antiexpert_path,\
    expert_path = args.expert_path, base_type = args.base_type, antiexpert_type = args.antiexpert_type, \
    expert_type = args.expert_type, tokenizer = args.tokenizer
)

No GPUs found!
Checking anti-expert local file name: CARC/anti-expert-model/bart-base_1e-06_0_32_jigsaw_full_30/checkpoint-40000


In [117]:
gen_path = "aa" + str(args.alpha_a) + "_ae" + str(args.alpha_e) + "_ab" + str(args.alpha_b) + "_base" + args.base_type[:5] + \
    "_anti" + args.antiexpert_type[:5] + "_expert" + args.expert_type[:5] + "_temp" + str(args.temperature)  + \
    "_sample" + bool2str(args.sample)  + "_topk" + str(args.top_k_gen) + "_reppenalty" + str(args.rep_penalty) + \
    "_filterp" + str(args.filter_p)  + "_maxlength" + str(args.max_length) + "_topp" + str(args.top_p) 

In [118]:
final_path = os.path.join(cur_path, gen_path)

In [119]:
print(final_path)

data/dexp_outputs/manual/masked_thresh1.2/aa1.5_ae4.25_ab1.0_basebase_antiantie_expertexper_temp2.5_sampleF_topk50_reppenalty1.0_filterp1.0_maxlength128_topp1.0


In [121]:
inputs =  ["I'm surprised you got it done, seeing as you're all girls!", "You are a human"]
decoded_mask_inputs = ["I'm surprised you got it done, seeing as you're all<mask>!", 'You are a<mask>']

In [122]:
# Call generate method

outputs, decoded_outputs = rewriter.generate(inputs, decoded_mask_inputs, alpha_a = args.alpha_a, alpha_e = args.alpha_e, alpha_b = args.alpha_b, \
    temperature = args.temperature, verbose = args.verbose, max_length = args.max_length, repetition_penalty= args.rep_penalty, \
    p = args.top_p, filter_p = args.filter_p, k = args.top_k_gen, batch_size = args.batch_size, sample = args.sample)

ORIGINAL 	
["<s>I'm surprised you got it done, seeing as you're all girls!</s>", '<s>You are a human</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>']
	
["<s>I'm surprised you got it done, seeing as you're all<mask>!</s>", '<s>You are a<mask></s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>']
Current outputs
	 ['</s><s>', '</s><s>']
Base

	 ['I', 'It', '"', 'She', 'You']
	 ['You', 'you', 'YOU', 'I', 'Your']
Anti

	 ['I', 'It', '\n', 'You', 'As']
	 ['You', 'you', 'YOU', 'She', 'Your']
Expert

	 ['I', 'It', '\n', 'You', ' ']
	 ['You', 'you', 'She', 'They', 'He']
Ensemble

	 ['I', 'It', 'You', 'Im', 'In']
	 ['You', 'you', 'She', 'They', 'He']
Next token: ['I', 'You']
Current outputs
	 ['</s><s>I', '</s><s>You']
Base

	 ["'m", ' was', "'re", "'d", "'s"]
	 [' are', '.', ' Are', "'re", ' know']
Anti

	 ["'m", "'re", "'s", ' ', "'t"]
	 [' are', ' ', "'re", '</s>', ' is']
Expert

	 ["'m", "'re", "'s", ' ', "'t"]
	 [' are', ' ', "'re", '.', ' is']
Ensemble

	 ["'m", "'re", 

**Rewrite Generate**

In [128]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
seed = args.seed
seed_everything(seed)

# Initalize tokenizer
tokenizer = BartTokenizer.from_pretrained(args.tokenizer)

# Save mask info
mask = tokenizer.mask_token
mask_id = tokenizer.mask_token_id

model_map = {"base": args.base_path, "antiexpert": args.antiexpert_path, "expert": args.expert_path}

base_type = args.base_type
antiexpert_type = args.antiexpert_type
expert_type = args.expert_type

# Initialize models
if args.base_type != "none":
    base_model = BartForConditionalGeneration.from_pretrained(model_map[base_type], forced_bos_token_id = tokenizer.bos_token_id).to(device)

if args.antiexpert_type != "none":
    antiexpert = BartForConditionalGeneration.from_pretrained(model_map[antiexpert_type], forced_bos_token_id = tokenizer.bos_token_id).to(device)

if args.expert_type != "none":
    expert = BartForConditionalGeneration.from_pretrained(model_map[expert_type], forced_bos_token_id = tokenizer.bos_token_id).to(device)

In [129]:
# Set models to eval
if base_model:
    base_model.eval()
if expert:
    expert.eval()
if antiexpert:
    antiexpert.eval()

In [130]:
# Convert inputs to list if they aren't already
if not isinstance(inputs, list):
    inputs = [inputs]
if not isinstance(decoded_mask_inputs, list):
    inputs = [decoded_mask_inputs]

assert len(inputs) == len(decoded_mask_inputs)

# Tokenize - the regular inputs, and the masked inputs
batch = tokenizer(inputs, return_tensors='pt', padding = True).to(device)
batch_masked = tokenizer(decoded_mask_inputs, return_tensors='pt', padding = True).to(device)

In [131]:
# Keep track of which generations aren't finished yet
unfinished_sents = torch.ones(len(inputs), dtype=torch.int32, device=device)    

# Start off our outputs with the eos token id, then the bos token id (match how BART generates)
outputs = torch.Tensor([tokenizer.eos_token_id,tokenizer.bos_token_id]).expand(len(inputs), -1).long().to(device)
start_length = 2

In [132]:
outputs

tensor([[2, 0],
        [2, 0]])

In [134]:
loop_idx = 0
# Substract start length from max length, since we start with 2 tokens
while loop_idx < (args.max_length - start_length):

    # Compute the logits for base, antiexpert, and expert
    # Base model sees the nonmasked inputs, expert and antiexpert see the masked inputs
    base_logits = base_model.forward(input_ids = batch["input_ids"], attention_mask = batch["attention_mask"], decoder_input_ids = outputs).logits
    antiexpert_logits = antiexpert.forward(input_ids = batch_masked["input_ids"], attention_mask = batch_masked["attention_mask"], decoder_input_ids = outputs).logits
    expert_logits = expert.forward(input_ids = batch_masked["input_ids"], attention_mask = batch_masked["attention_mask"], decoder_input_ids = outputs).logits
    
    if args.verbose:
        print("Current outputs\n\t", tokenizer.batch_decode(outputs))
        print("Base\n")
        for idxs in torch.topk(base_logits[:,-1,:], 5, dim=-1).indices:
            print("\t", tokenizer.batch_decode(idxs))
        # print("Base masked", tokenizer.batch_decode(torch.topk(base_logits2[:,-1,:], 10).indices[0]))
        print("Anti\n")
        for idxs in torch.topk(antiexpert_logits[:,-1,:], 5, dim=-1).indices:
            print("\t", tokenizer.batch_decode(idxs))
        print("Expert\n")
        for idxs in torch.topk(expert_logits[:,-1,:], 5, dim=-1).indices:
            print("\t", tokenizer.batch_decode(idxs))
        # print("Expert nonmasked", tokenizer.batch_decode(torch.topk(expert_logits2[:,-1,:], 10).indices[0]))
    
    # eos_predicted = torch.argmax(base_logits[:,-1,:], dim=-1) == tokenizer.eos_token_id
    
    # top_p filtering on the base logits
    if args.filter_p < 1.0:
        base_logits = gen_utils.top_k_top_p_filtering(base_logits, top_p=args.filter_p)

    # Change values of the logits with the temperature
    # Temperature (higher temperature => more likely to sample low probability tokens)
    if args.temperature != 1.0:
        base_logits = base_logits / args.temperature

    # Ensemble the logits and get the next token logits
    ensemble_logits = args.alpha_b * base_logits + args.alpha_e * expert_logits - args.alpha_a * antiexpert_logits
    next_token_logits = ensemble_logits[:, -1, :]

    # Add repetition penalty
    if args.rep_proc is not None:
        next_token_logits = args.rep_proc(outputs, next_token_logits)
    
    # Sample or greedily decode from the next_token_logits
    if args.sample:
        # Temperature (higher temperature => more likely to sample low probability tokens)
        # if temperature != 1.0:
        #     next_token_logits = next_token_logits / temperature
        if args.top_k_gen > 0 or args.top_p < 1.0:
            next_token_logits = gen_utils.top_k_top_p_filtering(next_token_logits, top_k=args.top_k_gen, top_p=args.top_p)
        # Sample from distribution
        probs = F.softmax(next_token_logits, dim=-1)
        next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
    else:
        # Greedy decoding
        next_tokens = torch.argmax(next_token_logits, dim=-1)

    # Get the tokens to add and identify sentences that are done generating
    tokens_to_add = next_tokens * unfinished_sents + tokenizer.pad_token_id * (1 - unfinished_sents)
    eos_in_sents = tokens_to_add == tokenizer.eos_token_id
    unfinished_sents.mul_((~eos_in_sents).int())

    # Update the outputs and the loop index
    outputs = torch.cat((outputs, tokens_to_add.unsqueeze(-1)), dim=-1)
    loop_idx += 1

    if args.verbose:
        print("Ensemble\n")
        for idxs in torch.topk(ensemble_logits[:,-1,:], 5, dim=-1).indices:
            print("\t", tokenizer.batch_decode(idxs))
        print("Next token:", tokenizer.batch_decode(tokens_to_add))

    # Stop generation when there is an EOS in each sentence
    if unfinished_sents.max() == 0:
        break

if args.verbose:
    decodes = tokenizer.batch_decode(outputs, skip_special_tokens = True)
    print("MINE:")
    for d in decodes:
        print("\t", d)
    generated_ids = base_model.generate(batch_masked['input_ids'], max_length = args.max_length, num_beams = 1, do_sample = False)
    output = "\n\t".join(tokenizer.batch_decode(generated_ids, skip_special_tokens = True))
    print("INPUT\n\t", "\n\t".join(inputs)); print("\nbase OUTPUT\n\t", output)
    generated_ids = expert.generate(batch_masked['input_ids'], max_length = args.max_length, num_beams = 1, do_sample = False)
    output = "\n\t".join(tokenizer.batch_decode(generated_ids, skip_special_tokens = True))
    print("\nexpert OUTPUT\n\t", output)
    generated_ids = antiexpert.generate(batch_masked['input_ids'], max_length = args.max_length, num_beams = 1, do_sample = False)
    output = "\n\t".join(tokenizer.batch_decode(generated_ids, skip_special_tokens = True))
    print("\nAnti expert OUTPUT\n\t", output)

Current outputs
	 ['</s><s>', '</s><s>']
Base

	 ['I', 'It', '"', 'She', 'You']
	 ['You', 'you', 'YOU', 'I', 'Your']
Anti

	 ['I', 'It', '\n', 'You', 'As']
	 ['You', 'you', 'YOU', 'She', 'Your']
Expert

	 ['I', 'It', '\n', 'You', ' ']
	 ['You', 'you', 'She', 'They', 'He']
Ensemble

	 ['I', 'It', 'You', 'Im', 'In']
	 ['You', 'you', 'She', 'They', 'He']
Next token: ['I', 'You']
Current outputs
	 ['</s><s>I', '</s><s>You']
Base

	 ["'m", ' was', "'re", "'d", "'s"]
	 [' are', '.', ' Are', "'re", ' know']
Anti

	 ["'m", "'re", "'s", ' ', "'t"]
	 [' are', ' ', "'re", '</s>', ' is']
Expert

	 ["'m", "'re", "'s", ' ', "'t"]
	 [' are', ' ', "'re", '.', ' is']
Ensemble

	 ["'m", "'re", "'s", "'t", "'d"]
	 [' are', ' ', "'re", '.', ' Are']
Next token: ["'m", ' are']
Current outputs
	 ["</s><s>I'm", '</s><s>You are']
Base

	 [' surprised', ' impressed', ' stunned', ' glad', ' surprising']
	 [' a', ' not', ' the', ' human', ' an']
Anti

	 [' surprised', ' ', ' surprising', ' impressed', ' stunned']

In [49]:
loop_idx = 0

In [46]:
# Compute the logits for base, antiexpert, and expert
# Base model sees the nonmasked inputs, expert and antiexpert see the masked inputs
base_logits = base_model.forward(input_ids = batch["input_ids"], attention_mask = batch["attention_mask"], decoder_input_ids = outputs).logits
antiexpert_logits = antiexpert.forward(input_ids = batch_masked["input_ids"], attention_mask = batch_masked["attention_mask"], decoder_input_ids = outputs).logits
expert_logits = expert.forward(input_ids = batch_masked["input_ids"], attention_mask = batch_masked["attention_mask"], decoder_input_ids = outputs).logits

In [55]:
ensemble_logits = args.alpha_b * base_logits + args.alpha_e * expert_logits - args.alpha_a * antiexpert_logits

In [84]:
ensemble_logits.shape

torch.Size([10, 2, 50265])

In [85]:
next_token_logits = ensemble_logits[:, -1, :]

In [86]:
next_token_logits.shape

torch.Size([10, 50265])

In [87]:
# Greedy decoding
next_tokens = torch.argmax(next_token_logits, dim=-1)

In [89]:
# Get the tokens to add and identify sentences that are done generating
tokens_to_add = next_tokens * unfinished_sents + tokenizer.pad_token_id * (1 - unfinished_sents)
eos_in_sents = tokens_to_add == tokenizer.eos_token_id
unfinished_sents.mul_((~eos_in_sents).int())

# Update the outputs and the loop index
outputs = torch.cat((outputs, tokens_to_add.unsqueeze(-1)), dim=-1)
loop_idx += 1

In [90]:
outputs

tensor([[    2,     0, 12967, 15734, 23829,  4832,   108,  1640,   849,  9226,
           354,   571, 19777,  5652,  5526,    29,     2,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1, 12967],
        [    2,     0, 13624,    82,   860,     7,  8439,    47,     6, 12810,
           142,    51,  5281,   110,   476,    93,    45,   142,    51,   218,
            75,   192,    24,     6,    53,   142,    51,   192,    24,     8,
            51,   218,    75,   236,    24,     7,  5152,     4,     2,     1,
             1,     1,     1,     1,     1,     1,     1, 13624],
        [    2,     0, 43041,    75, 17296,    95,   101,    10,   909,     8,
          5100,   936,   116,     2,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,

In [91]:
decodes = tokenizer.batch_decode(outputs, skip_special_tokens = True)
print("MINE:")
for d in decodes:
    print("\t", d)
generated_ids = base_model.generate(batch_masked['input_ids'], max_length = args.max_length, num_beams = 1, do_sample = False)
output = "\n\t".join(tokenizer.batch_decode(generated_ids, skip_special_tokens = True))
print("INPUT\n\t", "\n\t".join(inputs)); print("\nbase OUTPUT\n\t", output)
generated_ids = expert.generate(batch_masked['input_ids'], max_length = args.max_length, num_beams = 1, do_sample = False)
output = "\n\t".join(tokenizer.batch_decode(generated_ids, skip_special_tokens = True))
print("\nexpert OUTPUT\n\t", output)

MINE:
	 ugh exams suck :'( #thisisgayerthanaidsugh
	 Sometimes people try to destroy you, precisely because they recognize your power — not because they don't see it, but because they see it and they don't want it to exist.Sometimes
	 Isn't AIDS just like a black and gay problem?Isn
	 Cute.C
	 You're pretty for a black girl.You
	 Ladies first.L
	 Chinese people eat like savages. They just raise the bowl full of nothing but rice up to their faces and gobble gobble gobble. They do this because they can't see the food through their tiny eyes.Chinese
	 Don't worry. She'll change her style soon.Don
	 I want to steal a black baby. Black babies are just so much cuter than white babies!I
	 Are you really going to let a GIRL beat you?Are
INPUT
	 ugh exams suck :'( #thisisgayerthanaids
	Sometimes people try to destroy you, precisely because they recognize your power — not because they don't see it, but because they see it and they don't want it to exist.
	Isn't AIDS just like a black and gay pro

**Testing Base_model output on nonmasked input**

In [66]:
inputs =  ["I'm surprised you got it done, seeing as you're all girls!", "You are a human"]

In [70]:
batch = tokenizer(inputs, return_tensors='pt', padding = True).to(device)

In [72]:
tokenizer.batch_decode(base_model.generate(input_ids = batch["input_ids"], attention_mask = batch["attention_mask"]))   

["</s><s>I'm surprised you got it done, seeing as you're all girls!</s>",
 '</s><s>You are a human</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>']