In [1]:
from pathlib import Path
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForTokenClassification
)
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [21]:
model_dir = Path("../models/mBERT-finetuned-TRI-L3_Run_2_Epochs_5_45")

# Load tokenizer and config
tok   = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
cfg   = AutoConfig.from_pretrained(model_dir)

# Inspect to verify the head type and label set
print("Architectures:", cfg.architectures)   # e.g. ['BertForTokenClassification']
print("Labels:", cfg.id2label)               # e.g. {0:'O', 1:'PEOPLE', 2:'ELITE'}
print("Num labels:", cfg.num_labels)

# Load the model **that matches the head in cfg.architectures**
model = AutoModelForTokenClassification.from_pretrained(model_dir, config=cfg)
# If the line above raises an error, switch to AutoModelForSequenceClassification

Architectures: ['BertForTokenClassification']
Labels: {0: 'LABEL_0', 1: 'LABEL_1', 2: 'LABEL_2', 3: 'LABEL_3', 4: 'LABEL_4', 5: 'LABEL_5', 6: 'LABEL_6', 7: 'LABEL_7', 8: 'LABEL_8', 9: 'LABEL_9', 10: 'LABEL_10', 11: 'LABEL_11', 12: 'LABEL_12', 13: 'LABEL_13', 14: 'LABEL_14', 15: 'LABEL_15', 16: 'LABEL_16', 17: 'LABEL_17', 18: 'LABEL_18', 19: 'LABEL_19', 20: 'LABEL_20', 21: 'LABEL_21', 22: 'LABEL_22', 23: 'LABEL_23', 24: 'LABEL_24', 25: 'LABEL_25', 26: 'LABEL_26', 27: 'LABEL_27', 28: 'LABEL_28', 29: 'LABEL_29', 30: 'LABEL_30', 31: 'LABEL_31', 32: 'LABEL_32', 33: 'LABEL_33', 34: 'LABEL_34', 35: 'LABEL_35', 36: 'LABEL_36', 37: 'LABEL_37', 38: 'LABEL_38', 39: 'LABEL_39', 40: 'LABEL_40', 41: 'LABEL_41', 42: 'LABEL_42', 43: 'LABEL_43', 44: 'LABEL_44', 45: 'LABEL_45', 46: 'LABEL_46', 47: 'LABEL_47', 48: 'LABEL_48', 49: 'LABEL_49', 50: 'LABEL_50', 51: 'LABEL_51', 52: 'LABEL_52', 53: 'LABEL_53', 54: 'LABEL_54', 55: 'LABEL_55', 56: 'LABEL_56'}
Num labels: 57


In [6]:
labels = ["[PAD]", "[UNK]", "B-EGPOL", "B-EOFINANZ", "B-EOMEDIA", "B-EOMIL", "B-EOMOV", "B-EONGO", "B-EOPOL", "B-EOREL", "B-EOSCI", "B-EOWIRT", "B-EPFINANZ", "B-EPKULT", "B-EPMEDIA", "B-EPMIL", "B-EPMOV", "B-EPNGO", "B-EPPOL", "B-EPREL", "B-EPSCI", "B-EPWIRT", "B-GPE", "B-PAGE", "B-PETH", "B-PFUNK", "B-PGEN", "B-PNAT", "B-PSOZ", "I-EGPOL", "I-EOFINANZ", "I-EOMEDIA", "I-EOMIL", "I-EOMOV", "I-EONGO", "I-EOPOL", "I-EOREL", "I-EOSCI", "I-EOWIRT", "I-EPFINANZ", "I-EPKULT", "I-EPMEDIA", "I-EPMIL", "I-EPMOV", "I-EPNGO", "I-EPPOL", "I-EPREL", "I-EPSCI", "I-EPWIRT", "I-GPE", "I-PAGE", "I-PETH", "I-PFUNK", "I-PGEN", "I-PNAT", "I-PSOZ", "O"]

label2index, index2label = {}, {}
for i, item in enumerate(labels):
    label2index[item] = i
    index2label[i] = item

print(index2label)

