## parse dataset

In [1]:
from datasets import load_dataset, load_metric # Import dataset import function for hugging face
dataset = load_dataset("surrey-nlp/PLOD-CW") # import the coursework dataset from

from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification, pipeline
from transformers import BatchEncoding
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

import evaluate
seqeval = evaluate.load("seqeval")
metric = load_metric("seqeval")

import numpy as np
import os

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  metric = load_metric("seqeval")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [2]:
task = "ner" # Should be one of "ner", "pos" or "chunk"
model_checkpoint = "distilbert-base-uncased"
batch_size = 16

In [3]:
v = dataset["train"][286:287]["tokens"][0]
input_tokens = tokenizer(v, is_split_into_words=True)
input_ids = input_tokens["input_ids"]
id_tokens = []
for token in input_ids:
    id_tokens.append(tokenizer.convert_ids_to_tokens(token))
print(v)
print(id_tokens)
print(input_ids)
print(len(input_ids))

['(', 'EGF', ',', 'epidermal', 'growth', 'factor', ';', 'TGF', ',', 'transforming', 'growth', 'factor', ';', 'BTC', ',', 'betacellulin', ';', 'HB', '-', 'EGF', ',', 'heparin', '-', 'binding', 'epidermal', 'growth', 'factor', '(', 'EGF)-like', 'growth', 'factor', ';', 'EREG', ',', 'epiregulin', ';', 'NRG1', ',', 'neuregulin-1', ';', 'NRG2', ',', 'neuregulin-2', ';', 'NRG3', ',', 'neuregulin-3', ';', 'NRG4', ',', 'neuregulin-4', ';', 'PLCÎ³', ',', 'phospholipase', 'C', 'type', 'gamma', ';', 'CAMK2B', ',', 'calcium', '/', 'calmodulin', 'dependent', 'protein', 'kinase', ';', 'PRKCB', ',', 'Protein', 'kinase', 'C', '-', 'beta', ';', 'STAT5', ',', 'Signal', 'transducer', 'and', 'activator', 'of', 'transcription', '5', ';', 'src', ',', 'Rous', 'sarcoma', 'virus', 'gene', ';', 'CRK', ',', 'C', 'T10', 'regulator', 'of', 'a', 'tyrosine', 'kinase', ';', 'NCL', ',', 'NCK', 'Adaptor', 'Protein', '2', ';', 'PTK2', ',', 'PTK2', 'protein', 'tyrosine', 'kinase', '2', ';', 'ABL2', ',', 'V', '-', 'Abl', 

In [4]:
train_dict = dataset["train"]
train_tokens = train_dict["tokens"]
train_pos_tags = train_dict["pos_tags"]
train_ner_tags = train_dict["ner_tags"]

validation_dict = dataset["validation"]
validation_tokens = validation_dict["tokens"]
validation_pos_tags = validation_dict["pos_tags"]
validation_ner_tags = validation_dict["ner_tags"]

test_dict = dataset["test"]
test_tokens = test_dict["tokens"]
test_pos_tags = test_dict["pos_tags"]
test_ner_tags = test_dict["ner_tags"]

def data_to_lower(data:list[list[str]]) -> list[list[str]]:
    return [[token.lower() for token in tokens] for tokens in data]

train_tokens = data_to_lower(train_tokens)
validation_tokens = data_to_lower(validation_tokens)
test_tokens = data_to_lower(test_tokens)

class DataItem:
    def __init__(self, tokens, pos, ner, idx=0):
        self.idx=idx
        self.tokens:list[str] = tokens
        self.pos:list[str] = pos
        self.ner:list = ner
        self.tokenised_inputs:BatchEncoding = tokenizer(self.tokens, is_split_into_words=True) # also contains attention mask!

    def get_as_tuple(self) -> tuple:
        return (self.tokens, self.pos, self.ner)
    
    def get_as_tuple_list(self) -> list[tuple]:
        tuple_list = []
        for idx in range(len(self.tokens)-1):
            tuple_list.append((self.tokens[idx], self.pos[idx], self.ner[idx]))
        return tuple_list
    
    def ner_label2idx(self, label2idx_dict):
        if not isinstance(self.ner[0], str):
            print("WARNING - NER not listed as labels! NER Type: ",type(self.ner[0]),", Exiting...")
            return
        for idx, ner in enumerate(self.ner):
            ner[idx] = label2idx_dict[ner]
    
    def ner_idx2label(self, idx2label_dict):
        if not isinstance(self.ner[0], int):
            print("WARNING - NER not listed as indecies! Exiting...")
            return
        for idx, ner in enumerate(self.ner):
            ner[idx] = idx2label_dict[ner]

class DataCollection:
    def __init__(self, data_collection:list[DataItem], max_token_length=512):
        self.max_token_length = max_token_length
        self.data_collection:list[DataItem] = data_collection
        self.unique_tags = self.get_unique_tags()
        self.item_embeddings:dict = self.create_item_embeddings(self.unique_tags)
        self.reverse_embeddings:dict = {v:k for k,v in self.item_embeddings.items()}

    def get_token_list(self) -> list[list[str]]:
        return [data_item.tokens for data_item in self.data_collection]

    def get_pos_list(self) -> list[list[str]]:
        return [data_item.pos for data_item in self.data_collection]

    def get_ner_list(self) -> list[list[str]]:
        return [data_item.ner for data_item in self.data_collection]
    
    def get_ner_idx_list(self) -> list[list[str]]:
        ner_idx_list_collection = []
        for data_item in self.data_collection:
            ner_idx_list = []
            for ner_tag in data_item.ner:
                ner_idx_list.append(self.item_embeddings[ner_tag])
            ner_idx_list_collection.append(ner_idx_list)
        return ner_idx_list_collection

    def get_unique_tags(self) -> list[str]:
        unique_list = []
        ner_tags_list:list = self.get_ner_list()
        for ner_list in ner_tags_list:
            for ner in ner_list:
                if ner not in unique_list:
                    unique_list.append(ner)
        return unique_list
    
    def create_item_embeddings(self, tags:list[str]) -> dict:
        return {label:idx for idx, label in enumerate(tags)}
    
    def get_invalid_token_lengths(self) -> list[int]:
        invalid_lengths = []
        for idx, data_item in enumerate(self.data_collection):
            if len(data_item.tokenised_inputs["input_ids"]) >= self.max_token_length:
                invalid_lengths.append(idx)
                print("Data item idx ",idx," has tokens longer than ",self.max_token_length)
        return invalid_lengths

    def remove_invalid_token_length_items(self) -> None:
        invalid_lengths = self.get_invalid_token_lengths()
        for index in sorted(invalid_lengths, reverse=True):
            del self.data_collection[index]

train_data:list[DataItem] = []
for idx in range(len(train_tokens)):
    train_data.append(DataItem(train_tokens[idx], train_pos_tags[idx], train_ner_tags[idx], idx))
train_collection:DataCollection = DataCollection(train_data)

validation_data:list[DataItem] = []
for idx in range(len(validation_tokens)):
    validation_data.append(DataItem(validation_tokens[idx], validation_pos_tags[idx], validation_ner_tags[idx], idx))
validation_collection:DataCollection = DataCollection(validation_data)

test_data:list[DataItem] = []
for idx in range(len(test_tokens)):
    test_data.append(DataItem(test_tokens[idx], test_pos_tags[idx], test_ner_tags[idx], idx))
test_collection:DataCollection = DataCollection(test_data)

In [6]:
train_collection.remove_invalid_token_length_items()
validation_collection.remove_invalid_token_length_items()
test_collection.remove_invalid_token_length_items()

In [7]:
# Following https://huggingface.co/docs/transformers/main/en/tasks/token_classification

# More on https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/token_classification.ipynb#scrollTo=IjyhFKjlEP_B

## Tokenization Example

Adds extra start and end tags (CLS and SEP), as well potentially splits one word into 2. Thus have to realign indecies.

We also have to assign -100 to CLS and SEP so they are ignored by PyTorch loss function (CrossEntropyLoss)

Only label first token of a word, add -100 for subtokens of the same word

In [8]:
example = dataset["train"][0]
tokenized_input = tokenizer(example["tokens"], is_split_into_words=True)
tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])
print(tokens)
print(tokenized_input)

