In [575]:
from datasets import load_dataset, DatasetDict, load_from_disk
from transformers import AutoTokenizer, DataCollatorWithPadding
import torch
import pandas as pd
import numpy as np
from collections import Counter

In [576]:
imdb = load_from_disk("../data/imdb2")

In [577]:
imdb 

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    dev: Dataset({
        features: ['text', 'label'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 24000
    })
    attack_eval_truncated: Dataset({
        features: ['text', 'label'],
        num_rows: 500
    })
})

## 1. TF IDF method

In [4]:

def truncate(example, max_words=110):
    text = example['text']
    label = example['label']
    
    text = text.split()
    if len(text) > max_words:
        text = text[:max_words]
        
    # if label==0:
    #     count_neg.update(text)
    # elif label==1:
    #     count_pos.update(text)
    
    text = " ".join(text)
    return {'text_truncated': text}



In [5]:
pos_corpus = []
neg_corpus= []

for label, text in zip(imdb['train']['label'],imdb['train']['text']):
    text = text.lower().split()
    if label==0:
        neg_corpus.extend(text)
    elif label==1:
        pos_corpus.extend(text) 

In [6]:
len(neg_corpus), len(pos_corpus)

(2885848, 2958832)

In [7]:
neg_corpus = " ".join(neg_corpus)

In [8]:
pos_corpus = " ".join(pos_corpus)

In [9]:
len(pos_corpus)

16851768

In [10]:
from sklearn.feature_extraction import text
from sklearn.feature_extraction.text import TfidfVectorizer

stop_words = text.ENGLISH_STOP_WORDS.union(["book"])
from sklearn.feature_extraction.text import TfidfVectorizer
corpus = [neg_corpus, pos_corpus]
vectorizer = TfidfVectorizer(stop_words=stop_words)
X = vectorizer.fit_transform(corpus)

In [11]:
X.shape

(2, 74537)

#### TF-IDF does not work

In [578]:
neg_word_scores = dict(zip(vectorizer.get_feature_names(), X.toarray()[0]))
sorted(neg_word_scores.items(), key = lambda x:x[1], reverse=True)[:20]



[('br', 0.7372279087289554),
 ('movie', 0.349719652957164),
 ('film', 0.2691842689007864),
 ('like', 0.15744317429178104),
 ('just', 0.14874535281369225),
 ('good', 0.1039676792783463),
 ('bad', 0.10365954389586972),
 ('really', 0.08772054092958141),
 ('time', 0.08699222093463678),
 ('don', 0.0751150025555397),
 ('story', 0.07295805487820367),
 ('people', 0.06732758107113171),
 ('make', 0.0661370580024722),
 ('plot', 0.05819556882682592),
 ('movies', 0.05715911344940472),
 ('acting', 0.05680895960568134),
 ('way', 0.05587054730450267),
 ('think', 0.051010411953622145),
 ('characters', 0.05042215349616687),
 ('watch', 0.049721845808720105)]

In [579]:
pos_word_scores = dict(zip(vectorizer.get_feature_names(), X.toarray()[1]))
sorted(pos_word_scores.items(), key = lambda x:x[1], reverse=True)[:20]

[('br', 0.7403070234791804),
 ('film', 0.314857907416554),
 ('movie', 0.2868605137389216),
 ('like', 0.13592719594296315),
 ('good', 0.11613956432117781),
 ('just', 0.10756893360353523),
 ('story', 0.10193036076298088),
 ('time', 0.09797584167747211),
 ('great', 0.09651733083604873),
 ('really', 0.08232316347209329),
 ('people', 0.06736215020182244),
 ('best', 0.06495635912318592),
 ('love', 0.06468570762683933),
 ('life', 0.06319712439693298),
 ('way', 0.060701116152847595),
 ('films', 0.05733300864275647),
 ('think', 0.05495728995260291),
 ('movies', 0.05393482874418239),
 ('characters', 0.053513815305421),
 ('character', 0.05286725895303744)]

### PreProcessing

In [580]:
import re                                  # library for regular expression operations
import string                              # for string operations
from nltk.corpus import stopwords          # module for stop words that come with NLTK
from nltk.stem import PorterStemmer        # module for stemming
from nltk import tokenize   # module for tokenizing strings

In [585]:
stopwords_english = stopwords.words('english') 

