In [109]:
from collections import defaultdict, Counter
import os, operator
from progressbar import progressbar as pb
from nltk import word_tokenize
from typing import List, Tuple

from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForTokenClassification
import torch
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

# Read data from disk

In [40]:
data_path = "../SemEval2022-Task11_Train-Dev/RU-Russian/"
pretrained_path = "../../pretrained/sbert_large_nlu_ru/"
SAVE_PATH = "models/token_classification/sbert_large_nlu_ru/"
device = 'cuda'

max_len = 64

In [3]:
with open(os.path.join(data_path, "ru_train.conll")) as f:
    train_file = f.read().splitlines()
    
with open(os.path.join(data_path, "ru_dev.conll")) as f:
    dev_file = f.read().splitlines()

In [29]:
def parse_conll(file) -> Tuple[List, List]:
    texts, labels = [], []
    
    for row in file:
        if row.startswith("#"):
            new_texts, new_labels = [], []
            continue

        if row == "":
            texts.append(new_texts)
            labels.append(new_labels)

        else:
            parts = row.split()
            new_texts.append(parts[0])
            new_labels.append(parts[-1])

    return texts, labels

train_texts, train_labels = parse_conll(train_file)
dev_texts, dev_labels = parse_conll(dev_file)

# Prepare data

In [44]:
uniq_labels = set()
for item in train_labels:
    uniq_labels.update(item)
    
uniq_labels = sorted(list(uniq_labels))
uniq_labels.remove(O_TAG)
uniq_labels.insert(0, O_TAG)
    
GLOBAL_LABEL2ID = {label: idx for idx, label in enumerate(uniq_labels)}
GLOBAL_ID2LABEL = {idx: label for label, idx in GLOBAL_LABEL2ID.items()}
GLOBAL_LABEL2ID

{'O': 0,
 'B-CORP': 1,
 'B-CW': 2,
 'B-GRP': 3,
 'B-LOC': 4,
 'B-PER': 5,
 'B-PROD': 6,
 'I-CORP': 7,
 'I-CW': 8,
 'I-GRP': 9,
 'I-LOC': 10,
 'I-PER': 11,
 'I-PROD': 12}

In [31]:
tokenizer = BertTokenizer.from_pretrained(model_path, local_files_only=True)

O_TAG = 'O'
B_TAG = 'B-'
I_TAG = 'I-'

In [32]:
def pad_sequence(seq, max_len, pad_token):
    if len(seq) >= max_len:
        seq = seq[:max_len]
    else:
        seq = seq + [pad_token] * (max_len - len(seq))
    return seq

In [33]:
def process_tokens(words, labels):
    
    bert_tokens, bio_labels = [tokenizer.cls_token], [O_TAG]
    
    for word, label in zip(words, labels):
        tokens = tokenizer.tokenize(word)
        bert_tokens.extend(tokens)
        
        new_labels = [label] * len(tokens)
        bio_labels.extend(new_labels)
        
    bert_tokens.append(tokenizer.sep_token)
    bio_labels.append(O_TAG)
    
    for i, (token, label) in enumerate(zip(bert_tokens, bio_labels)):
        if token.startswith("##") and label.startswith(B_TAG):
            bio_labels[i] = I_TAG + label[2:]

    encoded_tokens = tokenizer.encode(bert_tokens, add_special_tokens=False)
    
    if len(bio_labels) >= max_len:
        bio_labels[max_len-1] = O_TAG
    
    bio_labels = pad_sequence(bio_labels, max_len, O_TAG)
    encoded_tokens = pad_sequence(encoded_tokens, max_len, tokenizer.pad_token_id)

    return encoded_tokens, bio_labels

In [35]:
def prepare_data_for_ner(texts, labels):
    result = np.zeros((len(texts), max_len), dtype=np.int32)
    fin_labels, fin_labels_encoded = [], []

    for i, (text, label) in pb(enumerate(zip(texts, labels))):
        
        c_words, c_labels = process_tokens(text, label) 
        assert len(c_words) == len(c_labels)
        
        result[i] = c_words
        fin_labels.append(c_labels)
    
    words_ids, labels_ids = [], []
    
    for sentence in fin_labels:
        new_labels = []
        for label in sentence:
            new_labels.append(GLOBAL_LABEL2ID[label])
        fin_labels_encoded.append(new_labels)
    
    return result, fin_labels_encoded