['[CLS]', 'for', 'this', 'purpose', 'the', 'gothenburg', 'young', 'persons', 'empowerment', 'scale', '(', 'g', '##ype', '##s', ')', 'was', 'developed', '.', '[SEP]']
{'input_ids': [101, 2005, 2023, 3800, 1996, 22836, 2402, 5381, 23011, 4094, 1006, 1043, 18863, 2015, 1007, 2001, 2764, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


In [8]:
def tokenize_and_align_labels(data_collection:DataCollection):
    tokenized_inputs = tokenizer(data_collection.get_token_list(), truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(data_collection.get_ner_idx_list()):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

tokenized_train = tokenize_and_align_labels(train_collection)
tokenized_validation = tokenize_and_align_labels(validation_collection)
tokenized_test = tokenize_and_align_labels(test_collection)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [9]:
label_list = train_collection.unique_tags
labels = train_collection.unique_tags


def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

id2label = train_collection.reverse_embeddings
label2id = train_collection.item_embeddings

In [10]:
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(labels), id2label=id2label, label2id=label2id)

Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
def turn_dict_to_list_of_dict(d):
    new_list = []

    for labels, inputs in zip(d["labels"], d["input_ids"]):
        entry = {"input_ids": inputs, "labels": labels}
        new_list.append(entry)

    return new_list

tokenised_train = turn_dict_to_list_of_dict(tokenized_train)
tokenised_val = turn_dict_to_list_of_dict(tokenized_validation)
tokenised_test = turn_dict_to_list_of_dict(tokenized_test)

In [12]:
model_output_dir = "distilbert_model"

training_args = TrainingArguments(
    output_dir=model_output_dir,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3, # number of epochs to train
    weight_decay=0.01, # The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights
    evaluation_strategy="epoch",
    save_strategy="epoch", # can save by epoch, steps or not at all
    save_total_limit=1, # how many checkpoints to keep before overriding (set to 1, so latest checkpoint is only kept)!
    load_best_model_at_end=True,
    report_to=['none'], # REQUIRED because otherwise keeps asking to log into "wandb"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenised_train,
    eval_dataset=tokenised_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False)


  0%|          | 0/3350 [00:00<?, ?it/s]

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


  0%|          | 0/10 [00:00<?, ?it/s]

Checkpoint destination directory distilbert_model\checkpoint-67 already exists and is non-empty.Saving will proceed but saved results may be invalid.


{'eval_loss': 0.5762450098991394, 'eval_precision': 0.872043918918919, 'eval_recall': 0.8261652330466093, 'eval_f1': 0.8484848484848484, 'eval_accuracy': 0.8296, 'eval_runtime': 0.1969, 'eval_samples_per_second': 777.179, 'eval_steps_per_second': 50.796, 'epoch': 1.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 0.7484994530677795, 'eval_precision': 0.8774989366227137, 'eval_recall': 0.8253650730146029, 'eval_f1': 0.8506339552623441, 'eval_accuracy': 0.823, 'eval_runtime': 0.1864, 'eval_samples_per_second': 821.002, 'eval_steps_per_second': 53.66, 'epoch': 2.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 0.8061400651931763, 'eval_precision': 0.8869174861051732, 'eval_recall': 0.8299659931986397, 'eval_f1': 0.8574971582101892, 'eval_accuracy': 0.8292, 'eval_runtime': 0.1813, 'eval_samples_per_second': 843.698, 'eval_steps_per_second': 55.144, 'epoch': 3.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 0.8855600953102112, 'eval_precision': 0.879940183721427, 'eval_recall': 0.8239647929585917, 'eval_f1': 0.8510330578512396, 'eval_accuracy': 0.8224, 'eval_runtime': 0.1756, 'eval_samples_per_second': 871.096, 'eval_steps_per_second': 56.934, 'epoch': 4.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 0.877755880355835, 'eval_precision': 0.887350936967632, 'eval_recall': 0.8335667133426685, 'eval_f1': 0.8596183599793709, 'eval_accuracy': 0.8324, 'eval_runtime': 0.1774, 'eval_samples_per_second': 862.404, 'eval_steps_per_second': 56.366, 'epoch': 5.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 0.9753626585006714, 'eval_precision': 0.886648270006368, 'eval_recall': 0.8355671134226845, 'eval_f1': 0.8603501544799176, 'eval_accuracy': 0.834, 'eval_runtime': 0.1795, 'eval_samples_per_second': 852.227, 'eval_steps_per_second': 55.701, 'epoch': 6.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.0449007749557495, 'eval_precision': 0.8811817597944765, 'eval_recall': 0.8233646729345869, 'eval_f1': 0.85129265770424, 'eval_accuracy': 0.8214, 'eval_runtime': 0.1999, 'eval_samples_per_second': 765.232, 'eval_steps_per_second': 50.015, 'epoch': 7.0}
{'loss': 0.1602, 'learning_rate': 1.701492537313433e-05, 'epoch': 7.46}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.0407954454421997, 'eval_precision': 0.8890308839190628, 'eval_recall': 0.8349669933986797, 'eval_f1': 0.8611512275634413, 'eval_accuracy': 0.8332, 'eval_runtime': 0.1804, 'eval_samples_per_second': 848.121, 'eval_steps_per_second': 55.433, 'epoch': 8.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.073683261871338, 'eval_precision': 0.8860786397449522, 'eval_recall': 0.8339667933586717, 'eval_f1': 0.8592333058532563, 'eval_accuracy': 0.8324, 'eval_runtime': 0.1835, 'eval_samples_per_second': 833.665, 'eval_steps_per_second': 54.488, 'epoch': 9.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.1222586631774902, 'eval_precision': 0.8875638841567292, 'eval_recall': 0.8337667533506702, 'eval_f1': 0.8598246518824137, 'eval_accuracy': 0.8318, 'eval_runtime': 0.1786, 'eval_samples_per_second': 856.647, 'eval_steps_per_second': 55.99, 'epoch': 10.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.1650903224945068, 'eval_precision': 0.8838887709615793, 'eval_recall': 0.8329665933186637, 'eval_f1': 0.8576725025746652, 'eval_accuracy': 0.8316, 'eval_runtime': 0.1906, 'eval_samples_per_second': 802.525, 'eval_steps_per_second': 52.453, 'epoch': 11.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.1740355491638184, 'eval_precision': 0.8847299021692897, 'eval_recall': 0.8321664332866573, 'eval_f1': 0.8576435419028966, 'eval_accuracy': 0.83, 'eval_runtime': 0.1816, 'eval_samples_per_second': 842.631, 'eval_steps_per_second': 55.074, 'epoch': 12.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.2134870290756226, 'eval_precision': 0.884974533106961, 'eval_recall': 0.8341668333666733, 'eval_f1': 0.8588198949644734, 'eval_accuracy': 0.8322, 'eval_runtime': 0.1854, 'eval_samples_per_second': 825.069, 'eval_steps_per_second': 53.926, 'epoch': 13.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.2318637371063232, 'eval_precision': 0.8857569933803118, 'eval_recall': 0.8297659531906382, 'eval_f1': 0.8568477587275356, 'eval_accuracy': 0.8284, 'eval_runtime': 0.19, 'eval_samples_per_second': 805.159, 'eval_steps_per_second': 52.625, 'epoch': 14.0}
{'loss': 0.0209, 'learning_rate': 1.4029850746268658e-05, 'epoch': 14.93}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.2416068315505981, 'eval_precision': 0.8864313058273076, 'eval_recall': 0.8337667533506702, 'eval_f1': 0.8592928564065561, 'eval_accuracy': 0.831, 'eval_runtime': 0.1936, 'eval_samples_per_second': 790.091, 'eval_steps_per_second': 51.64, 'epoch': 15.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.278630256652832, 'eval_precision': 0.8848536637470626, 'eval_recall': 0.8285657131426285, 'eval_f1': 0.8557851239669423, 'eval_accuracy': 0.8268, 'eval_runtime': 0.181, 'eval_samples_per_second': 845.077, 'eval_steps_per_second': 55.234, 'epoch': 16.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.273123025894165, 'eval_precision': 0.8851796725494365, 'eval_recall': 0.8327665533106622, 'eval_f1': 0.8581735724592868, 'eval_accuracy': 0.8308, 'eval_runtime': 0.1832, 'eval_samples_per_second': 835.254, 'eval_steps_per_second': 54.592, 'epoch': 17.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.3050628900527954, 'eval_precision': 0.8875689435723377, 'eval_recall': 0.8369673934786958, 'eval_f1': 0.8615257901781119, 'eval_accuracy': 0.8342, 'eval_runtime': 0.1837, 'eval_samples_per_second': 832.74, 'eval_steps_per_second': 54.427, 'epoch': 18.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.317600965499878, 'eval_precision': 0.8857384680490902, 'eval_recall': 0.837367473494699, 'eval_f1': 0.8608740359897172, 'eval_accuracy': 0.8354, 'eval_runtime': 0.1921, 'eval_samples_per_second': 796.541, 'eval_steps_per_second': 52.062, 'epoch': 19.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.347732424736023, 'eval_precision': 0.8874707509040629, 'eval_recall': 0.8345669133826765, 'eval_f1': 0.8602061855670103, 'eval_accuracy': 0.8316, 'eval_runtime': 0.1847, 'eval_samples_per_second': 828.251, 'eval_steps_per_second': 54.134, 'epoch': 20.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.3546124696731567, 'eval_precision': 0.8855075547988934, 'eval_recall': 0.832366473294659, 'eval_f1': 0.8581150752732521, 'eval_accuracy': 0.83, 'eval_runtime': 0.2056, 'eval_samples_per_second': 744.008, 'eval_steps_per_second': 48.628, 'epoch': 21.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.3885867595672607, 'eval_precision': 0.8871413390010627, 'eval_recall': 0.8349669933986797, 'eval_f1': 0.8602638087386645, 'eval_accuracy': 0.8316, 'eval_runtime': 0.1896, 'eval_samples_per_second': 806.973, 'eval_steps_per_second': 52.743, 'epoch': 22.0}
{'loss': 0.0057, 'learning_rate': 1.1044776119402986e-05, 'epoch': 22.39}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.3998149633407593, 'eval_precision': 0.8864313058273076, 'eval_recall': 0.8337667533506702, 'eval_f1': 0.8592928564065561, 'eval_accuracy': 0.831, 'eval_runtime': 0.1938, 'eval_samples_per_second': 789.331, 'eval_steps_per_second': 51.59, 'epoch': 23.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.4089477062225342, 'eval_precision': 0.8851623859053279, 'eval_recall': 0.8341668333666733, 'eval_f1': 0.858908341915551, 'eval_accuracy': 0.831, 'eval_runtime': 0.2054, 'eval_samples_per_second': 745.067, 'eval_steps_per_second': 48.697, 'epoch': 24.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.4438952207565308, 'eval_precision': 0.8841723549488054, 'eval_recall': 0.8291658331666333, 'eval_f1': 0.855786105089295, 'eval_accuracy': 0.827, 'eval_runtime': 0.1947, 'eval_samples_per_second': 785.939, 'eval_steps_per_second': 51.369, 'epoch': 25.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.4539906978607178, 'eval_precision': 0.8846317581949766, 'eval_recall': 0.831366273254651, 'eval_f1': 0.8571723213364959, 'eval_accuracy': 0.8282, 'eval_runtime': 0.1958, 'eval_samples_per_second': 781.241, 'eval_steps_per_second': 51.062, 'epoch': 26.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.4531972408294678, 'eval_precision': 0.8865429729153338, 'eval_recall': 0.8315663132626525, 'eval_f1': 0.8581750619322874, 'eval_accuracy': 0.8292, 'eval_runtime': 0.1917, 'eval_samples_per_second': 798.01, 'eval_steps_per_second': 52.158, 'epoch': 27.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.4620660543441772, 'eval_precision': 0.8850746268656716, 'eval_recall': 0.830366073214643, 'eval_f1': 0.8568479719269273, 'eval_accuracy': 0.828, 'eval_runtime': 0.1873, 'eval_samples_per_second': 816.839, 'eval_steps_per_second': 53.388, 'epoch': 28.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.49458646774292, 'eval_precision': 0.8837855159154027, 'eval_recall': 0.8275655131026205, 'eval_f1': 0.8547520661157025, 'eval_accuracy': 0.8256, 'eval_runtime': 0.2, 'eval_samples_per_second': 764.926, 'eval_steps_per_second': 49.995, 'epoch': 29.0}
{'loss': 0.0026, 'learning_rate': 8.059701492537314e-06, 'epoch': 29.85}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.490684151649475, 'eval_precision': 0.8845580404685836, 'eval_recall': 0.8307661532306462, 'eval_f1': 0.8568186507117804, 'eval_accuracy': 0.8288, 'eval_runtime': 0.197, 'eval_samples_per_second': 776.508, 'eval_steps_per_second': 50.752, 'epoch': 30.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.478588342666626, 'eval_precision': 0.8858116744780571, 'eval_recall': 0.8317663532706542, 'eval_f1': 0.8579387186629527, 'eval_accuracy': 0.8294, 'eval_runtime': 0.2035, 'eval_samples_per_second': 751.881, 'eval_steps_per_second': 49.143, 'epoch': 31.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.4939507246017456, 'eval_precision': 0.8842217484008529, 'eval_recall': 0.8295659131826365, 'eval_f1': 0.8560222933223244, 'eval_accuracy': 0.827, 'eval_runtime': 0.1937, 'eval_samples_per_second': 789.726, 'eval_steps_per_second': 51.616, 'epoch': 32.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.4894871711730957, 'eval_precision': 0.8864023870417732, 'eval_recall': 0.8319663932786557, 'eval_f1': 0.858322154576411, 'eval_accuracy': 0.829, 'eval_runtime': 0.2057, 'eval_samples_per_second': 743.733, 'eval_steps_per_second': 48.61, 'epoch': 33.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.5210912227630615, 'eval_precision': 0.8850991259859305, 'eval_recall': 0.8305661132226445, 'eval_f1': 0.8569659442724458, 'eval_accuracy': 0.8274, 'eval_runtime': 0.1936, 'eval_samples_per_second': 790.42, 'eval_steps_per_second': 51.661, 'epoch': 34.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.5031026601791382, 'eval_precision': 0.8863587997446265, 'eval_recall': 0.8331666333266653, 'eval_f1': 0.8589399876263146, 'eval_accuracy': 0.8302, 'eval_runtime': 0.1937, 'eval_samples_per_second': 789.789, 'eval_steps_per_second': 51.62, 'epoch': 35.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.5021260976791382, 'eval_precision': 0.8870658427445131, 'eval_recall': 0.8327665533106622, 'eval_f1': 0.8590590177465952, 'eval_accuracy': 0.8308, 'eval_runtime': 0.2015, 'eval_samples_per_second': 759.473, 'eval_steps_per_second': 49.639, 'epoch': 36.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.507655382156372, 'eval_precision': 0.887137989778535, 'eval_recall': 0.833366673334667, 'eval_f1': 0.859412068076328, 'eval_accuracy': 0.8312, 'eval_runtime': 0.2, 'eval_samples_per_second': 765.11, 'eval_steps_per_second': 50.007, 'epoch': 37.0}
{'loss': 0.0016, 'learning_rate': 5.074626865671642e-06, 'epoch': 37.31}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.5143983364105225, 'eval_precision': 0.8869250425894378, 'eval_recall': 0.8331666333266653, 'eval_f1': 0.8592057761732852, 'eval_accuracy': 0.831, 'eval_runtime': 0.1937, 'eval_samples_per_second': 789.697, 'eval_steps_per_second': 51.614, 'epoch': 38.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.5133357048034668, 'eval_precision': 0.885835995740149, 'eval_recall': 0.8319663932786557, 'eval_f1': 0.858056529812255, 'eval_accuracy': 0.8306, 'eval_runtime': 0.2045, 'eval_samples_per_second': 748.22, 'eval_steps_per_second': 48.903, 'epoch': 39.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.517109990119934, 'eval_precision': 0.886615515771526, 'eval_recall': 0.8321664332866573, 'eval_f1': 0.858528531627283, 'eval_accuracy': 0.8304, 'eval_runtime': 0.2, 'eval_samples_per_second': 764.844, 'eval_steps_per_second': 49.99, 'epoch': 40.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.5299570560455322, 'eval_precision': 0.8853543979504697, 'eval_recall': 0.8295659131826365, 'eval_f1': 0.856552721264071, 'eval_accuracy': 0.8278, 'eval_runtime': 0.1975, 'eval_samples_per_second': 774.705, 'eval_steps_per_second': 50.634, 'epoch': 41.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.5329430103302002, 'eval_precision': 0.8862862010221465, 'eval_recall': 0.8325665133026605, 'eval_f1': 0.8585869004641569, 'eval_accuracy': 0.8302, 'eval_runtime': 0.1857, 'eval_samples_per_second': 824.033, 'eval_steps_per_second': 53.858, 'epoch': 42.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.5291471481323242, 'eval_precision': 0.8866396761133604, 'eval_recall': 0.832366473294659, 'eval_f1': 0.8586463062319438, 'eval_accuracy': 0.8304, 'eval_runtime': 0.202, 'eval_samples_per_second': 757.349, 'eval_steps_per_second': 49.5, 'epoch': 43.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.5256043672561646, 'eval_precision': 0.8866879659211928, 'eval_recall': 0.8327665533106622, 'eval_f1': 0.8588817825459047, 'eval_accuracy': 0.831, 'eval_runtime': 0.1987, 'eval_samples_per_second': 770.015, 'eval_steps_per_second': 50.328, 'epoch': 44.0}
{'loss': 0.001, 'learning_rate': 2.08955223880597e-06, 'epoch': 44.78}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.5262418985366821, 'eval_precision': 0.8869250425894378, 'eval_recall': 0.8331666333266653, 'eval_f1': 0.8592057761732852, 'eval_accuracy': 0.8314, 'eval_runtime': 0.1936, 'eval_samples_per_second': 790.132, 'eval_steps_per_second': 51.643, 'epoch': 45.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.5400372743606567, 'eval_precision': 0.886378170965679, 'eval_recall': 0.8317663532706542, 'eval_f1': 0.8582043343653252, 'eval_accuracy': 0.83, 'eval_runtime': 0.2081, 'eval_samples_per_second': 735.059, 'eval_steps_per_second': 48.043, 'epoch': 46.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.5360428094863892, 'eval_precision': 0.8864507882403068, 'eval_recall': 0.832366473294659, 'eval_f1': 0.8585577220674714, 'eval_accuracy': 0.8306, 'eval_runtime': 0.1963, 'eval_samples_per_second': 779.587, 'eval_steps_per_second': 50.953, 'epoch': 47.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.5318161249160767, 'eval_precision': 0.8860732538330494, 'eval_recall': 0.832366473294659, 'eval_f1': 0.858380608561114, 'eval_accuracy': 0.8308, 'eval_runtime': 0.1961, 'eval_samples_per_second': 780.356, 'eval_steps_per_second': 51.004, 'epoch': 48.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.5407421588897705, 'eval_precision': 0.8860004261666311, 'eval_recall': 0.8317663532706542, 'eval_f1': 0.8580272389599669, 'eval_accuracy': 0.83, 'eval_runtime': 0.2007, 'eval_samples_per_second': 762.168, 'eval_steps_per_second': 49.815, 'epoch': 49.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.5382275581359863, 'eval_precision': 0.8864507882403068, 'eval_recall': 0.832366473294659, 'eval_f1': 0.8585577220674714, 'eval_accuracy': 0.8308, 'eval_runtime': 0.2138, 'eval_samples_per_second': 715.614, 'eval_steps_per_second': 46.772, 'epoch': 50.0}
{'train_runtime': 213.2798, 'train_samples_per_second': 251.079, 'train_steps_per_second': 15.707, 'train_loss': 0.028731520594055975, 'epoch': 50.0}