def preprocess(sent):
    clean_sent = []
    tokenizer = tokenize.TweetTokenizer(preserve_case=False, strip_handles=True, reduce_len=True)
    stemmer = PorterStemmer()
    
    sent = tokenizer.tokenize(sent,)
    #sent = tokenize.wordpunct_tokenize(sent.lower())
    
    for word in sent:
        if (word not in stopwords_english and word not in string.punctuation):  # remove stopwords, punctuation 
            word = stemmer.stem(word) 
            clean_sent.append(word)
    return clean_sent




In [586]:
sent = "Hello! my favorite movie is The Wonderful Life! It is a 10/10 movie. I liked it a lot!!! :) fucking-shit!!"
print(sent)
print(preprocess(sent))

Hello! my favorite movie is The Wonderful Life! It is a 10/10 movie. I liked it a lot!!! :) fucking-shit!!
['hello', 'favorit', 'movi', 'wonder', 'life', '10/10', 'movi', 'like', 'lot', ':)', 'fucking-shit']


In [587]:
print(preprocess(sent))

['hello', 'favorit', 'movi', 'wonder', 'life', '10/10', 'movi', 'like', 'lot', ':)', 'fucking-shit']


## 2. Simple count based method

In [18]:
positive_counts = Counter()
negative_counts = Counter()
total_counts = Counter()

In [19]:
splits = ['train', 'test', 'dev']
for split in splits:
    for label, text in zip(imdb[split]['label'],imdb[split]['text']):
        text = text.lower()
        text = preprocess(text)
        for word in text:
            if label==1:
                positive_counts[word]+=1
            elif label==0:
                negative_counts[word]+=1
            total_counts[word]+=1

In [20]:
positive_counts.most_common()[0:10]

[('br', 97954),
 ('film', 40342),
 ('movie', 36987),
 ('one', 26409),
 ('like', 17269),
 ('good', 14667),
 ('great', 12879),
 ('story', 12675),
 ('time', 12212),
 ('see', 12017)]

In [21]:
negative_counts.most_common()[0:10]

[('br', 103997),
 ('movie', 48870),
 ('film', 36029),
 ('one', 25292),
 ('like', 21975),
 ('even', 15172),
 ('bad', 14493),
 ('good', 14424),
 ('...', 14038),
 ('would', 13666)]

#### compute ratios of +ve to -ve counts

In [22]:
pos_neg_ratios = Counter()

for term,cnt in list(total_counts.most_common()):
    if(cnt > 100):
        pos_neg_ratio = positive_counts[term] / float(negative_counts[term]+1)
        pos_neg_ratios[term] = pos_neg_ratio

In [23]:
pos_neg_ratios.most_common()[0:20]

[('ponyo', 134.0),
 ('7/10', 26.6),
 ('custer', 19.25),
 ('edie', 19.0),
 ('9/10', 17.166666666666668),
 ('felix', 15.8),
 ('matthau', 15.266666666666667),
 ('8/', 15.0),
 ('10/10', 13.08108108108108),
 ('miyazaki', 12.222222222222221),
 ('paulie', 11.666666666666666),
 ('devito', 11.222222222222221),
 ('haines', 9.818181818181818),
 ('flawless', 8.48),
 ('perfection', 7.757575757575758),
 ('superbly', 7.571428571428571),
 ('understated', 7.5),
 ('wonderfully', 7.333333333333333),
 ('lemmon', 7.2),
 ('must-see', 7.068965517241379)]

In [24]:
pos_neg_ratios.most_common()[::-1][0:20]

[('uwe', 0.006578947368421052),
 ('boll', 0.010810810810810811),
 ('2/10', 0.017094017094017096),
 ('3/10', 0.01910828025477707),
 ('1/10', 0.03103448275862069),
 ('4/10', 0.03333333333333333),
 ('stinker', 0.04225352112676056),
 ('3k', 0.048701298701298704),
 ('mst', 0.05309734513274336),
 ('turd', 0.06),
 ('waste', 0.0671527244819647),
 ('unwatchable', 0.07142857142857142),
 ('yawn', 0.072),
 ('seagal', 0.07222222222222222),
 ('incoherent', 0.08444444444444445),
 ('unfunny', 0.08478260869565217),
 ('wasting', 0.09056603773584905),
 ('worst', 0.09135802469135802),
 ('camcorder', 0.09259259259259259),
 ('ugh', 0.096)]

## 3. Conditional Probability Score

In [25]:
# get the sentiment score of the words that have appeared at least 100 times
# V: total words
V = len(total_counts.keys())  
 
