In [6]:
import sys
sys.path.append("..")

import copy
from datasets import load_dataset
import math
import random
from sklearn.metrics import classification_report, accuracy_score, f1_score
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, DataCollatorWithPadding, \
                         AutoModelForSequenceClassification, BertForSequenceClassification

from resilient_nlp.mini_roben import Clustering, ClusterRepRecoverer, ClusterRecovererWithPassthrough
from resilient_nlp.models import BertClassifier
from resilient_nlp.perturbers import ToyPerturber, WordScramblerPerturber
from runner import ExperimentRunner

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [2]:
# imdb = load_dataset('imdb')
imdb = load_dataset('artemis13fowl/imdb')

Downloading:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Using custom data configuration artemis13fowl--imdb-f77fd77a6b2e946b


Downloading and preparing dataset imdb/plain_text (download: 40.09 MiB, generated: 63.08 MiB, post-processed: Unknown size, total: 103.17 MiB) to C:\Users\Jasko\.cache\huggingface\datasets\parquet\artemis13fowl--imdb-f77fd77a6b2e946b\0.0.0\0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901...


  0%|          | 0/4 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/43.6k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/816k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/20.9M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/20.3M [00:00<?, ?B/s]

  0%|          | 0/4 [00:00<?, ?it/s]

Dataset parquet downloaded and prepared to C:\Users\Jasko\.cache\huggingface\datasets\parquet\artemis13fowl--imdb-f77fd77a6b2e946b\0.0.0\0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901. Subsequent calls will reuse this data.


  0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
random.seed(11)
sampled_test_set = random.choices(imdb['test'], k=200)

In [10]:
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [11]:
checkpoint_finetuned = "artemis13fowl/bert-base-uncased-imdb"
model_finetuned = BertForSequenceClassification.from_pretrained(checkpoint_finetuned)

In [29]:
wsp = WordScramblerPerturber(perturb_prob=0.4, weight_add=1, weight_drop=1, weight_swap=1, weight_split_word=1, weight_merge_words=1)

In [30]:
roben_clustering = Clustering.from_pickle("../vocab100000_ed1.pkl")
roben_recoverer = ClusterRecovererWithPassthrough("cache", roben_clustering)
roben_clustering2 = Clustering.from_pickle("../vocab100000_ed1_gamma0.3.pkl")
roben_recoverer2 = ClusterRecovererWithPassthrough("cache", roben_clustering2)

In [31]:
random.seed(11)
sampled_test_set_perturbed = copy.deepcopy(sampled_test_set)

for row in sampled_test_set_perturbed:
    row['text'] = wsp.perturb([row['text']])[0][0]

sampled_test_set_perturbed_roben = copy.deepcopy(sampled_test_set_perturbed)
for row in sampled_test_set_perturbed_roben:
    row['text'] = roben_recoverer.recover(row['text'])

In [32]:
random.seed(11)
sampled_test_set_adv = []

for i in range(10):
    test_item = copy.deepcopy(sampled_test_set)

    for row in test_item:
        row['text'] = wsp.perturb([row['text']])[0][0]
    sampled_test_set_adv.append(test_item)

In [33]:
max_sequence_length = 128
batch_size = 32
eval_steps = 100

In [34]:
def standard_model_predict(tokenizer, model, sentences, recoverer):
    if recoverer is not None:
        sentences = [ recoverer.recover(s.lower()) for s in sentences ]
    tokenized = tokenizer(sentences, truncation=True, padding='max_length', max_length=max_sequence_length,
                          return_tensors='pt')
    preds = model(**tokenized)
    return torch.argmax(preds.logits, dim=1)

def wrap_standard_model(tokenizer, model, recoverer=None):
    return lambda sentences: standard_model_predict(tokenizer, model, sentences, recoverer)

