In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
import numpy as np
import evaluate
from transformers import DataCollatorForTokenClassification
from matplotlib import pyplot as plt 

In [None]:
wmout = load_dataset('alfarruggia/wmout')
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [None]:
# https://huggingface.co/docs/transformers/en/tasks/token_classification
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [None]:
tokenized_mout = wmout.map(tokenize_and_align_labels, batched=True)

In [None]:

label_list = ["0", "B-Workout", "I-Workout", "I-Frequency", "B-Frequency", "I-Duration", "B-Number", "B-Duration"]
id2label = {}
label2id = {}

for i, v in enumerate(label_list):
    id2label[i] = v
    label2id[v] = i

In [None]:
seqeval = evaluate.load("seqeval")

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [None]:
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)


model = AutoModelForTokenClassification.from_pretrained(
    "distilbert-base-uncased", num_labels=len(label_list), id2label=id2label, label2id=label2id
)

training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=11,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    save_total_limit=3
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_mout["train"],
    eval_dataset=tokenized_mout["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()
trainer.save_model('./model_output')

In [None]:
eval_loss = [ x.get('eval_loss') for x in  trainer.state.log_history]
eval_accuracy = [ x.get('eval_accuracy') for x in  trainer.state.log_history]
eval_precision = [ x.get('eval_precision') for x in  trainer.state.log_history]
eval_f1 = [ x.get('eval_f1') for x in  trainer.state.log_history]

figs, ax = plt.subplots(2, 2, figsize=(10, 10))

ax[0, 0].plot(range(0, len(eval_accuracy)),eval_accuracy, color='blue') 
ax[0, 0].set_title('Accuracy')

ax[0, 1].plot(range(0, len(eval_loss)),eval_loss, color='orange') 
ax[0, 1].set_title('Loss')

ax[1, 0].plot(range(0, len(eval_precision)),eval_precision, color='green') 
ax[1, 0].set_title('Precision')

ax[1, 1].plot(range(0, len(eval_f1)),eval_f1, color='red') 
ax[1, 1].set_title('F1')


In [None]:
# optimum-cli export tflite --model ./model_output --sequence_length 512 --task token-classification ./tflite_mobile