# Техническая часть

In [0]:
# !git clone https://github.com/DanilDmitriev1999/degree-project

In [0]:
!pip install pytorch-lightning
!pip install transformers

In [0]:
import pytorch_lightning as pl

from transformers import AdamW, BertModel
from transformers import BertTokenizer
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, RandomSampler, TensorDataset

#import fire

import sys
from collections import defaultdict

import warnings
warnings.filterwarnings('ignore')

# Model

In [0]:
def split_tag(chunk_tag):
    """
    split chunk tag into IOBES prefix and chunk_type
    e.g. 
    B-PER -> (B, PER)
    O -> (O, None)
    """
    if chunk_tag == 'O':
        return ('O', None)
    return chunk_tag.split('-', maxsplit=1)

def is_chunk_end(prev_tag, tag):
    """
    check if the previous chunk ended between the previous and current word
    e.g. 
    (B-PER, I-PER) -> False
    (B-LOC, O)  -> True
    Note: in case of contradicting tags, e.g. (B-PER, I-LOC)
    this is considered as (B-PER, B-LOC)
    """
    prefix1, chunk_type1 = split_tag(prev_tag)
    prefix2, chunk_type2 = split_tag(tag)

    if prefix1 == 'O':
        return False
    if prefix2 == 'O':
        return prefix1 != 'O'

    if chunk_type1 != chunk_type2:
        return True

    return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S']

def is_chunk_start(prev_tag, tag):
    """
    check if a new chunk started between the previous and current word
    """
    prefix1, chunk_type1 = split_tag(prev_tag)
    prefix2, chunk_type2 = split_tag(tag)

    if prefix2 == 'O':
        return False
    if prefix1 == 'O':
        return prefix2 != 'O'

    if chunk_type1 != chunk_type2:
        return True

    return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S']

In [0]:
def calc_metrics(tp, p, t, percent=True):
    """
    compute overall precision, recall and FB1 (default values are 0.0)
    if percent is True, return 100 * original decimal value
    """
    precision = tp / p if p else 0
    recall = tp / t if t else 0
    fb1 = 2 * precision * recall / (precision + recall) if precision + recall else 0
    if percent:
        return 100 * precision, 100 * recall, 100 * fb1
    else:
        return precision, recall, fb1

In [0]:
def count_chunks(true_seqs, pred_seqs):
    """
    true_seqs: a list of true tags
    pred_seqs: a list of predicted tags
    return: 
    correct_chunks: a dict (counter), 
                    key = chunk types, 
                    value = number of correctly identified chunks per type
    true_chunks:    a dict, number of true chunks per type
    pred_chunks:    a dict, number of identified chunks per type
    correct_counts, true_counts, pred_counts: similar to above, but for tags
    """
    correct_chunks = defaultdict(int)
    true_chunks = defaultdict(int)
    pred_chunks = defaultdict(int)

    correct_counts = defaultdict(int)
    true_counts = defaultdict(int)
    pred_counts = defaultdict(int)

    prev_true_tag, prev_pred_tag = 'O', 'O'
    correct_chunk = None

    for true_tag, pred_tag in zip(true_seqs, pred_seqs):
        if true_tag == pred_tag:
            correct_counts[true_tag] += 1
        true_counts[true_tag] += 1
        pred_counts[pred_tag] += 1

        _, true_type = split_tag(true_tag)
        _, pred_type = split_tag(pred_tag)

        if correct_chunk is not None:
            true_end = is_chunk_end(prev_true_tag, true_tag)
            pred_end = is_chunk_end(prev_pred_tag, pred_tag)

            if pred_end and true_end:
                correct_chunks[correct_chunk] += 1
                correct_chunk = None
            elif pred_end != true_end or true_type != pred_type:
                correct_chunk = None

        true_start = is_chunk_start(prev_true_tag, true_tag)
        pred_start = is_chunk_start(prev_pred_tag, pred_tag)

        if true_start and pred_start and true_type == pred_type:
            correct_chunk = true_type
        if true_start:
            true_chunks[true_type] += 1
        if pred_start:
            pred_chunks[pred_type] += 1

        prev_true_tag, prev_pred_tag = true_tag, pred_tag
    if correct_chunk is not None:
        correct_chunks[correct_chunk] += 1

    return (correct_chunks, true_chunks, pred_chunks, 
        correct_counts, true_counts, pred_counts)

