### Adversarial sample generation with code from the TextFooler paper

#### Requirements

1. TextFooler required resources: 

In [1]:
counter_fitting_embeddings_path = 'resources/embeddings/counter-fitted-vectors.txt'
counter_fitting_cos_sim_path = 'resources/cos_sim_counter_fitting.npy'
USE_cache_path = 'scratch/tf_cache'

2. Path of the tuned model to fool and it's training task:

In [2]:
task = 'imdb' # can be imdb or mnli
#model_path = 'resources/models/IMDB/pytorch_model.bin'
#model_path = 'resources/models/IMDB_on_lightning/pytorch_model.bin'
model_path = 'resources/models/co-tuned_IMDB_on_lightning_final_filter/pytorch_model.bin'

#task = 'mnli' # can be imdb or mnli
#model_path = 'resources/models/MNLI/pytorch_model.bin'
#model_path = 'resources/models/MNLI_on_lightning/pytorch_model.bin'
#model_path = 'resources/models/co-tuned_MNLI_on_lightning_final_filter/pytorch_model.bin'

3. Path of the dataset to generate samples from and the name of the output file:

In [3]:
# IMDB:

#dataset_path = 'data/IMDB/imdb_train.txt'
#output_path = 'data/IMDB/generated/imdb_adversarial_samples_for_train'
#dataset_path = 'data/IMDB/imdb_dev.txt'
#output_path = 'data/IMDB/generated/imdb_adversarial_samples_for_dev'
#dataset_path = 'data/IMDB/imdb_test.txt'
#output_path = 'data/IMDB/generated/imdb_adversarial_samples_for_test'

# MNLI:

#dataset_path = 'data/MNLI/original/multinli_1.0_train.txt'
#output_path = 'data/MNLI/generated/mnli_adversarial_samples_for_train'
#dataset_path = 'data/MNLI/original/multinli_1.0_dev_matched.txt'
#output_path = 'data/MNLI/generated/mnli_adversarial_samples_for_dev'
#dataset_path = 'data/MNLI/original/multinli_1.0_dev_mismatched.txt'
#output_path = 'data/MNLI/generated/mnli_adversarial_samples_for_test'

4. The number of samples to process from the dataset and batch size:

In [4]:
# dataset size is 40000 for imdb train and 390702 for MNLI train
# dev and test dataset sizes are 5K for IMDB and 10K for MNLI
#data_size = 390703 #add one for the header row that is skipped in the logic
#batch_size = 32

If we just want to run against the TextFooler sample data (for example, to compare baselines), use these dataset paths instead

In [5]:
# run with just the textfooler sample data
dataset_path = 'data/TextFooler/imdb'
#dataset_path = 'data/TextFooler/mnli_matched'
data_size = 1000
batch_size = 16

We can also run TextFooler against one of our evaluation purtubation models.

Imports

In [6]:
import torch
from torch import nn
from torch.utils.data import TensorDataset, Dataset, SequentialSampler, DataLoader
import pytorch_lightning as pl

from processors import MnliProcessor, ImdbProcessor
from bert_base_model import LightningBertForSequenceClassification
from firebert_fse import FireBERT_FSE
from firebert_fve import FireBERT_FVE

import string
import switch
import numpy as np
import os
import re

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

import tensorflow_hub as hub

from nltk.tokenize.treebank import TreebankWordDetokenizer

from transformers.data.processors.utils import DataProcessor, InputExample, InputFeatures

Instructions for updating:
non-resource variables are not supported in the long term


The text fooler logic that we generalized and encapsulated in a class

In [7]:
"""
The TextFooler algorithm in a class for generating adversarial texts. Adapted to work with a BERT
classifier tuned as a PyTorch Lightning model.

This class was adapted from code by TextFooler at https://github.com/jind11/TextFooler,
a code repository in support of the paper:

Jin, Di, et al. "Is BERT Really Robust? Natural Language Attack on Text Classification and Entailment."
arXiv preprint arXiv:1907.11932 (2019).
"""


