## Downloading and Preprocessing the Data

In [1]:
from collections import defaultdict
from urllib import request
import json
import pandas as pd

In [2]:
def parse_conllu_using_pandas(block):
    records = []
    for line in block.splitlines():
        if not line.startswith('#'):
            records.append(line.strip().split('\t'))
    return pd.DataFrame.from_records(
        records,
        columns=['ID', 'FORM', 'TAG', 'Misc1', 'Misc2'])

In [3]:
def tokens_to_labels(df):
    return (
        df.FORM.tolist(),
        df.TAG.tolist()
    )

In [4]:
def simplified_tokens_to_labels(df):
    simplified_form_list = []
    simplified_tag_list = []
    simplified_form_list, tag_list = tokens_to_labels(df)
    for tag in tag_list:
        if tag.startswith('B-'):
            simplified_tag_list.append('B')
        elif tag.startswith('I-'):
            simplified_tag_list.append('I')
        elif tag == 'O':
            simplified_tag_list.append('O')
        else:
            raise ValueError('Unexpected Label')
    return (
        simplified_form_list,
        simplified_tag_list
    )

In [5]:
PREFIX = "https://raw.githubusercontent.com/UniversalNER/"
DATA_URLS = {
    "en_ewt": {
        "train": "UNER_English-EWT/master/en_ewt-ud-train.iob2",
        "dev": "UNER_English-EWT/master/en_ewt-ud-dev.iob2",
        "test": "UNER_English-EWT/master/en_ewt-ud-test.iob2"
    },
    "en_pud": {
        "test": "UNER_English-PUD/master/en_pud-ud-test.iob2"
    }
}

In [6]:
# en_ewt is the main train-dev-test split
# en_pud is the OOD test set
data_dict = defaultdict(dict)
for corpus, split_dict in DATA_URLS.items():
    for split, url_suffix in split_dict.items():
        url = PREFIX + url_suffix
        with request.urlopen(url) as response:
            txt = response.read().decode('utf-8')
            data_frames = map(parse_conllu_using_pandas,
                              txt.strip().split('\n\n'))
            token_label_alignments = list(map(tokens_to_labels,
                                              data_frames))
            data_dict[corpus][split] = token_label_alignments

In [7]:
simplified_data_dict = defaultdict(dict)
for corpus, split_dict in DATA_URLS.items():
    for split, url_suffix in split_dict.items():
        url = PREFIX + url_suffix
        with request.urlopen(url) as response:
            txt = response.read().decode('utf-8')
            data_frames = map(parse_conllu_using_pandas,
                              txt.strip().split('\n\n'))
            simplified_token_label_alignments = list(map(simplified_tokens_to_labels,
                                              data_frames))
            simplified_data_dict[corpus][split] = simplified_token_label_alignments

In [8]:
# Save the data
with open('ner_data_dict.json', 'w', encoding='utf-8') as out:
    json.dump(data_dict, out, indent=2, ensure_ascii=False)

In [9]:
with open('ner_simplified_data_dict.json', 'w', encoding = 'utf-8') as out:
    json.dump(simplified_data_dict, out, indent = 2, ensure_ascii = False)

In [193]:
# Each subset of each corpus is a list of tuples where each tuple
# is a list of tokens with a corresponding list of labels.

# Train on data_dict['en_ewt']['train']; validate on data_dict['en_ewt']['dev']
# and test on data_dict['en_ewt']['test'] and data_dict['en_pud']['test']
data_dict['en_ewt']['train'][0], data_dict['en_pud']['test'][1]

