# 04_Model_Comparison_Eval.ipynb
### Purpose: Evaluate and compare BERT vs. TAPT-BERT on the FiNER-139 test set
<hr style="height:3px; width:100%; background-color:black; border:none; margin:auto;" />

## 1. Setup & Imports

In [4]:
import numpy as np
from datasets import Dataset, DatasetDict
from transformers import BertTokenizerFast, BertForTokenClassification
from transformers import DataCollatorForTokenClassification, Trainer, TrainingArguments
import evaluate
import matplotlib.pyplot as plt
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


<hr style="height:3px; width:100%; background-color:black; border:none; margin:auto;" />

## 2. Load the FiNER-139 Test Set

In [5]:
## 3. Load Labeled FiNER-139 Dataset and Create DatasetDict
train_path = "./pipeline/data/finer-train.jsonl" 
val_path = "./pipeline/data/finer-validation.jsonl"
test_path = "./pipeline/data/finer-test.jsonl"

# Expected format: {"tokens": [...], "ner_tags_str_mapped": [...]}
train_data = Dataset.from_json(train_path)
val_data = Dataset.from_json(val_path)
test_data = Dataset.from_json(test_path)

dataset = DatasetDict({
    "train": train_data,
    "validation": val_data,
    "test": test_data
})
dataset

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'ner_tags_str_mapped'],
        num_rows: 900384
    })
    validation: Dataset({
        features: ['id', 'tokens', 'ner_tags_str_mapped'],
        num_rows: 112494
    })
    test: Dataset({
        features: ['id', 'tokens', 'ner_tags_str_mapped'],
        num_rows: 108378
    })
})

<hr style="height:3px; width:100%; background-color:black; border:none; margin:auto;" />

## 3. Label Mapping Utilities

In [6]:
# Extract all labels
all_labels = set()
for split in ["train", "validation", "test"]:
    for ex in dataset[split]["ner_tags_str_mapped"]:
        all_labels.update(ex)

label_list = sorted(all_labels)
label_to_id = {label: i for i, label in enumerate(label_list)}
id_to_label = {i: label for label, i in label_to_id.items()}
num_labels = len(label_list)
print("num_labels: ", num_labels)
label_list


num_labels:  19


['B-Acquisition',
 'B-Assets',
 'B-Compensation',
 'B-Contingency',
 'B-Debt',
 'B-Equity',
 'B-Lease',
 'B-Other',
 'B-Revenue',
 'B-Tax',
 'I-Acquisition',
 'I-Assets',
 'I-Compensation',
 'I-Contingency',
 'I-Debt',
 'I-Equity',
 'I-Lease',
 'I-Other',
 'O']

In [7]:
## 3. Mapping format
print("Label → ID mapping:")
print(label_to_id)

print("\nID → Label mapping:")
print(id_to_label)

print(f"\nAnzahl Labels: {len(label_list)}")

Label → ID mapping:
{'B-Acquisition': 0, 'B-Assets': 1, 'B-Compensation': 2, 'B-Contingency': 3, 'B-Debt': 4, 'B-Equity': 5, 'B-Lease': 6, 'B-Other': 7, 'B-Revenue': 8, 'B-Tax': 9, 'I-Acquisition': 10, 'I-Assets': 11, 'I-Compensation': 12, 'I-Contingency': 13, 'I-Debt': 14, 'I-Equity': 15, 'I-Lease': 16, 'I-Other': 17, 'O': 18}

ID → Label mapping:
{0: 'B-Acquisition', 1: 'B-Assets', 2: 'B-Compensation', 3: 'B-Contingency', 4: 'B-Debt', 5: 'B-Equity', 6: 'B-Lease', 7: 'B-Other', 8: 'B-Revenue', 9: 'B-Tax', 10: 'I-Acquisition', 11: 'I-Assets', 12: 'I-Compensation', 13: 'I-Contingency', 14: 'I-Debt', 15: 'I-Equity', 16: 'I-Lease', 17: 'I-Other', 18: 'O'}

Anzahl Labels: 19


<hr style="height:3px; width:100%; background-color:black; border:none; margin:auto;" />

## 4. Tokenization with Label Alignment