In [0]:
def get_result(correct_chunks, true_chunks, pred_chunks,
    correct_counts, true_counts, pred_counts, verbose=True):
    """
    if verbose, print overall performance, as well as preformance per chunk type;
    otherwise, simply return overall prec, rec, f1 scores
    """
    # sum counts
    sum_correct_chunks = sum(correct_chunks.values())
    sum_true_chunks = sum(true_chunks.values())
    sum_pred_chunks = sum(pred_chunks.values())

    sum_correct_counts = sum(correct_counts.values())
    sum_true_counts = sum(true_counts.values())

    nonO_correct_counts = sum(v for k, v in correct_counts.items() if k != 'O')
    nonO_true_counts = sum(v for k, v in true_counts.items() if k != 'O')

    chunk_types = sorted(list(set(list(true_chunks) + list(pred_chunks))))

    # compute overall precision, recall and FB1 (default values are 0.0)
    prec, rec, f1 = calc_metrics(sum_correct_chunks, sum_pred_chunks, sum_true_chunks)
    res = (prec, rec, f1)
    if not verbose:
        return res

    # print overall performance, and performance per chunk type
    
    print("processed %i tokens with %i phrases; " % (sum_true_counts, sum_true_chunks), end='')
    print("found: %i phrases; correct: %i.\n" % (sum_pred_chunks, sum_correct_chunks), end='')
        
    print("accuracy: %6.2f%%; (non-O)" % (100*nonO_correct_counts/nonO_true_counts))
    print("accuracy: %6.2f%%; " % (100*sum_correct_counts/sum_true_counts), end='')
    print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" % (prec, rec, f1))

    # for each chunk type, compute precision, recall and FB1 (default values are 0.0)
    for t in chunk_types:
        prec, rec, f1 = calc_metrics(correct_chunks[t], pred_chunks[t], true_chunks[t])
        print("%17s: " %t , end='')
        print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" %
                    (prec, rec, f1), end='')
        print("  %d" % pred_chunks[t])

    return res

In [0]:
def evaluate(true_seqs, pred_seqs, verbose=True):
    (correct_chunks, true_chunks, pred_chunks,
        correct_counts, true_counts, pred_counts) = count_chunks(true_seqs, pred_seqs)
    result = get_result(correct_chunks, true_chunks, pred_chunks,
        correct_counts, true_counts, pred_counts, verbose=verbose)
    return result

