In [None]:
import string
from datasets import Dataset
from pathlib import Path

label2id = {
    '0': 0, '[name]': 1, '[name_1]': 1, '[name_2]': 1, '[scientist-name]': 1,
    '[surname]': 2, '[surname_1]': 2, '[surname_2]': 2, '[age]': 3,
    '[date-of-birth]': 4, '[date]': 5, '[sex]': 6, '[religion]': 7,
    '[political-view]': 8, '[#political-view]': 8, '[ethnicity]': 9,
    '[sexual-orientation]': 10, '[health]': 11, '[#health]': 11,
    '[weight-loss/muscle-gain]': 11, '[vegetarian/vegan]': 11,
    '[active/sedentary]': 11, '[relative]': 12, '[city]': 13, '[#city]': 13,
    '[address]': 14, '[street]': 14, '[email]': 15, '[phone]': 16,
    '[pesel]': 17, '[document-number]': 18, '[ID number]': 18, '[vin]': 18,
    '[company]': 19, '[school-name]': 20, '[job-title]': 21,
    '[healthcare-professional]': 21, '[bank-account]': 22,
    '[credit-card-number]': 23, '[username]': 24, '[social-media-username]': 24,
    '[secret]': 25
}

id2label = {
    0: '0',
    1: '[name]',
    2: '[surname]',
    3: '[age]',
    4: '[date-of-birth]',
    5: '[date]',
    6: '[sex]',
    7: '[religion]',
    8: '[political-view]',
    9: '[ethnicity]',
    10: '[sexual-orientation]',
    11: '[health]',
    12: '[relative]',
    13: '[city]',
    14: '[address]',
    15: '[email]',
    16: '[phone]',
    17: '[pesel]',
    18: '[document-number]',
    19: '[company]',
    20: '[school-name]',
    21: '[job-title]',
    22: '[bank-account]',
    23: '[credit-card-number]',
    24: '[username]',
    25: '[secret]'
}

RAW_FILE = Path("anonymized.txt")
ANNOTATED_FILE = Path("orig.txt")

def clean_token(token: str) -> str:
    return token.strip(string. punctuation)

def parse_alignment(raw_str: str, anno_str: str) -> tuple[list[str], list[int]]:
    raw_tokens = raw_str.split()
    anno_tokens = anno_str.split()
    tags = []

    r_idx, a_idx = 0, 0
    OUTSIDE_ID = label2id['0']

    while r_idx < len(raw_tokens):
        if a_idx >= len(anno_tokens):
            last_anno = clean_token(anno_tokens[-1]) if anno_tokens else None
            tags.append(label2id. get(last_anno, OUTSIDE_ID) if last_anno else OUTSIDE_ID)
            r_idx += 1
            continue

        raw_word = raw_tokens[r_idx]
        anno_word = anno_tokens[a_idx]
        cleaned_anno = clean_token(anno_word)
        is_tag = cleaned_anno in label2id and cleaned_anno != '0'

        if raw_word == anno_word and not is_tag:
            tags.append(OUTSIDE_ID)
            r_idx += 1
            a_idx += 1
        elif is_tag:
            tag_id = label2id[cleaned_anno]
            next_anchor = anno_tokens[a_idx + 1] if a_idx + 1 < len(anno_tokens) else None
            next_is_tag = next_anchor is not None and clean_token(next_anchor) in label2id

            if next_anchor and not next_is_tag:
                start_r = r_idx
                start_tags_len = len(tags)
                found_anchor = False

                while r_idx < len(raw_tokens):
                    if clean_token(raw_tokens[r_idx]) == clean_token(next_anchor):
                        found_anchor = True
                        break
                    tags.append(tag_id)
                    r_idx += 1

                if not found_anchor:
                    tags = tags[:start_tags_len]
                    tags.append(tag_id)
                    r_idx = start_r + 1

                a_idx += 1
            else:
                tags.append(tag_id)
                r_idx += 1
                a_idx += 1
        else:
            tags.append(OUTSIDE_ID)
            r_idx += 1
            a_idx += 1

    return raw_tokens, tags

def load_dataset(raw_path: Path, anno_path: Path) -> Dataset:
    with open(raw_path, encoding="utf-8") as f:
        raw_lines = f.read().splitlines()
    with open(anno_path, encoding="utf-8") as f:
        annotated_lines = f.read().splitlines()

    if len(raw_lines) != len(annotated_lines):
        print(f"WARNING: Line count mismatch - raw: {len(raw_lines)}, annotated: {len(annotated_lines)}")

    formatted_data = []
    mismatches = 0

    for i, (raw, anno) in enumerate(zip(raw_lines, annotated_lines)):
        if not raw.strip() or not anno.strip():
            continue

        tokens, ner_tags = parse_alignment(raw, anno)

        if len(tokens) != len(ner_tags):
            mismatches += 1
            m = min(len(tokens), len(ner_tags))
            tokens, ner_tags = tokens[:m], ner_tags[:m]

        formatted_data.append({"tokens": tokens, "ner_tags": ner_tags})

    if mismatches:
        print(f"WARNING: {mismatches} rows had length mismatches and were truncated")

    return Dataset.from_list(formatted_data)

dataset = load_dataset(RAW_FILE, ANNOTATED_FILE)
print(f"Dataset created with {len(dataset)} rows.")


In [None]:
import numpy as np
import evaluate
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification
)
import torch
import torch.nn as nn

class WeightedLossTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._loss_fct = None
        self._num_labels = None

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        num_labels = logits.shape[-1]

        if self._loss_fct is None or self._num_labels != num_labels:
            weight_tensor = torch.tensor(
                [1.0] + [5.0] * (num_labels - 1),
                device=model.device,
                dtype=torch.float32
            )
            self._loss_fct = nn.CrossEntropyLoss(weight=weight_tensor)
            self._num_labels = num_labels

        if labels is not None:
            active_loss = inputs["attention_mask"].view(-1) == 1
            active_logits = logits. view(-1, num_labels)
            active_labels = torch.where(
                active_loss,
                labels.view(-1),
                torch.tensor(self._loss_fct.ignore_index). type_as(labels)
            )
            loss = self._loss_fct(active_logits, active_labels)
        else:
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss

MODEL_CHECKPOINT = "allegro/herbert-large-cased"
LEARNING_RATE = 2e-5
EPOCHS = 1

tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)


def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"],
        truncation=True,
        is_split_into_words=True
    )

    labels = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs


tokenized_datasets = dataset.map(tokenize_and_align_labels, batched=True)

train_test_split = tokenized_datasets.train_test_split(test_size=0.1, seed=42)
train_dataset = train_test_split["train"]
eval_dataset = train_test_split["test"]

seqeval = evaluate.load("seqeval")


def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [id2label[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [id2label[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }


model = AutoModelForTokenClassification.from_pretrained(
    MODEL_CHECKPOINT,
    num_labels=len(id2label),
    id2label=id2label,
    label2id=label2id
)

data_collator = DataCollatorForTokenClassification(tokenizer)

args = TrainingArguments(
    output_dir="nerbert",
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    fp16=True,
    num_train_epochs=EPOCHS,
    weight_decay=0.01,
    logging_steps=10,
    load_best_model_at_end=True,
    eval_strategy="epoch",
    save_strategy="epoch",
    metric_for_best_model="f1"
)

trainer = WeightedLossTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    processing_class=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)
print("Starting training")
trainer.train()

trainer.save_model("nerbert")
tokenizer.save_pretrained("nerbert")
print("Model saved to 'nerbert'")