# NER as QA for ACE2005

На основе статьи `A Unified MRC Framework for Named Entity Recognition` (https://arxiv.org/abs/1910.11476) и их реализации `mrc-for-flat-nested-ner` (https://github.com/ShannonAI/mrc-for-flat-nested-ner), но при помощи `PyTorch` (https://pytorch.org).

Данная версия Jupyter Notebook для данных `Ace2005` с использованием `BertLarge`. 

### Загрузка данных для ознакомления

Загружаем данные `Ace2005`:

In [1]:
import json

In [2]:
all_data = json.load(open('datasets/ace2005/mrc-ner.train', encoding="utf-8"))

Смотрим на них:

In [3]:
for data in all_data:
    if data['start_position'] != []:
        print(data)
        break

{'context': "BEGALA Well , we ' ll debate that later on in the show .", 'end_position': [3], 'entity_label': 'ORG', 'impossible': False, 'qas_id': '0.2', 'query': 'organization entities are limited to companies, corporations, agencies, institutions and other groups of people.', 'span_position': ['3;3'], 'start_position': [3]}


In [19]:
all_data[13]

{'context': "We ' ll have a couple of experts come out , so I ' ll withhold my comments until then .",
 'end_position': [],
 'entity_label': 'WEA',
 'impossible': True,
 'qas_id': '1.7',
 'query': 'weapon entities are limited to physical devices such as instruments for physically harming such as guns, arms and gunpowder.',
 'span_position': [],
 'start_position': []}

### Построение модели

Подключаем библиотеки:

In [7]:
import argparse
import os
import json
from collections import namedtuple
from typing import Dict

import torch
from tqdm import tqdm
import torch.nn as nn
from models.classifier import MultiNonLinearClassifier, SingleLinearClassifier
from transformers import BertModel, BertPreTrainedModel, BertTokenizer
from tokenizers import BertWordPieceTokenizer
from torch import Tensor
from torch.nn.modules import CrossEntropyLoss, BCEWithLogitsLoss
from torch.utils.data import DataLoader
from transformers import AdamW
from torch.optim import SGD
from metrics.functional.query_span_f1 import query_span_f1

from datasets.mrc_ner_dataset import MRCNERDataset
from datasets.truncate_dataset import TruncateDataset
from datasets.collate_functions import collate_to_max_length
from models.bert_query_ner import BertQueryNER
from models.query_ner_config import BertQueryNerConfig
from loss import *
from utils.get_parser import get_parser
from utils.radom_seed import set_random_seed
import logging

set_random_seed(0)

Строим саму модель, в данном случае в качестве feature extractor берется `BertLarge` (может браться другой), задаются три головы сети - `начальная позиция`, `конечная позиция`, `спан` -- которые будут предсказываться.

In [None]:
class BertQueryNER(nn.Module):
    def __init__(self):
        super(BertQueryNER, self).__init__()
        self.bert = BertModel.from_pretrained('bert_large/')

        self.start_outputs = nn.Linear(self.bert.config.hidden_size, 1)
        self.end_outputs = nn.Linear(self.bert.config.hidden_size, 1)
        self.span_embedding = MultiNonLinearClassifier(self.bert.config.hidden_size * 2, 1, 0.3)

        self.hidden_size = self.bert.config.hidden_size
        


    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        """
        Args:
            input_ids: bert input tokens, tensor of shape [seq_len]
            token_type_ids: 0 for query, 1 for context, tensor of shape [seq_len]
            attention_mask: attention mask, tensor of shape [seq_len]
        Returns:
            start_logits: start/non-start probs of shape [seq_len]
            end_logits: end/non-end probs of shape [seq_len]
            match_logits: start-end-match probs of shape [seq_len, 1]
        """
        bert_outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)

        sequence_heatmap = bert_outputs[0]  # [batch, seq_len, hidden]
        batch_size, seq_len, hid_size = sequence_heatmap.size()

        start_logits = self.start_outputs(sequence_heatmap).squeeze(-1)  # [batch, seq_len, 1]
        end_logits = self.end_outputs(sequence_heatmap).squeeze(-1)  # [batch, seq_len, 1]

        # for every position $i$ in sequence, should concate $j$ to
        # predict if $i$ and $j$ are start_pos and end_pos for an entity.
        # [batch, seq_len, seq_len, hidden]
        start_extend = sequence_heatmap.unsqueeze(2).expand(-1, -1, seq_len, -1)
        # [batch, seq_len, seq_len, hidden]
        end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1)
        # [batch, seq_len, seq_len, hidden*2]
        span_matrix = torch.cat([start_extend, end_extend], 3)
        # [batch, seq_len, seq_len]
        span_logits = self.span_embedding(span_matrix).squeeze(-1)

        return start_logits, end_logits, span_logits

