In [1]:
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline

In [2]:

label_list=["O", "B-HUSTYP", "B-KONSTRUKTIONSDETALJ", "B-LST_DNR", "B-SR_SYSTEM", "B-SR_KOORDINATER", "B-INTRASIS", "I-HUSTYP", "I-KONSTRUKTIONSDETALJ", "I-LST_DNR", "I-SR_SYSTEM", "I-SR_KOORDINATER", "I-INTRASIS"]
main_path = "./kbtraining/checkpoint-15000/"
# Load pre-trained model and tokenizer
model_path = main_path + "model.safetensors"
tokenizer_path = main_path + "tokenizer.json"
print("model path: ", model_path)


model = AutoModelForTokenClassification.from_pretrained(main_path)
tokenizer = AutoTokenizer.from_pretrained(main_path)


model path:  ./kbtraining/checkpoint-15000/model.safetensors


In [3]:
# Perform inference or testing

# Example input text
input_text = "Utredningen har utförts enligt beslut av Länsstyrelsen i Västra Götalands\nlän (dnr 220-39195-99) och har bekostats av Alvereds golf. Länsstyrelsens dnr: 220-39195-99. Koordinater för undersökningsytans sydvästra hörn:\nx 6395,00  y 1272,25y."


# Tokenize input text
tokens = tokenizer.tokenize(input_text)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
input_ids = tokenizer.build_inputs_with_special_tokens(input_ids)


In [4]:
# Perform inference
with torch.no_grad():
    outputs = model(torch.tensor([input_ids]))

# Get predicted labels
predicted_labels = torch.argmax(outputs.logits, dim=2).squeeze()

# Decode labels
decoded_labels = [tokenizer.decode([label]) for label in predicted_labels]
res_labels = [label_list[label] for label in predicted_labels]

# Display results
print("Input Text:", input_text)
print("Tokens:", tokens)
print("Result", res_labels)

Input Text: Utredningen har utförts enligt beslut av Länsstyrelsen i Västra Götalands
län (dnr 220-39195-99) och har bekostats av Alvereds golf. Länsstyrelsens dnr: 220-39195-99. Koordinater för undersökningsytans sydvästra hörn:
x 6395,00  y 1272,25y.
Tokens: ['Utredningen', 'har', 'utförts', 'enligt', 'beslut', 'av', 'Länsstyrelsen', 'i', 'Västra', 'Götalands', 'län', '(', 'dn', '##r', '220', '-', '39', '##195', '-', '99', ')', 'och', 'har', 'bekosta', '##ts', 'av', 'Alv', '##ere', '##ds', 'golf', '.', 'Länsstyrelsens', 'dn', '##r', ':', '220', '-', '39', '##195', '-', '99', '.', 'Ko', '##ordin', '##ater', 'för', 'undersöknings', '##ytan', '##s', 'sydvästra', 'hörn', ':', 'x', '63', '##95', ',', '00', 'y', '127', '##2', ',', '25', '##y', '.']
Result ['O', 'O', 'O', 'O', 'O', 'O', 'B-LST_DNR', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O',

In [5]:
nlp =  pipeline('ner', model=main_path, tokenizer=main_path)
res = nlp(input_text)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [6]:
label_mapping = {
    f'LABEL_{i+1}': label_list[i] for i in range(len(label_list))
}
for item in res:
    if 'entity' in item and item['entity'] in label_mapping:
        item['entity'] = label_mapping[item['entity']]

In [7]:
# shows only realted labels 
[obj for obj in res if obj['entity'] != 'LABEL_0']

[{'entity': 'B-KONSTRUKTIONSDETALJ',
  'score': 0.55468464,
  'index': 6,
  'word': 'av',
  'start': 38,
  'end': 40},
 {'entity': 'B-SR_SYSTEM',
  'score': 0.7604707,
  'index': 47,
  'word': 'undersöknings',
  'start': 183,
  'end': 196},
 {'entity': 'I-SR_SYSTEM',
  'score': 0.9219862,
  'index': 48,
  'word': '##ytan',
  'start': 196,
  'end': 200}]

In [9]:
outputs

TokenClassifierOutput(loss=None, logits=tensor([[[  7.8506,  -3.7024,  -2.5528,  -5.6608,  -5.2152,  -6.0950, -10.3023,
           -5.8574,  -3.8255,  -6.7282,  -6.8973,  -7.0194, -10.4475],
         [  8.0818,  -4.2533,  -1.9392,  -6.1646,  -6.0892,  -6.1780, -10.4984,
           -5.8919,  -3.9401,  -6.5693,  -7.3575,  -7.0979, -10.3494],
         [  7.5309,  -4.1506,  -2.4790,  -6.3486,  -5.2030,  -5.6947, -10.1134,
           -5.5178,  -2.2674,  -6.9776,  -6.3065,  -6.9590, -10.1728],
         [  8.0139,  -4.6513,  -3.2390,  -6.4580,  -6.0157,  -6.1669, -10.3612,
           -5.3175,  -3.0026,  -6.3667,  -6.6653,  -6.6171, -10.1486],
         [  7.7853,  -5.0786,  -2.8571,  -6.1904,  -5.6895,  -5.9711, -10.2706,
           -4.9086,  -2.4196,  -6.6003,  -6.6581,  -6.6527, -10.0297],
         [  8.0698,  -4.2847,  -1.9188,  -6.0269,  -6.1773,  -6.3392, -10.5249,
           -5.7950,  -3.7393,  -6.7570,  -7.5077,  -7.2636, -10.3827],
         [  4.8294,  -3.1107,  -2.9952,   5.0515,  -3.