([['Where', 'in', 'the', 'world', 'is', 'Iguazu', '?'],
  ['O', 'O', 'O', 'O', 'O', 'B-LOC', 'O']],
 [['For',
   'those',
   'who',
   'follow',
   'social',
   'media',
   'transitions',
   'on',
   'Capitol',
   'Hill',
   ',',
   'this',
   'will',
   'be',
   'a',
   'little',
   'different',
   '.'],
  ['O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'B-LOC',
   'I-LOC',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O']])

In [194]:
simplified_data_dict['en_ewt']['train'][0], simplified_data_dict['en_pud']['test'][1]

([['Where', 'in', 'the', 'world', 'is', 'Iguazu', '?'],
  ['O', 'O', 'O', 'O', 'O', 'B', 'O']],
 [['For',
   'those',
   'who',
   'follow',
   'social',
   'media',
   'transitions',
   'on',
   'Capitol',
   'Hill',
   ',',
   'this',
   'will',
   'be',
   'a',
   'little',
   'different',
   '.'],
  ['O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'B',
   'I',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O']])

## Using BERT for Named Entity Recognition (NER)

In [12]:
from random import shuffle
from math import ceil

import torch
import torch.nn as nn

from transformers import AutoModel, AutoTokenizer
import datasets

from tqdm.auto import tqdm

We first fine-tune and evaluate the NER performance on original tagsets ('B-LOC', 'B-PER', 'B-ORG', 'I-LOC', 'I-PER', 'I-ORG', 'O')

In [195]:
class ClassificationHead(nn.Module):
    def __init__(self, model_dim = 768, n_classes = 7):
        super().__init__()
        self.linear = nn.Linear(model_dim, n_classes)

    def forward(self, x):
        return self.linear(x)

In [98]:
#A finetuned version (attempted but receive constant zero in the evaluation scores), so the single layer classification head is used in the experiment.
class ClassificationHead(nn.Module):
    def __init__(self, model_dim = 768, n_classes = 7):
        super().__init__()
        self.linear1 = nn.Linear(model_dim, 256)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(256, 128)
        self.linear3 = nn.Linear(128, n_classes)

    def forward(self, x):
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        return self.linear3(x)

In [196]:
labels = set()
for i in range(len(data_dict['en_ewt']['train'])):
    labels.update(data_dict['en_ewt']['train'][i][1])
n_classes = len(labels)
sorted(labels)

['B-LOC', 'B-ORG', 'B-PER', 'I-LOC', 'I-ORG', 'I-PER', 'O']

In [197]:
label_to_i = {
    label: i
    for i, label in enumerate(sorted(labels))
}
i_to_label = {
    i: label
    for label, i in label_to_i.items()
}

In [198]:
label_to_i

{'B-LOC': 0,
 'B-ORG': 1,
 'B-PER': 2,
 'I-LOC': 3,
 'I-ORG': 4,
 'I-PER': 5,
 'O': 6}

In [199]:
model_tag = 'google-bert/bert-base-uncased'

tokeniser = AutoTokenizer.from_pretrained(model_tag)

In [200]:
def process_sentence(sentence, label_to_i, tokeniser, encoder, clf_head,
                     encoder_device, clf_head_device):
    words, labels = sentence
    gold_labels = torch.tensor(
        [label_to_i[label] for label in labels]).to(clf_head_device)
    tokenisation = tokeniser(words, is_split_into_words=True,
                             return_tensors='pt')
    inputs = {k: v.to(encoder_device) for k, v in tokenisation.items()}

    # Remove the embedding of the CLS token and the SEP token.
    outputs = encoder(**inputs).last_hidden_state[0, 1:-1, :]

    # Take embeddings of the first/last subword and ignore the CLS and the SEP tokens
    word_ids = tokenisation.word_ids()[1:-1]
    processed_words = set()
    first_subword_embeddings = []

    for i, word_id in enumerate(word_ids):
        if word_id not in processed_words:
            first_subword_embeddings.append(outputs[i])
            processed_words.add(word_id)

    # Check the alignment of words and labels.
    assert len(first_subword_embeddings) == gold_labels.size(0)

    # Combine subword embeddings into a tensor and copy to the device of the classification head
    clf_head_inputs = torch.vstack(
        first_subword_embeddings).to(clf_head_device)

    # Return the logits and gold labels for subsequent processing
    return clf_head(clf_head_inputs), gold_labels

In [201]:
def train_epoch(data, label_to_i, tokeniser, encoder, clf_head,
                encoder_device, clf_head_device, loss_fn, optimiser):
    encoder.train()
    epoch_losses = torch.empty(len(data))
    for step_n, sentence in tqdm(
        enumerate(data),
        total=len(data),
        desc='Train',
        leave=False
    ):
        optimiser.zero_grad()
        logits, gold_labels = process_sentence(
            sentence, label_to_i, tokeniser,
            encoder, clf_head, encoder_device,
            clf_head_device)
        loss = loss_fn(logits, gold_labels)
        loss.backward()
        optimiser.step()
        epoch_losses[step_n] = loss.item()
    return epoch_losses.mean().item()

In [202]:
def extract_spans(labels):
    spans = []
    start = None # Start index of the span
    current_label = None

    # The span is defined by the boundary and the type of a named entity.
    # We extract the span in the format of (start_index, end_index, type)
    for i, tag in enumerate(labels):
        if tag.startswith('B-'):
            # Close previous span if open
            if start is not None:
                spans.append((start, i-1, current_label))
            start = i
            # Extract the type following the boundary
            current_label = tag[2:]
        elif tag.startswith('I-') and current_label == tag[2:]:
            # Continue current span if the type matches
            continue
        else:
            # Close previous span if open
            if start is not None:
                spans.append((start, i-1, current_label))
                start = None
                current_label = None

    # Close any open span at end
    if start is not None:
        spans.append((start, len(labels)-1, current_label))

    return spans

In [203]:
# Calculate the precision, recall, and f1 score using true positives, false positives, and false negatives
def scores(tp, fp, fn):
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    return precision, recall, f1_score


def span_matches(pred_span, gold_span, labelled=True):
    # Check both the boundary and type for labelled span matching scores
    if labelled:
        return pred_span == gold_span
    else:
        return pred_span[:2] == gold_span[:2]  # Check the boundary ('B', 'I', 'O') only


def validate_epoch(data, label_to_i, tokeniser, encoder, clf_head,
                   encoder_device, clf_head_device):
    encoder.eval()
    clf_head.eval()

    labelled_tp = 0
    labelled_fp = 0
    labelled_fn = 0

    unlabelled_tp = 0
    unlabelled_fp = 0
    unlabelled_fn = 0

    for step_n, sentence in tqdm(enumerate(data), total=len(data), desc='Eval', leave=False):
        with torch.no_grad():
            logits, gold_label_ids = process_sentence(
                sentence, label_to_i, tokeniser,
                encoder, clf_head, encoder_device,
                clf_head_device)

        predicted_label_ids = torch.argmax(logits, dim=-1).cpu().tolist()
        gold_label_ids = gold_label_ids.cpu().tolist()

        predicted_labels = [i_to_label[i] for i in predicted_label_ids]
        gold_labels = [i_to_label[i] for i in gold_label_ids]

        pred_spans = extract_spans(predicted_labels)
        gold_spans = extract_spans(gold_labels)

        # Labelled span metrics
        for pred_span in pred_spans:
            if any(span_matches(pred_span, gs, labelled=True) for gs in gold_spans):
                labelled_tp += 1
            else:
                labelled_fp += 1

        for gold_span in gold_spans:
            if not any(span_matches(gs, gold_span, labelled=True) for gs in pred_spans):
                labelled_fn += 1

        # Unlabelled span metrics
        for pred_span in pred_spans:
            if any(span_matches(pred_span, gs, labelled=False) for gs in gold_spans):
                unlabelled_tp += 1
            else:
                unlabelled_fp += 1

        for gold_span in gold_spans:
            if not any(span_matches(gs, gold_span, labelled=False) for gs in pred_spans):
                unlabelled_fn += 1

    labelled_precision, labelled_recall, labelled_f1 = scores(
        labelled_tp, labelled_fp, labelled_fn)

    unlabelled_precision, unlabelled_recall, unlabelled_f1 = scores(
        unlabelled_tp, unlabelled_fp, unlabelled_fn)

    return {
        "labelled span matching score": {
            "precision": labelled_precision,
            "recall": labelled_recall,
            "f1_score": labelled_f1,
        },
        "unlabelled span matching score": {
            "precision": unlabelled_precision,
            "recall": unlabelled_recall,
            "f1_score": unlabelled_f1,
        }
    }

In [204]:
encoder_device = 0
encoder = AutoModel.from_pretrained(
    model_tag).to(encoder_device)
clf_head = ClassificationHead(n_classes= n_classes)
clf_head_device = 0
clf_head.to(clf_head_device);

In [120]:
# baseline
n_epochs = 5
optimiser = torch.optim.AdamW(
    list(encoder.parameters()) + list(clf_head.parameters()), lr=10**(-5))
for epoch_n in tqdm(range(n_epochs)):
    results = validate_epoch(data_dict['en_ewt']['dev'][:250], label_to_i, tokeniser, encoder,
                              clf_head, encoder_device, clf_head_device)
    print(results)

  0%|          | 0/5 [00:00<?, ?it/s]

Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.0037914691943127963, 'recall': 0.032, 'f1_score': 0.006779661016949153}, 'unlabelled span matching score': {'precision': 0.01800947867298578, 'recall': 0.152, 'f1_score': 0.03220338983050847}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.0037914691943127963, 'recall': 0.032, 'f1_score': 0.006779661016949153}, 'unlabelled span matching score': {'precision': 0.01800947867298578, 'recall': 0.152, 'f1_score': 0.03220338983050847}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.0037914691943127963, 'recall': 0.032, 'f1_score': 0.006779661016949153}, 'unlabelled span matching score': {'precision': 0.01800947867298578, 'recall': 0.152, 'f1_score': 0.03220338983050847}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.0037914691943127963, 'recall': 0.032, 'f1_score': 0.006779661016949153}, 'unlabelled span matching score': {'precision': 0.01800947867298578, 'recall': 0.152, 'f1_score': 0.03220338983050847}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.0037914691943127963, 'recall': 0.032, 'f1_score': 0.006779661016949153}, 'unlabelled span matching score': {'precision': 0.01800947867298578, 'recall': 0.152, 'f1_score': 0.03220338983050847}}


