<a href="https://colab.research.google.com/github/JennyFrost/LLMs/blob/main/NER_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers



In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# !wget "https://raw.githubusercontent.com/patverga/torch-ner-nlp-from-scratch/master/data/conll2003/eng.train" -O train.txt
# !wget "https://github.com/patverga/torch-ner-nlp-from-scratch/blob/master/data/conll2003/eng.testa" -O dev.txt
# !wget "https://github.com/patverga/torch-ner-nlp-from-scratch/blob/master/data/conll2003/eng.testb" -O test.txt

In [None]:
def read_infile(infile):
    answer, sent, labels = [], [], []
    with open(infile, "r", encoding="utf8") as fin:
        for line in fin:
            line = line.strip()
            if line == "":
                if len(sent) > 0:
                    answer.append({"words": sent, "labels": labels})
                sent, labels = [], []
                continue
            splitted = line.split()
            if len(splitted) >= 4:
              tag = splitted[3]
              sent.append(splitted[0])
              labels.append(tag)
    if len(sent) > 0:
        answer.append({"words": sent, "labels": labels})
    return answer

In [None]:
train_data = read_infile("/content/drive/My Drive/eng.train.txt")
dev_data = read_infile("/content/drive/My Drive/eng.testa.txt")
test_data = read_infile("/content/drive/My Drive/eng.testb.txt")
for word, tag in zip(train_data[4]["words"], train_data[4]["labels"]):
    print(word, tag)

The O
European I-ORG
Commission I-ORG
said O
on O
Thursday O
it O
disagreed O
with O
German I-MISC
advice O
to O
consumers O
to O
shun O
British I-MISC
lamb O
until O
scientists O
determine O
whether O
mad O
cow O
disease O
can O
be O
transmitted O
to O
sheep O
. O


In [None]:
from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
tokens = tokenizer(train_data[4]["words"], is_split_into_words=True)
print(tokens["input_ids"])
print(tokenizer.convert_ids_to_tokens(tokens["input_ids"]))
print(tokens.word_ids())


[101, 1109, 1735, 2827, 1163, 1113, 9170, 1122, 19786, 1114, 1528, 5566, 1106, 11060, 1106, 188, 17315, 1418, 2495, 12913, 1235, 6479, 4959, 2480, 6340, 13991, 3653, 1169, 1129, 12086, 1106, 8892, 119, 102]
['[CLS]', 'The', 'European', 'Commission', 'said', 'on', 'Thursday', 'it', 'disagreed', 'with', 'German', 'advice', 'to', 'consumers', 'to', 's', '##hun', 'British', 'la', '##mb', 'until', 'scientists', 'determine', 'whether', 'mad', 'cow', 'disease', 'can', 'be', 'transmitted', 'to', 'sheep', '.', '[SEP]']
[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 14, 15, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, None]


In [None]:
def make_first_subtoken_mask(mask):
    mask = mask[1:-1]
    is_first_word = [False] + list((first != second) for first, second in zip(mask[:-1], mask[1:])) + [True, False]
    return is_first_word

print(make_first_subtoken_mask(tokens.word_ids()))

[False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False]


In [None]:
from collections import Counter
import numpy as np
from torch.utils.data.dataset import Dataset

class UDDataset(Dataset):

    def __init__(self, data, tokenizer, min_count=3, tags=None):
        self.data = data
        self.tokenizer = tokenizer
        if tags is None:
            tag_counts = Counter([tag for elem in data for tag in elem["labels"]])
            self.tags_ = ["<PAD>", "<UNK>"] + [x for x, count in tag_counts.items() if count >= min_count]
        else:
            self.tags_ = tags
        self.tag_indexes_ = {tag: i for i, tag in enumerate(self.tags_)}
        self.unk_index = 1
        self.ignore_index = -100

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        item = self.data[index]
        tokenization = self.tokenizer(item["words"], is_split_into_words=True)
        first_subtoken_mask = make_first_subtoken_mask(tokenization.word_ids())
        answer = {"input_ids": tokenization["input_ids"], "mask": first_subtoken_mask}
        if "labels" in item:
            labels = [self.tag_indexes_.get(tag, self.unk_index) for tag in item["labels"]]
            zero_labels = np.array([self.ignore_index] * len(tokenization["input_ids"]), dtype=int)
            zero_labels[first_subtoken_mask] = labels
            answer["y"] = zero_labels
        return answer

