In [8]:
import torch
from transformers import (
    XLMRobertaTokenizerFast,
    XLMRobertaForTokenClassification,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer
)
from datasets import Dataset
from sklearn.model_selection import train_test_split


In [1]:
import json
import torch
from datasets import Dataset, DatasetDict
from transformers import (
    XLMRobertaTokenizerFast,
    DataCollatorForTokenClassification,
    XLMRobertaForTokenClassification,
    Trainer,
    TrainingArguments
)

# ----------------------------------------------------
# 1. Load JSON dataset
# ----------------------------------------------------
with open("ner_dataset.json", "r", encoding="utf-8") as f:
    data = json.load(f)

# Train/validation split (90/10)
split = int(len(data) * 0.9)
train_data = data[:split]
val_data = data[split:]

# Wrap into Dataset
dataset = DatasetDict({
    "train": Dataset.from_list(train_data),
    "validation": Dataset.from_list(val_data)
})

# ----------------------------------------------------
# 2. Create label list
# ----------------------------------------------------
labels = ["PER", "ORG", "LOC"]
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for label, i in label2id.items()}

# ----------------------------------------------------
# 3. Load tokenizer
# ----------------------------------------------------
model_name = "xlm-roberta-base"
tokenizer = XLMRobertaTokenizerFast.from_pretrained(model_name)

# ----------------------------------------------------
# 4. Tokenize + align entities to tokens
# ----------------------------------------------------
def tokenize_and_align(example):
    encoding = tokenizer(
        example["text"],
        truncation=True,
        padding="max_length",
        max_length=128,
        return_offsets_mapping=True
    )

    labels_for_tokens = [ -100 ] * len(encoding["offset_mapping"])

    for ent in example["entities"]:
        for i, (start, end) in enumerate(encoding["offset_mapping"]):
            if start == ent["start"] and end == ent["end"]:
                labels_for_tokens[i] = label2id[ent["label"]]

    encoding["labels"] = labels_for_tokens
    encoding.pop("offset_mapping")
    return encoding

tokenized = dataset.map(tokenize_and_align, batched=False)

# ----------------------------------------------------
# 5. Load model
# ----------------------------------------------------
model = XLMRobertaForTokenClassification.from_pretrained(
    model_name,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id
)

# ----------------------------------------------------
# 6. Training setup
# ----------------------------------------------------
data_collator = DataCollatorForTokenClassification(tokenizer)

training_args = TrainingArguments(
    output_dir="./ner-xlm-roberta",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    weight_decay=0.01,
    logging_steps=20,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator
)

# ----------------------------------------------------
# 7. Train the model
# ----------------------------------------------------
trainer.train()

# ----------------------------------------------------
# 8. Save final model
# ----------------------------------------------------
trainer.save_model("./ner-xlm-roberta-final")
tokenizer.save_pretrained("./ner-xlm-roberta-final")

print("Training complete! Model saved in ner-xlm-roberta-final/")


  from .autonotebook import tqdm as notebook_tqdm
Map: 100%|██████████| 900/900 [00:00<00:00, 1071.72 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 1152.28 examples/s]
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


: 