In [122]:
# fine-tuned validation results
n_epochs = 5
loss_fn = nn.CrossEntropyLoss()
optimiser = torch.optim.AdamW(
    list(encoder.parameters()) + list(clf_head.parameters()), lr=10**(-5))
for epoch_n in tqdm(range(n_epochs)):
    loss = train_epoch(data_dict['en_ewt']['train'][:500], label_to_i, tokeniser, encoder, clf_head,
                       encoder_device, clf_head_device, loss_fn, optimiser)
    print(f'Epoch {epoch_n+1} training loss: {loss:.2f}')
    results = validate_epoch(data_dict['en_ewt']['dev'][:250], label_to_i, tokeniser, encoder,
                              clf_head, encoder_device, clf_head_device)
    print(results)

  0%|          | 0/5 [00:00<?, ?it/s]

Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 1 training loss: 0.29


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.56, 'recall': 0.336, 'f1_score': 0.42}, 'unlabelled span matching score': {'precision': 0.56, 'recall': 0.336, 'f1_score': 0.42}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 2 training loss: 0.13


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.525974025974026, 'recall': 0.648, 'f1_score': 0.5806451612903227}, 'unlabelled span matching score': {'precision': 0.5324675324675324, 'recall': 0.656, 'f1_score': 0.5878136200716846}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 3 training loss: 0.08


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.5193798449612403, 'recall': 0.536, 'f1_score': 0.5275590551181102}, 'unlabelled span matching score': {'precision': 0.5271317829457365, 'recall': 0.544, 'f1_score': 0.5354330708661418}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 4 training loss: 0.05


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.7758620689655172, 'recall': 0.72, 'f1_score': 0.7468879668049794}, 'unlabelled span matching score': {'precision': 0.7844827586206896, 'recall': 0.728, 'f1_score': 0.7551867219917011}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 5 training loss: 0.03


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.717741935483871, 'recall': 0.712, 'f1_score': 0.714859437751004}, 'unlabelled span matching score': {'precision': 0.75, 'recall': 0.744, 'f1_score': 0.746987951807229}}


In [205]:
# baseline test results
n_epochs = 5
optimiser = torch.optim.AdamW(
    list(encoder.parameters()) + list(clf_head.parameters()), lr=10**(-5))
for epoch_n in tqdm(range(n_epochs)):
    results = validate_epoch(data_dict['en_ewt']['test'][:250], label_to_i, tokeniser, encoder,
                              clf_head, encoder_device, clf_head_device)
    print(results)

  0%|          | 0/5 [00:00<?, ?it/s]

Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.017335766423357664, 'recall': 0.15079365079365079, 'f1_score': 0.03109656301145663}, 'unlabelled span matching score': {'precision': 0.021897810218978103, 'recall': 0.19047619047619047, 'f1_score': 0.03927986906710311}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.017335766423357664, 'recall': 0.15079365079365079, 'f1_score': 0.03109656301145663}, 'unlabelled span matching score': {'precision': 0.021897810218978103, 'recall': 0.19047619047619047, 'f1_score': 0.03927986906710311}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.017335766423357664, 'recall': 0.15079365079365079, 'f1_score': 0.03109656301145663}, 'unlabelled span matching score': {'precision': 0.021897810218978103, 'recall': 0.19047619047619047, 'f1_score': 0.03927986906710311}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.017335766423357664, 'recall': 0.15079365079365079, 'f1_score': 0.03109656301145663}, 'unlabelled span matching score': {'precision': 0.021897810218978103, 'recall': 0.19047619047619047, 'f1_score': 0.03927986906710311}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.017335766423357664, 'recall': 0.15079365079365079, 'f1_score': 0.03109656301145663}, 'unlabelled span matching score': {'precision': 0.021897810218978103, 'recall': 0.19047619047619047, 'f1_score': 0.03927986906710311}}


In [206]:
# baseline test results
n_epochs = 5
optimiser = torch.optim.AdamW(
    list(encoder.parameters()) + list(clf_head.parameters()), lr=10**(-5))
for epoch_n in tqdm(range(n_epochs)):
    results = validate_epoch(data_dict['en_pud']['test'][:250], label_to_i, tokeniser, encoder,
                              clf_head, encoder_device, clf_head_device)
    print(results)

  0%|          | 0/5 [00:00<?, ?it/s]

Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.00847816871555744, 'recall': 0.09345794392523364, 'f1_score': 0.015546055188495918}, 'unlabelled span matching score': {'precision': 0.022891055532005086, 'recall': 0.2523364485981308, 'f1_score': 0.041974349008938976}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.00847816871555744, 'recall': 0.09345794392523364, 'f1_score': 0.015546055188495918}, 'unlabelled span matching score': {'precision': 0.022891055532005086, 'recall': 0.2523364485981308, 'f1_score': 0.041974349008938976}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.00847816871555744, 'recall': 0.09345794392523364, 'f1_score': 0.015546055188495918}, 'unlabelled span matching score': {'precision': 0.022891055532005086, 'recall': 0.2523364485981308, 'f1_score': 0.041974349008938976}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.00847816871555744, 'recall': 0.09345794392523364, 'f1_score': 0.015546055188495918}, 'unlabelled span matching score': {'precision': 0.022891055532005086, 'recall': 0.2523364485981308, 'f1_score': 0.041974349008938976}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.00847816871555744, 'recall': 0.09345794392523364, 'f1_score': 0.015546055188495918}, 'unlabelled span matching score': {'precision': 0.022891055532005086, 'recall': 0.2523364485981308, 'f1_score': 0.041974349008938976}}