class PaperFooler(object):
    def __init__(self,
                 tokenizer,
                 lightning_model,
                 USE_cache_path,
                 counter_fitting_embeddings_path,
                 counter_fitting_cos_sim_path=None,
                 max_seq_length=128):
        super(PaperFooler, self).__init__()
        
        self.tokenizer = tokenizer
        self.model = lightning_model.cuda()
        self.max_seq_length = max_seq_length
        
        # prepare synonym extractor
        # build dictionary via the embedding file
        print("Building vocab...")
        self.idx2word = {}
        self.word2idx = {}

        with open(counter_fitting_embeddings_path, 'r', encoding="utf-8") as ifile:
            for line in ifile:
                word = line.split()[0]
                if word not in self.idx2word:
                    self.idx2word[len(self.idx2word)] = word
                    self.word2idx[word] = len(self.idx2word) - 1

        # for cosine similarity matrix
        print("Building cos sim matrix...")
        if counter_fitting_cos_sim_path:
            # load pre-computed cosine similarity matrix if provided
            print('Load pre-computed cosine similarity matrix from {}'.format(counter_fitting_cos_sim_path))
            self.cos_sim = np.load(counter_fitting_cos_sim_path)
        else:
            # calculate the cosine similarity matrix
            print('Start computing the cosine similarity matrix!')
            embeddings = []
            with open(counter_fitting_embeddings_path, 'r') as ifile:
                for line in ifile:
                    embedding = [float(num) for num in line.strip().split()[1:]]
                    embeddings.append(embedding)
            embeddings = np.array(embeddings)
            product = np.dot(embeddings, embeddings.T)
            norm = np.linalg.norm(embeddings, axis=1, keepdims=True)
            self.cos_sim = product / np.dot(norm, norm.T)
        print("Cos sim import finished!")

        # build the semantic similarity module
        self.sim_predictor = USE(USE_cache_path)

        self.stop_words_set = switch.get_stopwords()
  

    def text_pred(self, text_data, batch_size):
        # Switch the model to eval mode.
        self.model.eval()

        # transform text data into a batch of indices
        batch = self.transform_text(text_data, batch_size)

        probs_all = []
        for input_ids, attention_mask, token_type_ids, ex_idx in batch:
            input_ids = input_ids.cuda()
            attention_mask = attention_mask.cuda()
            token_type_ids = token_type_ids.cuda()
            ex_idx = ex_idx.cuda()

            with torch.no_grad():
                logits = self.model(input_ids=input_ids, attention_mask=attention_mask,
                                    token_type_ids=token_type_ids, example_idx=ex_idx)
                probs = nn.functional.softmax(logits, dim=-1)
                probs_all.append(probs)

        return torch.cat(probs_all, dim=0)


    def _truncate_seq_pair(self, tokens_a, tokens_b, max_length):
        """Truncates a sequence pair in place to the maximum length."""

        # This is a simple heuristic which will always truncate the longer sequence
        # one token at a time. This makes more sense than truncating an equal percent
        # of tokens from each, since if one sequence is very short then each token
        # that's truncated likely contains more information than a longer sequence.
        while True:
            total_length = len(tokens_a) + len(tokens_b)
            if total_length <= max_length:
                break
            if len(tokens_a) > len(tokens_b):
                tokens_a.pop()
            else:
                tokens_b.pop()

    def convert_examples_to_features(self, examples, max_seq_length, tokenizer):
        """Loads a data file into a list of `InputFeature`s."""

        features = []
        for (ex_index, (text_a, text_b)) in enumerate(examples):
            tokens_a = tokenizer.tokenize(' '.join(text_a))

            tokens_b = None
            if text_b:
                tokens_b = tokenizer.tokenize(' '.join(text_b))
                # Modifies `tokens_a` and `tokens_b` in place so that the total
                # length is less than the specified length.
                # Account for [CLS], [SEP], [SEP] with "- 3"
                self._truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
            else:
                # Account for [CLS] and [SEP] with "- 2"
                if len(tokens_a) > max_seq_length - 2:
                    tokens_a = tokens_a[:(max_seq_length - 2)]

            tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
            token_type_ids = [0] * len(tokens)

            if tokens_b:
                tokens += tokens_b + ["[SEP]"]
                token_type_ids += [1] * (len(tokens_b) + 1)

            input_ids = tokenizer.convert_tokens_to_ids(tokens)

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            attention_mask = [1] * len(input_ids)

            # Zero-pad up to the sequence length.
            padding = [0] * (max_seq_length - len(input_ids))
            input_ids += padding
            attention_mask += padding
            token_type_ids += padding

            assert len(input_ids) == max_seq_length
            assert len(attention_mask) == max_seq_length
            assert len(token_type_ids) == max_seq_length

            features.append(
                InputFeatures(input_ids=input_ids,
                              attention_mask=attention_mask,
                              token_type_ids=token_type_ids))
        return features

    def transform_text(self, data, batch_size):
        # transform data into seq of embeddings
        eval_features = self.convert_examples_to_features(list(zip(data['text_a'], data['text_b'])),
                                                          self.max_seq_length, self.tokenizer)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_attention_mask = torch.tensor([f.attention_mask for f in eval_features], dtype=torch.long)
        all_token_type_ids = torch.tensor([f.token_type_ids for f in eval_features], dtype=torch.long)
        all_idxs = torch.tensor([i for i in range(len(all_input_ids))], dtype=torch.long)
        
        eval_data = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_idxs)

        # Run prediction for data sequentially
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=batch_size)

        return eval_dataloader
    
    def pick_most_similar_words_batch(self, src_words, sim_mat, idx2word, ret_count=10, threshold=0.):
        """
        embeddings is a matrix with (d, vocab_size)
        """
        sim_order = np.argsort(-sim_mat[src_words, :])[:, 1:1 + ret_count]
        sim_words, sim_values = [], []
        for idx, src_word in enumerate(src_words):
            sim_value = sim_mat[src_word][sim_order[idx]]
            mask = sim_value >= threshold
            sim_word, sim_value = sim_order[idx][mask], sim_value[mask]
            sim_word = [idx2word[id] for id in sim_word]
            sim_words.append(sim_word)
            sim_values.append(sim_value)
        return sim_words, sim_values

    def pos_filter(self, ori_pos, new_pos_list):
        same = [True if ori_pos == new_pos or (set([ori_pos, new_pos]) <= set(['NOUN', 'VERB']))
                else False
                for new_pos in new_pos_list]
        return same    
         
    def generate_adversarial(self, task, text_a, text_b, true_label, batch_size,
           import_score_threshold=-1., sim_score_threshold=0.7, sim_score_window=15, synonym_num=50):
        # first check the prediction of the original text
        orig_probs = self.text_pred({'text_a': [text_a], 'text_b': [text_b]}, batch_size).squeeze()
        orig_label = torch.argmax(orig_probs)
        orig_prob = orig_probs.max()
        text_ls = text_b if task == 'mnli' else text_a
        if true_label != orig_label:
            return '', 0, orig_label, orig_label, 0
        else:
            len_text = len(text_ls)
            if len_text < sim_score_window:
                sim_score_threshold = 0.1  # shut down the similarity thresholding function
            half_sim_score_window = (sim_score_window - 1) // 2
            num_queries = 1

            # get the pos and verb tense info
            pos_ls = switch.get_pos(text_ls)

            # get importance score
            leave_1_texts = [text_ls[:ii]+['<oov>']+text_ls[min(ii+1, len_text):] for ii in range(len_text)]
            if task == 'mnli':
                leave_1_probs = self.text_pred({'text_a':[text_a]*len_text, 'text_b': leave_1_texts}, batch_size)
            else:
                leave_1_probs = self.text_pred({'text_a':leave_1_texts, 'text_b': [text_b]*len_text}, batch_size)                      
            num_queries += len(leave_1_texts)
            leave_1_probs_argmax = torch.argmax(leave_1_probs, dim=-1)
            import_scores = (orig_prob - leave_1_probs[:, orig_label] + (leave_1_probs_argmax != orig_label).float() * (
                        leave_1_probs.max(dim=-1)[0] - torch.index_select(orig_probs, 0,
                                                                          leave_1_probs_argmax))).data.cpu().numpy()

            # get words to perturb ranked by importance score for word in words_perturb
            words_perturb = []
            for idx, score in sorted(enumerate(import_scores), key=lambda x: x[1], reverse=True):
                if score > import_score_threshold and text_ls[idx] not in self.stop_words_set:
                    words_perturb.append((idx, text_ls[idx]))

            # find synonyms
            words_perturb_idx = [self.word2idx[word] for idx, word in words_perturb if word in self.word2idx]
            #src_words, sim_mat, idx2word, ret_count=10, threshold=0.
            synonym_words, _ = self.pick_most_similar_words_batch(words_perturb_idx, self.cos_sim, 
                                                                  self.idx2word, synonym_num, 0.5)
            synonyms_all = []
            for idx, word in words_perturb:
                if word in self.word2idx:
                    synonyms = synonym_words.pop(0)
                    if synonyms:
                        synonyms_all.append((idx, synonyms))

            # start replacing and attacking
            text_prime = text_ls[:]
            text_cache = text_prime[:]
            num_changed = 0
            for idx, synonyms in synonyms_all:
                new_texts = [text_prime[:idx] + [synonym] + text_prime[min(idx + 1, len_text):] for synonym in synonyms]
                if task == 'mnli':
                    new_probs = self.text_pred({'text_a': [text_a] * len(synonyms), 'text_b': new_texts}, batch_size)
                else:
                    new_probs = self.text_pred({'text_a': new_texts, 'text_b': [text_b] * len(synonyms)}, batch_size)
                
                # compute semantic similarity
                if idx >= half_sim_score_window and len_text - idx - 1 >= half_sim_score_window:
                    text_range_min = idx - half_sim_score_window
                    text_range_max = idx + half_sim_score_window + 1
                elif idx < half_sim_score_window and len_text - idx - 1 >= half_sim_score_window:
                    text_range_min = 0
                    text_range_max = sim_score_window
                elif idx >= half_sim_score_window and len_text - idx - 1 < half_sim_score_window:
                    text_range_min = len_text - sim_score_window
                    text_range_max = len_text
                else:
                    text_range_min = 0
                    text_range_max = len_text
                semantic_sims = \
                    self.sim_predictor.semantic_sim([' '.join(text_cache[text_range_min:text_range_max])] * len(new_texts),
                                           list(map(lambda x: ' '.join(x[text_range_min:text_range_max]), new_texts)))[0]
                
                num_queries += len(new_texts)
                if len(new_probs.shape) < 2:
                    new_probs = new_probs.unsqueeze(0)
                new_probs_mask = (orig_label != torch.argmax(new_probs, dim=-1)).data.cpu().numpy()
                # prevent bad synonyms
                new_probs_mask *= (semantic_sims >= sim_score_threshold)
                # prevent incompatible pos
                synonyms_pos_ls = [switch.get_pos(new_text[max(idx - 4, 0):idx + 5])[min(4, idx)]
                                   if len(new_text) > 10 else switch.get_pos(new_text)[idx] for new_text in new_texts]
                pos_mask = np.array(self.pos_filter(pos_ls[idx], synonyms_pos_ls))
                new_probs_mask *= pos_mask

                if np.sum(new_probs_mask) > 0:
                    text_prime[idx] = synonyms[(new_probs_mask * semantic_sims).argmax()]
                    num_changed += 1
                    break
                else:
                    new_label_probs = new_probs[:, orig_label] + torch.from_numpy(
                        (semantic_sims < sim_score_threshold) + (1 - pos_mask).astype(float)).float().cuda()
                    new_label_prob_min, new_label_prob_argmin = torch.min(new_label_probs, dim=-1)
                    if new_label_prob_min < orig_prob:
                        text_prime[idx] = synonyms[new_label_prob_argmin]
                        num_changed += 1
                text_cache = text_prime[:]
            
            if task == 'mnli':
                new_label = torch.argmax(self.text_pred({'text_a':[text_a], 'text_b': [text_prime]}, batch_size))
            else:
                new_label = torch.argmax(self.text_pred({'text_a':[text_prime], 'text_b': [text_b]}, batch_size))

            if true_label != new_label:
                return TreebankWordDetokenizer().detokenize(text_prime), num_changed, orig_label, new_label, num_queries
            else:
                return '', num_changed, orig_label, new_label, num_queries

