In [148]:
#!g1.1
%pip install sentencepiece

Defaulting to user installation because normal site-packages is not writeable


In [195]:
#!g1.1
%pip install --upgrade transformers

Defaulting to user installation because normal site-packages is not writeable


In [241]:
#!g1.1
from transformers import RemBertTokenizer, RemBertForTokenClassification

In [242]:
#!g1.1
languages = ['bn', 'de', 'en', 'es', 'fa', 'hi', 'ko', 'nl', 'ru', 'tr', 'zh']

# Read data

In [243]:
#!g1.1
data_path = "data/"
SAVE_PATH = "model/"

max_len = 64

In [244]:
#!g1.1
def parse_conll(file):
    texts, labels = [], []
    new_texts, new_labels = [], []
    
    with open(file, 'r') as f:
        for row in f:
            row = row.replace('\n', '')
            if row.startswith("#"):
                new_texts, new_labels = [], []
                continue

            if row == "":
                texts.append(new_texts)
                labels.append(new_labels)

            else:
                parts = row.split()
                new_texts.append(parts[0])
                new_labels.append(parts[-1])

    return texts, labels

In [124]:
#!g1.1
train_texts, train_labels = [], []
dev_texts, dev_labels = [], []

for lang in languages:
    train_path = data_path + lang.upper() + '/' + lang + '_train.conll'
    dev_path = data_path + lang.upper() + '/' + lang + '_dev.conll'
    train_t, train_l = parse_conll(train_path)
    dev_t, dev_l = parse_conll(dev_path)
    train_texts += train_t
    train_labels += train_l
    dev_texts += dev_t
    dev_labels += dev_l

# Prepare data

In [245]:
#!g1.1
import torch
import numpy as np
from tqdm.notebook import tqdm

In [126]:
#!g1.1
O_TAG = 'O'
B_TAG = 'B-'
I_TAG = 'I-'

uniq_labels = set()

for item in train_labels:
    uniq_labels.update(item)
    
uniq_labels = sorted(list(uniq_labels))
uniq_labels.remove(O_TAG)
uniq_labels.insert(0, O_TAG)
    
GLOBAL_LABEL2ID = {label: idx for idx, label in enumerate(uniq_labels)}
GLOBAL_ID2LABEL = {idx: label for label, idx in GLOBAL_LABEL2ID.items()}
GLOBAL_LABEL2ID

{'O': 0,
 'B-CORP': 1,
 'B-CW': 2,
 'B-GRP': 3,
 'B-LOC': 4,
 'B-PER': 5,
 'B-PROD': 6,
 'I-CORP': 7,
 'I-CW': 8,
 'I-GRP': 9,
 'I-LOC': 10,
 'I-PER': 11,
 'I-PROD': 12}

In [246]:
#!g1.1
tokenizer = RemBertTokenizer.from_pretrained('google/rembert')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=4697711.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=156.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=263.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=8714136.0, style=ProgressStyle(descript…




In [247]:
#!g1.1
def pad_sequence(seq, max_len, pad_token):
    if len(seq) >= max_len:
        seq = seq[:max_len]
    else:
        seq = seq + [pad_token] * (max_len - len(seq))
    return seq

In [248]:
#!g1.1
def process_tokens(words, labels):
    bert_tokens, bio_labels = [tokenizer.cls_token], [O_TAG]
    
    for word, label in zip(words, labels):
        tokens = tokenizer.tokenize(word)
        bert_tokens.extend(tokens)
        new_labels = [label] * len(tokens)
        bio_labels.extend(new_labels)
        
    bert_tokens.append(tokenizer.sep_token)
    bio_labels.append(O_TAG)
    
    for i, (token, label) in enumerate(zip(bert_tokens, bio_labels)):
        if token.startswith("##") and label.startswith(B_TAG):
            bio_labels[i] = I_TAG + label[2:]

    encoded_tokens = tokenizer.encode(bert_tokens, add_special_tokens=False)
    
    if len(bio_labels) >= max_len:
        bio_labels[max_len-1] = O_TAG
    
    bio_labels = pad_sequence(bio_labels, max_len, O_TAG)
    encoded_tokens = pad_sequence(encoded_tokens, max_len, tokenizer.pad_token_id)

    return encoded_tokens, bio_labels

In [249]:
#!g1.1
def prepare_data_for_ner(texts, labels):
    result = np.zeros((len(texts), max_len), dtype=np.int32)
    fin_labels, fin_labels_encoded = [], []

    for i, (text, label) in tqdm(enumerate(zip(texts, labels))):
        
        c_words, c_labels = process_tokens(text, label) 
        assert len(c_words) == len(c_labels)
        
        result[i] = c_words
        fin_labels.append(c_labels)
    
    words_ids, labels_ids = [], []
    
    for sentence in fin_labels:
        new_labels = []
        for label in sentence:
            new_labels.append(GLOBAL_LABEL2ID[label])
        fin_labels_encoded.append(new_labels)
    
    return result, fin_labels_encoded

In [131]:
#!g1.1
train_words_enc, train_labels_enc = prepare_data_for_ner(train_texts, train_labels)
dev_words_enc, dev_labels_enc = prepare_data_for_ner(dev_texts, dev_labels)

HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…





# Model

In [250]:
#!g1.1
device = 'cuda'

In [134]:
#!g1.1
model = RemBertForTokenClassification.from_pretrained('google/rembert', num_labels=len(GLOBAL_ID2LABEL))
model.to(device)
print(device)

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=2303882157.0), HTML(value='')))

