In [None]:
import sys
sys.path.append("..")

import copy
from datasets import load_dataset
import math
import random
from sklearn.metrics import classification_report
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, DataCollatorWithPadding, \
                         AutoModelForSequenceClassification, BertForSequenceClassification

from resilient_nlp.models import BertClassifier
from resilient_nlp.perturbers import ToyPerturber, WordScramblerPerturber
from lstm import ExperimentRunner

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
imdb = load_dataset('imdb')

In [None]:
random.seed(11)
sampled_test_set = random.choices(imdb['test'], k=200)

In [None]:
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
checkpoint_finetuned = "artemis13fowl/bert-base-uncased-imdb"
model_finetuned = BertForSequenceClassification.from_pretrained(checkpoint_finetuned)

In [None]:
tp = ToyPerturber()
wsp = WordScramblerPerturber(perturb_prob=1, weight_add=1, weight_drop=1, weight_swap=1, weight_split_word=1, weight_merge_words=1)

In [None]:
random.seed(11)
sampled_test_set_perturbed1 = copy.deepcopy(sampled_test_set)

for row in sampled_test_set_perturbed1:
    row['text'] = wsp.perturb([row['text']])[0][0]

In [None]:
random.seed(11)
sampled_test_set_perturbed2 = copy.deepcopy(sampled_test_set)

for row in sampled_test_set_perturbed2:
    row['text'] = tp.perturb([row['text']])[0][0]

In [None]:
max_sequence_length = 128
batch_size = 32
eval_steps = 100

In [None]:
def standard_model_predict(tokenizer, model, sentences):
    tokenized = tokenizer(sentences, truncation=True, padding='max_length', max_length=max_sequence_length,
                          return_tensors='pt')
    preds = model(**tokenized)
    return torch.argmax(preds.logits, dim=1)

def wrap_standard_model(tokenizer, model):
    return lambda sentences: standard_model_predict(tokenizer, model, sentences)

In [None]:
def mltokenizer_model_predict(runner, model, cls_embedding, sep_embedding, pad_embedding, sentences):
    # Truncate and lower case. Truncation is for performance only
    # sentences = [ s.lower()[:1000] for s in sentences]
    # To investigate - truncation gives only a small speedup and tanks accuracy.
    # So for now turning off truncation. This is not unfair, since we limit
    # ourselves to max_sequence_length anyway
    sentences = [ s.lower() for s in sentences]
    embedding = runner.embed(sentences=sentences,
        start_token=cls_embedding, end_token=sep_embedding, pad_token=pad_embedding,
        max_tokens=max_sequence_length)
    preds = model(inputs_embeds=embedding['inputs_embeds'], attention_mask=embedding['attention_mask'])
    return torch.argmax(preds.logits, dim=1)

def wrap_mltokenizer_model(mltokenizer_prefix, tokenizer, model):
    runner = ExperimentRunner(device)
    runner.model.load("../{}.pth".format(mltokenizer_prefix), device)
    runner.char_tokenizer.load_vocab("../{}_vocab.json".format(mltokenizer_prefix))
    cf_embedding = model.base_model.embeddings.word_embeddings
    cls_token_id = tokenizer.vocab['[CLS]']
    sep_token_id = tokenizer.vocab['[SEP]']
    pad_token_id = tokenizer.vocab['[PAD]']
    cls_embedding = cf_embedding(torch.tensor([cls_token_id])).view(-1)
    sep_embedding = cf_embedding(torch.tensor([sep_token_id])).view(-1)
    pad_embedding = cf_embedding(torch.tensor([pad_token_id])).view(-1)
    
    return lambda sentences: mltokenizer_model_predict(runner, model, cls_embedding, sep_embedding,
                                                      pad_embedding, sentences)

In [None]:
def evaluate_model(model, test_set):
    num_batches = math.ceil(len(test_set) / batch_size)
    
    sentences = [ x['text'] for x in test_set ]
    labels = [ x['label'] for x in test_set ]
    pred_batches = []
    
    for i in tqdm(range(num_batches)):
        bs = i * batch_size
        be = bs + batch_size
        
        pred_batches.append(model(sentences[bs:be]))
    preds = torch.cat(pred_batches)
    
    print(classification_report(labels, preds, digits=4))    

In [None]:
baseline_model = wrap_standard_model(tokenizer, model_finetuned)

In [None]:
mltok_model1 = wrap_mltokenizer_model('model4', tokenizer, model_finetuned)
mltok_model2 = wrap_mltokenizer_model('model5', tokenizer, model_finetuned)

In [None]:
evaluate_model(baseline_model, sampled_test_set)

In [None]:
evaluate_model(mltok_model1, sampled_test_set)

In [None]:
evaluate_model(mltok_model2, sampled_test_set)

In [None]:
evaluate_model(baseline_model, sampled_test_set_perturbed1)

In [None]:
evaluate_model(mltok_model1, sampled_test_set_perturbed1)

In [None]:
evaluate_model(mltok_model2, sampled_test_set_perturbed1)

In [None]:
evaluate_model(baseline_model, sampled_test_set_perturbed2)

In [None]:
evaluate_model(mltok_model1, sampled_test_set_perturbed2)

In [None]:
evaluate_model(mltok_model2, sampled_test_set_perturbed2)