## Fine-tuning BERT from HuggingFace Transformers on the IMDB Dataset
>Note: In this tutorial we use a smaller version of BERT, called __DistilBert__, that is easier and faster to train.

In [2]:
import pickle

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

from podium.datasets import IMDB, Iterator
from podium.models import Experiment
from podium.models.impl.pytorch import TorchTrainer, TorchModel
from podium.pipeline import Pipeline
from podium.storage import Field, LabelField, Vocab

### Defining the fields

* text - applies `BertTokenizer` to the instance data. We don't store `attention_mask` to reduce the memory footprint of the dataset, instead we create it ourselves on the fly.
* label - stores binary labels that represent the sentiment of an instance.

In [4]:
def create_fields():
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    
    def text_to_tokens(string):
        input_ids = tokenizer(string,
                              max_length=128,
                              padding=False,
                              truncation=True,
                              return_attention_mask=False
                             )['input_ids']
        
        return tokenizer.convert_ids_to_tokens(input_ids)        
        
    def token_to_input_id(token):
        return tokenizer.convert_tokens_to_ids(token)
        
    text = Field(name='text',
                 tokenizer=text_to_tokens,
                 custom_numericalize=token_to_input_id,
                 padding_token=0)
    
    label = LabelField(name='label', vocab=Vocab(specials=()))
    
    return {
        'text': text,
        'label': label
    } 

### Wrapping the model
`DistilBertForSequenceClassification` is a standard PyTorch `Module`so it can be easily wrapped in another `Module` that has a proper interface to Podium - the model has to return a dictionary that has a key `pred` that points to the model predictions. As mentioned earlier, we are creating the attention mask ourselves so this is a good place to do it. 

In [5]:
class BertModelWrapper(nn.Module):
    
    def __init__(self, **kwargs):
        super().__init__()
        self.model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', 
                                                                         return_dict=True)
    
    def forward(self, x):
        attention_mask = (x != 0).long()
        return_dict = self.model(x, attention_mask)
        return_dict['pred'] = return_dict['logits']
        return return_dict

### Loading the IMDB dataset

In [None]:
fields = create_fields()
imdb_train, imdb_test = IMDB.get_dataset_splits(fields)

#### Check the loaded data 

> Note: `None` values in the oupput are for caching purposes

In [8]:
print(imdb_train[0].text)
print(imdb_train[0].label)