Some weights of the model checkpoint at google/rembert were not used when initializing RemBertForTokenClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing RemBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RemBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RemBertForTokenClassification were not initialized from the model checkpoint at google/rembert and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



cuda


In [251]:
#!g1.1
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, f1_score, accuracy_score

batch_size = 32
patience = 1

In [136]:
#!g1.1
train_data = np.stack((train_words_enc, train_labels_enc), axis=1)
dev_data = np.stack((dev_words_enc, dev_labels_enc), axis=1)

train_batches = DataLoader(train_data, batch_size=batch_size, shuffle=True)
dev_batches = DataLoader(dev_data, batch_size=batch_size, shuffle=False)

# Train

In [137]:
#!g1.1
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=1)

In [138]:
#!g1.1
def _train(model, train_loader, optimizer, epoch_num):
    train_loss, train_acc, train_f1 = [], [], []
    model.train()

    for batch_num, batch in enumerate(train_loader):
        X_batch, y_batch = batch[:, 0, :], batch[:, 1, :]
        X_batch = X_batch.type(torch.LongTensor).to(device)
        y_batch = y_batch.type(torch.LongTensor).to(device)
        optimizer.zero_grad()

        out = model(input_ids=X_batch, labels=y_batch.contiguous(), return_dict=True)
        loss = out.loss
        y_pred = out.logits
        y_pred = torch.argmax(y_pred, dim=2)

        y_pred_flatten = torch.flatten(y_pred).cpu().numpy()
        y_batch_flatten = torch.flatten(y_batch).cpu().numpy()
        f1 = f1_score(y_batch_flatten, y_pred_flatten, average="micro")
        accuracy = accuracy_score(y_batch_flatten, y_pred_flatten)

        train_loss.append(loss.item())
        train_acc.append(accuracy)
        train_f1.append(f1)
        
        X_batch = X_batch.to('cpu')
        y_batch = y_batch.to('cpu')
        
        if batch_num % 50 == 0:
            print(f"TRAIN: Epoch: {epoch_num}, Batch: {batch_num + 1} / {len(train_loader)}, "
                          f"Loss: {loss.item():.3f}, Accuracy: {accuracy:.3f}, F1: {f1:.3f}")

        loss.backward()
        optimizer.step()

    return np.mean(train_loss), np.mean(train_acc), np.mean(train_f1)