TrainOutput(global_step=3350, training_loss=0.028731520594055975, metrics={'train_runtime': 213.2798, 'train_samples_per_second': 251.079, 'train_steps_per_second': 15.707, 'train_loss': 0.028731520594055975, 'epoch': 50.0})

In [13]:
trainer.evaluate()

  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 0.5762450098991394,
 'eval_precision': 0.872043918918919,
 'eval_recall': 0.8261652330466093,
 'eval_f1': 0.8484848484848484,
 'eval_accuracy': 0.8296,
 'eval_runtime': 0.1827,
 'eval_samples_per_second': 837.596,
 'eval_steps_per_second': 54.745,
 'epoch': 50.0}

In [14]:
# # Prepare the test data for evaluation in the same format as the training data

# predictions, labels, _ = trainer.predict(tokenised_test)
# predictions = np.argmax(predictions, axis=2)

# # label_list = test_collection.get_ner_idx_list()
# label_list = train_collection.unique_tags

# # Remove the predictions for the [CLS] and [SEP] tokens 
# true_predictions = [
#     [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
#     for prediction, label in zip(predictions, labels)
# ]
# true_labels = [
#     [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
#     for prediction, label in zip(predictions, labels)
# ]

# # Compute multiple metrics on the test restuls
# results = metric.compute(predictions=true_predictions, references=true_labels)
# results
predictions, labels, _ = trainer.predict(tokenised_val) # tokenized validation used instead of validation dataset (as recommended for vectorised)
predictions = np.argmax(predictions, axis=2)