Universal Sentence Encoder encapsulated in a class

In [8]:
"""
USE (Universal Sentence Encoder) in a class for determining semantic similarities.

This class was adapted from code by TextFooler at https://github.com/jind11/TextFooler,
a code repository in support of the paper:

Jin, Di, et al. "Is BERT Really Robust? Natural Language Attack on Text Classification and Entailment."
arXiv preprint arXiv:1907.11932 (2019).
"""

class USE(object):
    def __init__(self, cache_path):
        super(USE, self).__init__()
        #config =  tf.compat.v1.ConfigProto()
        #config.gpu_options.allow_growth = True
        #session =  tf.compat.v1.Session(config=config)
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=config)
        #tf.compat.v1.disable_eager_execution()
        
        os.environ['TFHUB_CACHE_DIR'] = cache_path
        module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/3"
        self.embed = hub.Module(module_url)

        self.build_graph()
        self.sess.run([tf.global_variables_initializer(), tf.tables_initializer()])

    def build_graph(self):
        self.sts_input1 = tf.placeholder(tf.string, shape=(None))
        self.sts_input2 = tf.placeholder(tf.string, shape=(None))

        sts_encode1 = tf.nn.l2_normalize(self.embed(self.sts_input1), axis=1)
        sts_encode2 = tf.nn.l2_normalize(self.embed(self.sts_input2), axis=1)
        self.cosine_similarities = tf.reduce_sum(tf.multiply(sts_encode1, sts_encode2), axis=1)
        clip_cosine_similarities = tf.clip_by_value(self.cosine_similarities, -1.0, 1.0)
        self.sim_scores = 1.0 - tf.acos(clip_cosine_similarities)

    def semantic_sim(self, sents1, sents2):
        scores = self.sess.run(
            [self.sim_scores],
            feed_dict={
                self.sts_input1: sents1,
                self.sts_input2: sents2,
            })
        return scores