In [35]:
def mltokenizer_model_predict(runner, model, cls_embedding, sep_embedding, pad_embedding, sentences):
    # Truncate and lower case. Truncation is for performance only
    # sentences = [ s.lower()[:1000] for s in sentences]
    # To investigate - truncation gives only a small speedup and tanks accuracy.
    # So for now turning off truncation. This is not unfair, since we limit
    # ourselves to max_sequence_length anyway
    sentences = [ s.lower() for s in sentences]
    embedding = runner.embed(sentences=sentences,
        start_token=cls_embedding, end_token=sep_embedding, pad_token=pad_embedding,
        max_tokens=max_sequence_length)
    preds = model(inputs_embeds=embedding['inputs_embeds'], attention_mask=embedding['attention_mask'])
    return torch.argmax(preds.logits, dim=1)

def wrap_mltokenizer_model(mltokenizer_prefix, tokenizer, model):
    filename = "../{}.pth".format(mltokenizer_prefix)
    runner = ExperimentRunner(device, model_filename=filename)
    cf_embedding = model.base_model.embeddings.word_embeddings
    cls_token_id = tokenizer.vocab['[CLS]']
    sep_token_id = tokenizer.vocab['[SEP]']
    pad_token_id = tokenizer.vocab['[PAD]']
    cls_embedding = cf_embedding(torch.tensor([cls_token_id])).view(-1)
    sep_embedding = cf_embedding(torch.tensor([sep_token_id])).view(-1)
    pad_embedding = cf_embedding(torch.tensor([pad_token_id])).view(-1)
    
    return lambda sentences: mltokenizer_model_predict(runner, model, cls_embedding, sep_embedding,
                                                      pad_embedding, sentences)

In [36]:
def evaluate_model(model, test_set):
    num_batches = math.ceil(len(test_set) / batch_size)
    
    sentences = [ x['text'] for x in test_set ]
    labels = [ x['label'] for x in test_set ]
    pred_batches = []
    
    for i in tqdm(range(num_batches)):
        bs = i * batch_size
        be = bs + batch_size
        
        pred_batches.append(model(sentences[bs:be]))
    preds = torch.cat(pred_batches)
    
    print(classification_report(labels, preds, digits=4))    

In [37]:
def evaluate_model_adv(model, test_sets):
    labels = [ x['label'] for x in test_sets[0] ]
    adv_preds = copy.copy(labels)
    accuracy_list = []
    f1_list = []
    
    for idx in tqdm(range(len(test_sets))):
        test_set = test_sets[idx]
        num_batches = math.ceil(len(test_set) / batch_size)
    
        sentences = [ x['text'] for x in test_set ]
        pred_batches = []
    
        for i in range(num_batches):
            bs = i * batch_size
            be = bs + batch_size
        
            pred_batches.append(model(sentences[bs:be]))
        preds = torch.cat(pred_batches)
        
        for i in range(len(adv_preds)):
            if labels[i] == 1.0 and preds[i] == 0.0:
                adv_preds[i] = 0.0
            elif labels[i] == 0.0 and preds[i] == 1.0:
                adv_preds[i] = 1.0

        accuracy_list.append(accuracy_score(labels, adv_preds))
        f1_list.append(f1_score(labels, adv_preds, average='macro'))
    
    print(classification_report(labels, adv_preds, digits=4))    
    
    return accuracy_list, f1_list

In [38]:
baseline_model = wrap_standard_model(tokenizer, model_finetuned)

In [39]:
mltok_model = wrap_mltokenizer_model('output/64k_lstm_all_pert_finetuned', tokenizer, model_finetuned)

In [40]:
baseline_roben_model = wrap_standard_model(tokenizer, model_finetuned, roben_recoverer)

In [41]:
baseline_roben_model2 = wrap_standard_model(tokenizer, model_finetuned, roben_recoverer2)

In [None]:
evaluate_model(baseline_model, sampled_test_set)

In [None]:
evaluate_model(baseline_roben_model, sampled_test_set)

In [None]:
evaluate_model(baseline_roben_model2, sampled_test_set)

In [None]:
evaluate_model(mltok_model, sampled_test_set)

