In [37]:
import argparse
from pathlib import Path
from typing import Union, List
import os
from transformers import BartForConditionalGeneration, BartTokenizer
from IPython import embed
from training.infilling import *
from utils import preprocess, detokenize, seed_everything
import nltk.tokenize.casual
import torch
import torch.nn.functional as F
import sys
from rewrite import gen_utils
from rewrite import generation_logits_process
import pandas as pd
import functools
import operator
from tqdm import tqdm
import re
import html
import string

In [38]:
# DEBUG Purpose
# parser = argparse.ArgumentParser()
# args = parser.parse_args()
class MyDict:
    def __init__(self, data):
        self.data = data
        for key, value in data.items():
            setattr(self, key, value)

args = MyDict({
    "thresh": 1.5,
    "base_path" : "facebook/bart-base", 
    "antiexpert_path" : "CARC/anti-expert-model/bart-base_1e-06_0_32_jigsaw_full_30/checkpoint-40000",
    "expert_path" : "CARC/expert-model/bart-base_2e-06_0_96_jigsaw_full_30/checkpoint-90000", 
    "tokenizer" : "facebook/bart-base",
    "seed": 0
})

In [39]:
seed = args.seed
thresh = args.thresh
base_path = args.base_path
antiexpert_path = args.antiexpert_path
expert_path = args.expert_path
tokenizer = args.tokenizer

In [40]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if not torch.cuda.is_available():
    print("No GPUs found!")
else:
    print("Found", str(torch.cuda.device_count()), "GPUS!")

seed = seed
seed_everything(seed)

# Initalize self.tokenizer
tokenizer = BartTokenizer.from_pretrained(tokenizer)

# Initialize models
model = BartForConditionalGeneration.from_pretrained(base_path, forced_bos_token_id = tokenizer.bos_token_id).to(device)
antiexpert = BartForConditionalGeneration.from_pretrained(antiexpert_path, forced_bos_token_id = tokenizer.bos_token_id).to(device)
expert = BartForConditionalGeneration.from_pretrained(expert_path, forced_bos_token_id = tokenizer.bos_token_id).to(device)
model.eval()
antiexpert.eval()
expert.eval()

No GPUs found!


BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50265, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=

In [41]:
inputs =  ["I'm surprised you got it done, seeing as you're all girls!", "You are a human"]

In [42]:
batch = tokenizer(inputs, return_tensors='pt', padding = True).to(device)
cur_labels = ["KL(base || exp)","KL(base || anti)","JS(exp || anti)"]

In [93]:
# Jensen divergence
def js_div(a,b, reduction):
    return 0.5 * F.kl_div(F.log_softmax(a, dim=-1), F.softmax(b,dim=-1), reduction=reduction) + \
         0.5 * F.kl_div(F.log_softmax(b, dim=-1), F.softmax(a,dim=-1), reduction=reduction) 

In [94]:
use_base_model_for_divergence = False
outputs = []

for i in tqdm(range(len(inputs)), desc = "Identifying masks"):

    # Filter text until pad_token
    cur_tok = batch["input_ids"][i]    
    pad = torch.where(cur_tok == tokenizer.pad_token_id)[0]
    if len(pad) > 0:
        pad = pad[0]
        cur_tok = cur_tok[:pad]

    cur_seq = inputs[i]
    tok_map = {}
    casual = nltk.tokenize.casual.casual_tokenize(cur_seq)
    
    tok_map = {}
    old_idx = 1
    cur_idx = 0
    cur_word = casual[0]
    for new_idx, c in enumerate(cur_tok):
        d = tokenizer.decode(c).strip()
        if cur_word.startswith(d):
            cur_word = cur_word.replace(d, "", 1)
            if cur_word == "":
                tok_map[cur_idx] = list(np.arange(old_idx, new_idx+1))
                old_idx = new_idx+1
                cur_idx += 1
                try:
                    cur_word = casual[cur_idx]
                except:
                    break 
    #     break
    # break
    # Default MaRCO implementation: use only the expert and anti-expert and find divergence of prob. distributions on each token in the input
    if not use_base_model_for_divergence:
        # ignore start and end idxs
        ignore_idxs = []

        for c_idx, c in enumerate(casual):
            punc_only = True
            for k in c:
                if k not in string.punctuation:
                    punc_only = False
                    break
            if punc_only:
                ignore_idxs.append(c_idx)

        sum_divs_ea = []
        for j in range(len(casual)):
            new_seq = casual.copy()
            new_seq[j] = tokenizer.mask_token
            new_full_seq = detokenize(new_seq)
            new_full_seq = re.sub(r"\s*<mask>", "<mask>", new_full_seq)

            new_tok = tokenizer(new_full_seq,return_tensors="pt").input_ids.to(device)
            mask_idx = torch.nonzero(new_tok[0] == tokenizer.mask_token_id)

            expert_logits = expert.forward(input_ids = new_tok).logits
            antiexpert_logits = antiexpert.forward(input_ids = new_tok).logits
            divs_ea = js_div(expert_logits,antiexpert_logits, reduction='none').sum(dim = -1)
            all_divs = []
            for cor_idx in mask_idx:
                all_divs.append(divs_ea[0][cor_idx.item()].item())
            sum_divs_ea.append(np.mean(all_divs))

        # delete the ignore idxs
        mean_norm_ea = np.delete(sum_divs_ea, ignore_idxs)
        mean_norm_ea = np.array(mean_norm_ea) / mean_norm_ea.mean()
        above_thresh = np.nonzero(mean_norm_ea >= thresh)[0]

        new_casual=casual.copy()
        for a in above_thresh:
            num_below = (np.array(ignore_idxs <= a)).sum()
            new_casual[a + num_below] = tokenizer.mask_token

        outputs.append(re.sub(r"\s*<mask>", "<mask>",detokenize(new_casual)))

Identifying masks: 100%|██████████| 2/2 [00:01<00:00,  1.02it/s]


In [111]:
if use_base_model_for_divergence:
    outputs = tokenizer.batch_decode(outputs,skip_special_tokens=False)

In [112]:
outputs

["I'm surprised you got it done,<mask> as you're all<mask>!",
 'You are a<mask>']

In [105]:
np.array(ignore_idxs <= a)

array([], dtype=bool)

In [107]:
casual

['You', 'are', 'a', 'human']

In [109]:
new_seq

['You', 'are', 'a', '<mask>']

In [81]:
new_seq = casual.copy()
new_seq[0] = tokenizer.mask_token

In [82]:
new_seq

['<mask>', 'are', 'a', 'human']

In [83]:
new_full_seq = detokenize(new_seq)


In [84]:
new_tok = tokenizer(new_full_seq,return_tensors="pt").input_ids

In [85]:
mask_idx = torch.nonzero(new_tok[0] == tokenizer.mask_token_id)


In [86]:
new_tok

tensor([[    0, 50264,    32,    10,  1050,     2]])

In [87]:
mask_idx

tensor([[1]])

In [88]:
expert_logits = expert.forward(input_ids = new_tok).logits
antiexpert_logits = antiexpert.forward(input_ids = new_tok).logits