Routines to read and scrub datasets (mnli and imdb)

In [9]:
"""
Utilities for working with local datasets for processing.

These methods were adapted from code by TextFooler at https://github.com/jind11/TextFooler,
a code repository in support of the paper:

Jin, Di, et al. "Is BERT Really Robust? Natural Language Attack on Text Classification and Entailment."
arXiv preprint arXiv:1907.11932 (2019).
"""

def clean_str(string, TREC=False):
    """
    Tokenization/string cleaning for all datasets except for SST.
    Every dataset is lower cased except for TREC
    """
    string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
    string = re.sub(r"\'s", " \'s", string)
    string = re.sub(r"\'ve", " \'ve", string)
    string = re.sub(r"n\'t", " n\'t", string)
    string = re.sub(r"\'re", " \'re", string)
    string = re.sub(r"\'d", " \'d", string)
    string = re.sub(r"\'ll", " \'ll", string)
    string = re.sub(r",", " , ", string)
    string = re.sub(r"!", " ! ", string)
    string = re.sub(r"\(", " \( ", string)
    string = re.sub(r"\)", " \) ", string)
    string = re.sub(r"\?", " \? ", string)
    string = re.sub(r"\s{2,}", " ", string)
    return string.strip() if TREC else string.strip().lower()

def read_corpus(path, data_size, clean=True, MR=True, encoding='utf8', shuffle=False, lower=True):
    data = []
    labels = []
    empty = []
    with open(path, encoding=encoding) as fin:
        for idx, line in enumerate(fin):
            if idx >= data_size:
                break
            if MR:
                label, sep, text = line.partition(' ')
                label = int(label)
            else:
                label, sep, text = line.partition(',')
                label = int(label) - 1
            if clean:
                text = clean_str(text.strip()) if clean else text.strip()
            if lower:
                text = text.lower()
            labels.append(label)
            data.append(text.split())
            empty.append(None)

    if shuffle:
        perm = list(range(len(data)))
        random.shuffle(perm)
        data = [data[i] for i in perm]
        labels = [labels[i] for i in perm]

    return {"text_a": data,
            "text_b": empty,
            "label": labels}

