In [1]:
import sys
sys.path.append("..")
from resilient_nlp.models import BertClassifier
from datasets import load_from_disk
from transformers import BertModel, AutoTokenizer, DataCollatorWithPadding
from transformers.modeling_outputs import TokenClassifierOutput
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from datasets import load_metric

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

#### Load dataset
For creating same splits as used in training use `create_imdb_data.ipynb`

In [3]:
imdb = load_from_disk("../data/imdb")
imdb

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    dev: Dataset({
        features: ['text', 'label'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 24000
    })
})

#### Tokenizer and Model Settings 

In [4]:
## Model P
max_sequence_length = 128
batch_size = 32
model_dir = "../models/"
checkpoint = "bert-base-cased"

#### Tokenize data

In [5]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [6]:
def tokenize_function(example):
    return tokenizer(example["text"],  truncation=True, padding="max_length", max_length=max_sequence_length)


tokenized_datasets = imdb.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Loading cached processed dataset at ../data/imdb/train/cache-0eb7ba904cdeca8b.arrow
Loading cached processed dataset at ../data/imdb/dev/cache-d60ebe42d5b57c80.arrow
Loading cached processed dataset at ../data/imdb/test/cache-3cdf482d29e393f0.arrow


In [7]:
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")
tokenized_datasets["train"].column_names

['labels', 'input_ids', 'token_type_ids', 'attention_mask']

#### Create inference dataloader

In [8]:
test_dataloader = DataLoader(
    tokenized_datasets["test"].shuffle().select(range(1000)), batch_size=batch_size, collate_fn=data_collator
)

#### Load finetuned model
Set to eval mode

In [9]:
model = BertClassifier(checkpoint='bert-base-cased', n_classes=2).to(device)
model.load_state_dict(torch.load('../models/bert-base-cased-imdb.pt'))
model.eval()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertClassifier(
  (model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

#### Run eval

In [10]:
test_accuracy = load_metric("accuracy")
test_f1 = load_metric("f1")
test_progress_bar = tqdm(range(len(test_dataloader)))

for batch in test_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    labels = batch['labels']
    predictions = torch.argmax(logits, dim=-1)
    test_f1.add_batch(predictions=predictions, references=labels)
    test_accuracy.add_batch(predictions=predictions, references=labels)
    test_progress_bar.update(1)

print(f"Test Accuracy:{test_accuracy.compute()}. Test F1:{test_f1.compute()}.")

  0%|          | 0/32 [00:00<?, ?it/s]

Test Accuracy:{'accuracy': 0.876}. Test F1:{'f1': 0.8807692307692307}.


### Common Sense Attacks

In [28]:
sample = imdb['test'].shuffle().select(range(5))
for i in range(5):
    print('-'*20)
    print(sample[i]['label'], sample[i]['text'])

--------------------
1 This movie is great. Stylish, fun, good acting. I'd seen it described variously as 'Lock, Stock and Two Smoking Muskets' and 'Reservoir Fops', both of which are excellent descriptions. The plot is simple, but it does not detract from the enjoyment. Carlyle is a brilliant ruffian and Miller is an excellent drunken gentleman. The sets and costumes are stunning, and the music and camerawork are refleshingly unusual for a 'costume drama'. Sense and Sensibility it definitely is not!!!!! My recommendation? Go see it, sit back with a huge tub of popcorn and have a damn good time.
--------------------
0 First of all, let me make it clear. This movie is a real piece of garbage, but although it is a real piece of garbage, it is an better piece of garbage than it could have been. It could have sucked big-time, but it doesn't...<br /><br />What this movie didn't have, was for example scary moments, good acting and a good script. It wasn't very entertaining either. But the mo