Функция `dataloader`'а, используем `tokenizer_large` (можно другой):

In [3]:
def get_dataloader(data_dir, prefix="train", limit: int = None) -> DataLoader:
    """get training dataloader"""
    """
    load_mmap_dataset
    """
    json_path = os.path.join(data_dir, f"mrc-ner.{prefix}")
    dataset = MRCNERDataset(json_path=json_path,
                            tokenizer= BertWordPieceTokenizer(BertTokenizer.from_pretrained('./tokenizer_large').vocab),
                            max_length=128,
                            is_chinese=False,
                            pad_to_maxlen=False
                            )

    if limit is not None:
        dataset = TruncateDataset(dataset, limit)

    dataloader = DataLoader(
        dataset=dataset,
        batch_size=16,
        num_workers=16,
        shuffle=True if prefix == "train" else False,
        collate_fn=collate_to_max_length
    )

    return dataloader

Загружаем данные уже через `dataloader`:

In [4]:
data_dir = 'ace2005/'
trainloader = get_dataloader(data_dir, 'train')
devloader = get_dataloader(data_dir, 'dev')
testloader = get_dataloader(data_dir, 'test')



Задаем веса, на основе статьи вес для спана берется в уменьшенном масштабе:

In [5]:
weight_start = 1
weight_end = 1
weight_span = 0.1
weight_sum = weight_start + weight_end + weight_span
weight_start = weight_start / weight_sum
weight_end = weight_end / weight_sum
weight_span = weight_span / weight_sum
print(weight_start, weight_end, weight_span)

0.47619047619047616 0.47619047619047616 0.047619047619047616


Задаем модель и функцию потерь:

In [6]:
model = BertQueryNER().cuda()
bce_loss = BCEWithLogitsLoss(reduction="none")

Задаем оптимизатор:

In [7]:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.01,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = AdamW(optimizer_grouped_parameters,
                  betas=(0.9, 0.98),  # according to RoBERTa paper
                  lr=3e-5)
t_total = (len(trainloader)) * 20
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=3e-5, pct_start=float(0/t_total),
    final_div_factor=1e4,
    total_steps=t_total, anneal_strategy='linear'
)

Функция для обучения:

In [8]:
def train(model, dataloder):
    model.train()
    mean_loss = 0
    for k, batch in enumerate(tqdm(dataloder)):
        tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx = batch
        attention_mask = (tokens != 0).long()
        tokens, attention_mask, token_type_ids = tokens.cuda(), attention_mask.cuda(), token_type_ids.cuda()
        start_logits, end_logits, span_logits = model(tokens, attention_mask, token_type_ids)
        start_logits, end_logits, span_logits = start_logits.cpu(), end_logits.cpu(), span_logits.cpu()
        batch_size, seq_len = start_logits.size()

        start_float_label_mask = start_label_mask.view(-1).float()
        end_float_label_mask = end_label_mask.view(-1).float()
        match_label_row_mask = start_label_mask.bool().unsqueeze(-1).expand(-1, -1, seq_len)
        match_label_col_mask = end_label_mask.bool().unsqueeze(-2).expand(-1, seq_len, -1)
        match_label_mask = match_label_row_mask & match_label_col_mask
        match_label_mask = torch.triu(match_label_mask, 0)  # start should be less equal to end

        start_preds = start_logits > 0
        end_preds = end_logits > 0

        match_candidates = torch.logical_or(
            (start_preds.unsqueeze(-1).expand(-1, -1, seq_len)
              & end_preds.unsqueeze(-2).expand(-1, seq_len, -1)),
            (start_labels.unsqueeze(-1).expand(-1, -1, seq_len)
              & end_labels.unsqueeze(-2).expand(-1, seq_len, -1))
        )
        match_label_mask = match_label_mask & match_candidates
        float_match_label_mask = match_label_mask.view(batch_size, -1).float()

        start_loss = bce_loss(start_logits.view(-1), start_labels.view(-1).float())
        start_loss = (start_loss * start_float_label_mask).sum() / start_float_label_mask.sum()
        end_loss = bce_loss(end_logits.view(-1), end_labels.view(-1).float())
        end_loss = (end_loss * end_float_label_mask).sum() / end_float_label_mask.sum()
        match_loss = bce_loss(span_logits.view(batch_size, -1), match_labels.view(batch_size, -1).float())
        match_loss = match_loss * float_match_label_mask
        match_loss = match_loss.sum() / (float_match_label_mask.sum() + 1e-10)


        total_loss = weight_start * start_loss + weight_end * end_loss + weight_span * match_loss
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        scheduler.step()
        mean_loss += total_loss.item()
        if k % 500 == 0 and k > 0:
            print(mean_loss / 500)
            mean_loss = 0

