In [1]:
import csv
import json
import random
import re
from collections import Counter, defaultdict

import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
import pandas as pd
import seaborn as sns

import torch

from datasets import (
    Dataset,
    DatasetDict,
    concatenate_datasets,
    load_dataset,
    load_from_disk,
)
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    DataCollatorForTokenClassification,
    Trainer,
    TrainingArguments,
    set_seed,
)

import evaluate

from nameparser import HumanName
from names_dataset import NameDataset, NameWrapper
from ethnicseer import EthnicClassifier
import nltk
from nltk.corpus import wordnet as wn

import pycountry_convert as pc
import pycountry
import pickle

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report, cohen_kappa_score

from transformers import BertTokenizerFast, BertForTokenClassification
from datasets import ClassLabel
from evaluate import load as load_metric
from sklearn.metrics import cohen_kappa_score
from itertools import combinations
import krippendorff

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "bert-base-cased"

In [3]:
conll_main = load_from_disk("./splits/conll_main")
conll_clean = load_from_disk("./splits/conll_clean")

ontonotes_main = load_from_disk("./splits/ontonotes_main")
ontonotes_clean = load_from_disk("./splits/ontonotes_clean")

In [4]:
from itertools import chain

flat_conll = list(chain.from_iterable(conll_main['tokens']))
flat_onto = list(chain.from_iterable(ontonotes_main['tokens']))
flat_conll_clean = list(chain.from_iterable(conll_clean['tokens']))
flat_onto_clean = list(chain.from_iterable(ontonotes_clean['tokens']))

# Load GPU

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


# Tokenisation & Alignment

In [6]:
ontonotes_id_to_label = {
    0: "O", 1: "B-CARDINAL", 2: "B-DATE", 3: "I-DATE", 4: "B-PERSON", 5: "I-PERSON",
    6: "B-NORP", 7: "B-GPE", 8: "I-GPE", 9: "B-LAW", 10: "I-LAW", 11: "B-ORG", 12: "I-ORG",
    13: "B-PERCENT", 14: "I-PERCENT", 15: "B-ORDINAL", 16: "B-MONEY", 17: "I-MONEY",
    18: "B-WORK_OF_ART", 19: "I-WORK_OF_ART", 20: "B-FAC", 21: "B-TIME", 22: "I-CARDINAL",
    23: "B-LOC", 24: "B-QUANTITY", 25: "I-QUANTITY", 26: "I-NORP", 27: "I-LOC",
    28: "B-PRODUCT", 29: "I-TIME", 30: "B-EVENT", 31: "I-EVENT", 32: "I-FAC",
    33: "B-LANGUAGE", 34: "I-PRODUCT", 35: "I-ORDINAL", 36: "I-LANGUAGE"
}

conll_label_to_id = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3,
                     'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}
id2label = {v: k for k, v in conll_label_to_id.items()}

ontonotes_to_conll_entity = {
    "PERSON": "PER", "ORG": "ORG", "GPE": "LOC", "LOC": "LOC",
    "NORP": "MISC", "FAC": "MISC", "EVENT": "MISC", "WORK_OF_ART": "MISC",
    "LAW": "MISC", "PRODUCT": "MISC", "LANGUAGE": "MISC",
    "DATE": None, "TIME": None, "PERCENT": None, "MONEY": None,
    "QUANTITY": None, "ORDINAL": None, "CARDINAL": None
}

In [7]:
def process_data(data_list):

    def process_single(data):
        word_ids = data['word_ids']
        predictions = data['predictions']
        gold = data['gold']
        tokenized_tokens = data['tokens']

        word_ids = [a for a in word_ids if a is not None]

        processed_predictions = []
        processed_gold = []

        current_word_id = None
        current_predictions = []
        current_gold = []

        for idx, word_id in enumerate(word_ids):
            if word_id != current_word_id:
                if current_predictions:
                    processed_predictions.append(
                        Counter(current_predictions).most_common(1)[0][0])
                    processed_gold.append(
                        Counter(current_gold).most_common(1)[0][0])

                current_word_id = word_id
                current_predictions = [predictions[idx]]
                current_gold = [gold[idx]]
            else:
                current_predictions.append(predictions[idx])
                current_gold.append(gold[idx])

        if current_predictions:
            processed_predictions.append(
                Counter(current_predictions).most_common(1)[0][0])
            processed_gold.append(
                Counter(current_gold).most_common(1)[0][0])

        return processed_predictions, processed_gold

    processed_predictions_list = []
    processed_gold_list = []

    for data in data_list:
        processed_predictions, processed_gold = process_single(data)
        processed_predictions_list.append(processed_predictions)
        processed_gold_list.append(processed_gold)

    return processed_predictions_list, processed_gold_list