In [123]:
# fine-tuned test results
n_epochs = 5
loss_fn = nn.CrossEntropyLoss()
optimiser = torch.optim.AdamW(
    list(encoder.parameters()) + list(clf_head.parameters()), lr=10**(-5))
for epoch_n in tqdm(range(n_epochs)):
    loss = train_epoch(data_dict['en_ewt']['train'][:500], label_to_i, tokeniser, encoder, clf_head,
                       encoder_device, clf_head_device, loss_fn, optimiser)
    print(f'Epoch {epoch_n+1} training loss: {loss:.2f}')
    results = validate_epoch(data_dict['en_ewt']['test'][:250], label_to_i, tokeniser, encoder,
                              clf_head, encoder_device, clf_head_device)
    print(results)

  0%|          | 0/5 [00:00<?, ?it/s]

Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 1 training loss: 0.02


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.574468085106383, 'recall': 0.42857142857142855, 'f1_score': 0.4909090909090909}, 'unlabelled span matching score': {'precision': 0.6382978723404256, 'recall': 0.47619047619047616, 'f1_score': 0.5454545454545455}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 2 training loss: 0.01


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.5, 'recall': 0.5396825396825397, 'f1_score': 0.5190839694656489}, 'unlabelled span matching score': {'precision': 0.625, 'recall': 0.6746031746031746, 'f1_score': 0.6488549618320612}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 3 training loss: 0.01


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.5462184873949579, 'recall': 0.5158730158730159, 'f1_score': 0.5306122448979592}, 'unlabelled span matching score': {'precision': 0.6302521008403361, 'recall': 0.5952380952380952, 'f1_score': 0.6122448979591836}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 4 training loss: 0.00


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.6198347107438017, 'recall': 0.5952380952380952, 'f1_score': 0.6072874493927126}, 'unlabelled span matching score': {'precision': 0.6528925619834711, 'recall': 0.626984126984127, 'f1_score': 0.6396761133603239}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 5 training loss: 0.00


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.6229508196721312, 'recall': 0.6031746031746031, 'f1_score': 0.6129032258064517}, 'unlabelled span matching score': {'precision': 0.6475409836065574, 'recall': 0.626984126984127, 'f1_score': 0.6370967741935485}}


In [125]:
# fine-tuned test results
n_epochs = 5
loss_fn = nn.CrossEntropyLoss()
optimiser = torch.optim.AdamW(
    list(encoder.parameters()) + list(clf_head.parameters()), lr=10**(-5))
for epoch_n in tqdm(range(n_epochs)):
    loss = train_epoch(data_dict['en_ewt']['train'][:500], label_to_i, tokeniser, encoder, clf_head,
                       encoder_device, clf_head_device, loss_fn, optimiser)
    print(f'Epoch {epoch_n+1} training loss: {loss:.2f}')
    results = validate_epoch(data_dict['en_pud']['test'][:250], label_to_i, tokeniser, encoder,
                              clf_head, encoder_device, clf_head_device)
    print(results)

  0%|          | 0/5 [00:00<?, ?it/s]

Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 1 training loss: 0.00


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.6329113924050633, 'recall': 0.4672897196261682, 'f1_score': 0.5376344086021505}, 'unlabelled span matching score': {'precision': 0.7468354430379747, 'recall': 0.5514018691588785, 'f1_score': 0.6344086021505376}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 2 training loss: 0.00


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.6474358974358975, 'recall': 0.4719626168224299, 'f1_score': 0.5459459459459459}, 'unlabelled span matching score': {'precision': 0.7628205128205128, 'recall': 0.5560747663551402, 'f1_score': 0.6432432432432432}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 3 training loss: 0.00


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.5217391304347826, 'recall': 0.4485981308411215, 'f1_score': 0.4824120603015076}, 'unlabelled span matching score': {'precision': 0.6521739130434783, 'recall': 0.5607476635514018, 'f1_score': 0.6030150753768845}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 4 training loss: 0.00


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.5568181818181818, 'recall': 0.45794392523364486, 'f1_score': 0.5025641025641026}, 'unlabelled span matching score': {'precision': 0.6818181818181818, 'recall': 0.5607476635514018, 'f1_score': 0.6153846153846154}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 5 training loss: 0.00


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.5974842767295597, 'recall': 0.4439252336448598, 'f1_score': 0.5093833780160859}, 'unlabelled span matching score': {'precision': 0.7358490566037735, 'recall': 0.5467289719626168, 'f1_score': 0.6273458445040214}}


Then we switch to the simplified tagset ('B', 'I', 'O')

In [207]:
class ClassificationHead(nn.Module):
    def __init__(self, model_dim = 768, n_classes = 3):
        super().__init__()
        self.linear = nn.Linear(model_dim, n_classes)

    def forward(self, x):
        return self.linear(x)

In [208]:
simplified_labels = set()
for i in range(len(simplified_data_dict['en_ewt']['train'])):
    simplified_labels.update(simplified_data_dict['en_ewt']['train'][i][1])
n_simplified_classes = len(simplified_labels)
sorted(simplified_labels)

['B', 'I', 'O']

In [209]:
simplified_label_to_i = {
    simplified_label: i
    for i, simplified_label in enumerate(sorted(simplified_labels))
}
simplified_i_to_label = {
    i: simplified_label
    for simplified_label, i in simplified_label_to_i.items()
}

In [210]:
simplified_label_to_i

{'B': 0, 'I': 1, 'O': 2}