# get the number of unique positive and negative words
N_pos = len(positive_counts.keys()) 
N_neg = len(negative_counts.keys()) 
 
def word_loglikelihood(w):
    if w in total_counts:
        p_w_pos = (positive_counts.get(w,0)+1 / (N_pos+V))
        p_w_neg = (negative_counts.get(w,0)+1 / (N_neg+V)) 
        return np.log(p_w_pos/p_w_neg)
    else:
        return(0)


word_scores = {}
 
for word in total_counts.keys():
    if total_counts[word]>=100:
        word_scores[word] = (word_loglikelihood(word))

In [26]:
dict(sorted(word_scores.items(), key=lambda item: item[1], reverse=True)[0:20])

{'ponyo': 17.32426012139519,
 '7/10': 3.349903810668324,
 'edie': 3.1267597683985526,
 'custer': 3.0910419060005108,
 '9/10': 2.9001277096823697,
 'felix': 2.865370035022651,
 'matthau': 2.794664404687918,
 '8/': 2.75684017726172,
 'miyazaki': 2.6210383586514943,
 '10/10': 2.598565865006212,
 'paulie': 2.574518344727968,
 'devito': 2.5356785129027952,
 'haines': 2.3795457696145554,
 'flawless': 2.178532295838215,
 'perfection': 2.0794414317722754,
 'understated': 2.061422869126818,
 'superbly': 2.0607492787521955,
 'lemmon': 2.0253741367470024,
 'wonderfully': 2.006416357656412,
 'must-see': 1.9908053449902765}

In [286]:
dict(sorted(word_scores.items(), key=lambda item: item[1])[0:20])

{'uwe': -5.017275910130278,
 'boll': -4.521786622222644,
 '2/10': -4.064743121347573,
 '3/10': -3.954443075251533,
 '1/10': -3.4692016854051477,
 '4/10': -3.398415420435532,
 'stinker': -3.159361277005897,
 '3k': -3.0187972960003715,
 'mst': -2.93267392982896,
 'turd': -2.8033597625478,
 'waste': -2.700401959213723,
 'unwatchable': -2.63428380650453,
 'seagal': -2.625225709591731,
 'yawn': -2.623056581366551,
 'incoherent': -2.4672068825296187,
 'unfunny': -2.4654884713043437,
 'wasting': -2.3978951232732917,
 'worst': -2.3927633638954666,
 'camcorder': -2.3702433836301124,
 'awful': -2.3371348865016888}

### Eval attack using imdb

In [28]:
import sys
import math
sys.path.append("..")
from resilient_nlp.models import BertClassifier
from resilient_nlp.perturbers import ToyPerturber, WordScramblerPerturber
from transformers import AutoTokenizer, DataCollatorWithPadding, TextClassificationPipeline, \
                         AutoModelForSequenceClassification, BertForSequenceClassification
from sklearn.metrics import classification_report
import torch
from tqdm import tqdm
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [29]:
checkpoint_finetuned = "../output/huggingface/bert-base-uncased-imdb"
model_finetuned = BertForSequenceClassification.from_pretrained(checkpoint_finetuned)

In [30]:
max_sequence_length = 128
batch_size = 32
eval_steps = 100

In [88]:
tokenizer_checkpoint = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint)

In [165]:
Counter(np.array(['S', 'F']))

Counter({'S': 1, 'F': 1})

In [573]:
wsp = WordScramblerPerturber(perturb_prob=1, weight_add=1, weight_drop=1, weight_swap=1, weight_split_word=1, weight_merge_words=1)

def bert_tokenize(sent):
    return bert_tokenizer(sent, truncation=True, padding='max_length', max_length=max_sequence_length,
                          return_tensors='pt')

def get_bert_output(text):
    tokenized = bert_tokenize(text)
    output = model_finetuned(**tokenized)
    logits = output['logits']
    pred = torch.argmax(logits, dim=1).item()
    smax = torch.nn.Softmax()
    probs = smax(logits)
    return pred, np.round(probs[0][pred].item(),4)

def perturb_word(word):
    pass

