# Environment Setup


In [None]:
import os
import torch
import evaluate
from nlpcw.utils import get_dataset, load_model, show_random_elements, tokenize_dataset
from transformers import (
    Trainer,
    TrainingArguments,
    DataCollatorForTokenClassification,
    EarlyStoppingCallback,
)
import numpy as np
import wandb
from pathlib import Path

In [None]:
wandb.login()

## Config


In [None]:
MODEL_NAME = "romainlhardy/roberta-large-finetuned-ner"
CHECKPOINT_PATH = None
# CHECKPOINT_PATH = "experiments/agile-navigator-qn9uu"
BATCH_SIZE = 1
NUM_EPOCHS = 10

## Dataset


In [None]:
dataset, id2label, label2id, num_labels = get_dataset()
label_list = dataset["train"].features["ner_tags"].feature.names  # type: ignore
show_random_elements(dataset["train"])  # type: ignore

## Model


In [None]:
tokenizer, config_model, model, save_path = load_model(
    exp_or_model_name=MODEL_NAME if CHECKPOINT_PATH == None else CHECKPOINT_PATH,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
)
# tokenizer.save_pretrained(save_path)
# model.save_pretrained(save_path)
# config_model.save_pretrained(save_path)
print(f"{save_path=}")

## Dataset Exploration


In [None]:
example = dataset["train"][4]  # type: ignore
print(example["tokens"])
print(example["ner_tags"])

In [None]:
tokenized_input = tokenizer(example["tokens"], is_split_into_words=True)
tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])  # type: ignore
print(tokens)

In [None]:
len(example[f"ner_tags"]), len(tokenized_input["input_ids"])  # type: ignore

In [None]:
word_ids = tokenized_input.word_ids()
aligned_labels = [-100 if i is None else example[f"ner_tags"][i] for i in word_ids]
print(len(aligned_labels), len(tokenized_input["input_ids"]))  # type: ignore

## Training


In [None]:
tokenized_dataset = tokenize_dataset(dataset, tokenizer)
tokenized_dataset

In [None]:
os.environ.setdefault('TOKENIZERS_PARALLELISM', 'true')
metric = evaluate.load("seqeval")
data_collator = DataCollatorForTokenClassification(tokenizer)

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    assert results != None
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }


args = TrainingArguments(
    output_dir=str(save_path),
    run_name=Path(save_path).name,
    overwrite_output_dir=True,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=50,
    save_total_limit=1,
    learning_rate=2e-5,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=4,
    num_train_epochs=NUM_EPOCHS,
    weight_decay=0.001,
    metric_for_best_model="f1",
    load_best_model_at_end=True,
)


trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_dataset["train"],  # type: ignore
    eval_dataset=tokenized_dataset["validation"],  # type: ignore
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

torch.mps.empty_cache()

In [None]:
trainer.train()