def evaluate_predictions(p, test_data):
    predictions, labels, _ = p

    pred_indices = [np.argmax(p, axis=-1) for p in predictions]
    label_indices = labels

    pred_tags = [[id2label[p] for p, l in zip(p_seq, l_seq) if l != -100]
                 for p_seq, l_seq in zip(pred_indices, label_indices)]
    gold_tags = [[id2label[l] for l in l_seq if l != -100]
                 for l_seq in label_indices]

    def add_preds(example, idx):
        length = len(example['word_ids'])
        example['predictions'] = pred_tags[idx][:length]
        example['gold'] = gold_tags[idx][:length]
        return example

    test_data = test_data.map(add_preds, with_indices=True)

    length = len(test_data['predictions'][0])

    pred, gold = process_data(test_data)

    flat_pred = [label for seq in pred for label in seq]
    flat_gold = [label for seq in gold for label in seq]

    print(classification_report(flat_gold, flat_pred, zero_division=0))

    return (flat_pred, flat_gold)

In [8]:
tokenizer = BertTokenizerFast.from_pretrained(model_name)

label_list = ['O', 'B-PER', 'I-PER', 'B-ORG',
              'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']


def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"],
        truncation=True,
        is_split_into_words=True,
        padding=True,
        return_special_tokens_mask=True,
        return_offsets_mapping=True,
    )
    all_word_ids = []
    all_labels = []
    for i, labels in enumerate(examples["labels"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        all_word_ids.append(word_ids)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            else:
                label_ids.append(labels[word_idx])
            previous_word_idx = word_idx
        all_labels.append(label_ids)

    tokenized_inputs["labels"] = all_labels
    tokenized_inputs["word_ids"] = all_word_ids
    return tokenized_inputs

In [9]:
conll_main = conll_main.map(tokenize_and_align_labels, batched=True)
conll_clean = conll_clean.map(tokenize_and_align_labels, batched=True)
ontonotes_main = ontonotes_main.map(tokenize_and_align_labels, batched=True)
ontonotes_clean = ontonotes_clean.map(tokenize_and_align_labels, batched=True)

In [10]:
def caps_correct(flat_tokens, x, y):
    if len(flat_tokens) == len(x) == len(y):

        correct_caps = 0
        incorrect_caps = 0

        for token, pred, gold in zip(flat_tokens, x, y):
            if token and token[0].isupper():
                if pred == gold:
                    correct_caps += 1
                else:
                    if incorrect_caps<30:
                        print(token , pred , gold)
                    incorrect_caps += 1

        print(f"Capitalized Tokens - Correct Predictions: {correct_caps}")
        print(f"Capitalized Tokens - Incorrect Predictions: {incorrect_caps}")
    else:
        print('not equal')

In [11]:
metric = load_metric("seqeval")


def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    return metric.compute(predictions=true_predictions, references=true_labels)

In [12]:
train_data = ontonotes_main
test_data = conll_main

train_data_name = 'onto'

model_path = f"./saved_model/{model_name}_{train_data_name}"
mod = AutoModelForTokenClassification.from_pretrained(model_path)

data_collator = DataCollatorForTokenClassification(tokenizer)

training_args = TrainingArguments(
    output_dir=f"./output/random",
    eval_strategy="no",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    report_to="none",
    fp16=True,
    save_strategy="no",
)

trainer = Trainer(
    model=mod,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=test_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)

predictions = trainer.predict(test_data)

x, y = evaluate_predictions(predictions, test_data)

caps_correct(flat_conll, x, y)

  trainer = Trainer(


              precision    recall  f1-score   support

       B-LOC       0.80      0.80      0.80      8535
      B-MISC       0.76      0.68      0.71      4062
       B-ORG       0.75      0.50      0.60      7398
       B-PER       0.86      0.93      0.89      7975
       I-LOC       0.68      0.62      0.65      1356
      I-MISC       0.53      0.58      0.56      1380
       I-ORG       0.73      0.76      0.74      4251
       I-PER       0.89      0.99      0.94      5503
           O       0.98      0.99      0.99    201398

    accuracy                           0.95    241858
   macro avg       0.78      0.76      0.77    241858
weighted avg       0.95      0.95      0.95    241858

Hong B-LOC B-MISC
Kong-based O I-MISC
Wuxi I-ORG B-ORG
Jiangsu I-LOC B-LOC
ISLAMABAD O B-LOC
BARCELONA B-MISC B-ORG
ATLETICO O B-ORG
SUPERCUP O B-MISC
Portsmouth B-LOC B-ORG
Corser O B-PER
Vatican B-ORG B-LOC
Owen-Jones I-PER B-PER
Air O B-ORG
Cargo O I-ORG
Newsroom O I-ORG
The B-ORG O
Lebanese

In [13]:
train_data = conll_main
test_data = ontonotes_main

train_data_name = 'conll'

model_path = f"./saved_model/{model_name}_{train_data_name}"
mod = AutoModelForTokenClassification.from_pretrained(model_path)

data_collator = DataCollatorForTokenClassification(tokenizer)

training_args = TrainingArguments(
    output_dir=f"./output/random",
    eval_strategy="no",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    report_to="none",
    fp16=True,
    save_strategy="no",
)

trainer = Trainer(
    model=mod,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=test_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)

predictions = trainer.predict(test_data)

x, y = evaluate_predictions(predictions, test_data)

caps_correct(flat_onto, x, y)

  trainer = Trainer(


              precision    recall  f1-score   support

       B-LOC       0.83      0.89      0.86     17495
      B-MISC       0.64      0.73      0.68     10657
       B-ORG       0.63      0.68      0.65     13041
       B-PER       0.90      0.93      0.91     15547
       I-LOC       0.71      0.66      0.68      5367
      I-MISC       0.66      0.37      0.47      7305
       I-ORG       0.90      0.71      0.79     18313
       I-PER       0.95      0.88      0.91     11086
           O       0.99      0.99      0.99   1011383

    accuracy                           0.97   1110194
   macro avg       0.80      0.76      0.77   1110194
weighted avg       0.97      0.97      0.97   1110194

Iraq B-LOC I-MISC
A B-LOC I-LOC
Happy B-MISC I-MISC
D. B-ORG B-MISC
Wash. B-ORG B-LOC
Capital B-ORG I-ORG
Smirnoff B-MISC B-ORG
Jack B-PER B-ORG
Daniel I-PER I-ORG
NYSE B-MISC B-ORG
White B-LOC I-ORG
House I-LOC I-ORG
Kate O B-PER
Mr. B-MISC O
October B-MISC B-PER
Grinch O B-PER
A. B-PER B-ORG


In [14]:
train_data = conll_main
test_data = conll_clean

train_data_name = 'conll'

model_path = f"./saved_model/{model_name}_{train_data_name}"
mod = AutoModelForTokenClassification.from_pretrained(model_path)

data_collator = DataCollatorForTokenClassification(tokenizer)

training_args = TrainingArguments(
    output_dir=f"./output/random",
    eval_strategy="no",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    report_to="none",
    fp16=True,
    save_strategy="no",
)

trainer = Trainer(
    model=mod,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=test_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)

predictions = trainer.predict(test_data)

x, y = evaluate_predictions(predictions, test_data)

caps_correct(flat_conll_clean, x, y)

  trainer = Trainer(


              precision    recall  f1-score   support

       B-LOC       0.96      0.97      0.96      2110
      B-MISC       0.92      0.91      0.92      1000
       B-ORG       0.94      0.95      0.95      1925
       B-PER       0.98      0.98      0.98      2084
       I-LOC       0.91      0.90      0.91       315
      I-MISC       0.86      0.85      0.86       337
       I-ORG       0.93      0.96      0.94      1039
       I-PER       1.00      0.99      0.99      1488
           O       1.00      1.00      1.00     49262

    accuracy                           0.99     59560
   macro avg       0.95      0.95      0.95     59560
weighted avg       0.99      0.99      0.99     59560

MRI O B-MISC
Santander B-ORG B-LOC
BAYERN B-PER B-ORG
BUNDESLIGA B-LOC B-MISC
Nicol B-ORG B-PER
Eyles B-ORG B-PER
Leicester B-ORG B-LOC
New B-MISC O
AHOLD B-LOC B-ORG
Eurograde B-MISC O
Far B-LOC O
North I-LOC O
Czech B-LOC B-ORG
Maronite B-MISC B-ORG
UK-US B-LOC B-MISC
BTPs B-ORG O
CAMPESE B-L

In [15]:
train_data = ontonotes_main
test_data = ontonotes_clean

train_data_name = 'onto'

model_path = f"./saved_model/{model_name}_{train_data_name}"
mod = AutoModelForTokenClassification.from_pretrained(model_path)

data_collator = DataCollatorForTokenClassification(tokenizer)

training_args = TrainingArguments(
    output_dir=f"./output/random",
    eval_strategy="no",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    report_to="none",
    fp16=True,
    save_strategy="no",
)

trainer = Trainer(
    model=mod,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=test_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)

predictions = trainer.predict(test_data)

x, y = evaluate_predictions(predictions, test_data)

caps_correct(flat_onto_clean, x, y)

  trainer = Trainer(


              precision    recall  f1-score   support

       B-LOC       0.97      0.97      0.97      4315
      B-MISC       0.91      0.91      0.91      2722
       B-ORG       0.92      0.94      0.93      3314
       B-PER       0.96      0.97      0.96      3890
       I-LOC       0.93      0.92      0.93      1258
      I-MISC       0.87      0.84      0.85      2067
       I-ORG       0.94      0.96      0.95      4675
       I-PER       0.97      0.97      0.97      2868
           O       1.00      1.00      1.00    253580

    accuracy                           0.99    278689
   macro avg       0.94      0.94      0.94    278689
weighted avg       0.99      0.99      0.99    278689

Constitutional O B-MISC
CCTV O B-ORG
Albo B-ORG B-PER
THE B-ORG O
CHILDREN B-ORG O
Child O B-MISC
Game O I-MISC
I O B-PER
Arab B-MISC I-PER
Contra B-MISC B-ORG
Pan B-ORG B-PER
American I-ORG I-PER
Thanksgiving B-MISC O
Iowa B-LOC I-LOC
WW2 B-ORG B-MISC
PhD B-MISC O
Suez I-MISC O
Canal I-MISC O