(None, ['[CLS]', 'bro', '##m', '##well', 'high', 'is', 'a', 'cartoon', 'comedy', '.', 'it', 'ran', 'at', 'the', 'same', 'time', 'as', 'some', 'other', 'programs', 'about', 'school', 'life', ',', 'such', 'as', '"', 'teachers', '"', '.', 'my', '35', 'years', 'in', 'the', 'teaching', 'profession', 'lead', 'me', 'to', 'believe', 'that', 'bro', '##m', '##well', 'high', "'", 's', 'satire', 'is', 'much', 'closer', 'to', 'reality', 'than', 'is', '"', 'teachers', '"', '.', 'the', 'scramble', 'to', 'survive', 'financially', ',', 'the', 'insight', '##ful', 'students', 'who', 'can', 'see', 'right', 'through', 'their', 'pathetic', 'teachers', "'", 'po', '##mp', ',', 'the', 'pet', '##tine', '##ss', 'of', 'the', 'whole', 'situation', ',', 'all', 'remind', 'me', 'of', 'the', 'schools', 'i', 'knew', 'and', 'their', 'students', '.', 'when', 'i', 'saw', 'the', 'episode', 'in', 'which', 'a', 'student', 'repeatedly', 'tried', 'to', 'burn', 'down', 'the', 'school', ',', 'i', 'immediately', 'recalled', '.', 

### Setting up the Podium Experiment

To fine-tune our model, we define an `Experiment`.

In [35]:
model_config = {
    'lr': 1e-5,
    'clip': float('inf'), # disable gradient clipping
    'num_epochs': 3,
}
model_config['num_classes'] = len(fields['label'].vocab)

In [63]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

iterator = Iterator(batch_size=32)
trainer = TorchTrainer(model_config['num_epochs'], device, iterator, imdb_test)

# here we have to swap axes to nullify the effect of swapping axes afterwards
# because we work with the batch-first model (we should add this option to Podium!!!)
feature_transformer = lambda feature_batch: feature_batch[0].astype(np.int64).swapaxes(0, 1)
label_transformer = lambda label_batch: label_batch[0].astype(np.int64)

experiment = Experiment(TorchModel, 
                        trainer=trainer, 
                        feature_transformer=feature_transformer,
                        label_transform_fn=label_transformer)

experiment.fit(imdb_train,  
               model_kwargs={
                   'model_class': BertModelWrapper,  
                   'criterion': nn.CrossEntropyLoss(),  
                   'optimizer': optim.AdamW,
                   'device': device,
                   **model_config
               })

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classi

[Batch]: 781 in 2.44000 seconds, loss=0.404058
Total time for train epoch: 5957.261180639267
[Valid]: 781 in 0.62000 seconds, loss=0.08201
Total time for valid epoch: 1807.9500014781952
[Batch]: 781 in 4.37339 seconds, loss=0.124552
Total time for train epoch: 7614.580070734024
[Valid]: 781 in 1.03415 seconds, loss=0.09244
Total time for valid epoch: 2411.316132545471
[Batch]: 781 in 2.63686 seconds, loss=0.112098
Total time for train epoch: 8038.268185377121
[Valid]: 781 in 0.83022 seconds, loss=0.06469
Total time for valid epoch: 2486.869591474533


### Saving and loading the fitted model

#### Utilities for saving/loading the model 

In [41]:
def save_model(model, file_path):
    with open(file_path, 'wb') as f:
        pickle.dump(fitted_model, f)

def load_model(file_path):
    with open(file_path, 'rb') as f:
        model = pickle.load(f)
    return model

In [None]:
fitted_model = experiment.model

model_file = 'bert_model.pt'
save_model(fitted_model, model_file)
loaded_model = load_model(model_file)

### Testing and making predictions on raw data

#### Utilities for making predictions with the raw model on the parsed dataset

In [60]:
cast_to_torch_transformer = lambda t: torch.from_numpy(t[0].astype(np.int64)).to(device)

def make_predictions(raw_model, dataset):
    raw_model.eval()
    
    # here we call `.batch()` on the dataset to get numericalized examples
    X, _ = dataset.batch()
    with torch.no_grad():
        predictions = raw_model(cast_to_torch_transformer(X))['pred']
        return predictions.cpu().numpy()

`raw_model` is an instance of `BertModelWrapper`

In [61]:
raw_model = loaded_model.model
predictions = make_predictions(raw_model, imdb_test[:4])
predictions

array([[ 1.90126  , -1.9712397],
       [ 0.9901674, -1.2427605],
       [ 2.4357972, -2.7238822],
       [ 2.5466118, -2.854033 ]], dtype=float32)

Check if the predictions are valid

In [62]:
_, y = imdb_test[:4].batch()
y_pred = predictions.argmax(axis=1)
print('y_pred == y_true:', (y_pred == y[0].ravel()).all())

y_pred == y_true: True


Making predictions on raw data 

In [59]:
pipe = Pipeline(fields=list(fields.values()), 
                example_format='list',
                feature_transformer=cast_to_torch_transformer,
                model=loaded_model)

instances = [
    ['This movie is horrible'],
    ['This movie is great!']
]

for instance in instances:
    predictions = pipe.predict_raw(instance)
    print(f'instance: {instance}, predicted label: '
          f'{fields["label"].vocab.itos[predictions.argmax()]}, '
          f'predictions: {predictions}')

instance: ['This movie is horrible'], predicted label: negative, predictions: [-2.3370621  2.6273308]
instance: ['This movie is great!'], predicted label: positive, predictions: [ 2.3128583 -2.7027676]
