In [1]:
import os
import random
import numpy as np

import functools

import torch
import torch.nn as nn

from torchtext import datasets


from torchtext.data import Field
from torchtext.data import BucketIterator

from transformers import BertTokenizer, BertModel


BERT_VERSION = 'bert-base-uncased'

SEED = 241

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def seed_everything(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

seed_everything(SEED)

In [2]:
bert = BertModel.from_pretrained(BERT_VERSION, output_hidden_states=True)
tokenizer = BertTokenizer.from_pretrained(BERT_VERSION)



In [3]:
tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id, tokenizer.unk_token_id

(101, 102, 0, 100)

In [5]:
def tokenize_and_convert_text(tokens, tokenizer, max_len):
  tokens = tokens[:max_len]
  token_ids = tokenizer.convert_tokens_to_ids(tokens)
  return token_ids

def tokenize_and_convert_labels(labels, max_len):
  return labels[:max_len]


text_preprocess = functools.partial(tokenize_and_convert_text, 
                                    tokenizer=tokenizer, 
                                    max_len=tokenizer.max_len_single_sentence)
tag_preprocess = functools.partial(tokenize_and_convert_labels,
                                   max_len=tokenizer.max_len_single_sentence)

In [17]:
TEXT = Field(lower=True,
             sequential=True,
             use_vocab=False,
             batch_first=True,
             include_lengths=True,
            #  preprocessing=text_preprocess,
             init_token=tokenizer.cls_token_id,
             pad_token=tokenizer.pad_token_id,
             unk_token=tokenizer.unk_token_id)
             # eos_token=tokenizer.sep_token_id)

LABEL = Field(batch_first=True,
              unk_token = None,
              init_token='<pad>',
              preprocessing=tag_preprocess)

In [18]:
from torchtext import datasets

In [19]:
fields = [('text', TEXT), ('udtags', LABEL)]

train_data, valid_data, test_data = datasets.UDPOS.splits(fields=fields)

In [23]:
valid_data[106].udtags

['ADV',
 'PUNCT',
 'PRON',
 'AUX',
 'VERB',
 'SCONJ',
 'PRON',
 'AUX',
 'AUX',
 'ADV',
 'VERB',
 'SCONJ',
 'DET',
 'NOUN',
 'ADP',
 'ADJ',
 'NOUN',
 'AUX',
 'ADJ',
 'PART',
 'VERB',
 'ADP',
 'PROPN',
 'PUNCT',
 'VERB',
 'SCONJ',
 'DET',
 'NOUN',
 'AUX',
 'AUX',
 'VERB',
 'DET',
 'NOUN',
 'ADP',
 'NOUN',
 'ADP',
 'ADV',
 'DET',
 'ADJ',
 'NUM',
 'NOUN',
 'PUNCT',
 'PUNCT']

In [10]:
LABEL.build_vocab(train_data)

In [11]:
batch_size = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator = BucketIterator.splits((train_data, ),
                                       batch_size=batch_size,
                                       device=device)[0]

valid_iterator = BucketIterator.splits((valid_data, ),
                                       batch_size=batch_size,
                                       device=device)[0]

test_iterator = BucketIterator.splits((test_data, ),
                                      batch_size=batch_size,
                                      device=device)[0]

In [12]:
from torchcrf import CRF


class BertCRFTagger(nn.Module):

  def __init__(self, bert, hidden_size, num_tags, dropout):
    super().__init__()
    self.bert = bert
    self.crf = CRF(num_tags, batch_first=True)
    self.fc = nn.Linear(hidden_size, num_tags)
    self.dropout = nn.Dropout(dropout)

  def generate_mask(self, input_temaplte):
    bs = input_temaplte.size(0)
    seq_len = torch.max(input_temaplte)
    mask = torch.ByteTensor(bs, seq_len).fill_(0)
    for i in range(bs):
      mask[i, :input_temaplte[i]] = 1
    return mask

  def forward(self, input_ids, text_lens, tags=None):
    bert_output = self.bert(input_ids)
    last_hidden_state = bert_output['hidden_states'][-1]
    
    emission = self.fc(last_hidden_state)
    mask = self.generate_mask(text_lens).to(device)

    if tags is not None:
      loss = -self.crf(torch.log_softmax(emission, dim=2), tags, mask=mask, reduction='mean')
      return loss
    else:
      prediction = self.crf.decode(emission, mask=mask)
      return prediction

In [13]:
dropout = 0.3
num_tags = len(LABEL.vocab.itos)
hidden_size = bert.config.to_dict()['hidden_size']


bert_crf_tagger = BertCRFTagger(bert, hidden_size, num_tags, dropout).to(device)
optimizer = torch.optim.Adam(bert_crf_tagger.parameters(), lr=2e-5)

In [14]:
from tqdm import tqdm

In [15]:
for i in range(2):
  bert_crf_tagger.train()
  error = 0.
  for batch in tqdm(train_iterator):
      optimizer.zero_grad()

      text, lens = batch.text
      labels = batch.udtags

      loss = bert_crf_tagger(text, lens, labels)

      loss.backward()
      optimizer.step()

      error += loss.detach().cpu().item()
  print('train error', error / len(train_iterator))

  error = 0.
  bert_crf_tagger.eval()
  with torch.no_grad():
    for batch in tqdm(valid_iterator):

        text, lens = batch.text
        labels = batch.udtags
        loss = bert_crf_tagger(text, lens, labels)

        error += loss.detach().cpu().item()
  print('valid error', error / len(valid_iterator))

100%|██████████| 392/392 [01:31<00:00,  4.29it/s]
  3%|▎         | 2/63 [00:00<00:03, 15.52it/s]

train error 9.26150271174859


100%|██████████| 63/63 [00:03<00:00, 16.80it/s]
  0%|          | 0/392 [00:00<?, ?it/s]

valid error 2.744781899073767


100%|██████████| 392/392 [01:31<00:00,  4.27it/s]
  3%|▎         | 2/63 [00:00<00:04, 13.25it/s]

train error 2.547208889558607


100%|██████████| 63/63 [00:03<00:00, 16.62it/s]

valid error 2.1571167858820113





In [68]:
import numpy as np


def calculate_accuracy(y_true, y_pred):
  assert y_true.shape == y_pred.shape
  assert len(y_true.shape) == 1
  y_true = y_true[1:]
  y_pred = y_pred[1:]
  return (y_true == y_pred).sum() / y_true.shape[0]

y_true_test = np.array([-1, 1, 2, 3])
y_pred_test = np.array([-1, 1, 0, 3])

calculate_accuracy(y_true_test, y_pred_test)

0.6666666666666666

In [92]:
total_true_labels = []
total_pred_labels = []

for index in range(len(test_data.examples)):

  text = '\t'.join(tokenizer.convert_ids_to_tokens(test_data.examples[index].text))
  true_labels = '\t'.join(['<pad>'] + test_data.examples[index].udtags)

  with torch.no_grad():
    tokens = ['[CLS]'] + text.split()[:tokenizer.max_len_single_sentence]
    ids = tokenizer.convert_tokens_to_ids(tokens)
    ids_tensor = torch.tensor([ids], device=device)
    lens = torch.tensor([len(ids)]).to(device)
    prediction = bert_crf_tagger(ids_tensor, lens)
    
  print('\t'.join(tokens))
  print(true_labels)
  print('\t'.join([LABEL.vocab.itos[p] for p in prediction[0]]))

  total_true_labels.extend(np.array([LABEL.vocab.itos[p] for p in prediction[0]]))
  total_pred_labels.extend(np.array(true_labels.split('\t')))

[1;30;43mВыходные данные были обрезаны до нескольких последних строк (5000).[0m
<pad>	PROPN	PROPN	PROPN	ADJ	NOUN	NUM	PROPN	PROPN	PUNCT	PROPN	NUM	PROPN	PUNCT	PROPN	NUM	NOUN	PUNCT	PUNCT	NUM	PUNCT	NUM	NOUN	PUNCT	PUNCT	NUM	PUNCT	NUM
<pad>	PROPN	PROPN	PROPN	NOUN	NOUN	NUM	PROPN	PROPN	PUNCT	NUM	NUM	PROPN	PUNCT	PROPN	NUM	NOUN	PUNCT	PUNCT	NUM	PUNCT	NUM	NOUN	PUNCT	PUNCT	NUM	PUNCT	NUM
[CLS]	thanks	.
<pad>	NOUN	PUNCT
<pad>	NOUN	PUNCT
[CLS]	ss
<pad>	PROPN
<pad>	<pad>
[CLS]	sara	,
<pad>	PROPN	PUNCT
<pad>	<pad>	PUNCT
[CLS]	currently	we	have	a	blank	"	sample	"	for	our	paragraph	[UNK]	which	are	attached	to	our	sample	[UNK]	for	(	a	)	us	corporate	,	(	b	)	hedge	funds	,	(	c	)	municipal	.
<pad>	ADV	PRON	VERB	DET	ADJ	PUNCT	NOUN	PUNCT	ADP	PRON	NOUN	NOUN	PRON	AUX	VERB	ADP	PRON	ADJ	NOUN	ADP	PUNCT	X	PUNCT	PROPN	ADJ	PUNCT	PUNCT	X	PUNCT	NOUN	NOUN	PUNCT	PUNCT	X	PUNCT	ADJ	PUNCT
<pad>	ADV	PRON	VERB	DET	ADJ	PUNCT	NOUN	PUNCT	ADP	PRON	NOUN	NOUN	PRON	AUX	VERB	ADP	PRON	NOUN	NOUN	ADP	PUNCT	X	PUNCT	PROPN	ADJ	PUNCT	PUNCT	

In [95]:
calculate_accuracy(np.array(total_true_labels), np.array(total_pred_labels))

0.9374378979133698

In [101]:
from sklearn.metrics import classification_report


print(classification_report(np.array(total_true_labels), np.array(total_pred_labels)))

              precision    recall  f1-score   support

       <pad>       1.00      0.87      0.93      2395
         ADJ       0.89      0.93      0.91      1617
         ADP       0.99      0.96      0.97      2062
         ADV       0.92      0.94      0.93      1207
         AUX       0.98      0.98      0.98      1498
       CCONJ       0.98      1.00      0.99       725
         DET       0.99      1.00      0.99      1891
        INTJ       0.70      0.97      0.81        87
        NOUN       0.92      0.93      0.92      4093
         NUM       0.83      0.81      0.82       548
        PART       0.95      0.93      0.94       643
        PRON       0.99      1.00      0.99      2142
       PROPN       0.78      0.82      0.80      1970
       PUNCT       0.98      0.96      0.97      3161
       SCONJ       0.94      0.96      0.95       377
         SYM       0.54      0.89      0.68        56
        VERB       0.96      0.97      0.97      2649
           X       0.19    