In [1]:
from datasets import load_dataset, load_metric
dataset = load_dataset("surrey-nlp/PLOD-CW")
from transformers import AutoTokenizer, AutoModelForTokenClassification

tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")
model = AutoModelForTokenClassification.from_pretrained("distilbert/distilbert-base-uncased", num_labels=4)

short_dataset = dataset["train"][0:200]
val_dataset = dataset["validation"]
test_dataset = dataset["test"]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [2]:
tokenized_input = tokenizer(short_dataset["tokens"], is_split_into_words=True)

# Example single sentence example.
for token in tokenized_input["input_ids"]:
    print(tokenizer.convert_ids_to_tokens(token))
    break

['[CLS]', 'for', 'this', 'purpose', 'the', 'gothenburg', 'young', 'persons', 'empowerment', 'scale', '(', 'g', '##ype', '##s', ')', 'was', 'developed', '.', '[SEP]']


In [34]:
label_encoding = {"B-O": 0, "B-AC": 1, "B-LF": 2, "I-LF": 3}
id2label = {v:k for k,v in label_encoding.items()}

label_list = []
for sample in short_dataset["ner_tags"]:
    label_list.append([label_encoding[tag] for tag in sample])

val_label_list = []
for sample in val_dataset["ner_tags"]:
    val_label_list.append([label_encoding[tag] for tag in sample])

test_label_list = []
for sample in test_dataset["ner_tags"]:
    test_label_list.append([label_encoding[tag] for tag in sample])


In [5]:
def tokenize_and_align_labels(short_dataset, list_name):
    tokenized_inputs = tokenizer(short_dataset["tokens"], truncation=True, is_split_into_words=True) ## For some models, you may need to set max_length to approximately 500.

    labels = []
    for i, label in enumerate(list_name):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx])
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [6]:
tokenized_datasets = tokenize_and_align_labels(short_dataset, label_list)
tokenized_val_datasets = tokenize_and_align_labels(val_dataset, val_label_list)
tokenized_test_datasets = tokenize_and_align_labels(test_dataset, test_label_list)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [7]:
# BERT's tokenizer returns the dataset in the form of a dictionary of lists (sentences). 
# we have to convert it into a list of dictionaries for training.
def turn_dict_to_list_of_dict(d):
    new_list = []

    for labels, inputs in zip(d["labels"], d["input_ids"]):
        entry = {"input_ids": inputs, "labels": labels}
        new_list.append(entry)

    return new_list

In [8]:
tokenised_train = turn_dict_to_list_of_dict(tokenized_datasets)
tokenised_val = turn_dict_to_list_of_dict(tokenized_val_datasets)
tokenised_test = turn_dict_to_list_of_dict(tokenized_test_datasets)

In [9]:
from transformers import DataCollatorForTokenClassification
data_collator = DataCollatorForTokenClassification(tokenizer)

In [10]:
import numpy as np

metric = load_metric("seqeval")
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

  metric = load_metric("seqeval")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [21]:
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback

# Training arguments (feel free to play arround with these values)
# model_name = "bert-base-uncased"
# epochs = 6
# batch_size = 4
# learning_rate = 2e-5

# args = TrainingArguments(
#     output_dir="distilBERT-finetuned-NER",
#     # evaluation_strategy = "epoch", ## Instead of focusing on loss and accuracy, we will focus on the F1 score
#     evaluation_strategy ='steps',
#     eval_steps = 7000,
#     save_total_limit = 3,
#     learning_rate=learning_rate,
#     per_device_train_batch_size=batch_size,
#     per_device_eval_batch_size=batch_size,
#     num_train_epochs=epochs,
#     weight_decay=0.001,
#     save_steps=35000,
#     metric_for_best_model = 'f1',
#     load_best_model_at_end=True,
#     report_to=['none'], # REQUIRED because otherwise keeps asking to log into "wandb"
# )

# trainer = Trainer(
#     model,
#     args,
#     train_dataset=tokenised_train,
#     eval_dataset=tokenised_val,
#     data_collator = data_collator,
#     tokenizer=tokenizer,
#     compute_metrics=compute_metrics,
#     callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
# )

model_output_dir:str = "distilBERT-finetuned-NER"

training_args = TrainingArguments(
    output_dir=model_output_dir,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=50,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    report_to=['none'] # REQUIRED because otherwise keeps asking to log into "wandb"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenised_train,
    eval_dataset=tokenised_val,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False)