In [211]:
def process_sentence(sentence, label_to_i, tokeniser, encoder, clf_head,
                     encoder_device, clf_head_device):
    words, labels = sentence
    gold_labels = torch.tensor(
        [simplified_label_to_i[label] for label in labels]).to(clf_head_device)
    tokenisation = tokeniser(words, is_split_into_words=True,
                             return_tensors='pt')
    inputs = {k: v.to(encoder_device) for k, v in tokenisation.items()}

    # Remove the embedding of the CLS token and the SEP token.
    outputs = encoder(**inputs).last_hidden_state[0, 1:-1, :]

    # Take embeddings of the first/last subword and ignore the CLS and the SEP tokens
    word_ids = tokenisation.word_ids()[1:-1]
    processed_words = set()
    first_subword_embeddings = []

    for i, word_id in enumerate(word_ids):
        if word_id not in processed_words:
            first_subword_embeddings.append(outputs[i])
            processed_words.add(word_id)

    # Check the alignment of words and labels.
    assert len(first_subword_embeddings) == gold_labels.size(0)

    # Combine subword embeddings into a tensor and copy to the device of the classification head
    clf_head_inputs = torch.vstack(
        first_subword_embeddings).to(clf_head_device)

    # Return the logits and gold labels for subsequent processing
    return clf_head(clf_head_inputs), gold_labels

In [212]:
def train_epoch(data, label_to_i, tokeniser, encoder, clf_head,
                encoder_device, clf_head_device, loss_fn, optimiser):
    encoder.train()
    epoch_losses = torch.empty(len(data))
    for step_n, sentence in tqdm(
        enumerate(data),
        total=len(data),
        desc='Train',
        leave=False
    ):
        optimiser.zero_grad()
        logits, gold_labels = process_sentence(
            sentence, simplified_label_to_i, tokeniser,
            encoder, clf_head, encoder_device,
            clf_head_device)
        loss = loss_fn(logits, gold_labels)
        loss.backward()
        optimiser.step()
        epoch_losses[step_n] = loss.item()
    return epoch_losses.mean().item()

In [213]:
def extract_spans(labels):
    spans = []
    start = None # Start index of the span
    current_label = None

    # The span is defined by the boundary and the type of a named entity.
    # We extract the span in the format of (start_index, end_index, type)
    for i, tag in enumerate(labels):
        if tag.startswith('B-'):
            # Close previous span if open
            if start is not None:
                spans.append((start, i-1, current_label))
            start = i
            # Extract the type following the boundary
            current_label = tag[2:]
        elif tag.startswith('I-') and current_label == tag[2:]:
            # Continue current span if the type matches
            continue
        else:
            # Close previous span if open
            if start is not None:
                spans.append((start, i-1, current_label))
                start = None
                current_label = None

    # Close any open span at end
    if start is not None:
        spans.append((start, len(labels)-1, current_label))

    return spans

In [214]:
# Calculate the precision, recall, and f1 score using true positives, false positives, and false negatives
def scores(tp, fp, fn):
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    return precision, recall, f1_score


def span_matches(pred_span, gold_span, labelled=True):
    # Check both the boundary and type for labelled span matching scores
    if labelled:
        return pred_span == gold_span
    else:
        return pred_span[:2] == gold_span[:2]  # Check the boundary ('B', 'I', 'O') only


def validate_epoch(data, label_to_i, tokeniser, encoder, clf_head,
                   encoder_device, clf_head_device):
    encoder.eval()
    clf_head.eval()

    labelled_tp = 0
    labelled_fp = 0
    labelled_fn = 0

    unlabelled_tp = 0
    unlabelled_fp = 0
    unlabelled_fn = 0

    for step_n, sentence in tqdm(enumerate(data), total=len(data), desc='Eval', leave=False):
        with torch.no_grad():
            logits, gold_label_ids = process_sentence(
                sentence, label_to_i, tokeniser,
                encoder, clf_head, encoder_device,
                clf_head_device)

        predicted_label_ids = torch.argmax(logits, dim=-1).cpu().tolist()
        gold_label_ids = gold_label_ids.cpu().tolist()

        predicted_labels = [i_to_label[i] for i in predicted_label_ids]
        gold_labels = [i_to_label[i] for i in gold_label_ids]

        pred_spans = extract_spans(predicted_labels)
        gold_spans = extract_spans(gold_labels)

        # Labelled span metrics
        for pred_span in pred_spans:
            if any(span_matches(pred_span, gs, labelled=True) for gs in gold_spans):
                labelled_tp += 1
            else:
                labelled_fp += 1

        for gold_span in gold_spans:
            if not any(span_matches(gs, gold_span, labelled=True) for gs in pred_spans):
                labelled_fn += 1

        # Unlabelled span metrics
        for pred_span in pred_spans:
            if any(span_matches(pred_span, gs, labelled=False) for gs in gold_spans):
                unlabelled_tp += 1
            else:
                unlabelled_fp += 1

        for gold_span in gold_spans:
            if not any(span_matches(gs, gold_span, labelled=False) for gs in pred_spans):
                unlabelled_fn += 1

    labelled_precision, labelled_recall, labelled_f1 = scores(
        labelled_tp, labelled_fp, labelled_fn)

    unlabelled_precision, unlabelled_recall, unlabelled_f1 = scores(
        unlabelled_tp, unlabelled_fp, unlabelled_fn)

    return {
        "labelled span matching score": {
            "precision": labelled_precision,
            "recall": labelled_recall,
            "f1_score": labelled_f1,
        },
        "unlabelled span matching score": {
            "precision": unlabelled_precision,
            "recall": unlabelled_recall,
            "f1_score": unlabelled_f1,
        }
    }

In [215]:
encoder_device = 0
encoder = AutoModel.from_pretrained(
    model_tag).to(encoder_device)
clf_head = ClassificationHead(n_classes= n_classes)
clf_head_device = 0
clf_head.to(clf_head_device);

In [216]:
# baseline test results
n_epochs = 5
optimiser = torch.optim.AdamW(
    list(encoder.parameters()) + list(clf_head.parameters()), lr=10**(-5))
for epoch_n in tqdm(range(n_epochs)):
    results = validate_epoch(simplified_data_dict['en_ewt']['test'][:250], simplified_label_to_i, tokeniser, encoder,
                              clf_head, encoder_device, clf_head_device)
    print(results)

  0%|          | 0/5 [00:00<?, ?it/s]

Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.40099833610648916, 'recall': 0.1644489935175708, 'f1_score': 0.23324461650133074}, 'unlabelled span matching score': {'precision': 0.9292845257903494, 'recall': 0.38109860116001365, 'f1_score': 0.5405274618920881}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.40099833610648916, 'recall': 0.1644489935175708, 'f1_score': 0.23324461650133074}, 'unlabelled span matching score': {'precision': 0.9292845257903494, 'recall': 0.38109860116001365, 'f1_score': 0.5405274618920881}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.40099833610648916, 'recall': 0.1644489935175708, 'f1_score': 0.23324461650133074}, 'unlabelled span matching score': {'precision': 0.9292845257903494, 'recall': 0.38109860116001365, 'f1_score': 0.5405274618920881}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.40099833610648916, 'recall': 0.1644489935175708, 'f1_score': 0.23324461650133074}, 'unlabelled span matching score': {'precision': 0.9292845257903494, 'recall': 0.38109860116001365, 'f1_score': 0.5405274618920881}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.40099833610648916, 'recall': 0.1644489935175708, 'f1_score': 0.23324461650133074}, 'unlabelled span matching score': {'precision': 0.9292845257903494, 'recall': 0.38109860116001365, 'f1_score': 0.5405274618920881}}