In [None]:
train_dataset = UDDataset(train_data, tokenizer)
for key, value in train_dataset[4].items():
    print(key, value)

input_ids [101, 1109, 1735, 2827, 1163, 1113, 9170, 1122, 19786, 1114, 1528, 5566, 1106, 11060, 1106, 188, 17315, 1418, 2495, 12913, 1235, 6479, 4959, 2480, 6340, 13991, 3653, 1169, 1129, 12086, 1106, 8892, 119, 102]
mask [False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False]
y [-100    2    3    3    2    2    2    2    2    2    4    2    2    2
    2 -100    2    4 -100    2    2    2    2    2    2    2    2    2
    2    2    2    2    2 -100]


In [None]:
train_dataset.tag_indexes_

{'<PAD>': 0,
 '<UNK>': 1,
 'B-LOC': 7,
 'B-MISC': 8,
 'B-ORG': 9,
 'I-LOC': 6,
 'I-MISC': 4,
 'I-ORG': 3,
 'I-PER': 5,
 'O': 2}

In [None]:
train_dataset.tags_

['<PAD>',
 '<UNK>',
 'O',
 'I-ORG',
 'I-MISC',
 'I-PER',
 'I-LOC',
 'B-LOC',
 'B-MISC',
 'B-ORG']

In [None]:
import torch
import numpy as np
import itertools

def pad_tensor(vec, length, dim, pad_symbol):
    # vec.shape = [3, 4, 5]
    # length=7, dim=1 -> pad_size = (3, 7-4, 5)
    pad_size = list(vec.shape)
    pad_size[dim] = length - vec.shape[dim]
    answer = torch.cat([vec, torch.ones(*pad_size, dtype=torch.long) * pad_symbol], axis=dim)
    return answer

def pad_tensors(tensors, pad=0):
    # дополняет тензоры из tensors до общей максимальной длины символом pad
    if isinstance(tensors[0], (int, np.integer)):
        return torch.LongTensor(tensors)
    elif isinstance(tensors[0], (float, np.float)):
        return torch.Tensor(tensors)
    tensors = [torch.LongTensor(tensor) for tensor in tensors]
    L = max(tensor.shape[0] for tensor in tensors)
    tensors = [pad_tensor(tensor, L, dim=0, pad_symbol=pad) for tensor in tensors]
    return torch.stack(tensors, axis=0)

class FieldBatchDataLoader:

    def __init__(self, X, batch_size=32, sort_by_length=True,
                 length_field=None, state=115, device="cpu"):
        self.X = X
        self.batch_size = batch_size
        self.sort_by_length = sort_by_length
        self.length_field = length_field  ## добавилось
        self.device = device
        np.random.seed(state)

    def __len__(self):
        return (len(self.X)-1) // self.batch_size + 1

    def __iter__(self):
        if self.sort_by_length:
            # отсортировать индексы по длине объектов [1, ..., 32] -> [7, 4, 15, ...]
            # изменилось взятие длины из поля
            if self.length_field is not None:
                lengths = [len(x[self.length_field]) for x in self.X]
            else:
                lengths = [len(list(x.values())[0]) for x in self.X]
            order = np.argsort(lengths)
            # сгруппировать в батчи [7, 4, 15, 31, 3, ...] -> [[7, 4, 15, 31], [3, ...], ...]
            batched_order = np.array([order[start:start+self.batch_size]
                                      for start in range(0, len(self.X), self.batch_size)])
            # переупорядочить батчи случайно: [[3, 11, 21, 19], [27, ...], ..., [7, ...], ...]
            np.random.shuffle(batched_order[:-1])
            # собрать посл-ть индексов: -> [3, 11, 21, 19, 27, ...]
            self.order = np.fromiter(itertools.chain.from_iterable(batched_order), dtype=int)
        else:
            self.order = np.arange(len(self.X))
            np.random.shuffle(self.order)
        self.idx = 0
        return self

    def __next__(self):
        if self.idx >= len(self.X):
            raise StopIteration()
        end = min(self.idx + self.batch_size, len(self.X))
        indexes = [self.order[i] for i in range(self.idx, end)]
        batch = dict()
        # перебираем все поля
        for field in self.X[indexes[0]]:
            batch[field] = pad_tensors([self.X[i][field] for i in indexes]).to(self.device)
        batch["indexes"] = indexes
        self.idx = end
        return batch

