In [6]:
# imports
import pandas as pd
import numpy as np

from code_.process_conll import process_file, extract_features
from code_.bert import Tokenizer, convert_to_dataset, compute_metrics, task, batch_size
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers import DataCollatorForTokenClassification
from datasets import load_metric

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


In [3]:
# constants
model_checkpoint = "distilbert-base-uncased"

In [4]:
# load data
df_train = pd.DataFrame()
df_val = pd.DataFrame()
df_test = pd.DataFrame()

In [None]:
labels_list = [] # to fill all possible gold labels here

In [5]:
dataset = convert_to_dataset(df_train, df_val, df_test)

In [None]:
tok = Tokenizer(model_checkpoint, labels_list)

In [None]:
tokenized_datasets = dataset.map(tok.tokenize_and_align_labels, batched=True)

In [None]:
# initialise model

In [None]:
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(labels_list))

In [None]:
model_name = model_checkpoint.split("/")[-1]
args = TrainingArguments(
    f"{model_name}-finetuned-{task}",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    push_to_hub=False,
)

In [None]:
data_collator = DataCollatorForTokenClassification(tok.tokenizer)
metric = load_metric("seqeval")

In [None]:

# labels = [labels_list[i] for i in example[f"{task}_tags"]]
# metric.compute(predictions=[labels], references=[labels])

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tok.tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
# train

In [None]:
trainer.train()

In [None]:
trainer.save_model('model_checkpoints/some_model_name.pth')

In [None]:
# results, plots, reports etc.

In [None]:
trainer.evaluate()

In [None]:
predictions, labels, _ = trainer.predict(tokenized_datasets["validation"])
predictions = np.argmax(predictions, axis=2)

# Remove ignored index (special tokens)
true_predictions = [
    [labels_list[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
true_labels = [
    [labels_list[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]

results = metric.compute(predictions=true_predictions, references=true_labels)
results