[BERT](https://arxiv.org/abs/1810.04805) is known to be good at Sequence tagging tasks like Named Entity Recognition. Let's see if it's true for POS-tagging.

In [None]:
! pip install pytorch_pretrained_bert

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch_pretrained_bert
  Downloading pytorch_pretrained_bert-0.6.2-py3-none-any.whl (123 kB)
[K     |████████████████████████████████| 123 kB 4.9 MB/s 
Collecting boto3
  Downloading boto3-1.24.89-py3-none-any.whl (132 kB)
[K     |████████████████████████████████| 132 kB 40.6 MB/s 
Collecting s3transfer<0.7.0,>=0.6.0
  Downloading s3transfer-0.6.0-py3-none-any.whl (79 kB)
[K     |████████████████████████████████| 79 kB 5.1 MB/s 
[?25hCollecting jmespath<2.0.0,>=0.7.1
  Downloading jmespath-1.0.1-py3-none-any.whl (20 kB)
Collecting botocore<1.28.0,>=1.27.89
  Downloading botocore-1.27.89-py3-none-any.whl (9.2 MB)
[K     |████████████████████████████████| 9.2 MB 7.0 MB/s 
[?25hCollecting urllib3<1.27,>=1.25.4
  Downloading urllib3-1.26.12-py2.py3-none-any.whl (140 kB)
[K     |████████████████████████████████| 140 kB 50.6 MB/s 
  Downloading urllib3-1.25.11-py2.py3-none-an

In [None]:
import os
from tqdm import tqdm_notebook as tqdm
import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
import torch.optim as optim
from pytorch_pretrained_bert import BertTokenizer

In [None]:
torch.__version__

'1.12.1+cu113'

# Data preparation

Thanks to the great NLTK, we don't have to worry about datasets. Some of Penn Tree Banks are included in it. I believe they serves for the purpose.

In [None]:
import nltk
nltk.download('treebank')
tagged_sents = nltk.corpus.treebank.tagged_sents()
len(tagged_sents)

[nltk_data] Downloading package treebank to /root/nltk_data...
[nltk_data]   Unzipping corpora/treebank.zip.


3914

In [None]:
tagged_sents[0]

[('Pierre', 'NNP'),
 ('Vinken', 'NNP'),
 (',', ','),
 ('61', 'CD'),
 ('years', 'NNS'),
 ('old', 'JJ'),
 (',', ','),
 ('will', 'MD'),
 ('join', 'VB'),
 ('the', 'DT'),
 ('board', 'NN'),
 ('as', 'IN'),
 ('a', 'DT'),
 ('nonexecutive', 'JJ'),
 ('director', 'NN'),
 ('Nov.', 'NNP'),
 ('29', 'CD'),
 ('.', '.')]

In [None]:
tags = list(set(word_pos[1] for sent in tagged_sents for word_pos in sent))

In [None]:
",".join(tags)

"VB,SYM,PDT,EX,WP$,JJ,VBG,-RRB-,RBR,#,MD,:,LS,RB,RBS,PRP$,NN,$,NNP,UH,CC,NNS,WDT,JJR,-LRB-,VBP,POS,VBZ,CD,JJS,IN,WP,'',RP,DT,.,PRP,``,-NONE-,VBD,VBN,NNPS,WRB,,,FW,TO"

In [None]:
# By convention, the 0'th slot is reserved for padding.
tags = ["<pad>"] + tags

In [None]:
tag2idx = {tag:idx for idx, tag in enumerate(tags)}
idx2tag = {idx:tag for idx, tag in enumerate(tags)}

In [None]:
# Let's split the data into train and test (or eval)
from sklearn.model_selection import train_test_split
train_data, test_data = train_test_split(tagged_sents, test_size=.1)
len(train_data), len(test_data)

(3522, 392)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Data loader


In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)

100%|██████████| 213450/213450 [00:00<00:00, 2727731.43B/s]


In [None]:
tagged_sents

[[('Pierre', 'NNP'), ('Vinken', 'NNP'), (',', ','), ('61', 'CD'), ('years', 'NNS'), ('old', 'JJ'), (',', ','), ('will', 'MD'), ('join', 'VB'), ('the', 'DT'), ('board', 'NN'), ('as', 'IN'), ('a', 'DT'), ('nonexecutive', 'JJ'), ('director', 'NN'), ('Nov.', 'NNP'), ('29', 'CD'), ('.', '.')], [('Mr.', 'NNP'), ('Vinken', 'NNP'), ('is', 'VBZ'), ('chairman', 'NN'), ('of', 'IN'), ('Elsevier', 'NNP'), ('N.V.', 'NNP'), (',', ','), ('the', 'DT'), ('Dutch', 'NNP'), ('publishing', 'VBG'), ('group', 'NN'), ('.', '.')], ...]

In [None]:
# sents, tags_li = [], [] # list of lists
# for sent in tagged_sents:
#     words = [word_pos[0] for word_pos in sent]
#     tags = [word_pos[1] for word_pos in sent]
#     sents.append(["[CLS]"] + words + ["[SEP]"])
#     tags_li.append(["<pad>"] + tags + ["<pad>"])
#     break

In [None]:
class PosDataset(data.Dataset):
    def __init__(self, tagged_sents):
        sents, tags_li = [], [] # list of lists
        for sent in tagged_sents:
            words = [word_pos[0] for word_pos in sent]
            tags = [word_pos[1] for word_pos in sent]
            sents.append(["[CLS]"] + words + ["[SEP]"])
            tags_li.append(["<pad>"] + tags + ["<pad>"])
        self.sents, self.tags_li = sents, tags_li

    def __len__(self):
        return len(self.sents)

    def __getitem__(self, idx):
        words, tags = self.sents[idx], self.tags_li[idx] # words, tags: string list

        # We give credits only to the first piece.
        x, y = [], [] # list of ids
        is_heads = [] # list. 1: the token is the first piece of a word
        for w, t in zip(words, tags):
            tokens = tokenizer.tokenize(w) if w not in ("[CLS]", "[SEP]") else [w]
            xx = tokenizer.convert_tokens_to_ids(tokens)

            is_head = [1] + [0]*(len(tokens) - 1)

            t = [t] + ["<pad>"] * (len(tokens) - 1)  # <PAD>: no decision
            yy = [tag2idx[each] for each in t]  # (T,)

            x.extend(xx)
            is_heads.extend(is_head)
            y.extend(yy)

        assert len(x)==len(y)==len(is_heads), "len(x)={}, len(y)={}, len(is_heads)={}".format(len(x), len(y), len(is_heads))

        # seqlen
        seqlen = len(y)

        # to string
        words = " ".join(words)
        tags = " ".join(tags)
        return words, x, is_heads, tags, y, seqlen


In [None]:
def pad(batch):
    '''Pads to the longest sample'''
    f = lambda x: [sample[x] for sample in batch]
    words = f(0)
    is_heads = f(2)
    tags = f(3)
    seqlens = f(-1)
    maxlen = np.array(seqlens).max()

    f = lambda x, seqlen: [sample[x] + [0] * (seqlen - len(sample[x])) for sample in batch] # 0: <pad>
    x = f(1, maxlen)
    y = f(-2, maxlen)


    f = torch.LongTensor

    return words, f(x), is_heads, tags, f(y), seqlens

# Model

In [None]:
from pytorch_pretrained_bert import BertModel

In [None]:
class Net(nn.Module):
    def __init__(self, vocab_size=None):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased')

        self.fc = nn.Linear(768, vocab_size)
        self.device = device

    def forward(self, x, y):
        '''
        x: (N, T). int64
        y: (N, T). int64
        '''
        x = x.to(device)
        y = y.to(device)
        
        if self.training:
            self.bert.train()
            encoded_layers, _ = self.bert(x)
            enc = encoded_layers[-1]
        else:
            self.bert.eval()
            with torch.no_grad():
                encoded_layers, _ = self.bert(x)
                enc = encoded_layers[-1]
        
        logits = self.fc(enc)
        y_hat = logits.argmax(-1)
        return logits, y, y_hat

# Train an evaluate

In [None]:
def train(model, iterator, optimizer, criterion):
    model.train()
    for i, batch in enumerate(iterator):
        words, x, is_heads, tags, y, seqlens = batch
        _y = y # for monitoring
        optimizer.zero_grad()
        logits, y, _ = model(x, y) # logits: (N, T, VOCAB), y: (N, T)

        logits = logits.view(-1, logits.shape[-1]) # (N*T, VOCAB)
        y = y.view(-1)  # (N*T,)

        loss = criterion(logits, y)
        loss.backward()

        optimizer.step()

        if i%10==0: # monitoring
            print("step: {}, loss: {}".format(i, loss.item()))

In [None]:
def eval(model, iterator):
    model.eval()

    Words, Is_heads, Tags, Y, Y_hat = [], [], [], [], []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            words, x, is_heads, tags, y, seqlens = batch

            _, _, y_hat = model(x, y)  # y_hat: (N, T)

            Words.extend(words)
            Is_heads.extend(is_heads)
            Tags.extend(tags)
            Y.extend(y.numpy().tolist())
            Y_hat.extend(y_hat.cpu().numpy().tolist())

    ## gets results and save
    with open("result", 'w') as fout:
        for words, is_heads, tags, y_hat in zip(Words, Is_heads, Tags, Y_hat):
            y_hat = [hat for head, hat in zip(is_heads, y_hat) if head == 1]
            preds = [idx2tag[hat] for hat in y_hat]
            assert len(preds)==len(words.split())==len(tags.split())
            for w, t, p in zip(words.split()[1:-1], tags.split()[1:-1], preds[1:-1]):
                fout.write("{} {} {}\n".format(w, t, p))
            fout.write("\n")
            
    ## calc metric
    y_true =  np.array([tag2idx[line.split()[1]] for line in open('result', 'r').read().splitlines() if len(line) > 0])
    y_pred =  np.array([tag2idx[line.split()[2]] for line in open('result', 'r').read().splitlines() if len(line) > 0])

    acc = (y_true==y_pred).astype(np.int32).sum() / len(y_true)

    print("acc=%.2f"%acc)


## Load model and train

In [None]:
model = Net(vocab_size=len(tag2idx))
model.to(device)
model = nn.DataParallel(model)

100%|██████████| 404400730/404400730 [00:06<00:00, 59561937.41B/s]


In [None]:
train_dataset = PosDataset(train_data)
eval_dataset = PosDataset(test_data)

train_iter = data.DataLoader(dataset=train_dataset,
                             batch_size=8,
                             shuffle=True,
                             num_workers=1,
                             collate_fn=pad)
test_iter = data.DataLoader(dataset=eval_dataset,
                             batch_size=8,
                             shuffle=False,
                             num_workers=1,
                             collate_fn=pad)

optimizer = optim.Adam(model.parameters(), lr = 0.0001)

criterion = nn.CrossEntropyLoss(ignore_index=0)

In [None]:
train(model, train_iter, optimizer, criterion)
eval(model, test_iter)


step: 0, loss: 3.964611291885376
step: 10, loss: 1.657906174659729
step: 20, loss: 0.4662945568561554
step: 30, loss: 0.3822211027145386
step: 40, loss: 0.2032579630613327
step: 50, loss: 0.1271943747997284
step: 60, loss: 0.08578399568796158
step: 70, loss: 0.25627240538597107
step: 80, loss: 0.09846189618110657
step: 90, loss: 0.12252505868673325
step: 100, loss: 0.1703256219625473
step: 110, loss: 0.08751732856035233
step: 120, loss: 0.10018518567085266
step: 130, loss: 0.1406237781047821
step: 140, loss: 0.07370074093341827
step: 150, loss: 0.09682586044073105
step: 160, loss: 0.07088088989257812
step: 170, loss: 0.06322076916694641
step: 180, loss: 0.09748023748397827
step: 190, loss: 0.18131068348884583
step: 200, loss: 0.07456100732088089
step: 210, loss: 0.1336268037557602
step: 220, loss: 0.19304615259170532
step: 230, loss: 0.0816136971116066
step: 240, loss: 0.12133247405290604
step: 250, loss: 0.067226342856884
step: 260, loss: 0.07155530154705048
step: 270, loss: 0.1195413

Check the result.

In [None]:
open('result', 'r').read().splitlines()[:100]

['But CC CC',
 'the DT DT',
 'legislation NN NN',
 'reflected VBD VBD',
 'a DT DT',
 'compromise NN NN',
 'agreed VBN VBN',
 'to TO TO',
 '* -NONE- -NONE-',
 'on IN IN',
 'Tuesday NNP NNP',
 'by IN IN',
 'President NNP NNP',
 'Bush NNP NNP',
 'and CC CC',
 'Democratic JJ JJ',
 'leaders NNS NNS',
 'in IN IN',
 'Congress NNP NNP',
 ', , ,',
 'after IN IN',
 'congressional JJ JJ',
 'Republicans NNPS NNS',
 'urged VBD VBD',
 'the DT DT',
 'White NNP NNP',
 'House NNP NNP',
 '*-2 -NONE- -NONE-',
 'to TO TO',
 'bend VB VB',
 'a DT DT',
 'bit NN NN',
 'from IN IN',
 'its PRP$ PRP$',
 'previous JJ JJ',
 'resistance NN NN',
 '* -NONE- -NONE-',
 'to TO TO',
 'compromise VB VB',
 '. . .',
 '',
 'The DT DT',
 'firm NN NN',
 'and CC CC',
 'Mr. NNP NNP',
 'Whelen NNP NNP',
 'allegedly RB RB',
 'sold VBD VBD',
 'securities NNS NNS',
 'to TO TO',
 'the DT DT',
 'public NN NN',
 'at IN IN',
 'unfair JJ JJ',
 'prices NNS NNS',
 ', , ,',
 'among IN IN',
 'other JJ JJ',
 'alleged JJ VBN',
 'violations NNS