In [217]:
# baseline test results
n_epochs = 5
optimiser = torch.optim.AdamW(
    list(encoder.parameters()) + list(clf_head.parameters()), lr=10**(-5))
for epoch_n in tqdm(range(n_epochs)):
    results = validate_epoch(simplified_data_dict['en_pud']['test'][:250], simplified_label_to_i, tokeniser, encoder,
                              clf_head, encoder_device, clf_head_device)
    print(results)

  0%|          | 0/5 [00:00<?, ?it/s]

Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.25482787573467675, 'recall': 0.1154650941601674, 'f1_score': 0.15892132478073046}, 'unlabelled span matching score': {'precision': 0.9126784214945424, 'recall': 0.4135438463001712, 'f1_score': 0.5691844482262076}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.25482787573467675, 'recall': 0.1154650941601674, 'f1_score': 0.15892132478073046}, 'unlabelled span matching score': {'precision': 0.9126784214945424, 'recall': 0.4135438463001712, 'f1_score': 0.5691844482262076}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.25482787573467675, 'recall': 0.1154650941601674, 'f1_score': 0.15892132478073046}, 'unlabelled span matching score': {'precision': 0.9126784214945424, 'recall': 0.4135438463001712, 'f1_score': 0.5691844482262076}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.25482787573467675, 'recall': 0.1154650941601674, 'f1_score': 0.15892132478073046}, 'unlabelled span matching score': {'precision': 0.9126784214945424, 'recall': 0.4135438463001712, 'f1_score': 0.5691844482262076}}


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.25482787573467675, 'recall': 0.1154650941601674, 'f1_score': 0.15892132478073046}, 'unlabelled span matching score': {'precision': 0.9126784214945424, 'recall': 0.4135438463001712, 'f1_score': 0.5691844482262076}}


In [218]:
# fine-tuned test results
n_epochs = 5
loss_fn = nn.CrossEntropyLoss()
optimiser = torch.optim.AdamW(
    list(encoder.parameters()) + list(clf_head.parameters()), lr=10**(-5))
for epoch_n in tqdm(range(n_epochs)):
    loss = train_epoch(simplified_data_dict['en_ewt']['train'][:500], simplified_label_to_i, tokeniser, encoder, clf_head,
                       encoder_device, clf_head_device, loss_fn, optimiser)
    print(f'Epoch {epoch_n+1} training loss: {loss:.2f}')
    results = validate_epoch(simplified_data_dict['en_ewt']['test'][:250], simplified_label_to_i, tokeniser, encoder,
                              clf_head, encoder_device, clf_head_device)
    print(results)

  0%|          | 0/5 [00:00<?, ?it/s]

Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 1 training loss: 0.26


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.9423404981235073, 'recall': 0.9423404981235073, 'f1_score': 0.9423404981235073}, 'unlabelled span matching score': {'precision': 1.0, 'recall': 1.0, 'f1_score': 1.0}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 2 training loss: 0.08


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.962128966223132, 'recall': 0.962128966223132, 'f1_score': 0.962128966223132}, 'unlabelled span matching score': {'precision': 1.0, 'recall': 1.0, 'f1_score': 1.0}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 3 training loss: 0.04


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.9658819515523712, 'recall': 0.9658819515523712, 'f1_score': 0.9658819515523712}, 'unlabelled span matching score': {'precision': 1.0, 'recall': 1.0, 'f1_score': 1.0}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 4 training loss: 0.02


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.9689525759126578, 'recall': 0.9689525759126578, 'f1_score': 0.9689525759126578}, 'unlabelled span matching score': {'precision': 1.0, 'recall': 1.0, 'f1_score': 1.0}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 5 training loss: 0.01


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.9645172296144661, 'recall': 0.9645172296144661, 'f1_score': 0.9645172296144661}, 'unlabelled span matching score': {'precision': 1.0, 'recall': 1.0, 'f1_score': 1.0}}


In [219]:
# fine-tuned test results
n_epochs = 5
loss_fn = nn.CrossEntropyLoss()
optimiser = torch.optim.AdamW(
    list(encoder.parameters()) + list(clf_head.parameters()), lr=10**(-5))
for epoch_n in tqdm(range(n_epochs)):
    loss = train_epoch(simplified_data_dict['en_ewt']['train'][:500], simplified_label_to_i, tokeniser, encoder, clf_head,
                       encoder_device, clf_head_device, loss_fn, optimiser)
    print(f'Epoch {epoch_n+1} training loss: {loss:.2f}')
    results = validate_epoch(simplified_data_dict['en_pud']['test'][:250], simplified_label_to_i, tokeniser, encoder,
                              clf_head, encoder_device, clf_head_device)
    print(results)

  0%|          | 0/5 [00:00<?, ?it/s]

Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 1 training loss: 0.00


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.9722275061822332, 'recall': 0.9722275061822332, 'f1_score': 0.9722275061822332}, 'unlabelled span matching score': {'precision': 1.0, 'recall': 1.0, 'f1_score': 1.0}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 2 training loss: 0.01


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.9714666159406505, 'recall': 0.9714666159406505, 'f1_score': 0.9714666159406505}, 'unlabelled span matching score': {'precision': 1.0, 'recall': 1.0, 'f1_score': 1.0}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 3 training loss: 0.00


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.9619554879208674, 'recall': 0.9619554879208674, 'f1_score': 0.9619554879208674}, 'unlabelled span matching score': {'precision': 1.0, 'recall': 1.0, 'f1_score': 1.0}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 4 training loss: 0.00


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.9638577135248241, 'recall': 0.9638577135248241, 'f1_score': 0.9638577135248241}, 'unlabelled span matching score': {'precision': 1.0, 'recall': 1.0, 'f1_score': 1.0}}


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 5 training loss: 0.01


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

