Based on https://github.com/huggingface/transformers/tree/master/examples/token-classification

Train model with https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner.py

In [135]:
import sys
import numpy as np
import torch
dummy = torch.zeros(1).cuda()
from transformers import *

In [136]:
def load_data(filename, field_num):
    with open(filename) as f:
        data = []
        line_acc = []
        for token_line in f:
            token_line = token_line.split()
            if len(token_line) == 0 and len(line_acc) > 0:
                data.append(" ".join(line_acc))
                line_acc = []
            else:
                line_acc.append(token_line[field_num])
        if len(line_acc) > 0:
            data.append(" ".join(line_acc))
            line_acc = []
        return data

In [137]:
test_inputs = load_data("./test.txt", 0)
test_labels = load_data("./test.txt", 1)
label_list = load_data("./labels.txt", 0)[0].split()

In [138]:
test_inputs[:3]

['1951 bis 1953 wurde der nördliche Teil als Jugendburg des Kolpingwerkes gebaut .',
 'Da Muck das Kriegsschreiben nicht überbracht hat , wird er als Retter des Landes ausgezeichnet und soll zum Schatzmeister ernannt werden .',
 'Mit 1. Jänner 2007 wurde Robert Schörgenhofer , als Nachfolger des ausgeschiedenen Dietmar Drabek , in die Kaderliste der FIFA-Schiedsrichter aufgenommen .']

In [139]:
test_labels[:3]

['O O O O O O O O O O B-OTH O O',
 'O B-PER O O O O O O O O O O O O O O O O O O O O',
 'O O O O O B-PER I-PER O O O O O B-PER I-PER O O O O O B-ORGpart O O']

In [140]:
label_list

['B-LOC',
 'B-LOCderiv',
 'B-LOCpart',
 'B-ORG',
 'B-ORGderiv',
 'B-ORGpart',
 'B-OTH',
 'B-OTHderiv',
 'B-OTHpart',
 'B-PER',
 'B-PERderiv',
 'B-PERpart',
 'I-LOC',
 'I-LOCderiv',
 'I-LOCpart',
 'I-ORG',
 'I-ORGderiv',
 'I-ORGpart',
 'I-OTH',
 'I-OTHderiv',
 'I-OTHpart',
 'I-PER',
 'I-PERderiv',
 'I-PERpart',
 'O']

In [141]:
def combine_tokens_and_labels(tokens, labels):
    rv = []
    for i in range(len(tokens)):
        sentence = tokens[i].split()
        label_sent = labels[i].split()
        for j in range(len(sentence)):
            if j == 0 and i > 0:
                rv.append("")
            rv.append(sentence[j] + '    ' + label_sent[j])
    return rv

In [142]:
combine_tokens_and_labels(test_inputs, test_labels)