In [0]:
class BertNER(pl.LightningModule):
    def __init__(self, num_labels, lr, train_dataloader, val_dataloader, ids_to_labels, labels_to_ids):
        super(BertNER, self).__init__()
        self.num_labels = num_labels
        self.lr = lr

        self.model = BertModel.from_pretrained("bert-base-multilingual-cased")
        self.classifier = torch.nn.Linear(768, num_labels) # model cant not output "X" labels
        self.dropout = torch.nn.Dropout(p=0.1)

        self.traindl, self.valdl = train_dataloader, val_dataloader # we can't overwrite self.train, self.train_dataloader
        self.ids_to_labels = ids_to_labels
        self.labels_to_ids = labels_to_ids

    def f1(self, y_true, y_pred):
        flatten = lambda l: [item for sublist in l for item in sublist]
        y_true = flatten(y_true)
        y_pred = flatten(y_pred)
        y_true = [self.ids_to_labels[l] for l in y_true]
        y_pred = [self.ids_to_labels[l] for l in y_pred]
        assert len(y_pred) == len(y_true)

        ids = [i for i, label in enumerate(y_true) if label != "X"]
        y_true_cleaned = [y_true[i] for i in ids]
        y_pred_cleaned = [y_pred[i] for i in ids]

        precision, recall, f1 = evaluate(y_true_cleaned, y_pred_cleaned)#, verbose=False)
        print(f"micro average precision: {precision}, recall: {recall}, f1: {f1}")
        return precision, recall, f1

    def forward(self, input_ids, attention_mask, labels=None):
        sequence_output = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )[0]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        outputs = (logits,)
        if labels is not None:
            # reference:
            # https://github.com/huggingface/transformers/blob/d5d7d886128732091e92afff7fcb3e094c71a7ec/src/transformers/modeling_bert.py#L1380-L1394
            loss_fct = torch.nn.CrossEntropyLoss()
            # X label described in Bert Paper section 4.3
            X = self.labels_to_ids["X"]
            not_X_mask = labels != X # since label of PAD is "X", attention mask is not needed

            # Only keep active parts of the loss
            active_loss = not_X_mask.view(-1)
            active_logits = logits.view(-1, self.num_labels)
            active_labels = torch.where(
                active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
            )
            loss = loss_fct(active_logits, active_labels)
            outputs = (loss,) + outputs

        return outputs  # (loss), scores

    def training_step(self, batch, batch_idx):
        loss, score = self.forward(*batch)
        tqdm_dict = {"train_loss": loss}
        return {"loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict}

    def validation_step(self, batch, batch_idx):
        _, mask, labels = batch
        loss, score = self.forward(*batch)
        labels_pred = torch.argmax(score, dim=-1)
        return {
            "val_loss": loss,
            "y_true": labels,
            "y_pred": labels_pred,
            "mask": mask,
        }

    def validation_end(self, outputs):
        val_loss = sum([out["val_loss"] for out in outputs]) / len(outputs)

        y_true, y_pred = [], []
        for out in outputs:
            batch_y_true = out["y_true"].cpu().numpy().tolist()
            batch_y_pred = out["y_pred"].cpu().numpy().tolist()
            batch_seq_lens = out["mask"].cpu().numpy().sum(-1).tolist()
            for i, length in enumerate(batch_seq_lens):
                batch_y_true[i] = batch_y_true[i][:length]
                batch_y_pred[i] = batch_y_pred[i][:length]
            y_true += batch_y_true
            y_pred += batch_y_pred

        precision, recall, f1 = self.f1(y_true, y_pred)
        tqdm_dict = {
            "val_loss": val_loss,
            "f1": f1
        }
        result = {"progress_bar": tqdm_dict, "log": tqdm_dict, "val_loss": val_loss}
        return result

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    @pl.data_loader
    def train_dataloader(self):
        return self.traindl

    @pl.data_loader
    def val_dataloader(self):
        return self.valdl

In [0]:
def convert_tokens_to_ids(tokens, pad=True):
    """Helper function
    """
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    ids = torch.LongTensor([token_ids])
    assert ids.size(1) < max_len, print(ids.size(1))
    if pad:
        padded_ids = torch.zeros(max_len).long()
        padded_ids[:ids.size(1)] = ids
        mask = torch.zeros(max_len).long()
        mask[0:ids.size(1)] = 1
        return padded_ids, mask
    else:
        return ids


In [0]:
def subword_tokenize(tokens, labels):
    """
    Helper function
    Segment each token into subwords while keeping track of
    token boundaries.
    Parameters
    ----------
    tokens: A sequence of strings, representing input tokens.
    Returns
    -------
    A tuple consisting of:
        - A list of subwords, flanked by the special symbols required
            by Bert (CLS and SEP).
        - An array of indices into the list of subwords, indicating
          that the corresponding subword is the start of a new
            token. For example, [1, 3, 4, 7] means that the subwords
            1, 3, 4, 7 are token starts, while all other subwords
            (0, 2, 5, 6, 8...) are in or at the end of tokens.
            This list allows selecting Bert hidden states that
            represent tokens, which is necessary in sequence
            labeling.
    """
    def flatten(list_of_lists):
        for list in list_of_lists:
            for item in list:
                yield item

    subwords = list(map(tokenizer.tokenize, tokens))
    subword_lengths = list(map(len, subwords))
    subwords = [CLS] + list(flatten(subwords)) + [SEP]
    token_start_idxs = 1 + np.cumsum([0] + subword_lengths[:-1])
    # X label described in Bert Paper section 4.3
    bert_labels = [[label] + (sublen-1) * ["X"] for sublen, label in zip(subword_lengths, labels)]
    bert_labels = ["O"] + list(flatten(bert_labels)) + ["O"]

    assert len(subwords) == len(bert_labels)
    assert len(subwords) <= 512
    return subwords, token_start_idxs, bert_labels

In [0]:
def subword_tokenize_to_ids(tokens, labels):
    """Segment each token into subwords while keeping track of token boundaries and convert subwords into IDs.
    Parameters
        ----------
        tokens: A sequence of strings, representing input tokens.
        Returns
        -------
        A tuple consisting of:
            - A list of subword IDs, including IDs of the special
                symbols (CLS and SEP) required by Bert.
            - A mask indicating padding tokens.
            - An array of indices into the list of subwords. See
                doc of subword_tokenize.
    """
    assert len(tokens) == len(labels)
    subwords, token_start_idxs, bert_labels = subword_tokenize(tokens, labels)
    subword_ids, mask = convert_tokens_to_ids(subwords)
    token_starts = torch.zeros(max_len)
    token_starts[token_start_idxs] = 1
    bert_labels = [labels_to_ids[label] for label in bert_labels]
    # X label described in Bert Paper section 4.3 is used for pading
    padded_bert_labels = torch.ones(max_len).long() * labels_to_ids["X"]
    padded_bert_labels[:len(bert_labels)] = torch.LongTensor(bert_labels)

    mask.require_grad = False
    return {
        "input_ids": subword_ids,
        "attention_mask": mask,
        "bert_token_starts": token_starts,
        "labels": padded_bert_labels
    }


# train

In [13]:
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
SEP = "[SEP]"
MASK = "[MASK]"
CLS = "[CLS]"
max_len = 402

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=995526.0, style=ProgressStyle(descripti…




In [0]:
LABELS = ['O', 'B-Person', 'I-Person', 'B-Org', 'I-Org', 'B-Loc', 'I-Loc', "X"]
ids_to_labels = {k:v for k, v in enumerate(LABELS)}
labels_to_ids = {v:k for k, v in enumerate(LABELS)}
num_labels = len(LABELS) - 1 # model can't output "X"

In [0]:
class CoNLL(Dataset):
    """simple class to read raw dataset
    """
    def __init__(self, path="degree-project/data"):
        entries = open(path, "r").read().strip().split("\n\n")

        self.sentence, self.label = [], []  # list of lists
        for entry in entries:
            words = [line.split()[0] for line in entry.splitlines()]
            tags = [self._check(line.split()[-1]) for line in entry.splitlines()]
            self.sentence.append(words)
            self.label.append(tags)

    def __len__(self):
        return len(self.sentence)

    @staticmethod
    def _check(tag):
        if tag == 'B-Location':
            tag = 'B-Loc'
        if tag == 'I-Location':
            tag = 'I-Loc'
        if tag == 'I-Facility':
            tag = 'O'
        if tag == 'I-LocOrg':
            tag = 'I-Loc'
        if tag == 'B-LocOrg':
            tag = 'B-Loc' 
        return tag

    def __getitem__(self, i):
        return self.sentence[i], self.label[i]

In [0]:
def prepare_dataset(path='degree-project/data/dev.txt'):
    dataset = CoNLL(path)
    featurized_sentences = []
    for tokens, labels in dataset:
        features = subword_tokenize_to_ids(tokens, labels)
        featurized_sentences.append(features)

    def collate(featurized_sentences_batch):
        keys = ("input_ids", "attention_mask", "bert_token_starts", "labels")
        output = {key: torch.stack([fs[key] for fs in featurized_sentences_batch], dim=0) for key in keys}
        return output

    dataset = collate(featurized_sentences)
    return TensorDataset(*[dataset[k] for k in ("input_ids", "attention_mask", "labels")])

In [0]:
val_dataset = prepare_dataset('degree-project/data/dev.txt')
train_dataset = prepare_dataset('degree-project/data/train.txt')
batch_size=16

sampler = RandomSampler(train_dataset)
train_dataloader= DataLoader(train_dataset, sampler=sampler, batch_size=batch_size, pin_memory=True)
val_dataloader= DataLoader(val_dataset, batch_size=batch_size, pin_memory=True)

In [0]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size=16
lr=5e-5
epoch=6
N_GPUs = torch.cuda.device_count()
model = BertNER(num_labels, lr, train_dataloader, val_dataloader, ids_to_labels, labels_to_ids).to(device)
trainer = pl.Trainer(
        fast_dev_run=False if N_GPUs > 0 else True,
        gpus=N_GPUs if N_GPUs != 0 else 0,
        distributed_backend="dp" if N_GPUs > 1 else None,
        max_epochs=epoch,
        #overfit_pct=0.01
)
trainer.fit(model)