In [None]:
train_dataloader = iter(FieldBatchDataLoader(train_dataset, batch_size=8, device="cuda"))
for i in range(10):
    batch = next(train_dataloader)
    for field, data in batch.items():
        print(f"{field}:{np.shape(data)}", end="\t")
    print("")



input_ids:torch.Size([8, 14])	mask:torch.Size([8, 14])	y:torch.Size([8, 14])	indexes:(8,)	
input_ids:torch.Size([8, 16])	mask:torch.Size([8, 16])	y:torch.Size([8, 16])	indexes:(8,)	
input_ids:torch.Size([8, 25])	mask:torch.Size([8, 25])	y:torch.Size([8, 25])	indexes:(8,)	
input_ids:torch.Size([8, 39])	mask:torch.Size([8, 39])	y:torch.Size([8, 39])	indexes:(8,)	
input_ids:torch.Size([8, 42])	mask:torch.Size([8, 42])	y:torch.Size([8, 42])	indexes:(8,)	
input_ids:torch.Size([8, 9])	mask:torch.Size([8, 9])	y:torch.Size([8, 9])	indexes:(8,)	
input_ids:torch.Size([8, 29])	mask:torch.Size([8, 29])	y:torch.Size([8, 29])	indexes:(8,)	
input_ids:torch.Size([8, 15])	mask:torch.Size([8, 15])	y:torch.Size([8, 15])	indexes:(8,)	
input_ids:torch.Size([8, 11])	mask:torch.Size([8, 11])	y:torch.Size([8, 11])	indexes:(8,)	
input_ids:torch.Size([8, 10])	mask:torch.Size([8, 10])	y:torch.Size([8, 10])	indexes:(8,)	


In [None]:
import torch.nn as nn
from transformers.optimization import AdamW

class BasicTransformersTaggingModel(nn.Module):

    def __init__(self, model, labels_number, lr=1e-5, device="cpu", **kwargs):
        super(BasicTransformersTaggingModel, self).__init__()
        self.model = model
        self.labels_number = labels_number
        self.build_network(labels_number)
        # определяем функцию потерь
        self.log_softmax = nn.LogSoftmax(dim=-1)
        self.criterion = nn.NLLLoss(reduction="mean", ignore_index=-100)
        self.device = device
        if self.device is not None:
            self.to(self.device)
        self.optimizer = AdamW(self.parameters(), lr=lr, weight_decay=0.01)

    @property
    def hidden_size(self):
        return self.model.config.hidden_size

    def forward(self, input_ids, **kwargs):
        raise NotImplementedError("You should implement forward pass in your derived class.")

    def train_on_batch(self, x, y, mask=None):
        self.train()
        self.optimizer.zero_grad()
        loss = self._validate(x, y, mask=mask)
        loss["loss"].backward()
        self.optimizer.step()
        return loss

    def validate_on_batch(self, x, y, mask=None):
        self.eval()
        with torch.no_grad():
            return self._validate(x, y, mask=mask)

    def _validate(self, x, y, mask=None):
        if self.device is not None:
            y = y.to(self.device)
        log_probs = self(**x) #   self.forward(x) = self.__call__(x)
        loss = self.criterion(log_probs.permute(0, 2, 1), y)
        _, labels = torch.max(log_probs, dim=-1)
        # собираем ответы (пока не работает??)
        batch_labels = [None] * len(labels)
        for i, elem in enumerate(labels):
            if mask is None:
                curr_mask = [True] * len(elem)
            else:
                curr_mask = mask[i]
            batch_labels[i] = elem[curr_mask]
        return {"loss": loss, "labels": batch_labels}

