## Fine-tuning Transformers for POS tagging

In this notebook, we will fine-tune the pre-trained BERT model.

In [None]:
!pip install transformers
!pip install torchtext==0.6.0

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.27.2-py3-none-any.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m34.0 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.13.3-py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.8/199.8 KB[0m [31m15.4 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m70.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.13.3 tokenizers-0.13.2 transformers-4.27.2
Looking in indexes: https://pypi.org/simple, https://us

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchtext

from torchtext import *
#from torchtext.legacy.data import Field, TabularDataset, BucketIterator, Iterator
from torchtext import datasets

from transformers import BertTokenizer, BertModel

import numpy as np

import time
import random
import functools

In [None]:
# SEED = 1234

# random.seed(SEED)
# np.random.seed(SEED)
# torch.manual_seed(SEED)
# torch.backends.cudnn.deterministic = True

In [None]:
# import BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [None]:
init_token = tokenizer.cls_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token

print(init_token, pad_token, unk_token)

# ensure that the input is formated in the same way in which the BERT model was trained, verify integer representation of special tokens
init_token_idx = tokenizer.convert_tokens_to_ids(init_token)
pad_token_idx = tokenizer.convert_tokens_to_ids(pad_token)
unk_token_idx = tokenizer.convert_tokens_to_ids(unk_token)

print(init_token_idx, pad_token_idx, unk_token_idx)

[CLS] [PAD] [UNK]
101 0 100


In [None]:
# verify what is the max length for sequences in the pre-trained model
max_input_length = tokenizer.max_model_input_sizes['bert-base-uncased']

print(max_input_length)

512


In [None]:
# helper functions for the input sequence and tags for building the vocabulary
def cut_and_convert_to_id(tokens, tokenizer, max_input_length):
    tokens = tokens[:max_input_length-1]
    tokens = tokenizer.convert_tokens_to_ids(tokens)
    return tokens

def cut_to_max_length(tokens, max_input_length):
    tokens = tokens[:max_input_length-1]
    return tokens

In [None]:
# functools allows to pass functions which already have some of their arguments supplied
text_preprocessor = functools.partial(cut_and_convert_to_id,
                                      tokenizer = tokenizer,
                                      max_input_length = max_input_length)

tag_preprocessor = functools.partial(cut_to_max_length,
                                     max_input_length = max_input_length)

In [None]:
# Field, the TorchText abstraction handles data processing.
# define the fields to be mapped along the fields in dataset
TEXT = data.Field(use_vocab = False,
                  lower = True,
                  preprocessing = text_preprocessor,
                  init_token = init_token_idx,
                  pad_token = pad_token_idx,
                  unk_token = unk_token_idx)

UD_TAGS = data.Field(unk_token = None,
                     init_token = '<pad>',
                     preprocessing = tag_preprocessor)

In [None]:
# map the above defined fields to their corresponding field in the dataset
fields = (("text", TEXT), ("udtags", UD_TAGS))

In [None]:
train_data, valid_data, test_data = datasets.UDPOS.splits(fields)

downloading en-ud-v2.zip


en-ud-v2.zip: 100%|██████████| 688k/688k [00:00<00:00, 43.3MB/s]


extracting


In [None]:
# sample from dataset, texts are already converted to ids using the pretrained model's vocabulary.
print(vars(train_data.examples[0]))

{'text': [2632, 1011, 100, 1024, 2137, 2749, 2730, 100, 14093, 2632, 1011, 100, 1010, 1996, 14512, 2012, 1996, 8806, 1999, 1996, 2237, 1997, 100, 1010, 2379, 1996, 9042, 3675, 1012], 'udtags': ['PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'ADJ', 'NOUN', 'VERB', 'PROPN', 'PROPN', 'PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'PROPN', 'PUNCT', 'ADP', 'DET', 'ADJ', 'NOUN', 'PUNCT']}


In [None]:
# build a vocabulary for the tag set in the training data
UD_TAGS.build_vocab(train_data)

print(UD_TAGS.vocab.stoi)

defaultdict(None, {'<pad>': 0, 'NOUN': 1, 'PUNCT': 2, 'VERB': 3, 'PRON': 4, 'ADP': 5, 'DET': 6, 'PROPN': 7, 'ADJ': 8, 'AUX': 9, 'ADV': 10, 'CCONJ': 11, 'PART': 12, 'NUM': 13, 'SCONJ': 14, 'X': 15, 'INTJ': 16, 'SYM': 17})


In [None]:
BATCH_SIZE = 32

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data),
    batch_size = BATCH_SIZE,
    device = device)

In [None]:
class BERTPoSTagger(nn.Module):
    def __init__(self, bert, output_dim, dropout):
        super().__init__()
        self.bert = bert
        embedding_dim = bert.config.to_dict()['hidden_size']
        self.fc = nn.Linear(embedding_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, text):
        #permute the input because BERT wants sequence with batch first
        text = text.permute(1, 0)

        embedded = self.dropout(self.bert(text)[0])

        embedded = embedded.permute(1, 0, 2)

        predictions = self.fc(self.dropout(embedded))
        #predictions = [sent len, batch size, output dim]

        return predictions

In [None]:
# load the actual pretrained BERT uncased model
bert = BertModel.from_pretrained('bert-base-uncased')

Downloading pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
OUTPUT_DIM = len(UD_TAGS.vocab)
DROPOUT = 0.25

model = BERTPoSTagger(bert,
                      OUTPUT_DIM,
                      DROPOUT)

In [None]:
# def count_parameters(model):
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 109,496,082 trainable parameters


In [None]:
LEARNING_RATE = 5e-5

optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)

In [None]:
# define a loss function, make sure to ignore losses whenever the target tag is a padding token
TAG_PAD_IDX = UD_TAGS.vocab.stoi[UD_TAGS.pad_token]

