## Fine-tuning BERT from HuggingFace Transformers on the IMDB Dataset
> __Note__: In this tutorial we use a smaller version of BERT, called __DistilBert__, that is easier and faster to train.

In [1]:
import copy
import pickle

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

from podium.datasets import IMDB, Iterator
from podium.models import Experiment
from podium.models.impl.pytorch import TorchTrainer, TorchModel
from podium.pipeline import Pipeline
from podium.storage import Field, LabelField, Vocab

### Defining the fields

* text - applies `BertTokenizer` to the instance data. We don't store `attention_mask` to reduce the memory footprint of the dataset, instead we create it ourselves on the fly.
* label - stores binary labels that represent the sentiment of an instance.

In [2]:
def create_fields():
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    
    def text_to_tokens(string):
        input_ids = tokenizer(string,
                              max_length=128,
                              padding=False,
                              truncation=True,
                              return_attention_mask=False
                             )['input_ids']
        
        return tokenizer.convert_ids_to_tokens(input_ids)
        
    text = Field(name='text',
                 tokenizer=text_to_tokens,
                 custom_numericalize=tokenizer.convert_tokens_to_ids,
                 padding_token=0)
    
    label = LabelField(name='label', vocab=Vocab(specials=()))
    
    return {
        'text': text,
        'label': label
    } 

### Wrapping the model

In this tutorial we will be using `DistilBertForSequenceClassification`. This model has additional layers on top of the base model, `DistilBertModel`, to perform classification and this layers are randomly initialized. So we define a function that will return a copy of the same instance of the model each time it gets called. We need this later to perform the comparison between the original, pretrained model and the model that is fine-tuned on a down-stream task.

In [3]:
def bert_initializer():
    model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', 
                                                                return_dict=True)
    def get_bert_model():
        return copy.deepcopy(model)
    
    return get_bert_model

get_bert_model = bert_initializer()

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classi

`DistilBertForSequenceClassification` is a standard PyTorch `Module` so it can be easily wrapped in another `Module` that has a proper interface to Podium - the model has to return a dictionary that has a key `pred` that points to the model predictions. As mentioned earlier, we are creating the attention mask ourselves so this is a good place to do it. 

In [4]:
class BertModelWrapper(nn.Module):
    
    def __init__(self, **kwargs):
        super().__init__()
        self.model = get_bert_model()
        
    def forward(self, x):
        attention_mask = (x != 0).long()
        return_dict = self.model(x, attention_mask)
        return_dict['pred'] = return_dict['logits']
        return return_dict

### Loading the IMDB dataset

In [5]:
fields = create_fields()
imdb_train, imdb_test = IMDB.get_dataset_splits(fields)

100%|██████████| 84.1M/84.1M [00:08<00:00, 9.56MB/s]


#### Check the loaded data 

> __Note__: `None` values in the output are for caching purposes.

In [6]:
print(imdb_train[0].text)
print(imdb_train[0].label)

