In [23]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline

tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")

nlp = pipeline("ner", model=model, tokenizer=tokenizer)
example = "My name is Wolfgang and I live in Berlin"

ner_results = nlp(example)
print(ner_results)

[{'entity': 'B-PER', 'score': 0.99901396, 'index': 4, 'word': 'Wolfgang', 'start': 11, 'end': 19}, {'entity': 'B-LOC', 'score': 0.999645, 'index': 9, 'word': 'Berlin', 'start': 34, 'end': 40}]


In [24]:
model

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

In [25]:
import torch

In [26]:
torch.nn.modules.linear.Linear(768, 100)

Linear(in_features=768, out_features=100, bias=True)

In [27]:
type(model.classifier)

torch.nn.modules.linear.Linear

In [28]:
ner_results = nlp(example)
ner_results

[{'entity': 'B-PER',
  'score': 0.99901396,
  'index': 4,
  'word': 'Wolfgang',
  'start': 11,
  'end': 19},
 {'entity': 'B-LOC',
  'score': 0.999645,
  'index': 9,
  'word': 'Berlin',
  'start': 34,
  'end': 40}]

In [None]:
# Change Last classifier Layer to custom

In [29]:
model.classifier = torch.nn.modules.linear.Linear(768, 100)

In [30]:
model

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

In [None]:
# will get error because labels is not mapping yet

In [31]:
ner_results = nlp(example)

KeyError: 17

In [None]:
# Custom Labels

In [33]:
test_dict = {}
for i in range(100):
    test_dict[i] = str(i)

In [None]:
# Map the labels

In [34]:
model.config.id2label = test_dict

In [35]:
ner_results = nlp(example)
ner_results

[{'entity': '17',
  'score': 0.026227705,
  'index': 1,
  'word': 'My',
  'start': 0,
  'end': 2},
 {'entity': '74',
  'score': 0.026183637,
  'index': 2,
  'word': 'name',
  'start': 3,
  'end': 7},
 {'entity': '61',
  'score': 0.024977818,
  'index': 3,
  'word': 'is',
  'start': 8,
  'end': 10},
 {'entity': '91',
  'score': 0.02574957,
  'index': 4,
  'word': 'Wolfgang',
  'start': 11,
  'end': 19},
 {'entity': '17',
  'score': 0.024696257,
  'index': 5,
  'word': 'and',
  'start': 20,
  'end': 23},
 {'entity': '17',
  'score': 0.026242768,
  'index': 6,
  'word': 'I',
  'start': 24,
  'end': 25},
 {'entity': '17',
  'score': 0.023752943,
  'index': 7,
  'word': 'live',
  'start': 26,
  'end': 30},
 {'entity': '27',
  'score': 0.020842379,
  'index': 8,
  'word': 'in',
  'start': 31,
  'end': 33},
 {'entity': '75',
  'score': 0.02482841,
  'index': 9,
  'word': 'Berlin',
  'start': 34,
  'end': 40}]