['1951    O',
 'bis    O',
 '1953    O',
 'wurde    O',
 'der    O',
 'nördliche    O',
 'Teil    O',
 'als    O',
 'Jugendburg    O',
 'des    O',
 'Kolpingwerkes    B-OTH',
 'gebaut    O',
 '.    O',
 '',
 'Da    O',
 'Muck    B-PER',
 'das    O',
 'Kriegsschreiben    O',
 'nicht    O',
 'überbracht    O',
 'hat    O',
 ',    O',
 'wird    O',
 'er    O',
 'als    O',
 'Retter    O',
 'des    O',
 'Landes    O',
 'ausgezeichnet    O',
 'und    O',
 'soll    O',
 'zum    O',
 'Schatzmeister    O',
 'ernannt    O',
 'werden    O',
 '.    O',
 '',
 'Mit    O',
 '1.    O',
 'Jänner    O',
 '2007    O',
 'wurde    O',
 'Robert    B-PER',
 'Schörgenhofer    I-PER',
 ',    O',
 'als    O',
 'Nachfolger    O',
 'des    O',
 'ausgeschiedenen    O',
 'Dietmar    B-PER',
 'Drabek    I-PER',
 ',    O',
 'in    O',
 'die    O',
 'Kaderliste    O',
 'der    O',
 'FIFA-Schiedsrichter    B-ORGpart',
 'aufgenommen    O',
 '.    O',
 '',
 'Die    O',
 'These    O',
 ',    O',
 'Schlatter    B-PER',
 '

In [34]:
num_labels = len(label_list)
id2label, label2id = dict(), dict()
for i in range(num_labels):
    id2label[i] = label_list[i]
    label2id[label_list[i]] = i

In [144]:
label2id

{'B-LOC': 0,
 'B-LOCderiv': 1,
 'B-LOCpart': 2,
 'B-ORG': 3,
 'B-ORGderiv': 4,
 'B-ORGpart': 5,
 'B-OTH': 6,
 'B-OTHderiv': 7,
 'B-OTHpart': 8,
 'B-PER': 9,
 'B-PERderiv': 10,
 'B-PERpart': 11,
 'I-LOC': 12,
 'I-LOCderiv': 13,
 'I-LOCpart': 14,
 'I-ORG': 15,
 'I-ORGderiv': 16,
 'I-ORGpart': 17,
 'I-OTH': 18,
 'I-OTHderiv': 19,
 'I-OTHpart': 20,
 'I-PER': 21,
 'I-PERderiv': 22,
 'I-PERpart': 23,
 'O': 24}

In [145]:
model_name = 'bert-base-multilingual-cased'
#config = AutoConfig.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name, num_labels=num_labels, id2label=id2label, label2id=label2id)
tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
model = AutoModelForTokenClassification.from_pretrained(model_name, config=config)

In [177]:
model_name = './germeval-model/'
config = AutoConfig.from_pretrained(model_name)
#config = AutoConfig.from_pretrained(model_name, num_labels=num_labels, id2label=id2label, label2id=label2id)
tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
model = AutoModelForTokenClassification.from_pretrained(model_name, config=config)

TypeError: __init__() got an unexpected keyword argument 'output_hidden_states'

In [166]:
#tokens = tokenizer.encode("Hallo Welt!", add_special_tokens=True, return_tensors='pt')
tokens = tokenizer.encode("Die Universität Zürich (kurz UZH) ist eine Universität in Zürich, in der Schweiz.", add_special_tokens=True, return_tensors='pt')

In [167]:
tokens

tensor([[  101, 10236, 13071, 19985,   113, 21375,   158, 13966, 12396,   114,
         10298, 10359, 13071, 10106, 19985,   117, 10106, 10118, 18804,   119,
           102]])

In [168]:
tokens.shape

torch.Size([1, 21])

In [169]:
tokenizer.decode(tokens[0])

'[CLS] Die Universität Zürich ( kurz UZH ) ist eine Universität in Zürich, in der Schweiz. [SEP]'

In [170]:
converted_tokens = tokenizer.convert_ids_to_tokens(tokens[0])
converted_tokens

['[CLS]',
 'Die',
 'Universität',
 'Zürich',
 '(',
 'kurz',
 'U',
 '##Z',
 '##H',
 ')',
 'ist',
 'eine',
 'Universität',
 'in',
 'Zürich',
 ',',
 'in',
 'der',
 'Schweiz',
 '.',
 '[SEP]']

In [171]:
outputs = model(tokens)[0]
#outputs

In [172]:
outputs

tensor([[[-3.5906e-01, -1.0638e+00, -1.6967e+00,  8.5544e-01, -1.1273e+00,
          -4.0354e-01, -1.6005e-01, -1.4345e+00, -1.8019e+00, -1.4220e+00,
          -1.5346e+00, -2.0497e+00, -1.0582e+00, -2.6558e+00, -3.3880e+00,
           4.6456e-02, -2.3043e+00, -1.2796e+00, -3.4760e-01, -1.8886e+00,
          -1.9711e+00, -1.2106e+00, -2.4590e+00, -2.2884e+00,  1.0665e+01],
         [-1.9638e-01, -8.7915e-01, -2.1614e+00,  3.3196e+00, -1.2437e+00,
          -6.5107e-01,  1.3831e+00, -1.4261e+00, -2.2791e+00, -1.0152e+00,
          -1.6993e+00, -2.6719e+00, -1.4586e+00, -3.2869e+00, -3.3570e+00,
           4.5006e-02, -3.0197e+00, -1.7428e+00, -1.6272e-01, -2.1788e+00,
          -2.3177e+00, -1.1708e+00, -2.8690e+00, -2.5177e+00,  9.3267e+00],
         [ 2.1120e+00, -8.3915e-01, -1.1120e+00,  8.4668e+00, -9.3635e-01,
          -3.9917e-01,  1.3020e+00, -1.7527e+00, -2.1602e+00,  3.9323e-01,
          -1.6782e+00, -1.7720e+00, -1.2818e+00, -2.4580e+00, -2.4322e+00,
           1.2277e+00, 

In [153]:
outputs.shape

torch.Size([1, 6, 25])

In [173]:
predictions = torch.argmax(outputs, axis=-1)

In [174]:
def label_ids_to_labels(ids):
    ids = ids.numpy()
    sent_labels_acc = []
    for j in range(ids.shape[1]):
            sent_labels_acc.append(model.config.id2label[ids[0, j]])
    return sent_labels_acc

In [175]:
predicted_labels = label_ids_to_labels(predictions)

In [176]:
combine_tokens_and_labels([" ".join(converted_tokens)],[" ".join(predicted_labels)])

['[CLS]    O',
 'Die    O',
 'Universität    B-ORG',
 'Zürich    I-ORG',
 '(    O',
 'kurz    O',
 'U    B-ORG',
 '##Z    I-ORG',
 '##H    I-ORG',
 ')    O',
 'ist    O',
 'eine    O',
 'Universität    O',
 'in    O',
 'Zürich    B-LOC',
 ',    O',
 'in    O',
 'der    O',
 'Schweiz    B-LOC',
 '.    O',
 '[SEP]    O']

In [133]:
" ".join(predicted_labels)

'O O O O O O'

In [134]:
" ".join(converted_tokens)

'[CLS] Hall ##o Welt ! [SEP]'