In [2]:
from collections import defaultdict, Counter
import os, operator
from progressbar import progressbar as pb
from nltk import word_tokenize
from typing import List, Tuple
from copy import deepcopy

from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForTokenClassification
import torch
import numpy as np
import pandas as pd

from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

import gdown #! pip install gdown

In [29]:
! gdown https://drive.google.com/uc?id=1-9W24SpNbBEL9Z_tVfT52UJdi-tgxh21

Downloading...
From: https://drive.google.com/uc?id=1-9W24SpNbBEL9Z_tVfT52UJdi-tgxh21
To: /Users/alex/Python/multiconer/my_experiments/models/token_classification/token_classification_sbert_large_nlu_ru.zip
100%|██████████████████████████████████████| 1.58G/1.58G [02:23<00:00, 11.0MB/s]


In [30]:
! unzip token_classification_sbert_large_nlu_ru.zip

Archive:  token_classification_sbert_large_nlu_ru.zip
   creating: token_classification_sbert_large_nlu_ru/
  inflating: token_classification_sbert_large_nlu_ru/config.json  
  inflating: __MACOSX/token_classification_sbert_large_nlu_ru/._config.json  
  inflating: token_classification_sbert_large_nlu_ru/pytorch_model.bin  
  inflating: __MACOSX/token_classification_sbert_large_nlu_ru/._pytorch_model.bin  


In [16]:
data_path = "../../../SemEval2022-Task11_Train-Dev/RU-Russian/"
pretrained = "sberbank-ai/sbert_large_nlu_ru"
SAVE_PATH = "token_classification_sbert_large_nlu_ru/"
device = 'cpu'

max_len = 64

In [4]:
with open(os.path.join(data_path, "ru_dev.conll")) as f:
    dev_file = f.read().splitlines()

In [5]:
def parse_conll(file) -> Tuple[List, List]:
    texts, labels = [], []
    
    for row in file:
        if row.startswith("#"):
            new_texts, new_labels = [], []
            continue

        if row == "":
            texts.append(new_texts)
            labels.append(new_labels)

        else:
            parts = row.split()
            new_texts.append(parts[0])
            new_labels.append(parts[-1])

    return texts, labels

dev_texts, dev_labels = parse_conll(dev_file)

In [6]:
len(dev_texts)

800

In [7]:
GLOBAL_LABEL2ID = {
    'O': 0, 'B-CORP': 1, 'B-CW': 2,
    'B-GRP': 3, 'B-LOC': 4, 'B-PER': 5,
    'B-PROD': 6, 'I-CORP': 7, 'I-CW': 8,
    'I-GRP': 9, 'I-LOC': 10, 'I-PER': 11,
    'I-PROD': 12
}

GLOBAL_ID2LABEL = {idx: label for label, idx in GLOBAL_LABEL2ID.items()}

In [10]:
tokenizer = BertTokenizer.from_pretrained(pretrained, local_files_only=True)

O_TAG = 'O'
B_TAG = 'B-'
I_TAG = 'I-'

In [11]:
def pad_sequence(seq, max_len, pad_token):
    if len(seq) >= max_len:
        seq = seq[:max_len]
    else:
        seq = seq + [pad_token] * (max_len - len(seq))
    return seq

In [12]:
def process_tokens(words, labels):
    
    bert_tokens, bio_labels = [tokenizer.cls_token], [O_TAG]
    
    for word, label in zip(words, labels):
        tokens = tokenizer.tokenize(word)
        bert_tokens.extend(tokens)
        
        new_labels = [label] * len(tokens)
        bio_labels.extend(new_labels)
        
    bert_tokens.append(tokenizer.sep_token)
    bio_labels.append(O_TAG)
    
    for i, (token, label) in enumerate(zip(bert_tokens, bio_labels)):
        if token.startswith("##") and label.startswith(B_TAG):
            bio_labels[i] = I_TAG + label[2:]

    encoded_tokens = tokenizer.encode(bert_tokens, add_special_tokens=False)
    
    if len(bio_labels) >= max_len:
        bio_labels[max_len-1] = O_TAG
    
    bio_labels = pad_sequence(bio_labels, max_len, O_TAG)
    encoded_tokens = pad_sequence(encoded_tokens, max_len, tokenizer.pad_token_id)

    return encoded_tokens, bio_labels

In [13]:
def prepare_data_for_ner(texts, labels):
    result = np.zeros((len(texts), max_len), dtype=np.int32)
    fin_labels, fin_labels_encoded = [], []

    for i, (text, label) in pb(enumerate(zip(texts, labels))):
        
        c_words, c_labels = process_tokens(text, label) 
        assert len(c_words) == len(c_labels)
        
        result[i] = c_words
        fin_labels.append(c_labels)
    
    words_ids, labels_ids = [], []
    
    for sentence in fin_labels:
        new_labels = []
        for label in sentence:
            new_labels.append(GLOBAL_LABEL2ID[label])
        fin_labels_encoded.append(new_labels)
    
    return result, fin_labels_encoded

In [14]:
dev_words_enc, dev_labels_enc = prepare_data_for_ner(dev_texts, dev_labels)
dev_data = np.stack((dev_words_enc, dev_labels_enc), axis=1)
dev_batches = DataLoader(dev_data, batch_size=4, shuffle=False)

| |           #                                     | 799 Elapsed Time: 0:00:01


# Load model

In [17]:
model = BertForTokenClassification.from_pretrained(SAVE_PATH, local_files_only=True)
model.to(device)
model.eval()
print()