def compute_attack_stats(results):
    attack_stats = {}
    attack_stats['total_attacks'] = len(results)
    attack_stats['avg_n_queries'] = np.round(results.num_queries.mean(),2)
    attack_stats['successful_attacks'] = results.loc[results['attack_status']=="Successful"].shape[0]
    attack_stats['failed_attacks'] = results.loc[results['attack_status']=="Failed"].shape[0]
    attack_stats['skipped_attacks'] = results.loc[results['attack_status']=="Skipped"].shape[0]
    attack_stats['attack_success_rate'] = 100*np.round(attack_stats['successful_attacks']/(attack_stats['successful_attacks'] + attack_stats['failed_attacks']),2)
    attack_stats['orig_accuracy'] = (attack_stats['total_attacks']  - attack_stats['skipped_attacks']) * 100.0 / (attack_stats['total_attacks'])
    attack_stats['attack_accuracy'] = (attack_stats['failed_attacks']) * 100.0 / (attack_stats['total_attacks'])
    return attack_stats

def word_score_attack( dataset, max_tokens_to_query=-1, max_tries_per_token=1, mode=0  ):
    """
    mode 0: Perserve best unsuccessful perturbation per token. Final attack can perturb utpo max_tokens_to_query tokens. 
    mode 1: Forgets unccessful perturbations. Final Attacks perturbs only 1 token per sample.
    """
    actuals = dataset['label']
    orig_texts = dataset['text']
    n_samples = len(actuals)
    
    orig_preds = np.zeros(n_samples)
    attack_status = np.empty(n_samples, dtype='object')
    perturbed_texts = np.empty(n_samples, dtype='object')
    orig_tokens=  np.empty(n_samples, dtype='object')
    perturbed_tokens =  np.empty(n_samples, dtype='object')
    n_queries =  np.empty(n_samples, dtype='object')
    
    for sample_idx, (orig_text, ground_truth) in enumerate(zip(orig_texts, actuals)):
        print(f"------------- Sample: {sample_idx} ---------------------------------")
        # print(orig_text)
        orig_pred, orig_score = get_bert_output(orig_text)
        orig_preds[sample_idx] = orig_pred
        
        
        if ground_truth!=orig_pred:  #Model has an error. skip_attack
            #print(f'Sample {sample_idx}. Attack Skipped')
            attack_status[sample_idx]='Skipped'
            continue
            
        orig_text = orig_text.lower()
        tokens = preprocess(orig_text)
        token_scores = {token: word_scores[token] if token in word_scores else 0 for token in tokens  }
        
        if max_tokens_to_query==-1:
            max_tokens_to_query = len(token_scores)
        else:
            max_tokens_to_query = min(max_tokens_to_query, len(token_scores))
       
        if orig_pred == 0: # fetch -ve sentiment tokens
            attack_tokens = sorted(token_scores.items(), key=lambda item: item[1])[:max_tokens_to_query]
        else: # fetch +ve sentiment tokens
            attack_tokens = sorted(token_scores.items(), key=lambda item: item[1], reverse=True)[:max_tokens_to_query]
        
        attack_passed = False
        token_idx=0
        sample_query_counter=0    
        text=orig_text
        worst_score = orig_score
        worst_text=orig_text
        
        while token_idx<max_tokens_to_query and not attack_passed:
            #print(f"----- token_idx: {token_idx} --------------")
            #token_idx = np.random.choice(top_n_tokens)
            attack_token = attack_tokens[token_idx][0]
            token_tries_counter = 0
            
            candidates = []
            
            
            for n_try in range(max_tries_per_token):
                perturbed_token = wsp.perturb([attack_token])[0][0]
                perturbed_text = text.replace(attack_token, perturbed_token, 1)
                perturbed_pred, perturbed_score = get_bert_output(perturbed_text)
                
                
                # print(f"----- n_try: {n_try}----")
                print(sample_idx,sample_query_counter, token_tries_counter, 
                      attack_token, perturbed_token, orig_pred, perturbed_pred,  
                      worst_score, perturbed_score)
                print(perturbed_text)
                
                
                sample_query_counter+=1 # increment sample_query_counter
                token_tries_counter+=1 ## increment token_tries_counter
                
                if perturbed_pred != orig_pred: # success
                    attack_passed = True
                    attack_status[sample_idx]='Successful'
                    perturbed_texts[sample_idx] = perturbed_text
                    orig_tokens[sample_idx] = attack_token
                    perturbed_tokens[sample_idx] = perturbed_token    
                    break     
                
                # track best attack (worse_score/worse_text) so far.
                if perturbed_score < worst_score:
                    worst_score = perturbed_score
                    worst_text = perturbed_text
                    
                
                    
            if mode==0:  ## if tries exhausted, update text to worst text. Worst perturbation per toekn are maintained.
                text= worst_text
            token_idx+=1 ## move to next token
                
        n_queries[sample_idx] = sample_query_counter
        
        if attack_passed == False: # attack failed
            #print(f'Sample {sample_idx}. Max tries exhausted')
            attack_status[sample_idx]='Failed'
        

            
    print(classification_report(actuals, orig_preds))
    status_counts = Counter(attack_status)
    print(status_counts)

    results = {'attack_status':attack_status,
               'ground_truth': actuals,
               'orig_prediction':orig_preds,
               'attacked_token':orig_tokens,
               'perturbed_token':perturbed_tokens,
               'num_queries':n_queries,
               'original_text':orig_texts,
               'perturbed_text':perturbed_texts,
              }
    
    results = pd.DataFrame.from_dict(results)
    
    success_rate = np.round(100*status_counts['Successful']/(status_counts['Successful'] + status_counts['Failed']) , 2)
    print(f'Success Rate {success_rate}')
    print(f'Avg Queries: {results.num_queries.mean()}')
    
    return results