{'labelled span matching score': {'precision': 0.9747003994673769, 'recall': 0.9747003994673769, 'f1_score': 0.9747003994673769}, 'unlabelled span matching score': {'precision': 1.0, 'recall': 1.0, 'f1_score': 1.0}}


Next we will fine-tune T5 on NER.

In [139]:
from random import shuffle
from math import ceil

import torch
import torch.nn as nn

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

from tqdm.auto import tqdm

In [140]:
import json

with open('ner_data_dict.json', 'r', encoding='utf-8') as f:
    data_dict = json.load(f)
with open('ner_simplified_data_dict.json', 'r', encoding = 'utf-8') as f:
    simplified_data_dict = json.load(f)

In [141]:
model_tag = 'google-t5/t5-base'

device = 0 if torch.cuda.is_available() else 'cpu'
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_tag, cache_dir='./hf_cache').to(device)
tokeniser = AutoTokenizer.from_pretrained(model_tag)

optim = torch.optim.AdamW(
    model.parameters(),
    lr=10**(-4))

In [142]:
labels = set()
for i in range(len(data_dict['en_ewt']['train'])):
    labels.update(data_dict['en_ewt']['train'][i][1])
n_classes = len(labels)
sorted(labels)

['B-LOC', 'B-ORG', 'B-PER', 'I-LOC', 'I-ORG', 'I-PER', 'O']

In [143]:
simplified_labels = set()
for i in range(len(simplified_data_dict['en_ewt']['train'])):
    simplified_labels.update(simplified_data_dict['en_ewt']['train'][i][1])
n_simplified_classes = len(simplified_labels)
sorted(simplified_labels)

['B', 'I', 'O']

In [144]:
def process_batch(batch_inputs, batch_labels,
                  tokeniser, model, device,
                  optimiser, max_len=512):
    optimiser.zero_grad()
    tokenisation = tokeniser(
        batch_inputs,
        return_tensors='pt',
        max_length=max_len,
        padding='longest',
        truncation=True
    )
    input_ids = tokenisation.input_ids.to(device)
    attention_mask = tokenisation.attention_mask.to(device)
    labels = tokeniser(
        batch_labels,
        return_tensors='pt',
        max_length=max_len,
        padding='longest',
        truncation=True
    ).input_ids.to(device)

    labels[labels == tokeniser.pad_token_id] = -100
    inputs = {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }
    loss = model(**inputs).loss
    loss.backward()
    optimiser.step()
    return loss.item()

In [162]:
def prepare_sentence(sentence_array):
    words, labels = sentence_array
    prepared_inputs = []
    for i in range(len(words)):
        tmp = words[:i] + ['~', words[i], '~'] + words[i+1:]
        prepared_inputs.append(' '.join(tmp))
    return prepared_inputs, labels

In [181]:
prepare_sentence(data_dict['en_ewt']['train'][0])

(['~ Where ~ in the world is Iguazu ?',
  'Where ~ in ~ the world is Iguazu ?',
  'Where in ~ the ~ world is Iguazu ?',
  'Where in the ~ world ~ is Iguazu ?',
  'Where in the world ~ is ~ Iguazu ?',
  'Where in the world is ~ Iguazu ~ ?',
  'Where in the world is Iguazu ~ ? ~'],
 ['O', 'O', 'O', 'O', 'O', 'B-LOC', 'O'])

In [163]:
def train_epoch(train_inputs, batch_size,
                tokeniser, model, device, optimizer):
    model.train()

    n_steps = len(train_inputs)
    epoch_losses = torch.zeros(n_steps)
    for step_n in tqdm(range(n_steps), leave=False, desc='Train'):
        prepared_inputs, labels = prepare_sentence(train_inputs[step_n])
        n_batches = ceil(len(prepared_inputs) / batch_size)
        sentence_losses_accum = 0.0
        for step_n in range(n_batches):
            lo = step_n * batch_size
            hi = lo + batch_size
            batch_texts = prepared_inputs[lo:hi]
            batch_labels = labels[lo:hi]
            loss = process_batch(batch_texts, batch_labels,
                                 tokeniser, model, device,
                                 optimizer)
            sentence_losses_accum += loss
        epoch_losses[step_n] = sentence_losses_accum / n_batches
    return epoch_losses.mean().item()

In [164]:
def get_class_prediction(prompt, tokeniser, model, device, max_len=512):
    tokenisation = tokeniser(
        prompt,
        return_tensors='pt',
        max_length=max_len,
        truncation=True
    )
    input_ids = tokenisation.input_ids.to(device)
    attention_mask = tokenisation.attention_mask.to(device)
    output = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
    max_new_tokens=10).squeeze()
    output_string = tokeniser.decode(
        output,
        skip_special_tokens=True
    ).strip()
    if not output_string:
        return None
    return output_string.split()[0]

In [170]:
def validate_epoch(dev_inputs, tokeniser, model, device, max_len=512):
    model.eval()
    n_steps = len(dev_inputs)
    epoch_hits = []
    for step_n in tqdm(range(n_steps), leave=False, desc='Validate'):
        prepared_inputs, labels = prepare_sentence(dev_inputs[step_n])
        with torch.no_grad():
            for input_sentence, gold_label in zip(prepared_inputs, labels):
                predicted_label = get_class_prediction(
                    input_sentence, tokeniser, model, device,
                    max_len=max_len)
                epoch_hits.append(int(predicted_label == gold_label))
    return sum(epoch_hits) / len(epoch_hits)

In [173]:
n_epochs = 3
batch_size = 16

n_train_exx = 500
n_dev_exx = 100

best_accuracy = 0.0
for epoch_n in tqdm(range(n_epochs)):
    epoch_loss = train_epoch(data_dict['en_ewt']['train'][: n_train_exx], batch_size,
                             tokeniser, model, device, optim)
    print(f'Epoch {epoch_n+1} loss:', round(epoch_loss, 2))
    epoch_dev_accuracy = validate_epoch(
        data_dict['en_ewt']['dev'][: n_dev_exx], tokeniser, model, device)

    print(f'Epoch {epoch_n+1} dev accuracy: {epoch_dev_accuracy:.2f}')

  0%|          | 0/3 [00:00<?, ?it/s]

Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 1 loss: 0.0


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1 dev accuracy: 0.96


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 2 loss: 0.0


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 2 dev accuracy: 0.98


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 3 loss: 0.0


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 3 dev accuracy: 0.98