def read_data(filepath, data_size, lowercase=False, ignore_punctuation=False, stopwords=[]):
    """
    Read the premises, hypotheses and labels from some NLI dataset's
    file and return them in a dictionary. The file should be in the same
    form as SNLI's .txt files.

    Args:
        filepath: The path to a file containing some premises, hypotheses
            and labels that must be read. The file should be formatted in
            the same way as the SNLI (and MultiNLI) dataset.

    Returns:
        A dictionary containing three lists, one for the premises, one for
        the hypotheses, and one for the labels in the input data.
    """
    
    labeldict = {"contradiction": 0,
                  "entailment": 1,
                  "neutral": 2}

    with open(filepath, 'r', encoding='utf8') as input_data:
        premises, hypotheses, labels = [], [], []

        # Translation tables to remove punctuation from strings.
        punct_table = str.maketrans({key: ' '
                                     for key in string.punctuation})

        for idx, line in enumerate(input_data):
            if idx >= data_size:
                break

            line = line.strip().split('\t')

            # Ignore sentences that have no gold label.
            if line[0] == '-':
                continue
            
            # skip the header row (if there is one)
            if line[0] == 'gold_label':
                continue

            premise = line[1]
            hypothesis = line[2]

            if lowercase:
                premise = premise.lower()
                hypothesis = hypothesis.lower()

            if ignore_punctuation:
                premise = premise.translate(punct_table)
                hypothesis = hypothesis.translate(punct_table)
                
            # strip ('s and )'s
            premise = premise.translate({ord(i):None for i in '()'})
            hypothesis = hypothesis.translate({ord(i):None for i in '()'})
            
            # Each premise and hypothesis is split into a list of words.
            premises.append([w for w in premise.rstrip().split()
                             if w not in stopwords])
            hypotheses.append([w for w in hypothesis.rstrip().split()
                             if w not in stopwords])
            labels.append(labeldict[line[0]])

        return {"text_a": premises,
                "text_b": hypotheses,
                "label": labels}