def _val(model, val_loader, epoch_num):
    val_loss, val_acc, val_f1 = [], [], []
    model.eval()

    for batch_num, batch in enumerate(val_loader):
        X_batch, y_batch = batch[:, 0, :], batch[:, 1, :]
        X_batch = X_batch.type(torch.LongTensor).to(device)
        y_batch = y_batch.type(torch.LongTensor).to(device)

        out = model(input_ids=X_batch, labels=y_batch.contiguous())
        loss = out.loss
        y_pred = out.logits
        y_pred = torch.argmax(y_pred, dim=2)

        y_pred_flatten = torch.flatten(y_pred).cpu().numpy()
        y_batch_flatten = torch.flatten(y_batch).cpu().numpy()
        f1 = f1_score(y_batch_flatten, y_pred_flatten, average="micro")
        accuracy = accuracy_score(y_batch_flatten, y_pred_flatten)

        val_loss.append(loss.item())
        val_acc.append(accuracy)
        val_f1.append(f1)
        
        X_batch = X_batch.to('cpu')
        y_batch = y_batch.to('cpu')
        
        if batch_num % 50 == 0:
            print(f"VAL: Epoch: {epoch_num}, Batch: {batch_num + 1} / {len(val_loader)}, "
                          f"Loss: {loss.item():.3f}, Accuracy: {accuracy:.3f}, F1: {f1:.3f}")

    return np.mean(val_loss), np.mean(val_acc), np.mean(val_f1)

In [139]:
#!g1.1
last_epoch = 0
dev_losses = []
patience = 1

for epoch in range(1, 20):
    train_loss, train_acc, train_f1 = _train(model, train_batches, optimizer, epoch)
    dev_loss, dev_acc, dev_f1 = _val(model, dev_batches, epoch)
    scheduler.step(dev_loss)

    if len(dev_losses) == 0 or dev_loss < dev_losses[-1]:
        model.save_pretrained(SAVE_PATH + str(epoch))

    elif last_epoch == 0:
        last_epoch = epoch + patience

    print(f"After epoch #{epoch}:")
    print(f"Train loss: {train_loss:.3f}, Train Accuracy: {train_acc:.3f}, Train F1: {train_f1:.3f}")
    print(f"Dev loss: {dev_loss:.3f}, Dev Accuracy: {dev_acc:.3f}, Dev F1: {dev_f1:.3f}\n")

    dev_losses.append(dev_loss)
    if epoch == last_epoch:
        model.save_pretrained(SAVE_PATH + str(epoch))
        break

TRAIN: Epoch: 1, Batch: 1 / 5260, Loss: 2.723, Accuracy: 0.035, F1: 0.035
TRAIN: Epoch: 1, Batch: 51 / 5260, Loss: 0.309, Accuracy: 0.912, F1: 0.912
TRAIN: Epoch: 1, Batch: 101 / 5260, Loss: 0.220, Accuracy: 0.935, F1: 0.935
TRAIN: Epoch: 1, Batch: 151 / 5260, Loss: 0.205, Accuracy: 0.930, F1: 0.930
TRAIN: Epoch: 1, Batch: 201 / 5260, Loss: 0.123, Accuracy: 0.964, F1: 0.964
TRAIN: Epoch: 1, Batch: 251 / 5260, Loss: 0.091, Accuracy: 0.969, F1: 0.969
TRAIN: Epoch: 1, Batch: 301 / 5260, Loss: 0.154, Accuracy: 0.949, F1: 0.949
TRAIN: Epoch: 1, Batch: 351 / 5260, Loss: 0.104, Accuracy: 0.969, F1: 0.969
TRAIN: Epoch: 1, Batch: 401 / 5260, Loss: 0.141, Accuracy: 0.963, F1: 0.963
TRAIN: Epoch: 1, Batch: 451 / 5260, Loss: 0.084, Accuracy: 0.974, F1: 0.974
TRAIN: Epoch: 1, Batch: 501 / 5260, Loss: 0.118, Accuracy: 0.966, F1: 0.966
TRAIN: Epoch: 1, Batch: 551 / 5260, Loss: 0.106, Accuracy: 0.960, F1: 0.960
TRAIN: Epoch: 1, Batch: 601 / 5260, Loss: 0.116, Accuracy: 0.962, F1: 0.962
TRAIN: Epoch: 1

# Evaluation

In [140]:
#!g1.1
%pip install allennlp

Defaulting to user installation because normal site-packages is not writeable
Collecting allennlp
  Downloading allennlp-2.8.0-py3-none-any.whl (738 kB)
     |████████████████████████████████| 738 kB 1.3 MB/s            
Collecting lmdb
  Downloading lmdb-1.3.0-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (305 kB)
     |████████████████████████████████| 305 kB 27.7 MB/s            