In [8]:
#Copied from production fine-tuning-script
tokenizer = BertTokenizerFast.from_pretrained("./pipeline/bert-tapt")  # shared tokenizer

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"],
        truncation=True,
        is_split_into_words=True,
        padding="max_length",
        max_length=512
    )
    
    labels = []
    for i, label in enumerate(examples["ner_tags_str_mapped"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        label_ids = []
        previous_word_idx = None
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label_to_id[label[word_idx]])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs
    
tokenized_test = (dataset["test"].map(tokenize_and_align_labels, batched=True, num_proc=5))


Map (num_proc=5):   0%|          | 0/108378 [00:00<?, ? examples/s]

<hr style="height:3px; width:100%; background-color:black; border:none; margin:auto;" />

## 5. Evaluation Function

In [17]:
import numpy as np


metric = evaluate.load("seqeval")

label_to_id = {v: k for k, v in id_to_label.items()}
O_ID = label_to_id.get("O", None)

def compute_metrics(p):
    logits, labels = p
    preds = np.argmax(logits, axis=-1)

    # IDs -> tags, mask out -100 (subword labels)
    true_labels = [[id_to_label[l] for l in lab if l != -100] for lab in labels]
    true_preds  = [[id_to_label[p] for p, l in zip(pr, lab) if l != -100]
                   for pr, lab in zip(preds, labels)]

    # Span-level (entity-level) scores via seqeval (micro)
    res = metric.compute(predictions=true_preds, references=true_labels)

    out = {
        "precision": res["overall_precision"],
        "recall":    res["overall_recall"],
        "f1":        res["overall_f1"],
        "accuracy":  res["overall_accuracy"],
    }

    # Per-entity type (flatten so Trainer/logger like the keys)
    per_type = {k: v for k, v in res.items() if not k.startswith("overall_")}
    for ent, m in per_type.items():
        out[f"type-{ent}_precision"] = m["precision"]
        out[f"type-{ent}_recall"]    = m["recall"]
        out[f"type-{ent}_f1"]        = m["f1"]
        out[f"type-{ent}_support"]   = m["number"]  # Anzahl wahrer Entitäten

    # Macro and weighted-macro across entity types
    if per_type:
        supports = np.array([m["number"] for m in per_type.values()], dtype=float)
        weights  = supports / supports.sum() if supports.sum() > 0 else np.ones_like(supports)/len(supports)

        precs = np.array([m["precision"] for m in per_type.values()], dtype=float)
        recs  = np.array([m["recall"]    for m in per_type.values()], dtype=float)
        f1s   = np.array([m["f1"]        for m in per_type.values()], dtype=float)

        out["precision_macro"]  = float(np.nanmean(precs))
        out["recall_macro"]     = float(np.nanmean(recs))
        out["f1_macro"]         = float(np.nanmean(f1s))

        out["precision_weighted"] = float(np.nansum(precs * weights))
        out["recall_weighted"]    = float(np.nansum(recs  * weights))
        out["f1_weighted"]        = float(np.nansum(f1s   * weights))

    # Token accuracy (including and excluding 'O'); still mask -100
    mask = (labels != -100)
    correct = (preds == labels) & mask
    token_acc = correct.sum() / mask.sum()
    out["token_acc"] = float(token_acc)

    if O_ID is not None:
        mask_wo_O = mask & (labels != O_ID)
        if mask_wo_O.sum() > 0:
            correct_wo_O = (preds == labels) & mask_wo_O
            out["token_acc_wo_O"] = float(correct_wo_O.sum() / mask_wo_O.sum())
        else:
            out["token_acc_wo_O"] = float("nan")

    return out


<hr style="height:3px; width:100%; background-color:black; border:none; margin:auto;" />

## 6. Data Collator

In [10]:
data_collator = DataCollatorForTokenClassification(tokenizer)

# 7. Evaluate Both Models

In [11]:
eval_args = TrainingArguments(
    per_device_eval_batch_size=64,
    dataloader_num_workers=24,
    dataloader_pin_memory=True,
    bf16=True,
    report_to="none",
    seed=42,
    data_seed=42,                           
)

# 🔵 Evaluate bert-finetuned

In [12]:
model_base = BertForTokenClassification.from_pretrained("./pipeline/bert-base-finetuned").to(device)

trainer_base = Trainer(
    model=model_base,
    tokenizer=tokenizer,
    data_collator=data_collator,
    args=eval_args
)

output_base = trainer_base.predict(tokenized_test.remove_columns(["tokens", "ner_tags_str_mapped"]))
f1_base, metrics_base = compute_metrics((output_base.predictions, output_base.label_ids))

print(f"📊 Base Model F1: {f1_base:.4f}")

  trainer_base = Trainer(


📊 Base Model F1: 0.8502


# 🟢 Evaluate bert-tapt-finetuned

In [13]:
model_tapt = BertForTokenClassification.from_pretrained("./pipeline/bert-tapt-finetuned").to(device)

trainer_tapt = Trainer(
    model=model_tapt,
    tokenizer=tokenizer,
    data_collator=data_collator,
    args=eval_args
)

output_tapt = trainer_tapt.predict(tokenized_test.remove_columns(["tokens", "ner_tags_str_mapped"]))
f1_tapt, metrics_tapt = compute_metrics((output_tapt.predictions, output_tapt.label_ids))

print(f"📈 TAPT Model F1: {f1_tapt:.4f}")

  trainer_tapt = Trainer(


📈 TAPT Model F1: 0.8596


<hr style="height:3px; width:100%; background-color:black; border:none; margin:auto;" />

# 8. Analysing Metrics

In [21]:
import pandas as pd
from IPython.display import display

# 1) Compute metrics from finished prediction outputs
metrics_base = compute_metrics((output_base.predictions, output_base.label_ids))
metrics_tapt  = compute_metrics((output_tapt.predictions,  output_tapt.label_ids))

results = {"Base": metrics_base, "TAPT": metrics_tapt}

# 2) Build summary table (overall/macro/weighted/token metrics; excludes per-type keys)
def make_summary_df(results):
    rows = []
    for name, m in results.items():
        row = {k: v for k, v in m.items() if not k.startswith("type-")}
        row["model"] = name
        rows.append(row)
    df = pd.DataFrame(rows).set_index("model")

    wanted = [
        "precision","recall","f1","accuracy",
        "token_acc","token_acc_wo_O",
        "precision_macro","recall_macro","f1_macro",
        "precision_weighted","recall_weighted","f1_weighted",
    ]
    cols = [c for c in wanted if c in df.columns] + [c for c in df.columns if c not in wanted]
    return df[cols]

summary_df = make_summary_df(results)

# 3) Build per-entity table (precision/recall/F1/support per entity type)
def make_per_type_df(results):
    rows = []
    for name, m in results.items():
        ents = sorted({k.split("_", 1)[0].replace("type-","") for k in m if k.startswith("type-")})
        for ent in ents:
            rows.append({
                "model": name,
                "entity": ent,
                "precision": m.get(f"type-{ent}_precision", float("nan")),
                "recall":    m.get(f"type-{ent}_recall",    float("nan")),
                "f1":        m.get(f"type-{ent}_f1",        float("nan")),
                "support":   m.get(f"type-{ent}_support",   float("nan")),
            })
    return pd.DataFrame(rows).sort_values(["entity","model"]).reset_index(drop=True)

per_type_df = make_per_type_df(results)

# 4) Display (rounded)
display(summary_df.round(4).rename_axis("model"))
display(per_type_df.round(4))

# Optional: CSV export
# summary_df.round(6).to_csv("ner_metrics_summary.csv")
# per_type_df.round(6).to_csv("ner_metrics_per_type.csv")


Unnamed: 0_level_0,precision,recall,f1,accuracy,token_acc,token_acc_wo_O,precision_macro,recall_macro,f1_macro,precision_weighted,recall_weighted,f1_weighted
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
Base,0.8116,0.8928,0.8502,0.9981,0.9981,0.887,0.8142,0.8853,0.8482,0.8117,0.8928,0.8502
TAPT,0.8223,0.9004,0.8596,0.9982,0.9982,0.895,0.8262,0.8931,0.8582,0.8226,0.9004,0.8596


Unnamed: 0,model,entity,precision,recall,f1,support
0,Base,Acquisition,0.7821,0.8531,0.816,1851
1,TAPT,Acquisition,0.7731,0.8504,0.8099,1851
2,Base,Assets,0.8088,0.9008,0.8523,2691
3,TAPT,Assets,0.8216,0.909,0.8631,2691
4,Base,Compensation,0.8419,0.9237,0.8809,4405
5,TAPT,Compensation,0.8508,0.933,0.89,4405
6,Base,Contingency,0.8122,0.8577,0.8343,1729
7,TAPT,Contingency,0.8306,0.8647,0.8473,1729
8,Base,Debt,0.7988,0.9017,0.8471,10394
9,TAPT,Debt,0.8095,0.9085,0.8562,10394