In [18]:
trainer.train()

  0%|          | 0/650 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.33404695987701416, 'eval_precision': 0.5078969243557773, 'eval_recall': 0.5835721107927412, 'eval_f1': 0.5431111111111112, 'eval_accuracy': 0.8891433806688044, 'eval_runtime': 0.2631, 'eval_samples_per_second': 478.963, 'eval_steps_per_second': 30.41, 'epoch': 1.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.34394311904907227, 'eval_precision': 0.5163793103448275, 'eval_recall': 0.5721107927411653, 'eval_f1': 0.5428183053919348, 'eval_accuracy': 0.8915865017559933, 'eval_runtime': 0.2595, 'eval_samples_per_second': 485.64, 'eval_steps_per_second': 30.834, 'epoch': 2.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.3537988066673279, 'eval_precision': 0.5192780968006563, 'eval_recall': 0.6045845272206304, 'eval_f1': 0.558693733451015, 'eval_accuracy': 0.8952511833867766, 'eval_runtime': 0.2615, 'eval_samples_per_second': 481.866, 'eval_steps_per_second': 30.595, 'epoch': 3.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.39186614751815796, 'eval_precision': 0.5340050377833753, 'eval_recall': 0.6074498567335244, 'eval_f1': 0.5683646112600536, 'eval_accuracy': 0.8938769277752329, 'eval_runtime': 0.2693, 'eval_samples_per_second': 467.913, 'eval_steps_per_second': 29.709, 'epoch': 4.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.4126548767089844, 'eval_precision': 0.5223395613322502, 'eval_recall': 0.6141356255969437, 'eval_f1': 0.5645302897278315, 'eval_accuracy': 0.8906703313482974, 'eval_runtime': 0.2518, 'eval_samples_per_second': 500.311, 'eval_steps_per_second': 31.766, 'epoch': 5.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.4348374307155609, 'eval_precision': 0.5171102661596958, 'eval_recall': 0.6494746895893028, 'eval_f1': 0.5757832345469941, 'eval_accuracy': 0.8899068560085509, 'eval_runtime': 0.2561, 'eval_samples_per_second': 492.01, 'eval_steps_per_second': 31.239, 'epoch': 6.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.4392145872116089, 'eval_precision': 0.5195402298850574, 'eval_recall': 0.6475644699140402, 'eval_f1': 0.576530612244898, 'eval_accuracy': 0.8908230264162468, 'eval_runtime': 0.2545, 'eval_samples_per_second': 495.066, 'eval_steps_per_second': 31.433, 'epoch': 7.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.4649152457714081, 'eval_precision': 0.5146506386175808, 'eval_recall': 0.6542502387774594, 'eval_f1': 0.5761143818334735, 'eval_accuracy': 0.8906703313482974, 'eval_runtime': 0.2513, 'eval_samples_per_second': 501.473, 'eval_steps_per_second': 31.84, 'epoch': 8.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.4622561037540436, 'eval_precision': 0.5291300877893057, 'eval_recall': 0.6332378223495702, 'eval_f1': 0.5765217391304347, 'eval_accuracy': 0.8943350129790808, 'eval_runtime': 0.2589, 'eval_samples_per_second': 486.588, 'eval_steps_per_second': 30.894, 'epoch': 9.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.46399936079978943, 'eval_precision': 0.5272277227722773, 'eval_recall': 0.6103151862464183, 'eval_f1': 0.5657370517928287, 'eval_accuracy': 0.8934188425713849, 'eval_runtime': 0.2466, 'eval_samples_per_second': 510.956, 'eval_steps_per_second': 32.442, 'epoch': 10.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.5022879242897034, 'eval_precision': 0.5199398043641836, 'eval_recall': 0.6599808978032474, 'eval_f1': 0.5816498316498318, 'eval_accuracy': 0.8903649412123988, 'eval_runtime': 0.2524, 'eval_samples_per_second': 499.232, 'eval_steps_per_second': 31.697, 'epoch': 11.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.5204578638076782, 'eval_precision': 0.5227103499627699, 'eval_recall': 0.670487106017192, 'eval_f1': 0.5874476987447699, 'eval_accuracy': 0.8926553672316384, 'eval_runtime': 0.2523, 'eval_samples_per_second': 499.337, 'eval_steps_per_second': 31.704, 'epoch': 12.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.5026178359985352, 'eval_precision': 0.5417376490630323, 'eval_recall': 0.6074498567335244, 'eval_f1': 0.5727149932462855, 'eval_accuracy': 0.8961673537944724, 'eval_runtime': 0.2716, 'eval_samples_per_second': 463.838, 'eval_steps_per_second': 29.45, 'epoch': 13.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.5490376353263855, 'eval_precision': 0.5007163323782235, 'eval_recall': 0.667621776504298, 'eval_f1': 0.5722472370036841, 'eval_accuracy': 0.8867002595816155, 'eval_runtime': 0.2542, 'eval_samples_per_second': 495.73, 'eval_steps_per_second': 31.475, 'epoch': 14.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.529660701751709, 'eval_precision': 0.5537948290241869, 'eval_recall': 0.6341929321872015, 'eval_f1': 0.591273374888691, 'eval_accuracy': 0.8976943044739655, 'eval_runtime': 0.2448, 'eval_samples_per_second': 514.761, 'eval_steps_per_second': 32.683, 'epoch': 15.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.5429145693778992, 'eval_precision': 0.5247678018575851, 'eval_recall': 0.6475644699140402, 'eval_f1': 0.579734929457033, 'eval_accuracy': 0.8946404031149794, 'eval_runtime': 0.2747, 'eval_samples_per_second': 458.7, 'eval_steps_per_second': 29.124, 'epoch': 16.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.5537455677986145, 'eval_precision': 0.5213483146067416, 'eval_recall': 0.664756446991404, 'eval_f1': 0.5843828715365238, 'eval_accuracy': 0.8915865017559933, 'eval_runtime': 0.2589, 'eval_samples_per_second': 486.683, 'eval_steps_per_second': 30.9, 'epoch': 17.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.5545938014984131, 'eval_precision': 0.5257966616084977, 'eval_recall': 0.66189111747851, 'eval_f1': 0.586046511627907, 'eval_accuracy': 0.8934188425713849, 'eval_runtime': 0.2585, 'eval_samples_per_second': 487.394, 'eval_steps_per_second': 30.946, 'epoch': 18.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.5749813914299011, 'eval_precision': 0.5119047619047619, 'eval_recall': 0.6571155682903533, 'eval_f1': 0.575491426181514, 'eval_accuracy': 0.8903649412123988, 'eval_runtime': 0.2586, 'eval_samples_per_second': 487.295, 'eval_steps_per_second': 30.939, 'epoch': 19.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.570259690284729, 'eval_precision': 0.5267295597484277, 'eval_recall': 0.6399235912129895, 'eval_f1': 0.5778352738249245, 'eval_accuracy': 0.8947930981829287, 'eval_runtime': 0.2589, 'eval_samples_per_second': 486.657, 'eval_steps_per_second': 30.899, 'epoch': 20.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.5958968997001648, 'eval_precision': 0.5136733185513673, 'eval_recall': 0.6638013371537727, 'eval_f1': 0.5791666666666666, 'eval_accuracy': 0.8867002595816155, 'eval_runtime': 0.2534, 'eval_samples_per_second': 497.301, 'eval_steps_per_second': 31.575, 'epoch': 21.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.5790733098983765, 'eval_precision': 0.533816425120773, 'eval_recall': 0.6332378223495702, 'eval_f1': 0.5792922673656619, 'eval_accuracy': 0.8957092685906245, 'eval_runtime': 0.2555, 'eval_samples_per_second': 493.169, 'eval_steps_per_second': 31.312, 'epoch': 22.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.5985667705535889, 'eval_precision': 0.5132743362831859, 'eval_recall': 0.664756446991404, 'eval_f1': 0.5792759051186018, 'eval_accuracy': 0.8892960757367537, 'eval_runtime': 0.2581, 'eval_samples_per_second': 488.171, 'eval_steps_per_second': 30.995, 'epoch': 23.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.588121235370636, 'eval_precision': 0.5392628205128205, 'eval_recall': 0.6427889207258835, 'eval_f1': 0.5864923747276688, 'eval_accuracy': 0.8946404031149794, 'eval_runtime': 0.2672, 'eval_samples_per_second': 471.599, 'eval_steps_per_second': 29.943, 'epoch': 24.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.5979374647140503, 'eval_precision': 0.5298681148176881, 'eval_recall': 0.6523400191021967, 'eval_f1': 0.5847602739726028, 'eval_accuracy': 0.8921972820277905, 'eval_runtime': 0.2496, 'eval_samples_per_second': 504.752, 'eval_steps_per_second': 32.048, 'epoch': 25.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6018063426017761, 'eval_precision': 0.5273865414710485, 'eval_recall': 0.6437440305635148, 'eval_f1': 0.5797849462365591, 'eval_accuracy': 0.8925026721636891, 'eval_runtime': 0.2509, 'eval_samples_per_second': 502.239, 'eval_steps_per_second': 31.888, 'epoch': 26.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6089754104614258, 'eval_precision': 0.5107169253510717, 'eval_recall': 0.6599808978032474, 'eval_f1': 0.5758333333333334, 'eval_accuracy': 0.891433806688044, 'eval_runtime': 0.2525, 'eval_samples_per_second': 499.094, 'eval_steps_per_second': 31.688, 'epoch': 27.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6103501319885254, 'eval_precision': 0.5518092105263158, 'eval_recall': 0.6408787010506208, 'eval_f1': 0.5930181175430844, 'eval_accuracy': 0.8975416094060162, 'eval_runtime': 0.2518, 'eval_samples_per_second': 500.314, 'eval_steps_per_second': 31.766, 'epoch': 28.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6189752221107483, 'eval_precision': 0.5404323458767014, 'eval_recall': 0.6446991404011462, 'eval_f1': 0.5879790940766552, 'eval_accuracy': 0.894945793250878, 'eval_runtime': 0.2571, 'eval_samples_per_second': 490.023, 'eval_steps_per_second': 31.113, 'epoch': 29.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6225976347923279, 'eval_precision': 0.5237377543330821, 'eval_recall': 0.6638013371537727, 'eval_f1': 0.5855096882898062, 'eval_accuracy': 0.8925026721636891, 'eval_runtime': 0.257, 'eval_samples_per_second': 490.215, 'eval_steps_per_second': 31.125, 'epoch': 30.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6141191720962524, 'eval_precision': 0.5592654424040067, 'eval_recall': 0.6399235912129895, 'eval_f1': 0.5968819599109131, 'eval_accuracy': 0.8987631699496106, 'eval_runtime': 0.2568, 'eval_samples_per_second': 490.647, 'eval_steps_per_second': 31.152, 'epoch': 31.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6158362627029419, 'eval_precision': 0.5416666666666666, 'eval_recall': 0.6580706781279847, 'eval_f1': 0.5942216472617508, 'eval_accuracy': 0.8958619636585738, 'eval_runtime': 0.257, 'eval_samples_per_second': 490.194, 'eval_steps_per_second': 31.123, 'epoch': 32.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6189525127410889, 'eval_precision': 0.5354691075514875, 'eval_recall': 0.670487106017192, 'eval_f1': 0.5954198473282443, 'eval_accuracy': 0.8943350129790808, 'eval_runtime': 0.2496, 'eval_samples_per_second': 504.884, 'eval_steps_per_second': 32.056, 'epoch': 33.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6147043704986572, 'eval_precision': 0.5518341307814992, 'eval_recall': 0.6609360076408787, 'eval_f1': 0.601477618426771, 'eval_accuracy': 0.8987631699496106, 'eval_runtime': 0.2554, 'eval_samples_per_second': 493.296, 'eval_steps_per_second': 31.32, 'epoch': 34.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6173309683799744, 'eval_precision': 0.5458135860979463, 'eval_recall': 0.6599808978032474, 'eval_f1': 0.5974924340683095, 'eval_accuracy': 0.8970835242021683, 'eval_runtime': 0.2651, 'eval_samples_per_second': 475.311, 'eval_steps_per_second': 30.178, 'epoch': 35.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6196343898773193, 'eval_precision': 0.5427450980392157, 'eval_recall': 0.6609360076408787, 'eval_f1': 0.5960378983634796, 'eval_accuracy': 0.8958619636585738, 'eval_runtime': 0.2512, 'eval_samples_per_second': 501.517, 'eval_steps_per_second': 31.842, 'epoch': 36.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.619455099105835, 'eval_precision': 0.5430149960536701, 'eval_recall': 0.6571155682903533, 'eval_f1': 0.5946413137424373, 'eval_accuracy': 0.8972362192701175, 'eval_runtime': 0.2578, 'eval_samples_per_second': 488.834, 'eval_steps_per_second': 31.037, 'epoch': 37.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6276939511299133, 'eval_precision': 0.5385814497272019, 'eval_recall': 0.6599808978032474, 'eval_f1': 0.5931330472103005, 'eval_accuracy': 0.8950984883188273, 'eval_runtime': 0.2585, 'eval_samples_per_second': 487.507, 'eval_steps_per_second': 30.953, 'epoch': 38.0}
{'loss': 0.0239, 'learning_rate': 4.615384615384616e-06, 'epoch': 38.46}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6313808560371399, 'eval_precision': 0.5380989787902593, 'eval_recall': 0.6542502387774594, 'eval_f1': 0.5905172413793103, 'eval_accuracy': 0.8950984883188273, 'eval_runtime': 0.2533, 'eval_samples_per_second': 497.341, 'eval_steps_per_second': 31.577, 'epoch': 39.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.630568265914917, 'eval_precision': 0.5435816164817749, 'eval_recall': 0.6552053486150907, 'eval_f1': 0.5941966219142486, 'eval_accuracy': 0.8960146587265231, 'eval_runtime': 0.2576, 'eval_samples_per_second': 489.209, 'eval_steps_per_second': 31.061, 'epoch': 40.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6396571397781372, 'eval_precision': 0.5342044581091469, 'eval_recall': 0.6638013371537727, 'eval_f1': 0.591993185689949, 'eval_accuracy': 0.8934188425713849, 'eval_runtime': 0.2615, 'eval_samples_per_second': 481.778, 'eval_steps_per_second': 30.589, 'epoch': 41.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6407474279403687, 'eval_precision': 0.5263963274674828, 'eval_recall': 0.6571155682903533, 'eval_f1': 0.584536958368734, 'eval_accuracy': 0.8928080622995878, 'eval_runtime': 0.2515, 'eval_samples_per_second': 500.943, 'eval_steps_per_second': 31.806, 'epoch': 42.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6340999007225037, 'eval_precision': 0.5405616224648986, 'eval_recall': 0.66189111747851, 'eval_f1': 0.5951051953628166, 'eval_accuracy': 0.8940296228431822, 'eval_runtime': 0.2538, 'eval_samples_per_second': 496.544, 'eval_steps_per_second': 31.527, 'epoch': 43.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6354873180389404, 'eval_precision': 0.5375677769171185, 'eval_recall': 0.6628462273161414, 'eval_f1': 0.5936698032506417, 'eval_accuracy': 0.8938769277752329, 'eval_runtime': 0.2596, 'eval_samples_per_second': 485.437, 'eval_steps_per_second': 30.821, 'epoch': 44.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6346672177314758, 'eval_precision': 0.5395348837209303, 'eval_recall': 0.664756446991404, 'eval_f1': 0.5956354300385108, 'eval_accuracy': 0.8950984883188273, 'eval_runtime': 0.2616, 'eval_samples_per_second': 481.624, 'eval_steps_per_second': 30.579, 'epoch': 45.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6352254152297974, 'eval_precision': 0.5391169635941131, 'eval_recall': 0.664756446991404, 'eval_f1': 0.5953806672369547, 'eval_accuracy': 0.8950984883188273, 'eval_runtime': 0.2533, 'eval_samples_per_second': 497.379, 'eval_steps_per_second': 31.58, 'epoch': 46.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6357123851776123, 'eval_precision': 0.5410852713178295, 'eval_recall': 0.6666666666666666, 'eval_f1': 0.59734702610184, 'eval_accuracy': 0.8957092685906245, 'eval_runtime': 0.2556, 'eval_samples_per_second': 492.862, 'eval_steps_per_second': 31.293, 'epoch': 47.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6377902030944824, 'eval_precision': 0.5368663594470046, 'eval_recall': 0.667621776504298, 'eval_f1': 0.5951468710089398, 'eval_accuracy': 0.894945793250878, 'eval_runtime': 0.2666, 'eval_samples_per_second': 472.656, 'eval_steps_per_second': 30.01, 'epoch': 48.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6388519406318665, 'eval_precision': 0.5363984674329502, 'eval_recall': 0.6685768863419294, 'eval_f1': 0.5952380952380952, 'eval_accuracy': 0.8947930981829287, 'eval_runtime': 0.2592, 'eval_samples_per_second': 486.193, 'eval_steps_per_second': 30.869, 'epoch': 49.0}


  0%|          | 0/8 [00:00<?, ?it/s]



