In [None]:
import sys
sys.path.append("..")

import copy
import cProfile
from datasets import load_dataset
import math
import random
from sklearn.metrics import classification_report, accuracy_score, f1_score
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, DataCollatorWithPadding, \
                         AutoModelForSequenceClassification, BertForSequenceClassification

from resilient_nlp.mini_roben import Clustering, ClusterRepRecoverer, ClusterRecovererWithPassthrough
from resilient_nlp.models import BertClassifier
from resilient_nlp.perturbers import ToyPerturber, WordScramblerPerturber
from runner import ExperimentRunner
from word_score_attack import BertWordScoreAttack

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
# imdb = load_dataset('imdb')
imdb = load_dataset('artemis13fowl/imdb')

In [None]:
random.seed(11)
sampled_test_set = imdb['attack_eval_truncated']

In [None]:
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
checkpoint_finetuned = "artemis13fowl/bert-base-uncased-imdb"
model_finetuned = BertForSequenceClassification.from_pretrained(checkpoint_finetuned).to(device)

In [None]:
wsp = WordScramblerPerturber(perturb_prob=0.4, weight_add=1, weight_drop=1, weight_swap=1, weight_split_word=1, weight_merge_words=1)

In [None]:
roben_clustering = Clustering.from_pickle("../vocab100000_ed1.pkl")
roben_recoverer = ClusterRecovererWithPassthrough("cache", roben_clustering)
roben_clustering2 = Clustering.from_pickle("../vocab100000_ed1_gamma0.3.pkl")
roben_recoverer2 = ClusterRecovererWithPassthrough("cache", roben_clustering2)

In [None]:
random.seed(11)
sampled_test_set_perturbed = copy.deepcopy(sampled_test_set)

for row in sampled_test_set_perturbed:
    row['text'] = wsp.perturb([row['text']])[0][0]

In [None]:
random.seed(11)
sampled_test_set_adv = []

for i in range(10):
    test_item = copy.deepcopy(sampled_test_set)

    for row in test_item:
        row['text'] = wsp.perturb([row['text']])[0][0]
    sampled_test_set_adv.append(test_item)

In [None]:
max_sequence_length = 128
batch_size = 32
eval_steps = 100

In [None]:
def standard_model_predict(tokenizer, model, sentences, recoverer, return_pred_tensor):
    if recoverer is not None:
        sentences = [ recoverer.recover(s.lower()) for s in sentences ]
    tokenized = tokenizer(sentences, truncation=True, padding='max_length', max_length=max_sequence_length,
                          return_tensors='pt')
    tokenized = { k: v.to(device) for k, v in tokenized.items() }
    preds = model(**tokenized)
    if return_pred_tensor:
        return preds
    else:
        return torch.argmax(preds.logits, dim=1)

def wrap_standard_model(tokenizer, model, recoverer=None, return_pred_tensor=True):
    return lambda sentences: standard_model_predict(tokenizer, model, sentences, recoverer, return_pred_tensor)

In [None]:
def mltokenizer_model_predict(runner, model, cls_embedding, sep_embedding, pad_embedding, sentences, return_pred_tensor):
    # Truncate and lower case. Truncation is for performance only
    sentences = [ s.lower()[:5*max_sequence_length] for s in sentences]
    embedding = runner.embed(sentences=sentences,
        start_token=cls_embedding, end_token=sep_embedding, pad_token=pad_embedding,
        max_tokens=max_sequence_length)
    preds = model(inputs_embeds=embedding['inputs_embeds'], attention_mask=embedding['attention_mask'])
    if return_pred_tensor:
        return preds
    else:
        return torch.argmax(preds.logits, dim=1)

def wrap_mltokenizer_model(mltokenizer_prefix, tokenizer, model, return_pred_tensor=True):
    filename = "../{}.pth".format(mltokenizer_prefix)
    runner = ExperimentRunner(device, model_filename=filename)
    cf_embedding = model.base_model.embeddings.word_embeddings
    cls_token_id = tokenizer.vocab['[CLS]']
    sep_token_id = tokenizer.vocab['[SEP]']
    pad_token_id = tokenizer.vocab['[PAD]']
    cls_embedding = cf_embedding(torch.tensor([cls_token_id], device=device)).view(-1)
    sep_embedding = cf_embedding(torch.tensor([sep_token_id], device=device)).view(-1)
    pad_embedding = cf_embedding(torch.tensor([pad_token_id], device=device)).view(-1)
    
    return lambda sentences: mltokenizer_model_predict(runner, model, cls_embedding, sep_embedding,
                                                      pad_embedding, sentences, return_pred_tensor)