criterion = nn.CrossEntropyLoss(ignore_index = TAG_PAD_IDX)

In [None]:
model = model.to(device)
criterion = criterion.to(device)

In [None]:
# calculate accuracy of predicting tags, ignoring predictions over padding tokens
def categorical_accuracy(preds, y, tag_pad_idx):
    # returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    max_preds = preds.argmax(dim = 1, keepdim = True)
    non_pad_elements = (y != tag_pad_idx).nonzero()
    correct = max_preds[non_pad_elements].squeeze(1).eq(y[non_pad_elements])
    return correct.sum() / torch.FloatTensor([y[non_pad_elements].shape[0]]).to(device)

In [None]:
def train(model, iterator, optimizer, criterion, tag_pad_idx):
    epoch_loss = 0
    epoch_acc = 0
    model.train()

    for batch in iterator:
        text = batch.text
        tags = batch.udtags
        optimizer.zero_grad()

        predictions = model(text)

        predictions = predictions.view(-1, predictions.shape[-1])
        tags = tags.view(-1)

        loss = criterion(predictions, tags)
        acc = categorical_accuracy(predictions, tags, tag_pad_idx)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += acc.item()
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [None]:
def evaluate(model, iterator, criterion, tag_pad_idx):
    epoch_loss = 0
    epoch_acc = 0
    model.eval()

    with torch.no_grad():
        for batch in iterator:
            text = batch.text
            tags = batch.udtags
            predictions = model(text)

            predictions = predictions.view(-1, predictions.shape[-1])
            tags = tags.view(-1)

            loss = criterion(predictions, tags)
            acc = categorical_accuracy(predictions, tags, tag_pad_idx)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
N_EPOCHS = 3
best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    start_time = time.time()
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion, TAG_PAD_IDX)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, TAG_PAD_IDX)
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'pot-model.pt')

    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

Epoch: 01 | Epoch Time: 2m 7s
	Train Loss: 0.383 | Train Acc: 89.22%
	 Val. Loss: 0.287 |  Val. Acc: 91.19%
Epoch: 02 | Epoch Time: 2m 8s
	Train Loss: 0.120 | Train Acc: 96.53%
	 Val. Loss: 0.278 |  Val. Acc: 91.71%
Epoch: 03 | Epoch Time: 2m 8s
	Train Loss: 0.078 | Train Acc: 97.75%
	 Val. Loss: 0.271 |  Val. Acc: 91.68%


In [None]:
model.load_state_dict(torch.load('pot-model.pt'))

test_loss, test_acc = evaluate(model, test_iterator, criterion, TAG_PAD_IDX)
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

Test Loss: 0.279 | Test Acc: 91.28%


In [None]:
# use model to tag actual sentences
def tag_sentence(model, device, sentence, tokenizer, text_field, tag_field):

    model.eval()

    if isinstance(sentence, str):
        tokens = tokenizer.tokenize(sentence)
    else:
        tokens = sentence

    numericalized_tokens = tokenizer.convert_tokens_to_ids(tokens)
    numericalized_tokens = [text_field.init_token] + numericalized_tokens

    unk_idx = text_field.unk_token

    unks = [t for t, n in zip(tokens, numericalized_tokens) if n == unk_idx]

    token_tensor = torch.LongTensor(numericalized_tokens)
    token_tensor = token_tensor.unsqueeze(-1).to(device)

    predictions = model(token_tensor)
    top_predictions = predictions.argmax(-1)

    predicted_tags = [tag_field.vocab.itos[t.item()] for t in top_predictions]
    predicted_tags = predicted_tags[1:]
    assert len(tokens) == len(predicted_tags)

    return tokens, predicted_tags, unks

In [None]:
# example 1
sentence = 'The Queen will deliver a speech about the conflict in North Korea at 1pm tomorrow.'

tokens, tags, unks = tag_sentence(model, device, sentence, tokenizer, TEXT, UD_TAGS)

[]


In [None]:
print("Predicted Tag\tToken\n")

for token, tag in zip(tokens, tags):
    print(f"{tag}\t\t{token}")

Predicted Tag	Token

DET		the
NOUN		queen
AUX		will
VERB		deliver
DET		a
NOUN		speech
ADP		about
DET		the
NOUN		conflict
ADP		in
PROPN		north
PROPN		korea
ADP		at
NUM		1
NOUN		##pm
NOUN		tomorrow
PUNCT		.


In [None]:
print(unks)

In [None]:
# example 2
sentence = 'Could you google it now?'

tokens, tags, unks = tag_sentence(model, device, sentence, tokenizer, TEXT, UD_TAGS)

print("Predicted Tag\tToken\n")
for token, tag in zip(tokens, tags):
    print(f"{tag}\t\t{token}")

In [None]:
# example 2
sentence = 'The passage was so low I had to duck'

tokens, tags, unks = tag_sentence(model, device, sentence, tokenizer, TEXT, UD_TAGS)

print("Predicted Tag\tToken\n")
for token, tag in zip(tokens, tags):
    print(f"{tag}\t\t{token}")

In [None]:
# example 3
sentence = 'colorless green ideas sleep furiously'

tokens, tags, unks = tag_sentence(model, device, sentence, tokenizer, TEXT, UD_TAGS)

print("Predicted Tag\tToken\n")
for token, tag in zip(tokens, tags):
    print(f"{tag}\t\t{token}")

[Reference](https://colab.research.google.com/github/sejas/pytorch-pos-tagging/blob/master/2%20-%20Fine-tuning%20Pretrained%20Transformers%20for%20PoS%20Tagging.ipynb)

[fine-tuning for ocr-document parsing](https://www.philschmid.de/fine-tuning-donut)