class TransformersTaggingModel(BasicTransformersTaggingModel):

    def build_network(self, labels_number):
        self.proj_layer = torch.nn.Linear(self.hidden_size, self.labels_number)
        return self

    def forward(self, input_ids, **kwargs):
        input_ids = input_ids.to(self.device)
        cls_output = self.model(input_ids)["last_hidden_state"]
        logits = self.proj_layer(cls_output)
        log_probs = self.log_softmax(logits)
        return log_probs

In [None]:
from transformers import BertModel

bert_model = BertModel.from_pretrained("bert-base-cased")

model = TransformersTaggingModel(bert_model, labels_number=len(train_dataset.tags_), device="cuda")
train_dataloader = iter(FieldBatchDataLoader(train_dataset, batch_size=8, device="cuda"))
batch = next(train_dataloader)
labels = batch["y"]
for i in range(100):
    loss = model.train_on_batch(batch, labels, batch["mask"])["loss"].item()
    if i < 5 or (i+1) % 10 == 0:
        print(i, loss)
print(model.validate_on_batch(batch, labels, batch["mask"])["loss"].item())

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


0 2.412137985229492
1 2.203413963317871
2 2.045546293258667
3 1.8634511232376099
4 1.7420639991760254
9 1.1849768161773682
19 0.5305710434913635
29 0.26225581765174866
39 0.11692790687084198
49 0.06172864884138107
59 0.03255518525838852
69 0.023915767669677734
79 0.01773018203675747
89 0.012969790026545525
99 0.012493778020143509
0.004089432302862406


In [None]:
from collections import defaultdict

def extract_groups(labels):
    """
    Извлекает группы из предсказанных списков меток для предложений.
    """
    groups = []
    for i, sent_labels in enumerate(labels):
        start, group_label = None, None
        for j, label in enumerate(sent_labels):
            if "-" in label:  ## внутри группы
                label_pos, label_type = label.split("-")   # label_pos \in [B, I], label_type \in [PER, ORG, ...]
                if label_pos == "B" or label_type != group_label:
                    ## началась новая группа
                    if group_label is not None:
                        groups.append((i, start, j, group_label))
                    start, group_label = j, label_type
            elif group_label is not None:
                groups.append((i, start, j, group_label))
                group_label = None
        if group_label is not None:
            groups.append((i, start, len(sent_labels), group_label))
    answer = defaultdict(list)
    for sent_index, start, end, group_label in groups:
        key = (sent_index, start, end)
        answer[group_label].append(key)
        answer["all"].append(key)
    return answer