Collecting datasets<2.0,>=1.2.1
  Downloading datasets-1.17.0-py3-none-any.whl (306 kB)
     |████████████████████████████████| 306 kB 29.2 MB/s            
Collecting base58
  Downloading base58-2.1.1-py3-none-any.whl (5.6 kB)
Collecting checklist==0.0.11
  Downloading checklist-0.0.11.tar.gz (12.1 MB)
     |████████████████████████████████| 12.1 MB 15.3 MB/s            
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting jsonnet>=0.10.0
  Downloading jsonnet-0.18.0.tar.gz (592 kB)
     |████████████████████████████████| 592 kB 38.2 MB/s            
[?25h  Preparing metadata (

In [162]:
#!g1.1
model = RemBertForTokenClassification.from_pretrained(SAVE_PATH + '2', local_files_only=True)
model.to(device)
model.eval()
print()




In [177]:
#!g1.1
def get_predictions(c_model, batches):
    pred_labels = []

    for item in tqdm(batches):
        item = item[:, 0, :]
        out = c_model(item.to(device))
        logits = out.logits
        logits = logits.argmax(axis=-1).tolist()
        pred_labels.extend(logits)
        
    return pred_labels

def prepare_preds_for_calc_metrics(pred_encoded, test_encoded, test_words):
    pred_extended, test_extended, pred_decoded = [], [], []

    for pred, test, words in tqdm(zip(pred_encoded, test_encoded, test_words)):
        words = words.tolist()
        
        words_encoded = tokenizer.convert_ids_to_tokens(words)

        if 0 in words:
            cut_ind = words.index(0)
        else:
            cut_ind = max_len

        pred, test = pred[:cut_ind], test[:cut_ind]
        pred, test = pred[1:-1], test[1:-1]
        
        pred = [GLOBAL_ID2LABEL[num] for i, num in enumerate(pred) 
            if not words_encoded[i].startswith('##')]
        test = [GLOBAL_ID2LABEL[num] for i, num in enumerate(test) 
            if not words_encoded[i].startswith('##')]

        pred_extended.extend(pred)
        pred_decoded.append(pred)
        test_extended.append(test)
        assert len(pred) == len(test)
    
    return pred_extended, test_extended, pred_decoded

In [190]:
#!g1.1
true = {}
preds = {}
for l in languages:
    dev_path = data_path + l.upper() + '/' + l + '_dev.conll'
    dev_texts, dev_labels = parse_conll(dev_path)
    dev_words_enc, dev_labels_enc = prepare_data_for_ner(dev_texts, dev_labels)
    dev_data = np.stack((dev_words_enc, dev_labels_enc), axis=1)
    dev_batches = DataLoader(dev_data, batch_size=batch_size, shuffle=False)
    pred_labels_enc = get_predictions(model, dev_batches)
    pred_extended, true_extended, pred_labels_decoded = prepare_preds_for_calc_metrics(
        pred_labels_enc, dev_labels_enc, dev_words_enc
    )
    true[l] = true_extended
    preds[l] = pred_labels_decoded

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




































In [191]:
#!g1.1
print(true['ru'][3])
print(preds['ru'][3])

['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-PER', 'B-PER', 'B-PER', 'I-PER', 'I-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'B-CW', 'B-CW', 'B-CW', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-PER', 'B-PER', 'B-PER', 'I-PER', 'I-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'B-CW', 'B-CW', 'B-CW', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


In [183]:
#!g1.1
def get_spans(labels):
    fin_spans = []
    for item_ in labels:

        item = deepcopy(item_)
        item.insert(0, "O")
        item.append("O")

        new_spans = {}
        for i, label in enumerate(item[1:-1], 1):

            if item[i] == "O":
                new_spans[(i-1, i-1)] = "O"
            else:
                if item[i-1] == 'O':
                    start_i = i
                if item[i+1] == 'O':
                    new_spans[(start_i-1, i-1)] = item[i].split('-')[1]
                    
        fin_spans.append(new_spans)
                
    return fin_spans

In [168]:
#!g1.1
from copy import deepcopy

In [166]:
#!g1.1
from collections import defaultdict
from typing import Set
from allennlp.training.metrics.metric import Metric


class SpanF1(Metric):
    def __init__(self, non_entity_labels=['O']) -> None:
        self._num_gold_mentions = 0
        self._num_recalled_mentions = 0
        self._num_predicted_mentions = 0
        self._TP, self._FP, self._GT = defaultdict(int), defaultdict(int), defaultdict(int)
        self.non_entity_labels = set(non_entity_labels)

    def __call__(self, batched_predicted_spans, batched_gold_spans, sentences=None):
        non_entity_labels = self.non_entity_labels
        for predicted_spans, gold_spans in zip(batched_predicted_spans, batched_gold_spans):
            gold_spans_set = set([x for x, y in gold_spans.items() if y not in non_entity_labels])
            pred_spans_set = set([x for x, y in predicted_spans.items() if y not in non_entity_labels])

            self._num_gold_mentions += len(gold_spans_set)
            self._num_recalled_mentions += len(gold_spans_set & pred_spans_set)
            self._num_predicted_mentions += len(pred_spans_set)

            for ky, val in gold_spans.items():
                if val not in non_entity_labels:
                    self._GT[val] += 1

            for ky, val in predicted_spans.items():
                if val in non_entity_labels:
                    continue
                if ky in gold_spans and val == gold_spans[ky]:
                    self._TP[val] += 1
                else:
                    self._FP[val] += 1
    
    def get_metric(self, reset: bool = False) -> float:
        all_tags: Set[str] = set()
        all_tags.update(self._TP.keys())
        all_tags.update(self._FP.keys())
        all_tags.update(self._GT.keys())
        all_metrics = {}

        for tag in all_tags:
            precision, recall, f1_measure = self.compute_prf_metrics(true_positives=self._TP[tag],
                                                                     false_negatives=self._GT[tag] - self._TP[tag],
                                                                     false_positives=self._FP[tag])
            all_metrics['P@{}'.format(tag)] = precision
            all_metrics['R@{}'.format(tag)] = recall
            all_metrics['F1@{}'.format(tag)] = f1_measure

        # Compute the precision, recall and f1 for all spans jointly.
        precision, recall, f1_measure = self.compute_prf_metrics(true_positives=sum(self._TP.values()),
                                                                 false_positives=sum(self._FP.values()),
                                                                 false_negatives=sum(self._GT.values())-sum(self._TP.values()))
        all_metrics["micro@P"] = precision
        all_metrics["micro@R"] = recall
        all_metrics["micro@F1"] = f1_measure

        if self._num_gold_mentions == 0:
            entity_recall = 0.0
        else:
            entity_recall = self._num_recalled_mentions / float(self._num_gold_mentions)

        if self._num_predicted_mentions == 0:
            entity_precision = 0.0
        else:
            entity_precision = self._num_recalled_mentions / float(self._num_predicted_mentions)

        all_metrics['MD@R'] = entity_recall
        all_metrics['MD@P'] = entity_precision
        all_metrics['MD@F1'] = 2. * ((entity_precision * entity_recall) / (entity_precision + entity_recall + 1e-13))
        all_metrics['ALLTRUE'] = self._num_gold_mentions
        all_metrics['ALLRECALLED'] = self._num_recalled_mentions
        all_metrics['ALLPRED'] = self._num_predicted_mentions
        if reset:
            self.reset()
        return all_metrics

    @staticmethod
    def compute_prf_metrics(true_positives: int, false_positives: int, false_negatives: int):
        precision = float(true_positives) / float(true_positives + false_positives + 1e-13)
        recall = float(true_positives) / float(true_positives + false_negatives + 1e-13)
        f1_measure = 2. * ((precision * recall) / (precision + recall + 1e-13))
        return precision, recall, f1_measure
    
    def reset(self):
        self._num_gold_mentions = 0
        self._num_recalled_mentions = 0
        self._num_predicted_mentions = 0
        self._TP.clear()
        self._FP.clear()
        self._GT.clear()

In [185]:
#!g1.1
import pandas as pd

In [192]:
#!g1.1
metrics = {}

for lang in languages:
    true_spans = get_spans(true[lang])
    pred_spans = get_spans(preds[lang])
    
    span_f1 = SpanF1()
    span_f1(pred_spans, true_spans)
    cur_metric = span_f1.get_metric()
    metrics[lang] = cur_metric

In [193]:
#!g1.1
pd.options.display.float_format = '{:.3f}'.format
df = pd.DataFrame(index=list(metrics["ru"].keys()))

for lang, metric in metrics.items():
    df[lang] = list(metric.values())

In [194]:
#!g1.1
df

Unnamed: 0,bn,de,en,es,fa,hi,ko,nl,ru,tr,zh
P@PER,0.919,0.937,0.946,0.931,0.731,0.909,0.759,0.966,0.796,0.826,0.838
R@PER,0.944,0.957,0.979,0.927,0.87,0.846,0.824,0.938,0.833,0.901,0.927
F1@PER,0.932,0.947,0.962,0.929,0.795,0.876,0.79,0.951,0.814,0.862,0.88
P@LOC,0.763,0.926,0.935,0.849,0.811,0.786,0.769,0.925,0.697,0.84,0.893
R@LOC,0.861,0.906,0.939,0.881,0.808,0.853,0.775,0.899,0.748,0.896,0.918
F1@LOC,0.809,0.916,0.937,0.865,0.81,0.818,0.772,0.912,0.721,0.867,0.905
P@CW,0.696,0.845,0.746,0.796,0.747,0.62,0.679,0.846,0.765,0.753,0.684
R@CW,0.65,0.849,0.759,0.771,0.673,0.727,0.759,0.799,0.825,0.785,0.761
F1@CW,0.672,0.847,0.752,0.783,0.708,0.669,0.717,0.822,0.794,0.768,0.72
P@PROD,0.63,0.833,0.771,0.721,0.652,0.686,0.743,0.745,0.777,0.714,0.764


# Run and record predictions for train set


In [252]:
#!g1.1
model = RemBertForTokenClassification.from_pretrained(SAVE_PATH + '2', local_files_only=True)
model.to(device)
model.eval()
print()




In [260]:
#!g1.1
def get_prediction_probs(c_model, batches, dev_words_enc):
    pred_probs = []
    for i, item in tqdm(enumerate(dev_batches)):
        # get prediction probabilities
        item = item[:, 0, :]
        out = model(item.to(device))
        logits = torch.nn.functional.softmax(out.logits, dim=2).cpu().detach().numpy()
        # get tokens
        words = dev_words_enc[i].tolist()
        words_encoded = tokenizer.convert_ids_to_tokens(words)
        if 0 in words:
            cut_ind = words.index(0)
        else:
            cut_ind = max_len
        # record needed predictions
        pred = [logits[0, i, :] for i in range(cut_ind) if words_encoded[i].startswith('▁')]
        pred_probs.append(pred)
    return pred_probs

In [264]:
#!g1.1
prediction_probs = {'label2id': GLOBAL_LABEL2ID}
for l in tqdm(languages):
    # read data
    dev_path = data_path + l.upper() + '/' + l + '_train.conll'
    dev_texts, dev_labels = parse_conll(dev_path)
    dev_words_enc, dev_labels_enc = prepare_data_for_ner(dev_texts, dev_labels)
    dev_data = np.stack((dev_words_enc, dev_labels_enc), axis=1)
    dev_batches = DataLoader(dev_data, batch_size=1, shuffle=False)
    # get prediction probabilities
    probs = get_prediction_probs(model, dev_batches, dev_words_enc)
    # add to dictionary
    sentences = []
    for sent, sent_probs, labels in zip(dev_texts, probs, dev_labels):
        sentence = []
        for token, prob, label in zip(sent, sent_probs, labels):
            sentence.append({'token': token, 'output_probs': prob, 'true_label': label})
        sentences.append(sentence)
    prediction_probs[l] = sentences

HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…





In [263]:
#!g1.1
GLOBAL_LABEL2ID

{'O': 0,
 'B-CORP': 1,
 'B-CW': 2,
 'B-GRP': 3,
 'B-LOC': 4,
 'B-PER': 5,
 'B-PROD': 6,
 'I-CORP': 7,
 'I-CW': 8,
 'I-GRP': 9,
 'I-LOC': 10,
 'I-PER': 11,
 'I-PROD': 12}

In [265]:
#!g1.1
prediction_probs['en'][0]

[{'token': 'his',
  'output_probs': array([0.9999131 , 0.00001098, 0.00003176, 0.00000448, 0.00000055,
         0.00000759, 0.00000263, 0.0000053 , 0.00001694, 0.00000272,
         0.00000034, 0.00000269, 0.00000095], dtype=float32),
  'true_label': 'O'},
 {'token': 'playlist',
  'output_probs': array([0.9449786 , 0.00615626, 0.04285581, 0.00146176, 0.00009282,
         0.0001821 , 0.00148606, 0.00087573, 0.00157215, 0.00012599,
         0.00004352, 0.00004381, 0.00012537], dtype=float32),
  'true_label': 'O'},
 {'token': 'includes',
  'output_probs': array([0.99995995, 0.00000028, 0.0000008 , 0.00000044, 0.00000007,
         0.00000017, 0.00000006, 0.00000588, 0.0000193 , 0.00001099,
         0.00000034, 0.00000107, 0.00000064], dtype=float32),
  'true_label': 'O'},
 {'token': 'sonny',
  'output_probs': array([0.00104131, 0.00101925, 0.00520316, 0.00461244, 0.0001945 ,
         0.98684025, 0.00011484, 0.00005835, 0.00026113, 0.00016429,
         0.00001456, 0.00044428, 0.0000315 ], dt

In [266]:
#!g1.1
import pickle

with open('train_probs.pickle', 'wb') as f:
    pickle.dump(prediction_probs, f)

In [267]:
#!g1.1
with open('train_probs.pickle', 'rb') as f:
    b = pickle.load(f)

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()



In [268]:
#!g1.1
prediction_probs['en'][1]

[{'token': 'it',
  'output_probs': array([0.99997985, 0.00000084, 0.00000811, 0.00000375, 0.00000045,
         0.00000123, 0.0000003 , 0.00000026, 0.00000246, 0.00000203,
         0.00000018, 0.00000037, 0.0000001 ], dtype=float32),
  'true_label': 'O'},
 {'token': 'is',
  'output_probs': array([0.99999785, 0.00000004, 0.00000009, 0.00000006, 0.00000002,
         0.00000005, 0.00000006, 0.00000007, 0.00000054, 0.00000066,
         0.00000011, 0.00000031, 0.00000014], dtype=float32),
  'true_label': 'O'},
 {'token': 'a',
  'output_probs': array([0.9999962 , 0.0000001 , 0.00000119, 0.00000021, 0.00000004,
         0.00000016, 0.00000032, 0.00000004, 0.00000115, 0.00000024,
         0.00000009, 0.00000009, 0.00000017], dtype=float32),
  'true_label': 'O'},
 {'token': 'series',
  'output_probs': array([0.9999081 , 0.00000929, 0.00001826, 0.00002181, 0.00000149,
         0.00000084, 0.00000108, 0.0000031 , 0.00002243, 0.00001206,
         0.00000084, 0.0000005 , 0.00000036], dtype=float32),

In [269]:
#!g1.1
b['en'][1]

[{'token': 'it',
  'output_probs': array([0.99997985, 0.00000084, 0.00000811, 0.00000375, 0.00000045,
         0.00000123, 0.0000003 , 0.00000026, 0.00000246, 0.00000203,
         0.00000018, 0.00000037, 0.0000001 ], dtype=float32),
  'true_label': 'O'},
 {'token': 'is',
  'output_probs': array([0.99999785, 0.00000004, 0.00000009, 0.00000006, 0.00000002,
         0.00000005, 0.00000006, 0.00000007, 0.00000054, 0.00000066,
         0.00000011, 0.00000031, 0.00000014], dtype=float32),
  'true_label': 'O'},
 {'token': 'a',
  'output_probs': array([0.9999962 , 0.0000001 , 0.00000119, 0.00000021, 0.00000004,
         0.00000016, 0.00000032, 0.00000004, 0.00000115, 0.00000024,
         0.00000009, 0.00000009, 0.00000017], dtype=float32),
  'true_label': 'O'},
 {'token': 'series',
  'output_probs': array([0.9999081 , 0.00000929, 0.00001826, 0.00002181, 0.00000149,
         0.00000084, 0.00000108, 0.0000031 , 0.00002243, 0.00001206,
         0.00000084, 0.0000005 , 0.00000036], dtype=float32),

In [None]:
#!g1.1
