In [48]:
from pathlib import Path
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForTokenClassification
)
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader

In [30]:
model_dir = Path("../models/bert-base-german-cased-finetuned-MOPE-L3_Run_2_Epochs_29")

# Load tokenizer and config
tok   = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
cfg   = AutoConfig.from_pretrained(model_dir)

# Inspect to verify the head type and label set
print("Architectures:", cfg.architectures)   # e.g. ['BertForTokenClassification']
print("Labels:", cfg.id2label)               # e.g. {0:'O', 1:'PEOPLE', 2:'ELITE'}
print("Num labels:", cfg.num_labels)

# Load the model **that matches the head in cfg.architectures**
model = AutoModelForTokenClassification.from_pretrained(model_dir, config=cfg)
# If the line above raises an error, switch to AutoModelForSequenceClassification

Architectures: ['BertForTokenClassification']
Labels: {0: 'LABEL_0', 1: 'LABEL_1', 2: 'LABEL_2', 3: 'LABEL_3', 4: 'LABEL_4', 5: 'LABEL_5', 6: 'LABEL_6', 7: 'LABEL_7', 8: 'LABEL_8', 9: 'LABEL_9', 10: 'LABEL_10', 11: 'LABEL_11', 12: 'LABEL_12', 13: 'LABEL_13', 14: 'LABEL_14', 15: 'LABEL_15', 16: 'LABEL_16', 17: 'LABEL_17', 18: 'LABEL_18', 19: 'LABEL_19', 20: 'LABEL_20', 21: 'LABEL_21', 22: 'LABEL_22', 23: 'LABEL_23', 24: 'LABEL_24', 25: 'LABEL_25', 26: 'LABEL_26', 27: 'LABEL_27', 28: 'LABEL_28', 29: 'LABEL_29', 30: 'LABEL_30', 31: 'LABEL_31', 32: 'LABEL_32', 33: 'LABEL_33', 34: 'LABEL_34', 35: 'LABEL_35', 36: 'LABEL_36', 37: 'LABEL_37', 38: 'LABEL_38', 39: 'LABEL_39', 40: 'LABEL_40', 41: 'LABEL_41', 42: 'LABEL_42', 43: 'LABEL_43', 44: 'LABEL_44', 45: 'LABEL_45', 46: 'LABEL_46', 47: 'LABEL_47', 48: 'LABEL_48', 49: 'LABEL_49', 50: 'LABEL_50', 51: 'LABEL_51', 52: 'LABEL_52', 53: 'LABEL_53', 54: 'LABEL_54', 55: 'LABEL_55', 56: 'LABEL_56'}
Num labels: 57


In [31]:
labels = ["[PAD]", "[UNK]", "B-EGPOL", "B-EOFINANZ", "B-EOMEDIA", "B-EOMIL", "B-EOMOV", "B-EONGO", "B-EOPOL", "B-EOREL", "B-EOSCI", "B-EOWIRT", "B-EPFINANZ", "B-EPKULT", "B-EPMEDIA", "B-EPMIL", "B-EPMOV", "B-EPNGO", "B-EPPOL", "B-EPREL", "B-EPSCI", "B-EPWIRT", "B-GPE", "B-PAGE", "B-PETH", "B-PFUNK", "B-PGEN", "B-PNAT", "B-PSOZ", "I-EGPOL", "I-EOFINANZ", "I-EOMEDIA", "I-EOMIL", "I-EOMOV", "I-EONGO", "I-EOPOL", "I-EOREL", "I-EOSCI", "I-EOWIRT", "I-EPFINANZ", "I-EPKULT", "I-EPMEDIA", "I-EPMIL", "I-EPMOV", "I-EPNGO", "I-EPPOL", "I-EPREL", "I-EPSCI", "I-EPWIRT", "I-GPE", "I-PAGE", "I-PETH", "I-PFUNK", "I-PGEN", "I-PNAT", "I-PSOZ", "O"]

label2index, index2label = {}, {}
for i, item in enumerate(labels):
    label2index[item] = i
    index2label[i] = item

print(index2label)

{0: '[PAD]', 1: '[UNK]', 2: 'B-EGPOL', 3: 'B-EOFINANZ', 4: 'B-EOMEDIA', 5: 'B-EOMIL', 6: 'B-EOMOV', 7: 'B-EONGO', 8: 'B-EOPOL', 9: 'B-EOREL', 10: 'B-EOSCI', 11: 'B-EOWIRT', 12: 'B-EPFINANZ', 13: 'B-EPKULT', 14: 'B-EPMEDIA', 15: 'B-EPMIL', 16: 'B-EPMOV', 17: 'B-EPNGO', 18: 'B-EPPOL', 19: 'B-EPREL', 20: 'B-EPSCI', 21: 'B-EPWIRT', 22: 'B-GPE', 23: 'B-PAGE', 24: 'B-PETH', 25: 'B-PFUNK', 26: 'B-PGEN', 27: 'B-PNAT', 28: 'B-PSOZ', 29: 'I-EGPOL', 30: 'I-EOFINANZ', 31: 'I-EOMEDIA', 32: 'I-EOMIL', 33: 'I-EOMOV', 34: 'I-EONGO', 35: 'I-EOPOL', 36: 'I-EOREL', 37: 'I-EOSCI', 38: 'I-EOWIRT', 39: 'I-EPFINANZ', 40: 'I-EPKULT', 41: 'I-EPMEDIA', 42: 'I-EPMIL', 43: 'I-EPMOV', 44: 'I-EPNGO', 45: 'I-EPPOL', 46: 'I-EPREL', 47: 'I-EPSCI', 48: 'I-EPWIRT', 49: 'I-GPE', 50: 'I-PAGE', 51: 'I-PETH', 52: 'I-PFUNK', 53: 'I-PGEN', 54: 'I-PNAT', 55: 'I-PSOZ', 56: 'O'}