In [565]:
text = "rivalry between brothers leads to main story line. navy commander chuck prescott(marshall thompson)has developed the y12 aircraft to test how far man can go up in the atmosphere. his brother, lt. dan prescott(bill edwards), seems to be the best test pilot around and is chosen to go up in the y12. dan of course has a problem with taking orders and is also an over confident dare devil. <br /><br />on dan's second flight, he hits over the 300 miles up comfort zone and his craft passes through a meteor dust storm. returning to earth, dan becomes a moster that resembles 200 pounds of bad asphalt. he also has a"
tokenized = bert_tokenize(text)
output = model_finetuned(**tokenized)
logits = output['logits']
pred = torch.argmax(logits, dim=1).item()
smax= torch.nn.Softmax()
probs = smax(logits)
print(pred, logits)
print(probs[0][pred].item())

1 tensor([[-0.5820,  0.9800]], grad_fn=<AddmmBackward0>)
0.8266438245773315


  import sys


In [574]:
results_20_4_0 = word_score_attack( imdb['attack_eval_truncated'][:3], max_tokens_to_query=10, max_tries_per_token=3, mode=0,)

------------- Sample: 0 ---------------------------------


  del sys.path[0]


0 0 0 simplicity simp licity 1 1 0.9955 0.9955
there must have been some interesting conversations on the set of eagle's wing, with martin sheen straight off apocalypse now co-starred with the actor he replaced on coppola's film, harvey keitel. a real unloved child of a movie, dating back to the last major batch of westerns in 1979-80, it was much reviled at the time for being made by a british studio and director (conveniently ignoring the fact that many of the classic american westerns were directed by european émigrés), which seems a bit of an over-reaction.<br /><br />the plot is simp licity itself, as martin sheen's inexperienced trapper finds himself fighting with sam waterston's nonosyllabic kiowa warrior over the
0 1 1 simplicity simpicity 1 1 0.9955 0.9955
there must have been some interesting conversations on the set of eagle's wing, with martin sheen straight off apocalypse now co-starred with the actor he replaced on coppola's film, harvey keitel. a real unloved child of a 

In [568]:
compute_attack_stats(results_20_4_0)

{'total_attacks': 100,
 'avg_n_queries': 51.6,
 'successful_attacks': 43,
 'failed_attacks': 47,
 'skipped_attacks': 10,
 'attack_success_rate': 48.0,
 'orig_accuracy': 90.0,
 'attack_accuracy': 47.0}

In [496]:
results_20_4_1 = word_score_attack( imdb['attack_eval_truncated'][:20],max_tokens_to_query=10, max_tries_per_token=2, mode=1)


20it [01:23,  4.16s/it]

              precision    recall  f1-score   support

           0       0.91      0.83      0.87        12
           1       0.78      0.88      0.82         8

    accuracy                           0.85        20
   macro avg       0.84      0.85      0.85        20
weighted avg       0.86      0.85      0.85        20

Counter({'Failed': 15, 'Skipped': 3, 'Successful': 2})
Success Rate 11.76
Avg Queries: 18.176470588235293





In [503]:
compute_attack_stats(results_20_4_1)

{'total_attacks': 20,
 'avg_n_queries': 18.18,
 'successful_attacks': 2,
 'failed_attacks': 15,
 'skipped_attacks': 3,
 'attack_success_rate': 12.0,
 'orig_accuracy': 85.0,
 'attack_accuracy': 75.0}