(None, ['[CLS]', 'dominic', '##k', '(', 'nicky', ')', 'luciano', 'wears', 'a', "'", 'hulk', "'", 't', '-', 'shirt', 'and', 'tr', '##udge', '##s', 'off', 'everyday', 'to', 'perform', 'his', 'duties', 'as', 'a', 'garbage', 'man', '.', 'he', 'uses', 'his', 'physical', 'power', 'in', 'picking', 'up', 'other', "'", 's', 'trash', 'and', 'hauling', 'it', 'to', 'the', 'town', 'dump', '.', 'he', 'reads', 'comic', '-', 'book', 'hero', 'stories', 'and', 'loves', 'wrestlers', 'and', 'wrestling', ',', 'going', 'to', 'wrestlemania', 'with', 'his', 'twin', 'brother', 'eugene', 'on', 'their', 'birthday', 'is', 'a', 'yearly', 'tradition', '.', 'he', 'talks', 'kindly', 'with', 'the', 'many', 'people', 'he', 'comes', 'in', 'contact', 'with', 'during', 'his', 'day', '.', 'he', 'reads', 'comic', 'books', ',', 'which', 'he', 'finds', 'in', 'the', 'trash', ',', 'with', 'a', 'young', 'boy', 'who', 'he', 'often', 'passes', 'by', 'while', 'on', 'the', 'garbage', 'route', '.', 'unfortunately', ',', 'dominic', '#

### Setting up the Podium Experiment

To fine-tune the model, we define an `Experiment`.

In [7]:
model_config = {
    'lr': 1e-5,
    'clip': float('inf'), # disable gradient clipping
    'num_epochs': 5,
}
model_config['num_classes'] = len(fields['label'].vocab)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

iterator = Iterator(batch_size=32)
trainer = TorchTrainer(model_config['num_epochs'], device, iterator, imdb_test)

# here we have to swap axes to nullify the effect of swapping axes afterwards
# because we work with the batch-first model (we should add this option to Podium!!!)
feature_transformer = lambda feature_batch: feature_batch[0].astype(np.int64).swapaxes(0, 1)
label_transformer = lambda label_batch: label_batch[0].astype(np.int64)

experiment = Experiment(TorchModel, 
                        trainer=trainer, 
                        feature_transformer=feature_transformer,
                        label_transform_fn=label_transformer)

experiment.fit(imdb_train,  
               model_kwargs={
                   'model_class': BertModelWrapper,  
                   'criterion': nn.CrossEntropyLoss(),  
                   'optimizer': optim.AdamW,
                   'device': device,
                   **model_config
               })

[Batch]: 781 in 0.10555 seconds, loss=0.43808
Total time for train epoch: 205.91309213638306
[Valid]: 781 in 0.00576 seconds, loss=0.38564
Total time for valid epoch: 63.414066314697266
[Batch]: 781 in 0.10557 seconds, loss=0.39357
Total time for train epoch: 199.31143307685852
[Valid]: 781 in 0.00574 seconds, loss=0.03775
Total time for valid epoch: 60.488966941833496
[Batch]: 781 in 0.08967 seconds, loss=0.17306
Total time for train epoch: 196.89669013023376
[Valid]: 781 in 0.00577 seconds, loss=0.15879
Total time for valid epoch: 60.47404861450195
[Batch]: 781 in 0.10608 seconds, loss=0.04704
Total time for train epoch: 199.86573767662048
[Valid]: 781 in 0.00578 seconds, loss=0.00348
Total time for valid epoch: 60.53778100013733
[Batch]: 781 in 0.10526 seconds, loss=0.00671
Total time for train epoch: 200.03476238250732
[Valid]: 781 in 0.00653 seconds, loss=0.74549
Total time for valid epoch: 60.55412006378174


### Saving and loading the fitted model

#### Utilities for saving/loading the model 

In [9]:
def save_model(model, file_path):
    with open(file_path, 'wb') as f:
        pickle.dump(model, f)

def load_model(file_path):
    with open(file_path, 'rb') as f:
        model = pickle.load(f)
    return model

In [10]:
fitted_model = experiment.model

model_file = 'bert_model.pt'
save_model(fitted_model, model_file)
loaded_model = load_model(model_file)

### Testing and making predictions on raw data

#### Utilities for making predictions with the raw model on the parsed dataset

In [11]:
cast_to_torch_transformer = lambda t: torch.from_numpy(t[0].astype(np.int64)).to(device)

@torch.no_grad()
def make_predictions(raw_model, dataset, batch_size=32):
    raw_model.eval()
    
    def predict(batch):
        predictions = raw_model(cast_to_torch_transformer(batch))['pred']
        return predictions.cpu().numpy()

    iterator = Iterator(batch_size=batch_size, 
                        shuffle=False)
    
    predictions = []
    for x_batch, _ in iterator(dataset):
        batch_prediction = predict(x_batch)
        predictions.append(batch_prediction)
        
    return np.concatenate(predictions)

#### Model comparioson: pretrained BERT vs pretrained + fine-tuned BERT

In [12]:
_, y_true = imdb_test.batch()
y_true = y_true[0].ravel()

Evaluation of the pretrained BERT

In [13]:
predictions = make_predictions(BertModelWrapper().to(device), imdb_test)
y_pred = predictions.argmax(axis=1)

print('accuracy score:', accuracy_score(y_true, y_pred))
print('precision score:', precision_score(y_true, y_pred, zero_division=0))
print('recall score:', recall_score(y_true, y_pred, zero_division=0))
print('f1 score:', f1_score(y_true, y_pred, zero_division=0))

accuracy score: 0.5
precision score: 0.0
recall score: 0.0
f1 score: 0.0


Evaluation of the pretrained + fine-tuned BERT

In [14]:
loaded_model_raw = loaded_model.model
predictions = make_predictions(loaded_model_raw, imdb_test)
y_pred = predictions.argmax(axis=1)

print('accuracy score:', accuracy_score(y_true, y_pred))
print('precision score:', precision_score(y_true, y_pred))
print('recall score:', recall_score(y_true, y_pred))
print('f1 score:', f1_score(y_true, y_pred))

accuracy score: 0.87036
precision score: 0.8693258875149581
recall score: 0.87176
f1 score: 0.8705412422608348


> __Note__: `loaded_model_raw` is an instance of `BertModelWrapper`.

Without fine-tuning, the model is very stubborn and predicts the same class all the time.

#### Making predictions on raw data 

In [15]:
pipe = Pipeline(fields=list(fields.values()), 
                example_format='list',
                feature_transformer=cast_to_torch_transformer,
                model=loaded_model)

instances = [
    ['This movie is horrible'],
    ['This movie is great!']
]

for instance in instances:
    predictions = pipe.predict_raw(instance)
    print(f'instance: {instance}, predicted label: '
          f'{fields["label"].vocab.itos[predictions.argmax()]}, '
          f'predictions: {predictions}')

instance: ['This movie is horrible'], predicted label: negative, predictions: [-3.139483   3.0472898]
instance: ['This movie is great!'], predicted label: positive, predictions: [ 2.5500994 -3.031134 ]