Routine to generate samples from the fooler and dataset path

In [10]:
import time

def elapsed_time():
    global t_start

    t_now = time.time()
    t = t_now-t_start
    t_start = t_now
    return t

def generate_samples(task, fooler, dataset_path, data_size, batch_size):
    
    # get data to attack, fetch first [data_size] data samples for adversarial attacking
    
    if task == 'mnli':
        dataloader = read_data
        labeldict = {0: "contradiction",
                     1: "entailment",
                     2:  "neutral"}
    else:
        #imdb
        dataloader = read_corpus
        labeldict = {0: 0, 1: 1}

    data = dataloader(dataset_path, data_size)
    print("Data import finished!")
        
    test_examples = [InputExample(i, TreebankWordDetokenizer().detokenize(a), 
                                  TreebankWordDetokenizer().detokenize(b) if b is not None else None,
                                  labeldict[label]) \
                     for i,(a,b,label) in \
                     enumerate(zip(data['text_a'], data['text_b'], data['label']))]
    

    fooler.model.set_test_dataset(None, examples=test_examples)
    
    orig_failures = 0.
    adv_failures = 0.
    changed_rates = []
    nums_queries = []
    
    adv_examples=[]
    
    for idx, text_a in enumerate(data['text_a']):
        if idx % 10 == 0:
            print('elapsed time: {}s - {} samples out of {} have been finished!'.format(
                round(elapsed_time(),2), idx, data_size))

            message = 'accuracy: {:.3f}%, adv accuracy: {:.3f}%, ' \
              'avg changed rate: {:.3f}%, num of queries: {:.1f}\n'.format((1-orig_failures/(idx+1))*100,
                                                                 (1-adv_failures/(idx+1))*100,
                                                                 np.mean(changed_rates)*100,
                                                                 np.mean(nums_queries))
            print(message)

        text_b, true_label = data['text_b'][idx], data['label'][idx]
                    
        new_text, num_changed, orig_label, \
            new_label, num_queries  = fooler.generate_adversarial(task, text_a, text_b, true_label,
                                                                  batch_size=batch_size,)
        if true_label != orig_label:
            orig_failures += 1
        else:
            nums_queries.append(num_queries)
        if true_label != new_label:
            adv_failures += 1
            if new_text != '':
                adv_examples.append(
                    InputExample(guid=idx,
                                 text_a=TreebankWordDetokenizer().detokenize(text_a) if task == 'mnli' else new_text,
                                 text_b=new_text if task == 'mnli' else text_b,
                                 label=labeldict[true_label] if task == 'mnli' else true_label))
            
        changed_rate = 1.0 * num_changed / (len(text_b) if task == 'mnli' else len(text_a))
        if true_label == orig_label and true_label != new_label:
            changed_rates.append(changed_rate)

            
    print('elapsed time: {}s - {} samples out of {} have been finished!'.format(
        round(elapsed_time(),2), idx+1, data_size))            
            
    message = 'For target model {}: original accuracy: {:.3f}%, adv accuracy: {:.3f}%, ' \
              'avg changed rate: {:.3f}%, num of queries: {:.1f}\n'.format(task,
                                                                 (1-orig_failures/(idx+1))*100,
                                                                 (1-adv_failures/(idx+1))*100,
                                                                 np.mean(changed_rates)*100,
                                                                 np.mean(nums_queries))
    print(message)

    return adv_examples