In [36]:
train_words_enc, train_labels_enc = prepare_data_for_ner(train_texts, train_labels)
dev_words_enc, dev_labels_enc = prepare_data_for_ner(dev_texts, dev_labels)

| |                                 #             | 15299 Elapsed Time: 0:00:15
| |        #                                        | 799 Elapsed Time: 0:00:00


# Model

In [45]:
model = BertForTokenClassification.from_pretrained(pretrained_path, num_labels=len(GLOBAL_ID2LABEL))
model.to(device)
print(device)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at ../../pretrained/sbert_large_nlu_ru/ and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


cuda


In [46]:
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, f1_score, accuracy_score

batch_size = 32
patience = 1

In [52]:
train_data = np.stack((train_words_enc, train_labels_enc), axis=1)
dev_data = np.stack((dev_words_enc, dev_labels_enc), axis=1)

train_batches = DataLoader(train_data, batch_size=batch_size, shuffle=True)
dev_batches = DataLoader(dev_data, batch_size=batch_size, shuffle=False)

# Train

In [49]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [50]:
def _train(model, train_loader, optimizer, epoch_num):
    train_loss, train_acc, train_f1 = [], [], []
    model.train()

    for batch_num, batch in enumerate(train_loader):
        X_batch, y_batch = batch[:, 0, :], batch[:, 1, :]
        X_batch = X_batch.type(torch.LongTensor).to(device)
        y_batch = y_batch.type(torch.LongTensor).to(device)
        optimizer.zero_grad()

        out = model(input_ids=X_batch, labels=y_batch.contiguous(), return_dict=True)
        loss = out.loss
        y_pred = out.logits
        y_pred = torch.argmax(y_pred, dim=2)

        y_pred_flatten = torch.flatten(y_pred).cpu().numpy()
        y_batch_flatten = torch.flatten(y_batch).cpu().numpy()
        f1 = f1_score(y_batch_flatten, y_pred_flatten, average="micro")
        accuracy = accuracy_score(y_batch_flatten, y_pred_flatten)

        train_loss.append(loss.item())
        train_acc.append(accuracy)
        train_f1.append(f1)
        
        if batch_num % 50 == 0:
            print(f"TRAIN: Epoch: {epoch_num}, Batch: {batch_num + 1} / {len(train_loader)}, "
                          f"Loss: {loss.item():.3f}, Accuracy: {accuracy:.3f}, F1: {f1:.3f}")

        loss.backward()
        optimizer.step()

    return np.mean(train_loss), np.mean(train_acc), np.mean(train_f1)

def _val(model, val_loader, epoch_num):
    val_loss, val_acc, val_f1 = [], [], []
    model.eval()

    for batch_num, batch in enumerate(val_loader):
        X_batch, y_batch = batch[:, 0, :], batch[:, 1, :]
        X_batch = X_batch.type(torch.LongTensor).to(device)
        y_batch = y_batch.type(torch.LongTensor).to(device)

        out = model(input_ids=X_batch, labels=y_batch.contiguous())
        loss = out.loss
        y_pred = out.logits
        y_pred = torch.argmax(y_pred, dim=2)

        y_pred_flatten = torch.flatten(y_pred).cpu().numpy()
        y_batch_flatten = torch.flatten(y_batch).cpu().numpy()
        f1 = f1_score(y_batch_flatten, y_pred_flatten, average="micro")
        accuracy = accuracy_score(y_batch_flatten, y_pred_flatten)

        val_loss.append(loss.item())
        val_acc.append(accuracy)
        val_f1.append(f1)
        
        if batch_num % 50 == 0:
            print(f"VAL: Epoch: {epoch_num}, Batch: {batch_num + 1} / {len(val_loader)}, "
                          f"Loss: {loss.item():.3f}, Accuracy: {accuracy:.3f}, F1: {f1:.3f}")

    return np.mean(val_loss), np.mean(val_acc), np.mean(val_f1)

In [53]:
last_epoch = 0
dev_losses = []
patience = 1