{0: '[PAD]', 1: '[UNK]', 2: 'B-EGPOL', 3: 'B-EOFINANZ', 4: 'B-EOMEDIA', 5: 'B-EOMIL', 6: 'B-EOMOV', 7: 'B-EONGO', 8: 'B-EOPOL', 9: 'B-EOREL', 10: 'B-EOSCI', 11: 'B-EOWIRT', 12: 'B-EPFINANZ', 13: 'B-EPKULT', 14: 'B-EPMEDIA', 15: 'B-EPMIL', 16: 'B-EPMOV', 17: 'B-EPNGO', 18: 'B-EPPOL', 19: 'B-EPREL', 20: 'B-EPSCI', 21: 'B-EPWIRT', 22: 'B-GPE', 23: 'B-PAGE', 24: 'B-PETH', 25: 'B-PFUNK', 26: 'B-PGEN', 27: 'B-PNAT', 28: 'B-PSOZ', 29: 'I-EGPOL', 30: 'I-EOFINANZ', 31: 'I-EOMEDIA', 32: 'I-EOMIL', 33: 'I-EOMOV', 34: 'I-EONGO', 35: 'I-EOPOL', 36: 'I-EOREL', 37: 'I-EOSCI', 38: 'I-EOWIRT', 39: 'I-EPFINANZ', 40: 'I-EPKULT', 41: 'I-EPMEDIA', 42: 'I-EPMIL', 43: 'I-EPMOV', 44: 'I-EPNGO', 45: 'I-EPPOL', 46: 'I-EPREL', 47: 'I-EPSCI', 48: 'I-EPWIRT', 49: 'I-GPE', 50: 'I-PAGE', 51: 'I-PETH', 52: 'I-PFUNK', 53: 'I-PGEN', 54: 'I-PNAT', 55: 'I-PSOZ', 56: 'O'}


In [26]:
text = ["Herr", "Präsident", "!", "Liebe", "Kolleginnen", "und", "Kollegen", "!", "Bereits", "im", "Koalitionsvertrag", "haben", "Union", "und", "SPD", "festgehalten", ",", "dass", "wir", "den", "Rechtsstaat", "stärken", "möchten", ",", "indem", "wir", "den", "Strafprozess", "modernisieren", "und", "die", "Strafverfahren", "beschleunigen", ".", "Und", "der", "Staatssekretär", "hat", "darauf", "hingewiesen", ":", "Im", "Grunde", "genommen", "ist", "das", ",", "was", "wir", "heute", "einleiten", ",", "nur", "eine", "weitere", "Säule", "des", "Pakts", "für", "den", "Rechtsstaat", ";", "denn", "wir", "setzen", "eben", "auf", "viele", "verschiedene", "Instrumente", "."]
encoded = tok(text,
              truncation=True,
              padding='max_length',
              max_length=150,
              is_split_into_words=True,
              return_tensors="pt")

with torch.inference_mode():
    output = model(**encoded)

# Token-level classification
if hasattr(output, "logits"):
    logits = output.logits                    # [batch, seq_len, num_labels]
    preds  = torch.argmax(logits, dim=-1)     # take the best label for each token
    labels = [index2label.get(int(i)) for i in preds[0]]
    tokens = tok.convert_ids_to_tokens(encoded["input_ids"][0])
    print(list(zip(tokens, labels)))

[('[CLS]', 'O'), ('Herr', 'B-EPPOL'), ('Präsident', 'I-EPPOL'), ('!', 'O'), ('Liebe', 'O'), ('Ko', 'O'), ('##lle', 'O'), ('##gin', 'O'), ('##nen', 'O'), ('und', 'O'), ('Kollegen', 'O'), ('!', 'O'), ('Bereits', 'O'), ('im', 'O'), ('Ko', 'O'), ('##ali', 'O'), ('##tions', 'O'), ('##vertrag', 'O'), ('haben', 'O'), ('Union', 'B-EOPOL'), ('und', 'I-EOPOL'), ('SPD', 'I-EOPOL'), ('fest', 'O'), ('##gehalten', 'O'), (',', 'O'), ('dass', 'O'), ('wir', 'O'), ('den', 'O'), ('Rechts', 'O'), ('##staat', 'O'), ('st', 'O'), ('##ärke', 'O'), ('##n', 'O'), ('möchte', 'O'), ('##n', 'O'), (',', 'O'), ('indem', 'O'), ('wir', 'O'), ('den', 'O'), ('St', 'O'), ('##raf', 'O'), ('##pro', 'O'), ('##zes', 'O'), ('##s', 'O'), ('moderni', 'O'), ('##sier', 'O'), ('##en', 'O'), ('und', 'O'), ('die', 'O'), ('St', 'O'), ('##raf', 'O'), ('##verfahren', 'O'), ('be', 'O'), ('##sch', 'O'), ('##leu', 'O'), ('##nig', 'O'), ('##en', 'O'), ('.', 'O'), ('Und', 'O'), ('der', 'B-EPPOL'), ('Staat', 'I-EPPOL'), ('##sse', 'I-EPPOL'),