Routines to create and save the samples

In [11]:
use_eval_model = False

eval_model_type = 'FUSE' # can be FUSE or FIVE


#FIVE best MNLI params as of 4/14/2020
# eval_model_hparams =  {'batch_size': 8, 'use_USE': False, 'stop_words': True, 'perturb_words': 1, 
#                          'verbose': False, 'vote_avg_logits': True, 'std': 8.139999999999995, 'vector_count': 8}

eval_model_hparams =  {'use_USE':True, 'USE_method':"filter", 'USE_multiplier':17, 'stop_words':True, 'perturb_words':3,
            'candidates_per_word':13, 'total_alternatives':12, 'match_pos':True, 'batch_size':1,'verbose':False, 
            'vote_avg_logits':True}


In [12]:
def create_examples(bert_model, task, dataset_path, data_size, batch_size, max_seq_length,
                    counter_fitting_embeddings_path, counter_fitting_cos_sim_path, USE_cache_path):
    
    
    print("Building TextFooler...")
    fooler = PaperFooler(bert_model.tokenizer,
                         bert_model,
                         USE_cache_path,
                         counter_fitting_embeddings_path,
                         counter_fitting_cos_sim_path,
                         max_seq_length = max_seq_length)
    print("TextFooler built!")

    
    return generate_samples(task, fooler, dataset_path, data_size, batch_size)


In [13]:
def save_examples(bert_model, adv_examples, output_path):
    
    features = bert_model.get_processor()._create_features(adv_examples)
    
    torch.save(features, output_path)
    
    with open(output_path + '.txt', 'w') as output:
        for row in adv_examples:
            output.write(row.label + '\t' +
                         row.text_a + '\t' + 
                         row.text_b + '\n') if task == 'mnli' else \
            output.write(str(row.label) + ' ' + 
                         row.text_a + '\n')
            
    import pickle
    with open(output_path + '.pkl', "wb") as f:
        pickle.dump(adv_examples, f)
    
    print('\nPyTorch-ready InputFeature file saved in {}'.format(output_path))
    print('Pickled InputExample file saved in {}.pkl'.format(output_path))
    print('Raw text saved in {}.txt'.format(output_path))
    

In [14]:
processor = MnliProcessor() if task == 'mnli' else ImdbProcessor()