In [42]:
evaluate_model(baseline_model, sampled_test_set_perturbed)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:22<00:00,  3.15s/it]

              precision    recall  f1-score   support

           0     0.7525    0.7917    0.7716        96
           1     0.7980    0.7596    0.7783       104

    accuracy                         0.7750       200
   macro avg     0.7752    0.7756    0.7749       200
weighted avg     0.7761    0.7750    0.7751       200






In [43]:
evaluate_model(baseline_roben_model, sampled_test_set_perturbed)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:17<00:00,  2.54s/it]

              precision    recall  f1-score   support

           0     0.7449    0.7604    0.7526        96
           1     0.7745    0.7596    0.7670       104

    accuracy                         0.7600       200
   macro avg     0.7597    0.7600    0.7598       200
weighted avg     0.7603    0.7600    0.7601       200






In [44]:
evaluate_model(baseline_roben_model2, sampled_test_set_perturbed)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:17<00:00,  2.51s/it]

              precision    recall  f1-score   support

           0     0.7917    0.7917    0.7917        96
           1     0.8077    0.8077    0.8077       104

    accuracy                         0.8000       200
   macro avg     0.7997    0.7997    0.7997       200
weighted avg     0.8000    0.8000    0.8000       200






In [45]:
evaluate_model(mltok_model, sampled_test_set_perturbed)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [03:38<00:00, 31.22s/it]

              precision    recall  f1-score   support

           0     0.8400    0.8750    0.8571        96
           1     0.8800    0.8462    0.8627       104

    accuracy                         0.8600       200
   macro avg     0.8600    0.8606    0.8599       200
weighted avg     0.8608    0.8600    0.8601       200






In [46]:
evaluate_model_adv(baseline_model, sampled_test_set_adv)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [02:48<00:00, 16.90s/it]

              precision    recall  f1-score   support

           0     0.4107    0.4792    0.4423        96
           1     0.4318    0.3654    0.3958       104

    accuracy                         0.4200       200
   macro avg     0.4213    0.4223    0.4191       200
weighted avg     0.4217    0.4200    0.4181       200






([0.775, 0.625, 0.555, 0.51, 0.47, 0.465, 0.455, 0.445, 0.445, 0.42],
 [0.7749493636068114,
  0.62454006157543,
  0.5531118977680702,
  0.5082296266559615,
  0.46808510638297873,
  0.46273003439532023,
  0.45268760513168127,
  0.4433160309937561,
  0.4433160309937561,
  0.41907051282051283])

In [47]:
evaluate_model_adv(baseline_roben_model, sampled_test_set_adv)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [02:55<00:00, 17.60s/it]

              precision    recall  f1-score   support

           0     0.2935    0.2812    0.2872        96
           1     0.3611    0.3750    0.3679       104

    accuracy                         0.3300       200
   macro avg     0.3273    0.3281    0.3276       200
weighted avg     0.3286    0.3300    0.3292       200






([0.76, 0.655, 0.575, 0.505, 0.45, 0.405, 0.375, 0.37, 0.35, 0.33],
 [0.7597838054248824,
  0.6542999574137628,
  0.5741376286981137,
  0.5046904315196998,
  0.45,
  0.4042702310330154,
  0.37460913070669166,
  0.36943248924031635,
  0.34895833333333337,
  0.32757928542753917])

In [50]:
evaluate_model_adv(mltok_model, sampled_test_set_adv)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [34:03<00:00, 204.32s/it]

              precision    recall  f1-score   support

           0     0.6667    0.7500    0.7059        96
           1     0.7391    0.6538    0.6939       104

    accuracy                         0.7000       200
   macro avg     0.7029    0.7019    0.6999       200
weighted avg     0.7043    0.7000    0.6996       200






([0.86, 0.82, 0.775, 0.75, 0.74, 0.74, 0.735, 0.72, 0.715, 0.7],
 [0.8599439775910365,
  0.8199279711884755,
  0.7749493636068114,
  0.74997499749975,
  0.7399739973997399,
  0.7399739973997399,
  0.7349403615813558,
  0.7199999999999999,
  0.7149928748218706,
  0.6998799519807923])