In [2]:
!pip install transformers

[0m

In [3]:
colab=False

if colab:
    from google.colab import drive
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.optim as optim

import time
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split

from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments

In [4]:
MODE_RU = True

In [5]:
model_name = 'DeepPavlov/rubert-base-cased' if MODE_RU else 'bert-base-multilingual-cased'

In [6]:
if colab:
    drive.mount('/content/drive/')

    dir = 'drive/MyDrive/BS/DATA_EXTRACTION/'
    corp_cased = dir + 'corp_cased.csv'
else:
    corp_cased = '/kaggle/input/corp-cased/corp_cased.csv'

In [7]:
df = pd.read_csv(corp_cased, sep='\t', header=None, on_bad_lines='skip')
df.dropna(inplace=True)

df.head()

Unnamed: 0,0,1
0,Школа злословия учит прикусить язык,NOUN NOUN VERB INFN NOUN
1,Сохранится ли градус дискуссии в новом сезоне,VERB PRCL NOUN NOUN PREP ADJF NOUN
2,Великолепная Школа злословия вернулась в эфир ...,ADJF NOUN NOUN VERB PREP NOUN PREP ADJF NOUN P...
3,В истории программы это уже не первый ребрендинг,PREP NOUN NOUN NPRO ADVB PRCL ADJF NOUN
4,Сейчас с трудом можно припомнить что начиналас...,ADVB PREP NOUN PRED INFN CONJ VERB NOUN PREP N...


In [8]:
sentences = df[0].to_numpy()
tags = df[1].to_numpy()

In [9]:
sentences = np.array(list(map(lambda x: str(x).split(), sentences)))
tags = np.array(list(map(lambda x: str(x).split(), tags)))

  """Entry point for launching an IPython kernel.
  


In [10]:
def build_voc_t(ttoi):
    idx = 0
    
    for tags_ in tags:
        for tag in tags_:
            if tag not in ttoi:
                ttoi[tag] = idx
                idx += 1

def creator(x, y, ttoi):
    for i in range(len(sentences)):
        for j in range(len(sentences[i])):
            x_elem = []
            #word before
            if j != 0:
                x_elem.append(sentences[i][j - 1])

            #current word
            x_elem.append(sentences[i][j])

            #word after
            if j != len(sentences[i]) - 1:
                x_elem.append(sentences[i][j + 1])

            x.append(' '.join(x_elem))
            y.append(ttoi[tags[i][j]])

In [11]:
ttoi = {}
x = []
y = []

build_voc_t(ttoi)

creator(x, y, ttoi)

In [12]:
x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.7, stratify=y, shuffle=True)

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [14]:
tokenizer = BertTokenizer.from_pretrained(model_name, max_length=512, do_lower_case=False)

Downloading:   0%|          | 0.00/1.57M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/642 [00:00<?, ?B/s]

In [15]:
x_train_enc = tokenizer(x_train, truncation=True, padding=True, max_length=512)
x_test_enc = tokenizer(x_test, truncation=True, padding=True, max_length=512)

In [16]:
temp_ids = x_train_enc['input_ids'][0]

print(temp_ids)
print(tokenizer.decode(temp_ids))

[101, 1758, 12548, 11881, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[CLS] за звание Чемпион [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]


In [17]:
class PosTagDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
        item['labels'] = torch.tensor([self.labels[idx]])
        return item

In [18]:
dataset_train = PosTagDataset(x_train_enc, y_train)
dataset_test = PosTagDataset(x_test_enc, y_test)

In [19]:
print(dataset_train.__getitem__(0))

{'input_ids': tensor([  101,  1758, 12548, 11881,   102,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0]), 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'attention_mask': tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [20]:
cr_labels = []
cr_names = []

for name, label in ttoi.items():
    cr_labels.append(label)
    cr_names.append(name)

print(cr_labels)
print(cr_names)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]
['NOUN', 'VERB', 'INFN', 'PRCL', 'PREP', 'ADJF', 'NPRO', 'ADVB', 'PRED', 'CONJ', 'Name', 'Surn', 'PRTF', 'COMP', 'NUMR', 'UNKN', 'Patr', 'INTJ', 'PRTS', 'GRND', 'Geox', 'ADJS']


In [21]:
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=len(ttoi))
model.to(device)

Downloading:   0%|          | 0.00/681M [00:00<?, ?B/s]

Some weights of the model checkpoint at DeepPavlov/rubert-base-cased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were n

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elemen

In [22]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Модель имеет {count_parameters(model):,} обучаемых параметров')

Модель имеет 177,870,358 обучаемых параметров


In [23]:
def compute_metrics(y_pred):
    y_true = y_pred.label_ids
    y_pred = y_pred.predictions.argmax(-1)
    cl_rep = classification_report(y_pred, y_true, labels=cr_labels, target_names=cr_names) # accuracy_score - функция из sklearn.metrics
    return {'classification report': cl_rep}

In [27]:
training_args = TrainingArguments(
    output_dir = 'results/',
    num_train_epochs = 5, # Число эпох
    per_device_train_batch_size = 8, # Размеры пакетов обучения и оценки
    per_device_eval_batch_size = 8,
    warmup_steps = 100, # Шаг выдачи предупреждений
    max_steps = 3000,
    weight_decay = 0.01, # Коэффициент уменьшения весов
    load_best_model_at_end = True, # Флаг загрузки лучшей модели после завершения обучения
    logging_steps = 500, # Шаг сохранения весов (checkpoint)
    evaluation_strategy = 'steps' # Стратегия обучения
)

using `logging_steps` to initialize `eval_steps` to 500
PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [28]:
trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = dataset_train,
    eval_dataset = dataset_test,
    compute_metrics = compute_metrics
)

max_steps is given, it will override any value given in num_train_epochs


In [29]:
trainer.train() # Обучение

***** Running training *****
  Num examples = 1089240
  Num Epochs = 1
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 3000
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Step,Training Loss,Validation Loss,Classification report
500,0.7482,0.626199,precision recall f1-score support  NOUN 0.93 0.97 0.95 146373  VERB 0.89 0.92 0.91 40145  INFN 0.98 0.96 0.97 11173  PRCL 0.71 0.48 0.57 20933  PREP 0.99 0.92 0.95 60017  ADJF 0.90 0.93 0.91 67693  NPRO 0.87 0.89 0.88 13640  ADVB 0.37 0.75 0.50 8632  PRED 0.00 0.00 0.00 0  CONJ 0.96 0.73 0.83 54352  Name 0.00 0.00 0.00 1  Surn 0.91 0.21 0.34 17843  PRTF 0.87 0.63 0.73 9998  COMP 0.00 0.00 0.00 0  NUMR 0.90 0.66 0.76 3291  UNKN 0.00 0.41 0.00 29  Patr 0.00 0.00 0.00 0  INTJ 0.00 0.00 0.00 0  PRTS 0.90 0.73 0.81 3934  GRND 0.00 0.00 0.00 0  Geox 0.87 0.67 0.76 8764  ADJS 0.00 0.00 0.00 0  accuracy 0.85 466818  macro avg 0.55 0.49 0.49 466818 weighted avg 0.91 0.85 0.87 466818
1000,0.5453,0.474206,precision recall f1-score support  NOUN 0.98 0.92 0.95 161294  VERB 0.93 0.95 0.94 40545  INFN 0.99 0.92 0.95 11852  PRCL 0.73 0.90 0.81 11459  PREP 0.98 0.98 0.98 56062  ADJF 0.91 0.93 0.92 68236  NPRO 0.94 0.71 0.81 18482  ADVB 0.84 0.74 0.79 20062  PRED 0.00 0.00 0.00 0  CONJ 0.94 0.94 0.94 41274  Name 0.57 0.89 0.69 3533  Surn 0.50 0.70 0.58 2903  PRTF 0.79 0.84 0.82 6894  COMP 0.00 0.00 0.00 0  NUMR 0.93 0.82 0.87 2723  UNKN 0.46 0.55 0.50 8525  Patr 0.00 0.00 0.00 0  INTJ 0.00 0.00 0.00 0  PRTS 0.60 0.88 0.71 2188  GRND 0.76 0.90 0.83 1806  Geox 0.82 0.93 0.87 5946  ADJS 0.41 0.50 0.45 3034  accuracy 0.90 466818  macro avg 0.64 0.68 0.66 466818 weighted avg 0.92 0.90 0.91 466818
1500,0.4501,0.366432,precision recall f1-score support  NOUN 0.96 0.97 0.97 151236  VERB 0.96 0.95 0.95 42104  INFN 0.98 0.98 0.98 10905  PRCL 0.91 0.83 0.87 15345  PREP 0.99 0.97 0.98 56683  ADJF 0.96 0.91 0.93 74105  NPRO 0.85 0.95 0.90 12595  ADVB 0.88 0.79 0.83 19559  PRED 0.00 0.00 0.00 0  CONJ 0.96 0.95 0.96 41929  Name 0.79 0.48 0.59 9157  Surn 0.00 0.00 0.00 1  PRTF 0.81 0.87 0.84 6810  COMP 0.00 0.00 0.00 0  NUMR 0.87 0.91 0.89 2309  UNKN 0.58 0.51 0.54 11430  Patr 0.00 0.00 0.00 0  INTJ 0.00 0.00 0.00 0  PRTS 0.97 0.78 0.86 4000  GRND 0.77 0.92 0.84 1783  Geox 0.84 0.94 0.89 6002  ADJS 0.17 0.72 0.27 865  accuracy 0.92 466818  macro avg 0.65 0.66 0.64 466818 weighted avg 0.94 0.92 0.93 466818
2000,0.3775,0.342883,precision recall f1-score support  NOUN 0.97 0.96 0.97 153996  VERB 0.95 0.97 0.96 40960  INFN 0.99 0.98 0.99 11047  PRCL 0.91 0.95 0.93 13481  PREP 0.99 0.98 0.99 56113  ADJF 0.95 0.93 0.94 71410  NPRO 0.92 0.92 0.92 13993  ADVB 0.90 0.76 0.83 20831  PRED 0.54 0.94 0.69 878  CONJ 0.95 0.98 0.97 40134  Name 0.75 0.74 0.75 5628  Surn 0.68 0.68 0.68 4076  PRTF 0.90 0.84 0.87 7879  COMP 0.00 0.00 0.00 0  NUMR 0.85 0.95 0.90 2162  UNKN 0.48 0.60 0.53 8173  Patr 0.00 0.00 0.00 0  INTJ 0.00 0.00 0.00 0  PRTS 0.97 0.77 0.86 4044  GRND 0.78 0.95 0.86 1767  Geox 0.89 0.82 0.86 7401  ADJS 0.45 0.59 0.51 2845  accuracy 0.93 466818  macro avg 0.72 0.74 0.73 466818 weighted avg 0.94 0.93 0.94 466818
2500,0.3761,0.315291,precision recall f1-score support  NOUN 0.97 0.97 0.97 153076  VERB 0.96 0.96 0.96 41751  INFN 0.99 0.97 0.98 11171  PRCL 0.92 0.89 0.91 14379  PREP 0.99 0.98 0.99 56110  ADJF 0.97 0.92 0.94 73033  NPRO 0.92 0.93 0.93 13827  ADVB 0.90 0.81 0.85 19489  PRED 0.89 0.66 0.76 2071  CONJ 0.96 0.97 0.97 41049  Name 0.76 0.79 0.77 5321  Surn 0.76 0.65 0.70 4782  PRTF 0.86 0.88 0.87 7192  COMP 0.00 0.00 0.00 0  NUMR 0.91 0.89 0.90 2440  UNKN 0.44 0.68 0.53 6552  Patr 0.00 0.00 0.00 0  INTJ 0.00 0.00 0.00 0  PRTS 0.96 0.86 0.91 3566  GRND 0.81 0.91 0.86 1896  Geox 0.87 0.89 0.88 6582  ADJS 0.52 0.77 0.62 2531  accuracy 0.94 466818  macro avg 0.74 0.74 0.74 466818 weighted avg 0.94 0.94 0.94 466818
3000,0.3494,0.288676,precision recall f1-score support  NOUN 0.97 0.97 0.97 153406  VERB 0.96 0.96 0.96 41553  INFN 0.99 0.98 0.98 11093  PRCL 0.92 0.91 0.91 14232  PREP 0.99 0.98 0.99 55968  ADJF 0.97 0.93 0.95 72202  NPRO 0.91 0.95 0.93 13328  ADVB 0.91 0.82 0.86 19406  PRED 0.89 0.72 0.80 1883  CONJ 0.97 0.97 0.97 41306  Name 0.74 0.83 0.78 4893  Surn 0.68 0.71 0.69 3898  PRTF 0.89 0.87 0.88 7542  COMP 0.00 0.00 0.00 0  NUMR 0.97 0.87 0.92 2677  UNKN 0.55 0.62 0.58 9020  Patr 0.00 0.00 0.00 0  INTJ 0.00 0.00 0.00 0  PRTS 0.96 0.85 0.90 3641  GRND 0.81 0.91 0.86 1905  Geox 0.86 0.93 0.89 6260  ADJS 0.52 0.74 0.61 2605  accuracy 0.94 466818  macro avg 0.75 0.75 0.75 466818 weighted avg 0.95 0.94 0.94 466818


***** Running Evaluation *****
  Num examples = 466818
  Batch size = 8
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Trainer is attempting to log a value of "              precision    recall  f1-score   support

        NOUN       0.93      0.97      0.95    146373
        VERB       0.89      0.92      0.91     40145
        INFN       0.98      0.96      0.97     11173
        PRCL       0.71      0.48      0.57     20933
        PREP       0.99      0.92      0.95     60017
        ADJF       0.90      0.93      0.91     67693
        NPRO       0.87      0.89      0.88     13640
        ADVB       0.37      0.75      0.50      8632
        PRED       0.00      0.00      0.00         0
        CONJ       0.96      0.73      0.83     54352
        Name       0.00      0.00      0.00         1
        Surn       0.91      0.21      0.34     17843
        PRTF       

TrainOutput(global_step=3000, training_loss=0.4744397430419922, metrics={'train_runtime': 7938.1037, 'train_samples_per_second': 3.023, 'train_steps_per_second': 0.378, 'total_flos': 801810453600000.0, 'train_loss': 0.4744397430419922, 'epoch': 0.02})

In [30]:
trainer.evaluate() # Оценка

***** Running Evaluation *****
  Num examples = 466818
  Batch size = 8


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Trainer is attempting to log a value of "              precision    recall  f1-score   support

        NOUN       0.97      0.97      0.97    153406
        VERB       0.96      0.96      0.96     41553
        INFN       0.99      0.98      0.98     11093
        PRCL       0.92      0.91      0.91     14232
        PREP       0.99      0.98      0.99     55968
        ADJF       0.97      0.93      0.95     72202
        NPRO       0.91      0.95      0.93     13328
        ADVB       0.91      0.82      0.86     19406
        PRED       0.89      0.72      0.80      1883
        CONJ       0.97      0.97      0.97     41306
        Name       0.74      0.83      0.78      4893
        Surn       0.68      0.71      0.69      3898
        PRTF       0.89      0.87      0.88      7542
        COMP       0.00      0.00    

{'eval_loss': 0.2886764109134674,
 'eval_classification report': '              precision    recall  f1-score   support\n\n        NOUN       0.97      0.97      0.97    153406\n        VERB       0.96      0.96      0.96     41553\n        INFN       0.99      0.98      0.98     11093\n        PRCL       0.92      0.91      0.91     14232\n        PREP       0.99      0.98      0.99     55968\n        ADJF       0.97      0.93      0.95     72202\n        NPRO       0.91      0.95      0.93     13328\n        ADVB       0.91      0.82      0.86     19406\n        PRED       0.89      0.72      0.80      1883\n        CONJ       0.97      0.97      0.97     41306\n        Name       0.74      0.83      0.78      4893\n        Surn       0.68      0.71      0.69      3898\n        PRTF       0.89      0.87      0.88      7542\n        COMP       0.00      0.00      0.00         0\n        NUMR       0.97      0.87      0.92      2677\n        UNKN       0.55      0.62      0.58      902