In [None]:
def extract_batch_metrics(true_labels, pred_labels):
    answer = defaultdict(float)
    labels = train_dataset.tags_
    # print(true_labels, type(true_labels))
    # print(pred_labels)
    if isinstance(true_labels, torch.Tensor):
        true_labels = true_labels.detach().cpu().numpy()
        true_labels = true_labels.tolist()
        # has_dummy = any([-100 in elem for elem in true_labels])
        # if has_dummy:
        #   print(true_labels)
        #   print(pred_labels)
        true_labels = [[label for label in elem if label not in [0, -100]] for elem in true_labels]
        # if has_dummy:
          # print(true_labels)
          # print(pred_labels)
    if isinstance(true_labels[0][0], (int, np.integer)):
        true_labels = [list(map(lambda label: train_dataset.tags_[label], elem)) for elem in true_labels]
        print(true_labels)
        pred_labels =[list(map(lambda label: train_dataset.tags_[label], elem)) for elem in pred_labels]
        print(pred_labels)
    if isinstance(pred_labels[0], torch.Tensor):
        pred_labels = [elem.tolist() for elem in pred_labels]
    else:
        L = max(len(x) for x in true_labels)
        true_labels = np.array([elem + ['<PAD>'] * (L-len(elem)) for elem in true_labels])
        pred_labels = np.array([elem + ['<PAD>'] * (L-len(elem)) for elem in pred_labels])
    # маска для слов
    word_mask = (true_labels != '<BEGIN>') * (true_labels != '<END>') * (true_labels != '<PAD>')
    # маска для сущностей в эталоне
    label_mask = word_mask * (true_labels != 'O')
    equal_mask = (true_labels == pred_labels)
    # число ненулевых ответов
    answer["total"] = np.count_nonzero(word_mask)
    # число правильных ответов
    answer["correct"] = np.count_nonzero(equal_mask * word_mask)
    answer["TP_token"] = np.count_nonzero(equal_mask * label_mask)
    # FP_token: было слово (word_mask), правильный ответ негативный (true_labels == "O"), была ошибка (~equal_mask)
    answer["FP_token"] = np.count_nonzero((true_labels == "O") * ~equal_mask * word_mask)
    answer["FN_token"] = np.count_nonzero((true_labels != "O") * ~equal_mask * word_mask)
    # правильные группы
    true_groups = extract_groups(true_labels)
    pred_groups = extract_groups(pred_labels)
    for label, label_data in pred_groups.items():
        #  перебираем предсказанные группы каждого типа
        if label == "all":
            continue
        for key in label_data:
            answer["TP" if key in true_groups[label] else "FP"] += 1  # проверяем, была ли эта группа для данного типа
            answer["TP_bound" if key in true_groups["all"] else "FP_bound"] += 1 # проверяем, была ли эта группа вообще
    # повторяем для исходных данных
    for label, label_data in true_groups.items():
        if label == "all":
            continue
        for key in label_data:
            if key not in pred_groups[label]:
                answer["FN"] += 1  # группа была, но не предсказалась для данного типа
            if key not in pred_groups["all"]:
                answer["FN_bound"] += 1  # группа была, но не предсказалась вообще

    return answer

In [None]:
def update_metrics(metrics, batch_output, batch_labels):
    n_batches = metrics["n_batches"]
    metrics["loss"] = (metrics["loss"] * n_batches + batch_output["loss"].item()) / (n_batches + 1)
    metrics["n_batches"] += 1
    batch_metrics = extract_batch_metrics(batch_labels, batch_output["labels"])
    for key, value in batch_metrics.items():
        metrics[key] += value
    for suffix in ["", "_bound", "_token"]:
        metrics["F1"+suffix] = metrics["TP"+suffix] / max(metrics["TP"+suffix] + 0.5*(metrics["FN"+suffix]+metrics["FP"+suffix]), 1.0)
    metrics["accuracy"] = metrics["correct"] / max(metrics["total"], 1)

In [None]:
import tqdm

def initialize_metrics():
    metrics = {"total": 0, "correct": 0, "n_batches": 0, "loss": 0.0}
    for metric in ["TP", "FP", "FN"]:
        for suffix in ["", "_bound", "_token"]:
            metrics[metric+suffix] = 0.0
    return metrics

def do_epoch(model, dataloader, mode="validate", epoch=1):
    metrics = initialize_metrics()
    func = model.train_on_batch if mode == "train" else model.validate_on_batch
    progress_bar = tqdm.notebook.tqdm(dataloader, leave=True)
    progress_bar.set_description(f"{mode}, epoch={epoch}")
    for batch in progress_bar:
        batch_answers, mask = batch["y"], batch.get("mask")
        if mask is not None:
            mask = mask.bool()
        batch_output = func(batch, batch_answers, mask=mask)
        update_metrics(metrics, batch_output, batch_answers)
        postfix = {"loss": round(metrics["loss"], 4), "acc": round(100 * metrics["accuracy"], 2)}
        for key, value in metrics.items():
            if key[:2] == "F1":
                postfix[key] = round(100 * value, 2)
        progress_bar.set_postfix(postfix)
    return metrics