for epoch in range(1, 25 + 1):
    train_loss, train_acc, train_f1 = _train(model, train_batches, optimizer, epoch)
    dev_loss, dev_acc, dev_f1 = _val(model, dev_batches, epoch)

    if len(dev_losses) == 0 or dev_loss < dev_losses[-1]:
        model.save_pretrained(SAVE_PATH)

    elif last_epoch == 0:
        last_epoch = epoch + patience

    print(f"After epoch #{epoch}:")
    print(f"Train loss: {train_loss:.3f}, Train Accuracy: {train_acc:.3f}, Train F1: {train_f1:.3f}")
    print(f"Dev loss: {dev_loss:.3f}, Dev Accuracy: {dev_acc:.3f}, Dev F1: {dev_f1:.3f}\n")

    dev_losses.append(dev_loss)
    if epoch == last_epoch:
        break

TRAIN: Epoch: 1, Batch: 1 / 479, Loss: 2.521, Accuracy: 0.062, F1: 0.062
TRAIN: Epoch: 1, Batch: 51 / 479, Loss: 0.193, Accuracy: 0.939, F1: 0.939
TRAIN: Epoch: 1, Batch: 101 / 479, Loss: 0.154, Accuracy: 0.946, F1: 0.946
TRAIN: Epoch: 1, Batch: 151 / 479, Loss: 0.111, Accuracy: 0.958, F1: 0.958
TRAIN: Epoch: 1, Batch: 201 / 479, Loss: 0.126, Accuracy: 0.954, F1: 0.954
TRAIN: Epoch: 1, Batch: 251 / 479, Loss: 0.125, Accuracy: 0.964, F1: 0.964
TRAIN: Epoch: 1, Batch: 301 / 479, Loss: 0.107, Accuracy: 0.960, F1: 0.960
TRAIN: Epoch: 1, Batch: 351 / 479, Loss: 0.131, Accuracy: 0.947, F1: 0.947
TRAIN: Epoch: 1, Batch: 401 / 479, Loss: 0.063, Accuracy: 0.979, F1: 0.979
TRAIN: Epoch: 1, Batch: 451 / 479, Loss: 0.063, Accuracy: 0.979, F1: 0.979
VAL: Epoch: 1, Batch: 1 / 25, Loss: 0.101, Accuracy: 0.954, F1: 0.954
After epoch #1:
Train loss: 0.133, Train Accuracy: 0.957, Train F1: 0.957
Dev loss: 0.091, Dev Accuracy: 0.969, Dev F1: 0.969

TRAIN: Epoch: 2, Batch: 1 / 479, Loss: 0.056, Accuracy: 

# Eval

In [54]:
model = BertForTokenClassification.from_pretrained(SAVE_PATH, local_files_only=True)
model.to(device)
model.eval()
print()




In [61]:
def get_predictions(c_model, batches):
    pred_labels = []

    for item in pb(batches):
        item = item[:, 0, :]
        out = c_model(item.to(device))
        logits = out.logits
        logits = logits.argmax(axis=-1).tolist()
        pred_labels.extend(logits)
        
    return pred_labels

In [74]:
def prepare_preds_for_calc_metrics(pred_encoded, test_encoded, test_words):
    pred_extended, test_extended, pred_decoded = [], [], []

    for pred, test, words in zip(pred_encoded, test_encoded, test_words):
        words = words.tolist()
        
        words_encoded = tokenizer.convert_ids_to_tokens(words)

        if 0 in words:
            cut_ind = words.index(0)
        else:
            cut_ind = max_len

        pred, test = pred[:cut_ind], test[:cut_ind]
        pred, test = pred[1:-1], test[1:-1]
        
        pred = [GLOBAL_ID2LABEL[num] for i, num in enumerate(pred) 
            if not words_encoded[i].startswith('##')]
        test = [GLOBAL_ID2LABEL[num] for i, num in enumerate(test) 
            if not words_encoded[i].startswith('##')]

        pred_extended.extend(pred)
        pred_decoded.append(pred)
        test_extended.extend(test)
        assert len(pred) == len(test)
    
    return pred_extended, test_extended, pred_decoded

In [75]:
pred_labels_enc = get_predictions(model, dev_batches)
pred_extended, true_extended, pred_labels_decoded = prepare_preds_for_calc_metrics(
    pred_labels_enc, dev_labels_enc, dev_words_enc
)