In [None]:
@torch.no_grad()
def evaluate_model(model, test_set):
    num_batches = math.ceil(len(test_set) / batch_size)
    
    sentences = [ x['text'] for x in test_set ]
    labels = [ x['label'] for x in test_set ]
    pred_batches = []
    
    for i in tqdm(range(num_batches)):
        bs = i * batch_size
        be = bs + batch_size
        
        output = model(sentences[bs:be])
        
        pred_batches.append(torch.argmax(output.logits, dim=1).detach().cpu())
    preds = torch.cat(pred_batches)
    
    print(classification_report(labels, preds, digits=4))    

In [None]:
@torch.no_grad()
def evaluate_model_adv(model, test_sets):
    labels = [ x['label'] for x in test_sets[0] ]
    adv_preds = copy.copy(labels)
    accuracy_list = []
    f1_list = []
    
    for idx in tqdm(range(len(test_sets))):
        test_set = test_sets[idx]
        num_batches = math.ceil(len(test_set) / batch_size)
    
        sentences = [ x['text'] for x in test_set ]
        pred_batches = []
    
        for i in range(num_batches):
            bs = i * batch_size
            be = bs + batch_size
        
            output = model(sentences[bs:be])
        
            pred_batches.append(torch.argmax(output.logits, dim=1).detach().cpu())
        preds = torch.cat(pred_batches)
        
        for i in range(len(adv_preds)):
            if labels[i] == 1.0 and preds[i] == 0.0:
                adv_preds[i] = 0.0
            elif labels[i] == 0.0 and preds[i] == 1.0:
                adv_preds[i] = 1.0

        accuracy_list.append(accuracy_score(labels, adv_preds))
        f1_list.append(f1_score(labels, adv_preds, average='macro'))
    
    print(classification_report(labels, adv_preds, digits=4))    
    
    return accuracy_list, f1_list

In [None]:
@torch.no_grad()
def evaluate_model_word_score(model, test_set):
    attacker = BertWordScoreAttack(
        WordScramblerPerturber(perturb_prob=1, weight_add=1, weight_drop=1, weight_swap=1, weight_split_word=1,
                               weight_merge_words=1),
        "../output/imdb_word_scores.json", model, tokenizer=None, max_sequence_length=max_sequence_length
    )

    res = attacker.attack(test_set, max_tokens_to_query=10, max_tries_per_token=2, mode=0, print_summary=False)
    
    print(classification_report(res['ground_truth'], res['perturbed_preds'], digits=4))    

In [None]:
baseline_model = wrap_standard_model(tokenizer, model_finetuned)

In [None]:
mltok_model = wrap_mltokenizer_model('output/64k_lstm_all_pert_finetuned', tokenizer, model_finetuned)

In [None]:
baseline_roben_model = wrap_standard_model(tokenizer, model_finetuned, roben_recoverer)

In [None]:
baseline_roben_model2 = wrap_standard_model(tokenizer, model_finetuned, roben_recoverer2)

In [None]:
evaluate_model(baseline_model, sampled_test_set)

In [None]:
evaluate_model(baseline_roben_model, sampled_test_set)

In [None]:
evaluate_model(baseline_roben_model2, sampled_test_set)

In [None]:
evaluate_model(mltok_model, sampled_test_set)

In [None]:
evaluate_model(baseline_model, sampled_test_set_perturbed)

In [None]:
evaluate_model(baseline_roben_model, sampled_test_set_perturbed)

In [None]:
evaluate_model(baseline_roben_model2, sampled_test_set_perturbed)

In [None]:
evaluate_model(mltok_model, sampled_test_set_perturbed)

In [None]:
evaluate_model_adv(baseline_model, sampled_test_set_adv)

In [None]:
evaluate_model_adv(baseline_roben_model, sampled_test_set_adv)

In [None]:
evaluate_model_adv(mltok_model, sampled_test_set_adv)