if use_eval_model:
    bert_model = (FireBERT_FSE(load_from=model_path, 
                               processor=processor, 
                               hparams=eval_model_hparams) if eval_model_type == 'FUSE' else
                  FireBERT_FVE(load_from=model_path, 
                               processor=processor, 
                               hparams=eval_model_hparams))
else:    
    bert_model = LightningBertForSequenceClassification(
        load_from=model_path, 
        processor=processor)

In [15]:
t_start = time.time()

In [16]:
adv_examples = create_examples(bert_model, task, dataset_path, data_size, batch_size, 128 if task =='mnli' else 256, 
                               counter_fitting_embeddings_path, counter_fitting_cos_sim_path, USE_cache_path)

Building TextFooler...
Building vocab...
Building cos sim matrix...
Load pre-computed cosine similarity matrix from resources/cos_sim_counter_fitting.npy


INFO:absl:Using scratch/tf_cache to cache modules.


Cos sim import finished!
INFO:tensorflow:Saver not created because there are no variables in the graph to restore


INFO:tensorflow:Saver not created because there are no variables in the graph to restore


INFO:tensorflow:Saver not created because there are no variables in the graph to restore


INFO:tensorflow:Saver not created because there are no variables in the graph to restore


TextFooler built!
Data import finished!
elapsed time: 14.77s - 0 samples out of 1000 have been finished!
accuracy: 100.000%, adv accuracy: 100.000%, avg changed rate: nan%, num of queries: nan



  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


elapsed time: 115.93s - 10 samples out of 1000 have been finished!
accuracy: 90.909%, adv accuracy: 18.182%, avg changed rate: 13.232%, num of queries: 1305.8

elapsed time: 90.54s - 20 samples out of 1000 have been finished!
accuracy: 95.238%, adv accuracy: 9.524%, avg changed rate: 9.344%, num of queries: 1067.3

elapsed time: 157.43s - 30 samples out of 1000 have been finished!
accuracy: 96.774%, adv accuracy: 6.452%, avg changed rate: 8.907%, num of queries: 1153.0

elapsed time: 79.96s - 40 samples out of 1000 have been finished!
accuracy: 97.561%, adv accuracy: 4.878%, avg changed rate: 8.754%, num of queries: 1060.2

elapsed time: 123.51s - 50 samples out of 1000 have been finished!
accuracy: 96.078%, adv accuracy: 5.882%, avg changed rate: 8.147%, num of queries: 1075.1

elapsed time: 84.3s - 60 samples out of 1000 have been finished!
accuracy: 93.443%, adv accuracy: 4.918%, avg changed rate: 7.693%, num of queries: 1046.3

elapsed time: 82.34s - 70 samples out of 1000 have bee

elapsed time: 85.52s - 530 samples out of 1000 have been finished!
accuracy: 94.727%, adv accuracy: 2.825%, avg changed rate: 9.597%, num of queries: 1121.9

elapsed time: 126.56s - 540 samples out of 1000 have been finished!
accuracy: 94.824%, adv accuracy: 2.773%, avg changed rate: 9.566%, num of queries: 1121.1

elapsed time: 64.37s - 550 samples out of 1000 have been finished!
accuracy: 94.737%, adv accuracy: 2.722%, avg changed rate: 9.570%, num of queries: 1115.0

elapsed time: 163.43s - 560 samples out of 1000 have been finished!
accuracy: 94.831%, adv accuracy: 2.852%, avg changed rate: 9.583%, num of queries: 1120.7

elapsed time: 71.1s - 570 samples out of 1000 have been finished!
accuracy: 94.746%, adv accuracy: 2.802%, avg changed rate: 9.619%, num of queries: 1115.8

elapsed time: 81.1s - 580 samples out of 1000 have been finished!
accuracy: 94.836%, adv accuracy: 2.754%, avg changed rate: 9.631%, num of queries: 1110.8

elapsed time: 69.23s - 590 samples out of 1000 have 

Uncomment if you want to see the samples

In [17]:
#adv_examples

Comment or uncomment depending on whether you want to save them

In [18]:
#save_examples(bert_model, adv_examples, output_path)