In [None]:
text = ["Herr", "Präsident", "!", "Liebe", "Kolleginnen", "und", "Kollegen", "!", "Bereits", "im", "Koalitionsvertrag", "haben", "Union", "und", "SPD", "festgehalten", ",", "dass", "wir", "den", "Rechtsstaat", "stärken", "möchten", ",", "indem", "wir", "den", "Strafprozess", "modernisieren", "und", "die", "Strafverfahren", "beschleunigen", ".", "Und", "der", "Staatssekretär", "hat", "darauf", "hingewiesen", ":", "Im", "Grunde", "genommen", "ist", "das", ",", "was", "wir", "heute", "einleiten", ",", "nur", "eine", "weitere", "Säule", "des", "Pakts", "für", "den", "Rechtsstaat", ";", "denn", "wir", "setzen", "eben", "auf", "viele", "verschiedene", "Instrumente", "."]
encoded = tok(text,
              truncation=True,
              padding='max_length',
              max_length=150,
              is_split_into_words=True,
              return_tensors="pt")

with torch.inference_mode():
    output = model(**encoded)

# Token-level classification
if hasattr(output, "logits"):
    logits = output.logits                    # [batch, seq_len, num_labels]
    preds  = torch.argmax(logits, dim=-1)     # take the best label for each token
    labels = [index2label.get(int(i)) for i in preds[0]]
    tokens = tok.convert_ids_to_tokens(encoded["input_ids"][0])
    print(list(zip(tokens, labels)))

In [42]:
examples = {"words": [["Herr", "Präsident", "!", "Liebe", "Kolleginnen", "und", "Kollegen", "!", "Bereits", "im", "Koalitionsvertrag", "haben", "Union", "und", "SPD", "festgehalten", ",", "dass", "wir", "den", "Rechtsstaat", "stärken", "möchten", ",", "indem", "wir", "den", "Strafprozess", "modernisieren", "und", "die", "Strafverfahren", "beschleunigen", ".", "Und", "der", "Staatssekretär", "hat", "darauf", "hingewiesen", ":", "Im", "Grunde", "genommen", "ist", "das", ",", "was", "wir", "heute", "einleiten", ",", "nur", "eine", "weitere", "Säule", "des", "Pakts", "für", "den", "Rechtsstaat", ";", "denn", "wir", "setzen", "eben", "auf", "viele", "verschiedene", "Instrumente", "."]],
"tags": [["B-EPPOL", "I-EPPOL", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "B-EOPOL", "I-EOPOL", "I-EOPOL", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "B-EPPOL", "I-EPPOL", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O"]]}
test_data = load_dataset('json', data_files={'test':'test.json'})

def tokenize_and_align_labels(examples):
    tokenized_inputs = bert_tokenizer(examples["words"],
                                      truncation=True,
                                      padding='max_length',
                                      max_length=150,
                                      is_split_into_words=True)
    labels = []

    for idx, label in enumerate(examples["tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=idx)
        previous_word_idx = None
        label_ids = [];
        for word_idx in word_ids:
            if word_idx is None or word_idx == previous_word_idx: # Since the tokenizer may split words into subwords, we need to handle this case, so we skip the label for subwords
                label_ids.append(-100) # -100 will be ignored by the loss function
            else:
                label_ids.append(label2index[label[word_idx]])
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

def encode_dataset(corpus):
    """
    This function tokenizes the input dataset and aligns the labels with the tokens. It wil return a DatasetDict with the tokenized inputs and labels.

    Args:
        corpus (DatasetDict): The input dataset to be tokenized and aligned.

    Returns:
        DatasetDict: A dictionary containing the tokenized inputs and aligned labels.
    """
    return corpus.map(tokenize_and_align_labels, batched=True, remove_columns=['words', 'tags'])

    # ── corpus fields explained ────────────────────────────────────────────────────
    # input_ids      : word-piece token IDs incl. special tokens; 0-padding to max_len
    # token_type_ids : segment markers (all 0 here because we pass only one sentence)
    # attention_mask : 1 = real token, 0 = pad ⇒ tells BERT which positions to ignore
    # labels         : gold tag ID per token; -100 on [CLS]/[SEP]/padding & sub-tokens
    #                 (-100 is PyTorch’s ignore_index, so loss is computed only where
    #                  the label ≥ 0)


data_encoded = encode_dataset(test_data)

# Load data
test_dataset = data_encoded['test'].with_format("torch")
# Create a DataLoader for the test dataset
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [47]:
print(result['test'][0])

{'input_ids': [101, 38000, 24571, 106, 25808, 30186, 11270, 18823, 11216, 10130, 90958, 106, 26345, 10211, 30186, 13133, 15024, 58831, 13289, 11457, 10130, 23327, 34519, 64036, 117, 11064, 33963, 10140, 79037, 58922, 28780, 81609, 10115, 93487, 10115, 117, 35417, 33963, 10140, 10838, 29552, 73099, 19579, 10107, 84984, 59719, 10136, 10130, 10128, 10838, 29552, 64234, 10347, 12044, 101304, 28713, 10136, 119, 41523, 10118, 28435, 12818, 90877, 11250, 20345, 19911, 80597, 131, 10796, 23191, 10112, 38023, 10298, 10242, 117, 10134, 33963, 13025, 10290, 36777, 10681, 117, 11354, 10359, 15133, 156, 91982, 10284, 10139, 48465, 10806, 10307, 10140, 79037, 58922, 132, 20882, 33963, 85635, 173, 10965, 10329, 18602, 22668, 62988, 10112, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,