Функция для валидации с выводом `recall`, `precision` и `f1` для полученных спанов, а также их сохранение для каждой эпохи:

In [9]:
def val(model, dataloder, epoch):
    model.eval()
    output = []
    for batch in tqdm(dataloder):
        tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx = batch
        attention_mask = (tokens != 0).long()
        tokens, attention_mask, token_type_ids = tokens.cuda(), attention_mask.cuda(), token_type_ids.cuda()
        start_logits, end_logits, span_logits = model(tokens, attention_mask, token_type_ids)
        start_logits, end_logits, span_logits = start_logits.cpu(), end_logits.cpu(), span_logits.cpu()
        start_preds, end_preds = start_logits > 0, end_logits > 0
        span_f1_stats = query_span_f1(start_preds=start_preds, end_preds=end_preds, match_logits=span_logits,
                                      start_label_mask=start_label_mask, end_label_mask=end_label_mask,
                                      match_labels=match_labels)
        output.append(span_f1_stats)
    all_counts = torch.stack(output).sum(0)
    span_tp, span_fp, span_fn = all_counts
    span_recall = span_tp / (span_tp + span_fn + 1e-10)
    span_precision = span_tp / (span_tp + span_fp + 1e-10)
    span_f1 = span_precision * span_recall * 2 / (span_recall + span_precision + 1e-10)
    print(span_recall, span_precision, span_f1)
    with open('./log.txt', 'a') as f:
        f.write(str(epoch) + '\n')
        f.write(str(span_recall.item()) + ' ' + str(span_precision.item()) + ' ' + str(span_f1.item()) + '\n')

### Обучение

Запускаем обучение с сохранением весов:

In [None]:
for i in range(12 * len(trainloader)):
    scheduler.step()

In [None]:
for i in range(20):
    train(model, trainloader)
    torch.save(model.state_dict(), 'weights/BertLargeAce2005/weight' + str(i) + '.pth')
    val(model, testloader, i)

 16%|█▌        | 501/3194 [03:53<19:07,  2.35it/s]

0.12858526457287373


 31%|███▏      | 1001/3194 [07:47<17:01,  2.15it/s]

0.075910789469257


 47%|████▋     | 1501/3194 [11:36<11:44,  2.40it/s]

0.061657362127676606


 63%|██████▎   | 2001/3194 [15:29<08:20,  2.38it/s]

0.05112056450545788


 78%|███████▊  | 2501/3194 [19:19<05:41,  2.03it/s]

0.04574927730485797


 94%|█████████▍| 3001/3194 [23:07<01:33,  2.08it/s]

0.04551873746048659


100%|██████████| 3194/3194 [24:37<00:00,  2.16it/s]
100%|██████████| 464/464 [00:53<00:00,  8.71it/s]
  0%|          | 0/3194 [00:00<?, ?it/s]

tensor(0.6917) tensor(0.7556) tensor(0.7223)


 16%|█▌        | 501/3194 [03:55<21:30,  2.09it/s]

0.03606189742870629


 31%|███▏      | 1001/3194 [07:49<17:17,  2.11it/s]

0.03181121840281412


 47%|████▋     | 1501/3194 [11:37<11:41,  2.41it/s]

0.049799949171952906


 63%|██████▎   | 2001/3194 [15:24<08:31,  2.33it/s]

0.030978964817244558


 78%|███████▊  | 2501/3194 [19:16<05:31,  2.09it/s]

0.03074457576707937


 94%|█████████▍| 3001/3194 [23:06<01:18,  2.47it/s]

0.03055802345322445


100%|██████████| 3194/3194 [24:33<00:00,  2.17it/s]
100%|██████████| 464/464 [00:53<00:00,  8.68it/s]
  0%|          | 0/3194 [00:00<?, ?it/s]

tensor(0.7099) tensor(0.8597) tensor(0.7777)


 16%|█▌        | 501/3194 [03:56<20:30,  2.19it/s]