In [None]:
train_dataset = UDDataset(train_data, tokenizer)
dev_dataset = UDDataset(dev_data, tokenizer, tags=train_dataset.tags_)
train_dataloader = iter(FieldBatchDataLoader(train_dataset, batch_size=8, device="cuda"))
dev_dataloader = iter(FieldBatchDataLoader(dev_dataset, batch_size=8, device="cuda"))

bert_model = BertModel.from_pretrained("bert-base-cased")
model = TransformersTaggingModel(bert_model, labels_number=len(train_dataset.tags_), device="cuda")
best_val_acc = 0.0
checkpoint = "checkpoint_best.pt"
for epoch in range(3):
    do_epoch(model, train_dataloader, mode="train", epoch=epoch+1)
    epoch_metrics = do_epoch(model, dev_dataloader, mode="validate", epoch=epoch+1)
    if epoch_metrics["accuracy"] > best_val_acc:
        best_val_acc = epoch_metrics["accuracy"]
        torch.save(model.state_dict(), checkpoint)
        # print("Saving ")
model.load_state_dict(torch.load(checkpoint))
do_epoch(model, dev_dataloader, mode="validate", epoch="evaluate")

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


HBox(children=(FloatProgress(value=0.0, max=1874.0), HTML(value='')))

[['I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['I-LOC', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['I-LOC', 'O'], ['O', 'I-PER', 'I-PER', 'O'], ['I-LOC', 'O']]
[['I-LOC', '<UNK>', 'I-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'I-ORG', 'I-MISC'], ['I-ORG', 'I-LOC'], ['I-LOC', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-MISC', 'I-ORG', 'I-ORG'], ['I-ORG', 'I-MISC', 'I-MISC', 'I-MISC', 'I-MISC', '<UNK>', 'I-ORG', 'I-MISC'], ['I-MISC', 'I-LOC', 'I-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'I-MISC'], ['I-ORG', 'I-ORG'], ['I-MISC', 'I-MISC', 'I-ORG', 'I-ORG'], ['I-ORG', 'I-ORG']]
[['O'], ['O'], ['O'], ['O'], ['O'], ['O'], ['O'], ['O']]
[['I-ORG'], ['I-ORG'], ['I-ORG'], ['I-ORG'], ['I-ORG'], ['I-ORG'], ['I-ORG'], ['I-MISC']]
[['O', 'I-ORG', 'I-ORG', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'O', 'I-PER', 'I-PER', 'I-PER', 

HBox(children=(FloatProgress(value=0.0, max=434.0), HTML(value='')))

[['O'], ['O'], ['O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O'], ['O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O']]
[['O'], ['O'], ['O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O'], ['O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O']]
[['I-LOC', 'I-LOC', 'O', 'O', 'O', 'I-PER', 'I-PER', 'I-PER', 'I-PER', 'O', 'O', 'O', 'O', 'I-PER', 'I-PER', 'O'], ['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-ORG', 'O', 'O', 'O', 'O'], ['I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-MISC', 'I-MISC', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-ORG', 'O', 'I-LOC', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 

HBox(children=(FloatProgress(value=0.0, max=1874.0), HTML(value='')))

[['I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['I-ORG', 'O', 'O', 'O', 'O', 'O'], ['I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O'], ['I-ORG', 'O', 'I-ORG', 'O'], ['O', 'O', 'O', 'O', 'O', 'O']]
[['I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['I-ORG', 'O', 'O', 'O', 'O', 'O'], ['I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O'], ['I-ORG', 'O', 'I-ORG', 'O'], ['O', 'O', 'O', 'O', 'O', 'O']]
[['I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['I-ORG', 'O', 'O', 'O', 'O', 'O'], ['I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O'], ['O', 'I-PER', 'I-PER', 'O', 'I-LOC', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'I-ORG', 'I-ORG', 'O', 'O'], ['I-ORG', 'O', 'I-ORG', 'O', 'I

HBox(children=(FloatProgress(value=0.0, max=434.0), HTML(value='')))

[['O', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'I-LOC', 'O'], ['I-ORG', 'I-ORG', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'I-PER', 'I-PER', 'O', 'I-LOC', 'O', 'I-ORG', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O'], ['O', 'I-PER', 'I-PER', 'O', 'I-PER', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']]
[['O', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'I-LOC', 'O'], ['I-ORG', 'I-ORG', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'I-PER', 'I-PER', 'O', 'I-LOC', 'O', 'I-ORG', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O'], ['O', 'I-PER', 'I-PER', 'O', 'I-PER', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']]
[['O', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-MISC', 'I-PER', 'O', 'I-LOC', 'I-LOC', 'O'], ['I-ORG', 'O', 'O', 'O', 'O', 'I-ORG', 'I-ORG', 'O', 'O', 'I-MISC'

HBox(children=(FloatProgress(value=0.0, max=1874.0), HTML(value='')))

[['O', 'O', 'O', 'O'], ['I-LOC', 'O', 'O', 'O'], ['O', 'I-LOC', 'O', 'O'], ['O', 'O', 'O', 'O'], ['O', 'O'], ['O', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O'], ['O']]
[['O', 'O', 'O', 'O'], ['I-LOC', 'O', 'O', 'O'], ['O', 'I-LOC', 'O', 'O'], ['O', 'O', 'O', 'O'], ['O', 'O'], ['O', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O'], ['O']]
[['I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-MISC', 'I-MISC', 'I-MISC', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-PER', 'I-PER', 'O', 'O'], ['O', 'O', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 

HBox(children=(FloatProgress(value=0.0, max=434.0), HTML(value='')))

[['O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O'], ['O'], ['O'], ['O'], ['I-PER', 'I-PER', 'O', 'O', 'O'], ['O', 'O', 'O']]
[['O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O'], ['O'], ['O'], ['O'], ['I-PER', 'I-PER', 'O', 'O', 'O'], ['O', 'O', 'O']]
[['I-LOC', 'O', 'I-MISC', 'O', 'O', 'O'], ['O'], ['O'], ['O'], ['O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O'], ['I-ORG', 'O', 'I-ORG', 'I-ORG', 'O'], ['O']]
[['I-LOC', 'O', 'I-MISC', 'O', 'O', 'O'], ['O'], ['O'], ['O'], ['O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O'], ['I-ORG', 'O', 'I-ORG', 'I-ORG', 'O'], ['O']]
[['I-LOC', 'I-LOC', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['I-LOC', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['I-ORG', 'O', 'I-PER', 'O', 'O', 'I-LOC', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['I-PER', 'O', 'I-MISC', 'O', 'O', 'O', 'O', 'O']]
[['I-LOC', 'I-LOC', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['I-LOC', 'O'], 

HBox(children=(FloatProgress(value=0.0, max=434.0), HTML(value='')))

[['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-ORG', 'O', 'I-PER', 'I-PER', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'I-PER', 'I-PER', 'O', 'I-LOC', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'I-PER', 'I-PER', 'O', 'I-LOC', 'O', 'O', 'I-ORG', 'I-ORG', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'O', 'O', 'O']]
[['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-ORG', 'O', 'I-PER', 'I-PER', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'I-PER', 'I-PER', 'O', 'I-LOC', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 

{'F1': 0.9341687552213868,
 'F1_bound': 0.9637426900584796,
 'F1_token': 0.969600377781713,
 'FN': 351.0,
 'FN_bound': 174.0,
 'FN_token': 390.0,
 'FP': 437.0,
 'FP_bound': 260.0,
 'FP_token': 125.0,
 'TP': 5591.0,
 'TP_bound': 5768.0,
 'TP_token': 8213.0,
 'accuracy': 0.990015122726744,
 'correct': 51063,
 'loss': 0.036309429807801656,
 'n_batches': 434,
 'total': 51578}

In [None]:
def predict_with_model(model, dataset):
    model.eval()
    dataloader = FieldBatchDataLoader(dataset, device=model.device)
    answer = [None] * len(dataset)
    for batch in dataloader:
        with torch.no_grad():
            batch_answer = model(**batch)
        _, labels = torch.max(batch_answer, dim=-1)
        for i, sent_labels in zip(batch["indexes"], labels):
            mask = dataset[i]["mask"]
            answer[i] = np.take(dataset.tags_, sent_labels.cpu().numpy()[:len(mask)][mask])
    return answer

In [None]:
test_dataset = UDDataset(test_data, tokenizer, tags=train_dataset.tags_)
pred_labels = predict_with_model(model, test_dataset)



In [None]:
pred_labels = [elem.tolist() for elem in pred_labels]
corr_labels = [elem["labels"] for elem in test_data]
print(corr_labels[5], type(corr_labels[5]))
print(pred_labels[5], type(pred_labels[5]))
# print(y_pred[0])
metrics = extract_batch_metrics(corr_labels, pred_labels)
for suffix in ["", "_bound", "_token"]:
    metrics["F1"+suffix] = metrics["TP"+suffix] / max(metrics["TP"+suffix] + 0.5*(metrics["FN"+suffix]+metrics["FP"+suffix]), 1.0)
    print("F1{}\t{}".format(suffix, metrics["F1"+suffix] ))
metrics["accuracy"] = metrics["correct"] / max(metrics["total"], 1)
print("accuracy\t{:.2f}".format(100*metrics["accuracy"]))
print(metrics)

['O', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O'] <class 'list'>
['O', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O'] <class 'list'>
F1	0.9017341040462428
F1_bound	0.9462252583639867
F1_token	0.9459983730680183
accuracy	98.15
defaultdict(<class 'float'>, {'total': 46666, 'correct': 45803, 'TP_token': 7559, 'FP_token': 310, 'FN_token': 553, 'FP': 622.0, 'TP_bound': 5402.0, 'FP_bound': 368.0, 'TP': 5148.0, 'FN': 500.0, 'FN_bound': 246.0, 'F1': 0.9017341040462428, 'F1_bound': 0.9462252583639867, 'F1_token': 0.9459983730680183, 'accuracy': 0.9815068786696953})


In [None]:
sent = test_data[220]
answer = corr_labels[220]
pred = pred_labels[220]
print(sent['words'], f'correct labels: {answer}', f'predicted labels: {pred}', sep='\n')

['Botes', '72', '68', ',', 'Greg', 'Reid', '72', '68', ',', 'Clinton', 'Whitelaw', '70']
correct labels: ['I-PER', 'O', 'O', 'O', 'I-PER', 'I-PER', 'O', 'O', 'O', 'I-PER', 'I-PER', 'O']
predicted labels: ['I-PER', 'O', 'O', 'O', 'I-PER', 'I-PER', 'O', 'O', 'O', 'I-PER', 'I-PER', 'O']


In [None]:
sent = test_data[152]
answer = corr_labels[152]
pred = pred_labels[152]
print(sent['words'], f'correct labels: {answer}', f'predicted labels: {pred}', sep='\n')

['Dutch', 'forward', 'Reggie', 'Blinker', 'had', 'his', 'indefinite', 'suspension', 'lifted', 'by', 'FIFA', 'on', 'Friday', 'and', 'was', 'set', 'to', 'make', 'his', 'Sheffield', 'Wednesday', 'comeback', 'against', 'Liverpool', 'on', 'Saturday', '.']
correct labels: ['I-MISC', 'O', 'I-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-ORG', 'I-ORG', 'O', 'O', 'I-ORG', 'O', 'O', 'O']
predicted labels: ['I-MISC', 'O', 'I-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-ORG', 'I-ORG', 'O', 'O', 'I-ORG', 'O', 'O', 'O']