In [176]:
n_epochs = 3
batch_size = 16

n_train_exx = 500
n_test_exx = 100

best_accuracy = 0.0
for epoch_n in tqdm(range(n_epochs)):
    epoch_test_accuracy = validate_epoch(
        data_dict['en_ewt']['test'][: n_test_exx], tokeniser, model, device)

    print(f'Epoch {epoch_n+1} test accuracy: {epoch_test_accuracy:.2f}')

  0%|          | 0/3 [00:00<?, ?it/s]

Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1 test accuracy: 0.97


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 2 test accuracy: 0.97


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 3 test accuracy: 0.97


In [178]:
n_epochs = 3
batch_size = 16

n_train_exx = 500
n_test_exx = 100

best_accuracy = 0.0
for epoch_n in tqdm(range(n_epochs)):
    epoch_test_accuracy = validate_epoch(
        data_dict['en_pud']['test'][: n_test_exx], tokeniser, model, device)

    print(f'Epoch {epoch_n+1} test accuracy: {epoch_test_accuracy:.2f}')

  0%|          | 0/3 [00:00<?, ?it/s]

Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1 test accuracy: 0.97


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 2 test accuracy: 0.97


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 3 test accuracy: 0.97


In [174]:
n_epochs = 3
batch_size = 16

n_train_exx = 500
n_test_exx = 100

best_accuracy = 0.0
for epoch_n in tqdm(range(n_epochs)):
    epoch_loss = train_epoch(data_dict['en_ewt']['train'][: n_train_exx], batch_size,
                             tokeniser, model, device, optim)
    print(f'Epoch {epoch_n+1} loss:', round(epoch_loss, 2))
    epoch_test_accuracy = validate_epoch(
        data_dict['en_ewt']['test'][: n_test_exx], tokeniser, model, device)

    print(f'Epoch {epoch_n+1} test accuracy: {epoch_test_accuracy:.2f}')

  0%|          | 0/3 [00:00<?, ?it/s]

Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 1 loss: 0.0


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1 test accuracy: 0.97


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 2 loss: 0.0


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 2 test accuracy: 0.97


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 3 loss: 0.0


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 3 test accuracy: 0.97


In [175]:
n_epochs = 3
batch_size = 16

n_train_exx = 500
n_test_exx = 100

best_accuracy = 0.0
for epoch_n in tqdm(range(n_epochs)):
    epoch_loss = train_epoch(data_dict['en_ewt']['train'][: n_train_exx], batch_size,
                             tokeniser, model, device, optim)
    print(f'Epoch {epoch_n+1} loss:', round(epoch_loss, 2))
    epoch_test_accuracy = validate_epoch(
        data_dict['en_pud']['test'][: n_test_exx], tokeniser, model, device)

    print(f'Epoch {epoch_n+1} test accuracy: {epoch_test_accuracy:.2f}')

  0%|          | 0/3 [00:00<?, ?it/s]

Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 1 loss: 0.0


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1 test accuracy: 0.98


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 2 loss: 0.0


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 2 test accuracy: 0.98


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 3 loss: 0.0


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 3 test accuracy: 0.97


In [182]:
n_epochs = 3
batch_size = 16

n_train_exx = 500
n_test_exx = 100

best_accuracy = 0.0
for epoch_n in tqdm(range(n_epochs)):
    epoch_test_accuracy = validate_epoch(
        simplified_data_dict['en_ewt']['test'][: n_test_exx], tokeniser, model, device)

    print(f'Epoch {epoch_n+1} test accuracy: {epoch_test_accuracy:.2f}')

  0%|          | 0/3 [00:00<?, ?it/s]

Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1 test accuracy: 0.94


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 2 test accuracy: 0.94


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 3 test accuracy: 0.94


In [183]:
n_epochs = 3
batch_size = 16

n_train_exx = 500
n_test_exx = 100

best_accuracy = 0.0
for epoch_n in tqdm(range(n_epochs)):
    epoch_test_accuracy = validate_epoch(
        simplified_data_dict['en_pud']['test'][: n_test_exx], tokeniser, model, device)

    print(f'Epoch {epoch_n+1} test accuracy: {epoch_test_accuracy:.2f}')

  0%|          | 0/3 [00:00<?, ?it/s]

Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1 test accuracy: 0.93


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 2 test accuracy: 0.93


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 3 test accuracy: 0.93


In [184]:
n_epochs = 3
batch_size = 16

n_train_exx = 500
n_test_exx = 100

best_accuracy = 0.0
for epoch_n in tqdm(range(n_epochs)):
    epoch_loss = train_epoch(simplified_data_dict['en_ewt']['train'][: n_train_exx], batch_size,
                             tokeniser, model, device, optim)
    print(f'Epoch {epoch_n+1} loss:', round(epoch_loss, 2))
    epoch_test_accuracy = validate_epoch(
        simplified_data_dict['en_ewt']['test'][: n_test_exx], tokeniser, model, device)

    print(f'Epoch {epoch_n+1} test accuracy: {epoch_test_accuracy:.2f}')

  0%|          | 0/3 [00:00<?, ?it/s]

Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 1 loss: 0.0


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1 test accuracy: 0.97


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 2 loss: 0.0


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 2 test accuracy: 0.98


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 3 loss: 0.0


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 3 test accuracy: 0.98


In [185]:
n_epochs = 3
batch_size = 16

n_train_exx = 500
n_test_exx = 100

best_accuracy = 0.0
for epoch_n in tqdm(range(n_epochs)):
    epoch_loss = train_epoch(simplified_data_dict['en_ewt']['train'][: n_train_exx], batch_size,
                             tokeniser, model, device, optim)
    print(f'Epoch {epoch_n+1} loss:', round(epoch_loss, 2))
    epoch_test_accuracy = validate_epoch(
        simplified_data_dict['en_ewt']['test'][: n_test_exx], tokeniser, model, device)

    print(f'Epoch {epoch_n+1} test accuracy: {epoch_test_accuracy:.2f}')

  0%|          | 0/3 [00:00<?, ?it/s]

Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 1 loss: 0.0


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1 test accuracy: 0.98


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 2 loss: 0.0


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 2 test accuracy: 0.98


Train:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 3 loss: 0.0


Validate:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 3 test accuracy: 0.97