100% (25 of 25) |########################| Elapsed Time: 0:00:02 Time:  0:00:02


# Examples

### Good

In [95]:
good_idx = [223, 0, 120]

In [96]:
for idx in good_idx:
    print(f"Text: {dev_texts[idx]}")
    print(f"True: {dev_labels[idx]}")
    print(f"Pred: {pred_labels_decoded[idx]}\n")

Text: ['сан', 'педро', 'де', 'атакама', '(', ')', '—', 'посёлок', 'в', 'чили', '.']
True: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O']
Pred: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O']

Text: ['важным', 'традиционным', 'промыслом', 'является', 'производство', 'пальмового', 'масла', '.']
True: ['O', 'O', 'O', 'O', 'O', 'B-PROD', 'I-PROD', 'O']
Pred: ['O', 'O', 'O', 'O', 'O', 'B-PROD', 'I-PROD', 'O']

Text: ['в', '1862', 'году', 'стал', 'членом', 'правления', 'русского', 'общества', 'пароходства', 'и', 'торговли', '.']
True: ['O', 'O', 'O', 'O', 'O', 'O', 'B-CORP', 'I-CORP', 'I-CORP', 'I-CORP', 'I-CORP', 'O']
Pred: ['O', 'O', 'O', 'O', 'O', 'O', 'B-CORP', 'I-CORP', 'I-CORP', 'I-CORP', 'I-CORP', 'O']



### Bad

In [99]:
bad_idx = [9, 673, 722]

In [100]:
for idx in bad_idx:
    print(f"Text: {dev_texts[idx]}")
    print(f"True: {dev_labels[idx]}")
    print(f"Pred: {pred_labels_decoded[idx]}\n")

Text: ['по', 'состоянию', 'на', '1', 'января', '2014', 'года', 'в', 'минской', 'подземке', 'установлено', '996', 'видеокамер', '.']
True: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-PROD', 'O']
Pred: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']

Text: ['собор', 'святого', 'иоанна', '—', 'протестантская', 'христианская', 'церковь', ',', 'расположена', 'в', 'шанхае', ',', 'китай', ',', 'в', 'районе', 'чаннин', '.']
True: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'B-LOC', 'O', 'O', 'O', 'B-LOC', 'O']
Pred: ['B-LOC', 'O', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'I-LOC']

Text: ['после', 'распада', 'ссср', 'начались', 'массовые', 'отмены', 'пригородных', 'поездов', 'по', 'всей', 'донецкой', 'железной', 'дороге', '.']
True: ['O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-CORP', 'I-CORP', 'I-CORP', 'O']
Pred: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O',

# Metrics

In [101]:
flat_true = [item.replace("I-", "").replace("B-", "") for item in true_extended]
flat_pred = [item.replace("I-", "").replace("B-", "") for item in pred_extended]
assert len(flat_true) == len(flat_pred)

### Micro Avg. on all classes

In [104]:
result = {}
avg = "micro"

result["ALL"] = {
    "precision": precision_score(flat_true, flat_pred, average=avg),
    "recall": recall_score(flat_true, flat_pred, average=avg),
    "f1": f1_score(flat_true, flat_pred, average=avg)
}

### By class

In [106]:
for label in ["O", 'CORP', 'CW', 'GRP', 'LOC', 'PER', 'PROD']:
    precision = precision_score(flat_true, flat_pred, labels=[label], average=avg)
    recall = recall_score(flat_true, flat_pred, labels=[label], average=avg)
    f1 = f1_score(flat_true, flat_pred, labels=[label], average=avg)
    
    result[label] = {
        "precision": precision,
        "recall": recall,
        "f1": f1
    }

In [110]:
pd.DataFrame(result)

Unnamed: 0,ALL,O,CORP,CW,GRP,LOC,PER,PROD
precision,0.923568,0.954357,0.837143,0.789082,0.813008,0.788079,0.82764,0.693215
recall,0.923568,0.966294,0.736181,0.685345,0.719424,0.751579,0.828927,0.804795
f1,0.923568,0.960288,0.783422,0.733564,0.763359,0.769397,0.828283,0.744849