{'eval_loss': 0.6389782428741455, 'eval_precision': 0.5363984674329502, 'eval_recall': 0.6685768863419294, 'eval_f1': 0.5952380952380952, 'eval_accuracy': 0.8946404031149794, 'eval_runtime': 0.2598, 'eval_samples_per_second': 485.013, 'eval_steps_per_second': 30.795, 'epoch': 50.0}
{'train_runtime': 77.2572, 'train_samples_per_second': 129.438, 'train_steps_per_second': 8.413, 'train_loss': 0.018852789126909696, 'epoch': 50.0}


TrainOutput(global_step=650, training_loss=0.018852789126909696, metrics={'train_runtime': 77.2572, 'train_samples_per_second': 129.438, 'train_steps_per_second': 8.413, 'train_loss': 0.018852789126909696, 'epoch': 50.0})

In [13]:
# Prepare the test data for evaluation in the same format as the training data

predictions, labels, _ = trainer.predict(tokenised_test)
predictions = np.argmax(predictions, axis=2)

# Remove the predictions for the [CLS] and [SEP] tokens 
true_predictions = [
    [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
true_labels = [
    [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]

# Compute multiple metrics on the test restuls
results = metric.compute(predictions=true_predictions, references=true_labels)
results

  0%|          | 0/10 [00:00<?, ?it/s]



{'0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 3, 3, 0, 1, 0, 0, 0, 0, 0, 0, 2, 3, 3, 1, 0, 1, 0, 0, 0, 0, 0, 0]': {'precision': 0.6310679611650486,
  'recall': 0.7303370786516854,
  'f1': 0.6770833333333335,
  'number': 267},
 '0, 0, 0, 0, 2, 3, 3, 3, 3, 0, 1, 0, 0, 0, 0]': {'precision': 0.5272727272727272,
  'recall': 0.5951492537313433,
  'f1': 0.5591586327782646,
  'number': 536},
 '0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 3, 0, 1, 0, 0, 0, 0, 2, 3, 3, 0, 1, 0, 0, 0, 0, 0]': {'precision': 0.34234234234234234,
  'recall': 0.2550335570469799,
  'f1': 0.2923076923076923,
  'number': 149},
 '1, 0, 2, 3, 3, 0]': {'precision': 0.2532467532467532,
  'recall': 0.3023255813953488,
  'f1': 0.2756183745583039,
  'number': 129},
 'overall_precision': 0.5012722646310432,
 'overall_recall': 0.5467160037002775,
 'overall_f1': 0.523008849557522,
 'overall_accuracy': 0.8998646820027063}

In [40]:
text = "For this purpose the Gothenburg Young Persons Empowerment Scale (GYPES) was developed."
# model = AutoTokenizer.from_pretrained("./distilBERT-finetuned-NER/checkpoint-26/")
import os
from transformers import pipeline

checkpoint_list:list[str] = os.listdir(model_output_dir)
last_checkpoint:str = checkpoint_list[-1:][0]
last_checkpoint_path:str = os.path.join(model_output_dir, last_checkpoint)
classifier = pipeline("ner", model=last_checkpoint_path)
result:list[dict] = classifier(text)
for r in result:
    entity:str = r["entity"] 
    r["entity"] = id2label[int(entity.replace("LABEL_", ""))]
result



Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


{0: 'B-O', 1: 'B-AC', 2: 'B-LF', 3: 'I-LF'}


[{'entity': 'B-O',
  'score': 0.9985917,
  'index': 1,
  'word': 'for',
  'start': 0,
  'end': 3},
 {'entity': 'B-O',
  'score': 0.9988061,
  'index': 2,
  'word': 'this',
  'start': 4,
  'end': 8},
 {'entity': 'B-O',
  'score': 0.99888355,
  'index': 3,
  'word': 'purpose',
  'start': 9,
  'end': 16},
 {'entity': 'B-O',
  'score': 0.99921954,
  'index': 4,
  'word': 'the',
  'start': 17,
  'end': 20},
 {'entity': 'B-LF',
  'score': 0.9956285,
  'index': 5,
  'word': 'gothenburg',
  'start': 21,
  'end': 31},
 {'entity': 'I-LF',
  'score': 0.99947685,
  'index': 6,
  'word': 'young',
  'start': 32,
  'end': 37},
 {'entity': 'I-LF',
  'score': 0.9995701,
  'index': 7,
  'word': 'persons',
  'start': 38,
  'end': 45},
 {'entity': 'I-LF',
  'score': 0.9996407,
  'index': 8,
  'word': 'empowerment',
  'start': 46,
  'end': 57},
 {'entity': 'I-LF',
  'score': 0.9995134,
  'index': 9,
  'word': 'scale',
  'start': 58,
  'end': 63},
 {'entity': 'B-O',
  'score': 0.99941266,
  'index': 10,
  '