0.022608816804597153


 31%|███▏      | 1001/3194 [07:45<15:47,  2.31it/s]

0.02286664771893993


 47%|████▋     | 1501/3194 [11:37<13:31,  2.09it/s]

0.023191756658721717


 63%|██████▎   | 2001/3194 [15:29<09:48,  2.03it/s]

0.024187730440113228


 78%|███████▊  | 2501/3194 [19:17<05:07,  2.26it/s]

0.02328848049606313


 94%|█████████▍| 3001/3194 [23:07<01:27,  2.20it/s]

0.022608148890780284


100%|██████████| 3194/3194 [24:37<00:00,  2.16it/s]
100%|██████████| 464/464 [00:52<00:00,  8.79it/s]
  0%|          | 0/3194 [00:00<?, ?it/s]

tensor(0.7581) tensor(0.8365) tensor(0.7954)


 16%|█▌        | 501/3194 [03:52<26:19,  1.71it/s]

0.01784350582587649


 31%|███▏      | 1001/3194 [07:42<15:39,  2.33it/s]

0.017986737339699174


 47%|████▋     | 1501/3194 [11:33<12:31,  2.25it/s]

0.018268012501648626


 63%|██████▎   | 2001/3194 [15:23<08:29,  2.34it/s]

0.01976233862398658


 78%|███████▊  | 2501/3194 [19:12<05:11,  2.22it/s]

0.01847411076858407


 94%|█████████▍| 3001/3194 [23:02<01:26,  2.22it/s]

0.017393200945865828


100%|██████████| 3194/3194 [24:32<00:00,  2.17it/s]
100%|██████████| 464/464 [00:53<00:00,  8.67it/s]
  0%|          | 0/3194 [00:00<?, ?it/s]

tensor(0.8142) tensor(0.8091) tensor(0.8116)


 16%|█▌        | 501/3194 [03:49<18:06,  2.48it/s]

0.016011687345482643


 31%|███▏      | 1001/3194 [07:40<16:25,  2.22it/s]

0.015342441417000374


 47%|████▋     | 1501/3194 [11:30<11:57,  2.36it/s]

0.014979549706156831


 63%|██████▎   | 2001/3194 [15:24<08:59,  2.21it/s]

0.016763105851685396


 78%|███████▊  | 2501/3194 [19:16<05:59,  1.93it/s]

0.015364996974851237


 94%|█████████▍| 3001/3194 [23:06<01:38,  1.96it/s]

0.013775797555004828


100%|██████████| 3194/3194 [24:35<00:00,  2.16it/s]
100%|██████████| 464/464 [00:53<00:00,  8.70it/s]
  0%|          | 0/3194 [00:00<?, ?it/s]

tensor(0.7815) tensor(0.8451) tensor(0.8121)


 16%|█▌        | 501/3194 [03:52<20:09,  2.23it/s]

0.011875225371244595


 31%|███▏      | 1001/3194 [07:47<19:43,  1.85it/s]

0.011067986467969604


 47%|████▋     | 1501/3194 [11:33<13:05,  2.15it/s]

0.011468220981754712


 63%|██████▎   | 2001/3194 [15:28<09:37,  2.06it/s]

0.012283691499389533


 78%|███████▊  | 2501/3194 [19:20<06:53,  1.68it/s]

0.009944946995739883


 94%|█████████▍| 3001/3194 [23:09<01:22,  2.34it/s]

0.012394545834191376


100%|██████████| 3194/3194 [24:39<00:00,  2.16it/s]
100%|██████████| 464/464 [00:53<00:00,  8.61it/s]
  0%|          | 0/3194 [00:00<?, ?it/s]

tensor(0.8277) tensor(0.8280) tensor(0.8279)


 16%|█▌        | 501/3194 [03:49<21:15,  2.11it/s]

0.00912054906013509


 31%|███▏      | 1001/3194 [07:44<17:36,  2.08it/s]

0.008904309441681107


 47%|████▋     | 1501/3194 [11:34<13:26,  2.10it/s]

0.008617623055411968


 63%|██████▎   | 2001/3194 [15:26<08:15,  2.41it/s]

0.00841231902621803


 78%|███████▊  | 2501/3194 [19:18<05:28,  2.11it/s]

0.007972502143325982


 94%|█████████▍| 3001/3194 [23:07<01:34,  2.05it/s]

0.008814028137698188