In [18]:
def get_predictions(c_model, batches):
    pred_labels = []

    for item in pb(batches):
        item = item[:, 0, :]
        out = c_model(item.to(device))
        logits = out.logits
        logits = logits.argmax(axis=-1).tolist()
        pred_labels.extend(logits)
        
    return pred_labels

In [19]:
def prepare_preds_for_calc_metrics(pred_encoded, test_encoded, test_words):
    pred_extended, test_extended, pred_decoded = [], [], []

    for pred, test, words in zip(pred_encoded, test_encoded, test_words):
        words = words.tolist()
        
        words_encoded = tokenizer.convert_ids_to_tokens(words)

        if 0 in words:
            cut_ind = words.index(0)
        else:
            cut_ind = max_len

        pred, test = pred[:cut_ind], test[:cut_ind]
        pred, test = pred[1:-1], test[1:-1]
        
        pred = [GLOBAL_ID2LABEL[num] for i, num in enumerate(pred) 
            if not words_encoded[i].startswith('##')]
        test = [GLOBAL_ID2LABEL[num] for i, num in enumerate(test) 
            if not words_encoded[i].startswith('##')]

        pred_extended.extend(pred)
        pred_decoded.append(pred)
        test_extended.extend(test)
        assert len(pred) == len(test)
    
    return pred_extended, test_extended, pred_decoded

In [20]:
pred_labels_enc = get_predictions(model, dev_batches)
pred_extended, true_extended, pred_labels_decoded = prepare_preds_for_calc_metrics(
    pred_labels_enc, dev_labels_enc, dev_words_enc
)

100% (200 of 200) |######################| Elapsed Time: 0:06:59 Time:  0:06:59


# Examples

### Good

In [21]:
good_idx = [223, 0, 120]

In [22]:
for idx in good_idx:
    print(f"Text: {dev_texts[idx]}")
    print(f"True: {dev_labels[idx]}")
    print(f"Pred: {pred_labels_decoded[idx]}\n")

Text: ['сан', 'педро', 'де', 'атакама', '(', ')', '—', 'посёлок', 'в', 'чили', '.']
True: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O']
Pred: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O']

Text: ['важным', 'традиционным', 'промыслом', 'является', 'производство', 'пальмового', 'масла', '.']
True: ['O', 'O', 'O', 'O', 'O', 'B-PROD', 'I-PROD', 'O']
Pred: ['O', 'O', 'O', 'O', 'O', 'B-PROD', 'I-PROD', 'O']

Text: ['в', '1862', 'году', 'стал', 'членом', 'правления', 'русского', 'общества', 'пароходства', 'и', 'торговли', '.']
True: ['O', 'O', 'O', 'O', 'O', 'O', 'B-CORP', 'I-CORP', 'I-CORP', 'I-CORP', 'I-CORP', 'O']
Pred: ['O', 'O', 'O', 'O', 'O', 'O', 'B-CORP', 'I-CORP', 'I-CORP', 'I-CORP', 'I-CORP', 'O']



### Bad

In [23]:
bad_idx = [9, 673, 722]

In [24]:
for idx in bad_idx:
    print(f"Text: {dev_texts[idx]}")
    print(f"True: {dev_labels[idx]}")
    print(f"Pred: {pred_labels_decoded[idx]}\n")

Text: ['по', 'состоянию', 'на', '1', 'января', '2014', 'года', 'в', 'минской', 'подземке', 'установлено', '996', 'видеокамер', '.']
True: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-PROD', 'O']
Pred: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']

Text: ['собор', 'святого', 'иоанна', '—', 'протестантская', 'христианская', 'церковь', ',', 'расположена', 'в', 'шанхае', ',', 'китай', ',', 'в', 'районе', 'чаннин', '.']
True: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'B-LOC', 'O', 'O', 'O', 'B-LOC', 'O']
Pred: ['B-LOC', 'O', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'I-LOC']

Text: ['после', 'распада', 'ссср', 'начались', 'массовые', 'отмены', 'пригородных', 'поездов', 'по', 'всей', 'донецкой', 'железной', 'дороге', '.']
True: ['O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-CORP', 'I-CORP', 'I-CORP', 'O']
Pred: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O',

# Calc Metrics

In [25]:
flat_true = [item.replace("I-", "").replace("B-", "") for item in true_extended]
flat_pred = [item.replace("I-", "").replace("B-", "") for item in pred_extended]
assert len(flat_true) == len(flat_pred)

### Micro Avg. on all classes

In [26]:
result = {}
avg = "micro"

result["ALL"] = {
    "precision": precision_score(flat_true, flat_pred, average=avg),
    "recall": recall_score(flat_true, flat_pred, average=avg),
    "f1": f1_score(flat_true, flat_pred, average=avg)
}

### By class

In [27]:
for label in ["O", 'CORP', 'CW', 'GRP', 'LOC', 'PER', 'PROD']:
    precision = precision_score(flat_true, flat_pred, labels=[label], average=avg)
    recall = recall_score(flat_true, flat_pred, labels=[label], average=avg)
    f1 = f1_score(flat_true, flat_pred, labels=[label], average=avg)
    
    result[label] = {
        "precision": precision,
        "recall": recall,
        "f1": f1
    }

In [28]:
pd.DataFrame(result)

Unnamed: 0,ALL,O,CORP,CW,GRP,LOC,PER,PROD
precision,0.923568,0.954357,0.837143,0.789082,0.813008,0.788079,0.82764,0.693215
recall,0.923568,0.966294,0.736181,0.685345,0.719424,0.751579,0.828927,0.804795
f1,0.923568,0.960288,0.783422,0.733564,0.763359,0.769397,0.828283,0.744849