# Remove ignored index (special tokens)
true_predictions = [
    [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
true_labels = [
    [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]

results = metric.compute(predictions=true_predictions, references=true_labels)
results

  0%|          | 0/8 [00:00<?, ?it/s]

{'AC': {'precision': 0.7269372693726938,
  'recall': 0.7490494296577946,
  'f1': 0.7378277153558053,
  'number': 263},
 'LF': {'precision': 0.4508670520231214,
  'recall': 0.5131578947368421,
  'f1': 0.48000000000000004,
  'number': 152},
 'O': {'precision': 0.9624645892351275,
  'recall': 0.9568176484393335,
  'f1': 0.9596328115805578,
  'number': 4261},
 'overall_precision': 0.9299145299145299,
 'overall_recall': 0.93071000855432,
 'overall_f1': 0.9303120991876871,
 'overall_accuracy': 0.9224}

In [15]:
for data_item in validation_collection.data_collection:
    t = data_item.tokens
    print(len(t))
    print(t)
    tok = tokenizer(t, is_split_into_words=True)
    print(len(tok["input_ids"]))
    print(tok) # tokenizer adds the CLS and SEP! So should be safe to rip out first and last character
    toks = []
    for to in tok["input_ids"]:
        toks.append(tokenizer.convert_ids_to_tokens(to))
    print(len(toks))
    print(toks)
    break

51
['=', 'manual', 'ability', 'classification', 'system', ';', 'quest', '=', 'quest', '-', 'quality', 'of', 'upper', 'extremity', 'skills', 'test', ';', 'cont', '=', 'control', ';', 'm', '=', 'male', ',', 'f', '=', 'female', ',', 'v', '=', 'verbal', ',', 'nonv', '=', 'non', '-', 'verbal', ',', '|quad', '=', 'quadriplegia', ',', 'di', '=', 'diplegia', ',', 'hemi', '=', 'hemiplegia', '.']
67
{'input_ids': [101, 1027, 6410, 3754, 5579, 2291, 1025, 8795, 1027, 8795, 1011, 3737, 1997, 3356, 4654, 7913, 16383, 4813, 3231, 1025, 9530, 2102, 1027, 2491, 1025, 1049, 1027, 3287, 1010, 1042, 1027, 2931, 1010, 1058, 1027, 12064, 1010, 2512, 2615, 1027, 2512, 1011, 12064, 1010, 1064, 17718, 1027, 17718, 29443, 23115, 2401, 1010, 4487, 1027, 16510, 23115, 2401, 1010, 19610, 2072, 1027, 19610, 11514, 23115, 2401, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [16]:
id2label[-100] = ""
validation_labels =  [label for label in tokenized_validation["labels"][0]]
print(validation_labels)
validation_label_ids = [id2label[label] for label in tokenized_validation["labels"][0]]
print(validation_label_ids)
print(len(validation_label_ids))

[-100, 0, 1, 2, 2, 2, 0, 3, 0, 1, 2, 2, 2, 2, 2, -100, -100, 2, 2, 0, 0, -100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -100, 0, 0, 0, 0, 0, 0, -100, 0, 0, -100, -100, -100, 0, 0, 0, 0, -100, -100, 0, 0, -100, 0, 0, -100, -100, -100, 0, -100]
['', 'B-O', 'B-LF', 'I-LF', 'I-LF', 'I-LF', 'B-O', 'B-AC', 'B-O', 'B-LF', 'I-LF', 'I-LF', 'I-LF', 'I-LF', 'I-LF', '', '', 'I-LF', 'I-LF', 'B-O', 'B-O', '', 'B-O', 'B-O', 'B-O', 'B-O', 'B-O', 'B-O', 'B-O', 'B-O', 'B-O', 'B-O', 'B-O', 'B-O', 'B-O', 'B-O', 'B-O', 'B-O', '', 'B-O', 'B-O', 'B-O', 'B-O', 'B-O', 'B-O', '', 'B-O', 'B-O', '', '', '', 'B-O', 'B-O', 'B-O', 'B-O', '', '', 'B-O', 'B-O', '', 'B-O', 'B-O', '', '', '', 'B-O', '']
67


In [36]:
idx_num = 1
sentances = []
for data_item in validation_collection.data_collection:
    sentances.append({"sentance":" ".join(data_item.tokens),"token_len":len(data_item.tokens)})
print("sentance:", sentances[idx_num]["sentance"])
print("token length:", sentances[idx_num]["token_len"])

tokenized_input = [data.tokenised_inputs for data in validation_collection.data_collection]
print("original length:", len(tokenized_input[idx_num]["input_ids"]))
token_values = []
for token in tokenized_input[idx_num]["input_ids"]:
    token_value = tokenizer.convert_ids_to_tokens(token)
    if token_value != "[CLS]" and token_value != "[SEP]":
        token_values.append(token_value)
print(token_values)
print("after removing cls and sep length:", len(token_values))

checkpoint_list:list[str] = os.listdir(model_output_dir)
last_checkpoint:str = checkpoint_list[-1:][0]
last_checkpoint_path:str = os.path.join(model_output_dir, last_checkpoint)
classifier = pipeline("ner", model=last_checkpoint_path)
result:list[dict] = classifier(sentances[idx_num]["sentance"])
print("predicted tokens length:",len(result))
print([value["entity"] for value in result])


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


sentance: electro - oculography ( eog ) ( retiport32 , roland consult , wiesbaden , germany ) was performed in all patients according to the guidelines of the international society for clinical electrophysiology of vision ( iscev).[12 ] arden ratios below 1.8 were rated as pathologic .
token length: 46
original length: 71
['electro', '-', 'o', '##cu', '##log', '##raphy', '(', 'e', '##og', ')', '(', 're', '##tip', '##ort', '##32', ',', 'roland', 'consult', ',', 'wi', '##es', '##bad', '##en', ',', 'germany', ')', 'was', 'performed', 'in', 'all', 'patients', 'according', 'to', 'the', 'guidelines', 'of', 'the', 'international', 'society', 'for', 'clinical', 'electro', '##phy', '##sio', '##logy', 'of', 'vision', '(', 'is', '##ce', '##v', ')', '.', '[', '12', ']', 'arden', 'ratios', 'below', '1', '.', '8', 'were', 'rated', 'as', 'path', '##olo', '##gic', '.']
after removing cls and sep length: 69
predicted tokens length: 69
['B-LF', 'I-LF', 'I-LF', 'I-LF', 'I-LF', 'I-LF', 'B-O', 'B-AC', 'B-A

In [74]:
sentances = []
for data_item in validation_collection.data_collection:
    sentances.append({"sentance":" ".join(data_item.tokens),"token_len":len(data_item.tokens)})

measure_dict = []
id2label = validation_collection.reverse_embeddings
tokenized_validation
for idx, tokenised in enumerate(tokenised_val):
    # d = {"tags":[], "sentance":"","classifier_res":[]}
    tokenised_l = tokenised["labels"][1:-1]
    tags = []
    prev_id = list(validation_collection.item_embeddings.keys())[0]
    for label_id in tokenised_l:
        if label_id != -100:
            tags.append(id2label[label_id])
            prev_id = id2label[label_id]
        else:
            tags.append(prev_id)
    tokenised_in = tokenised["input_ids"][1:-1]
    tokens = []
    for token in tokenised_in:
        tokens.append(tokenizer.convert_ids_to_tokens(token))
    result:list[dict] = classifier(sentances[idx]["sentance"])
    classifier_entities = [value["entity"] for value in result]
    measure_dict.append({"tags":tags, "tokens":tokens, "sentance":sentances[idx]["sentance"], "predicted_tags": classifier_entities})


In [77]:
idx_val = 3
print(measure_dict[idx_val]["sentance"])
print(len(measure_dict[idx_val]["tags"]))
print(len(measure_dict[idx_val]["predicted_tags"]))

t -snares syn-1a and snap25 through their interactions with pm - bound voltage - gated calcium channels ( cav ) , l - type in Î² - cells and n - type in neurons , position the predocked sgs to the site of maximum ca2 + influx for efficient exocytosis [ 6â€“12 ] .
71
71


In [80]:
tags = validation_collection.item_embeddings.keys()

class TagMetrics:
    def __init__(self, tag:str, tag_dict:dict):
        self.tag = tag
        self.correct = tag_dict["correct"]
        self.incorrect = tag_dict["incorrect"]
        self.total = tag_dict["total"]

class MetricItem:
    def __init__(self, total:int, correct:int, incorrect:int, tags_dict:dict, idx=0):
        self.idx:int = idx
        self.total:int = total
        self.correct:int = correct
        self.incorrect:int = incorrect
        self.tag_metric:list[TagMetrics] = [TagMetrics(tag, tags_dict[tag]) for tag in tags_dict.keys()]

class MetricCollection:
    def __init__(self, metric_items:list[MetricItem]):
        self.data_collection:list[MetricItem] = metric_items
        self.total_correct:int = sum([item.correct for item in metric_items]) 
        self.total_incorrect:int = sum([item.incorrect for item in metric_items])
        self.total_label_measurement:dict = self.__items_to_label_measure__()

    def __items_to_label_measure__(self) -> dict:
        tag_measure = {}
        for item in metric_items:
            for tag_metric in item.tag_metric:
                if tag_metric.tag not in tag_measure.keys():
                    tag_measure[tag_metric.tag] = {"correct":tag_metric.correct, "incorrect":tag_metric.incorrect, "total":tag_metric.total}
                else:
                    tag_measure[tag_metric.tag]["correct"] += tag_metric.correct
                    tag_measure[tag_metric.tag]["incorrect"] += tag_metric.incorrect
                    tag_measure[tag_metric.tag]["total"] += tag_metric.total
        return tag_measure

print("VALIDATION COLLECTION SIZE: ", len(validation_collection.data_collection))
total_tokens = 0
for idx in range(len(measure_dict)):
    total_tokens += len(measure_dict[idx]["tokens"])
print("TOTAL TOKENS IN VALIDATION SET:", total_tokens)


metric_items:list[MetricItem] = []
for idx in range(len(measure_dict)):
    correct = 0
    incorrect = 0
    tags_dict = {tag:{"correct":0, "incorrect":0, "total":0} for tag in tags}
    current_total = len(measure_dict[idx]["tokens"])
    for d_idx in range(current_total):
        tag = measure_dict[idx]["tags"][d_idx]
        predicted_tag = measure_dict[idx]["predicted_tags"][d_idx]
        if tag == predicted_tag:
            correct += 1
            tags_dict[tag]["correct"] += 1
            tags_dict[tag]["total"] += 1
        else:
            incorrect += 1
            tags_dict[tag]["incorrect"] += 1
            tags_dict[tag]["total"] += 1
    metric_items.append(MetricItem(current_total, correct, incorrect, tags_dict, idx=idx))
    # break # for only 1 item

metric_collection:MetricCollection = MetricCollection(metric_items)
print("TOTAL VALUES:", metric_collection.total_label_measurement)
# print("IDX:1")
# print("TOTAL CORRECT:",correct,"/",current_total)
# print("TOTAL INCORRECT:",incorrect,"/",current_total)
# for k in tags_dict.keys():
#     print(k, ":", tags_dict[k])

VALIDATION COLLECTION SIZE:  126
TOTAL TOKENS IN VALIDATION SET: 6549
TOTAL VALUES: {'B-O': {'correct': 4969, 'incorrect': 239, 'total': 5208}, 'B-LF': {'correct': 56, 'incorrect': 233, 'total': 289}, 'I-LF': {'correct': 410, 'incorrect': 79, 'total': 489}, 'B-AC': {'correct': 324, 'incorrect': 239, 'total': 563}}


In [90]:
import pandas as pd
tags = validation_collection.get_unique_tags()
df = pd.DataFrame(0, columns=tags, index=tags)

for idx in range(len(measure_dict)):
    for d_idx in range(len(measure_dict[idx]["tokens"])):
        tag = measure_dict[idx]["tags"][d_idx]
        predicted_tag = measure_dict[idx]["predicted_tags"][d_idx]
        df.at[tag, predicted_tag] += 1 # first is row, second is column (meaning rows are true tags and columns are predicted tags)
print(df) # center diagonally is TP (and TN) and everything else is FP (and FN)!

def calc_precision(TP, FP):
    return TP / (TP + FP)

def calc_recall(TP, FN):
    return TP / (TP + FN)

def calc_f1(precision, recall):
    return (2 * precision * recall) / (precision + recall)

total_correct = 0
for tag in tags:
    total_correct += df.at[tag,tag]

total_incorrect = total_tokens - total_correct

TP = 56 + 410 + 324
TN = 4969
FP = 104 + 1 + 7 + 77 + 0 + 232
FN = 19 + 126 + 1 + 132 + 3 + 88
# precision = calc_precision(total_correct, total_incorrect)
# recall = calc_recall(total_correct, total_incorrect)
# f1_score = calc_f1(precision, recall)

precision = calc_precision(TP, FP)
recall = calc_recall(TP, FN)
f1_score = calc_f1(precision, recall)

print("Precision:", precision)
print("Recall:", recall)
print("F1_Score:", f1_score)


       B-O  B-LF  I-LF  B-AC
B-O   4969    19   132    88
B-LF   104    56   126     3
I-LF    77     1   410     1
B-AC   232     0     7   324
Precision: 0.652353426919901
Recall: 0.6816220880069025
F1_Score: 0.6666666666666666


In [20]:
# Dataframe the first item in validation set!

data = []
for idx in range(len(measure_dict)):
    for d_idx in range(len(measure_dict[idx]["tokens"])):
        data.append([measure_dict[idx]["tokens"][d_idx], measure_dict[idx]["tags"][d_idx], measure_dict[idx]["predicted_tags"][d_idx]])
    break
df = pd.DataFrame(data, columns=['Tokens', 'Tags', "Predicted Tags"])
df.to_csv('result.csv', index=False)  
df

Unnamed: 0,Tokens,Tags,Predicted Tags
0,=,B-O,B-O
1,manual,B-LF,B-LF
2,ability,I-LF,I-LF
3,classification,I-LF,I-LF
4,system,I-LF,I-LF
...,...,...,...
60,hem,B-O,B-O
61,##ip,B-O,B-O
62,##leg,B-O,B-O
63,##ia,B-O,B-O