100%|██████████| 3194/3194 [24:34<00:00,  2.17it/s]
100%|██████████| 464/464 [00:53<00:00,  8.71it/s]
  0%|          | 0/3194 [00:00<?, ?it/s]

tensor(0.8429) tensor(0.8366) tensor(0.8397)


 16%|█▌        | 501/3194 [03:53<24:44,  1.81it/s]

0.007156675949554483


 31%|███▏      | 1001/3194 [07:46<17:37,  2.07it/s]

0.007781360217326437


 47%|████▋     | 1501/3194 [11:37<11:40,  2.42it/s]

0.0072049033869261625


 63%|██████▎   | 2001/3194 [15:28<08:21,  2.38it/s]

0.007768534375463787


 78%|███████▊  | 2501/3194 [19:14<05:50,  1.98it/s]

0.006468321072254184


 94%|█████████▍| 3001/3194 [23:04<01:28,  2.19it/s]

0.0066971955895060095


100%|██████████| 3194/3194 [24:32<00:00,  2.17it/s]
100%|██████████| 464/464 [00:53<00:00,  8.67it/s]
  0%|          | 0/3194 [00:00<?, ?it/s]

tensor(0.8502) tensor(0.8136) tensor(0.8315)


 16%|█▌        | 501/3194 [03:57<21:00,  2.14it/s]

0.005555701369274175


 31%|███▏      | 1001/3194 [07:43<15:20,  2.38it/s]

0.005587242028788751


 47%|████▋     | 1501/3194 [11:31<14:28,  1.95it/s]

0.005305168127291836


 63%|██████▎   | 2001/3194 [15:26<09:24,  2.11it/s]

0.0060723101334697275


 78%|███████▊  | 2501/3194 [19:21<06:06,  1.89it/s]

0.0062871849885050325


 94%|█████████▍| 3001/3194 [23:10<01:26,  2.23it/s]

0.006844878789943323


100%|██████████| 3194/3194 [24:39<00:00,  2.16it/s]
100%|██████████| 464/464 [00:53<00:00,  8.67it/s]
  0%|          | 0/3194 [00:00<?, ?it/s]

tensor(0.8521) tensor(0.8348) tensor(0.8434)


 16%|█▌        | 501/3194 [03:56<21:50,  2.06it/s]

0.004093723350764776


 31%|███▏      | 1001/3194 [07:47<16:15,  2.25it/s]

0.004911231752288586


 47%|████▋     | 1501/3194 [11:37<13:15,  2.13it/s]

0.004951428298409155


 63%|██████▎   | 2001/3194 [15:27<10:02,  1.98it/s]

0.005247203349957999


 78%|███████▊  | 2501/3194 [19:14<05:12,  2.22it/s]

0.004656592710765835


 94%|█████████▍| 3001/3194 [23:04<01:33,  2.07it/s]

0.005232201359729515


100%|██████████| 3194/3194 [24:32<00:00,  2.17it/s]
100%|██████████| 464/464 [00:53<00:00,  8.68it/s]
  0%|          | 0/3194 [00:00<?, ?it/s]

tensor(0.8455) tensor(0.8324) tensor(0.8389)


 16%|█▌        | 501/3194 [03:52<20:03,  2.24it/s]

0.004495534385561768


 31%|███▏      | 1001/3194 [07:41<15:06,  2.42it/s]

0.004730289321611053


 47%|████▋     | 1501/3194 [11:33<11:54,  2.37it/s]

0.003926033476853263


 63%|██████▎   | 2001/3194 [15:26<09:29,  2.09it/s]

0.003665020250267844


 78%|███████▊  | 2501/3194 [19:14<05:01,  2.30it/s]

0.0034495801364973887


 94%|█████████▍| 3001/3194 [23:07<01:22,  2.33it/s]

0.0031303183382024144


100%|██████████| 3194/3194 [24:35<00:00,  2.16it/s]
100%|██████████| 464/464 [00:53<00:00,  8.72it/s]
  0%|          | 0/3194 [00:00<?, ?it/s]

tensor(0.8521) tensor(0.8292) tensor(0.8405)


 16%|█▌        | 501/3194 [03:56<23:11,  1.93it/s]  

0.002992035857952942


 31%|███▏      | 1001/3194 [07:49<18:11,  2.01it/s]

0.0027949979454024287


 47%|████▋     | 1501/3194 [11:42<13:28,  2.09it/s]

0.0033511064438807806


 52%|█████▏    | 1645/3194 [12:48<11:51,  2.18it/s]