In [1]:
# For XLM-R
model_name = "xlm-roberta-base"

# For DistilBERT
model_name = "distilbert-base-multilingual-cased"

# For mBERT
model_name = "bert-base-multilingual-cased"


In [None]:
from transformers import AutoModelForTokenClassification

model = AutoModelForTokenClassification.from_pretrained(
    model_name,
    num_labels=len(tag2id),
    id2label=id2tag,
    label2id=tag2id
)


In [None]:
from transformers import AutoTokenizer, AutoModelForTokenClassification

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForTokenClassification.from_pretrained(
    model_name,
    num_labels=len(tag2id),
    id2label=id2tag,
    label2id=tag2id
)


In [None]:
def tokenize_and_align(example):
    tokenized = tokenizer(example["tokens"], is_split_into_words=True, truncation=True)
    word_ids = tokenized.word_ids()

    previous_word_idx = None
    label_ids = []
    for word_idx in word_ids:
        if word_idx is None:
            label_ids.append(-100)
        elif word_idx != previous_word_idx:
            label_ids.append(example["labels"][word_idx])
        else:
            label_ids.append(example["labels"][word_idx])
        previous_word_idx = word_idx

    tokenized["labels"] = label_ids
    return tokenized

tokenized_dataset = dataset.map(tokenize_and_align, batched=True)


In [7]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir=f"./ner_{model_name.replace('/', '_')}",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    num_train_epochs=5,
    weight_decay=0.01,
    save_strategy="epoch",
    logging_dir="./logs",
    load_best_model_at_end=True,
    metric_for_best_model="f1"
)


In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

trainer.train()
trainer.save_model(f"ner_model_{model_name.replace('/', '_')}")
