Token classification assigns a label to individual tokens in a sentence. One of the most common token classification tasks is Named Entity Recognition (NER). NER attempts to find a label for each entity in a sentence, such as a person, location, or organization.

This guide will show you how to:

Finetune DistilBERT on the WNUT 17 dataset to detect new entities.
Use your finetuned model for inference.

Start by loading the WNUT 17 dataset from the 🤗 Datasets library:

In [1]:
from datasets import load_dataset
dataset = load_dataset("wnut_17")

In [2]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 3394
    })
    validation: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 1009
    })
    test: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 1287
    })
})

In [3]:
dataset["train"][8]

{'id': '8',
 'tokens': ['Friday', 'Night', 'Eats', 'http://twitpic.com/2pdvtr'],
 'ner_tags': [0, 0, 0, 0]}

In [4]:
label_list = dataset["train"].features[f"ner_tags"].feature.names
label_list

['O',
 'B-corporation',
 'I-corporation',
 'B-creative-work',
 'I-creative-work',
 'B-group',
 'I-group',
 'B-location',
 'I-location',
 'B-person',
 'I-person',
 'B-product',
 'I-product']


The letter that prefixes each ner_tag indicates the token position of the entity:


*   B- indicates the beginning of an entity

*   I- indicates a token is contained inside the same entity (for example, the State token is a part of an entity like Empire State Building).


*  0 indicates the token doesn't correspond to any entity




In [5]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [6]:
example = dataset["train"][0]
tokenized_input = tokenizer(example["tokens"],is_split_into_words=True)
print(tokenized_input["input_ids"])
tokens =  tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])

[101, 1030, 2703, 17122, 2009, 1005, 1055, 1996, 3193, 2013, 2073, 1045, 1005, 1049, 2542, 2005, 2048, 3134, 1012, 3400, 2110, 2311, 1027, 9686, 2497, 1012, 3492, 2919, 4040, 2182, 2197, 3944, 1012, 102]


In [7]:
print(tokens)

['[CLS]', '@', 'paul', '##walk', 'it', "'", 's', 'the', 'view', 'from', 'where', 'i', "'", 'm', 'living', 'for', 'two', 'weeks', '.', 'empire', 'state', 'building', '=', 'es', '##b', '.', 'pretty', 'bad', 'storm', 'here', 'last', 'evening', '.', '[SEP]']


In [8]:
print(tokenized_input)

{'input_ids': [101, 1030, 2703, 17122, 2009, 1005, 1055, 1996, 3193, 2013, 2073, 1045, 1005, 1049, 2542, 2005, 2048, 3134, 1012, 3400, 2110, 2311, 1027, 9686, 2497, 1012, 3492, 2919, 4040, 2182, 2197, 3944, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


However, this adds some special tokens [CLS] and [SEP] and the subword tokenization creates a mismatch between the input and labels. A single word corresponding to a single label may now be split into two subwords.



1.   Mapping all tokens to their corresponding word with the word_ids method.
2.   Assigning the label -100 to the special tokens [CLS] and [SEP] so they're ignored by the PyTorch loss function.
3.Only labeling the first token of a given word. Assign -100 to other subtokens from the same word.





In [9]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

To apply the preprocessing function over the entire dataset, use 🤗 Datasets map function. You can speed up the map function by setting batched=True to process multiple elements of the dataset at once:

In [10]:
tokenized_dataset = dataset.map(tokenize_and_align_labels, batched=True)

In [11]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [12]:
import evaluate

seqeval = evaluate.load("seqeval")

In [13]:
import numpy as np

labels = [label_list[i] for i in example[f"ner_tags"]]


def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

Before you start training your model, create a map of the expected ids to their labels with id2label and label2id:

In [14]:
id2label = {
    0: "O",
    1: "B-corporation",
    2: "I-corporation",
    3: "B-creative-work",
    4: "I-creative-work",
    5: "B-group",
    6: "I-group",
    7: "B-location",
    8: "I-location",
    9: "B-person",
    10: "I-person",
    11: "B-product",
    12: "I-product",
}
label2id = {
    "O": 0,
    "B-corporation": 1,
    "I-corporation": 2,
    "B-creative-work": 3,
    "I-creative-work": 4,
    "B-group": 5,
    "I-group": 6,
    "B-location": 7,
    "I-location": 8,
    "B-person": 9,
    "I-person": 10,
    "B-product": 11,
    "I-product": 12,
}

In [15]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer

model = AutoModelForTokenClassification.from_pretrained(
    "distilbert-base-uncased", num_labels=13, id2label=id2label, label2id=label2id
)

Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
training_args = TrainingArguments(
    output_dir="my_wnut_model",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [17]:
trainer.evaluate()

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'eval_loss': 2.68152117729187,
 'eval_precision': 0.005356186395286556,
 'eval_recall': 0.07414272474513438,
 'eval_f1': 0.009990633780830472,
 'eval_accuracy': 0.013423966482835278,
 'eval_runtime': 5.7743,
 'eval_samples_per_second': 222.884,
 'eval_steps_per_second': 14.028}

In [18]:
trainer.train()

Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,No log,0.278258,0.624746,0.285449,0.391858,0.940618
2,No log,0.256084,0.576052,0.329935,0.419564,0.943226
3,0.184700,0.270109,0.540613,0.376274,0.443716,0.946219
4,0.184700,0.287609,0.578561,0.365153,0.447727,0.946518
5,0.053700,0.292569,0.573034,0.378128,0.455611,0.947116


TrainOutput(global_step=1065, training_loss=0.11445152233464058, metrics={'train_runtime': 137.8751, 'train_samples_per_second': 123.082, 'train_steps_per_second': 7.724, 'total_flos': 229344858861480.0, 'train_loss': 0.11445152233464058, 'epoch': 5.0})

In [19]:
trainer.evaluate()

{'eval_loss': 0.2560841143131256,
 'eval_precision': 0.5760517799352751,
 'eval_recall': 0.329935125115848,
 'eval_f1': 0.4195639363582793,
 'eval_accuracy': 0.9432260271044419,
 'eval_runtime': 3.6608,
 'eval_samples_per_second': 351.561,
 'eval_steps_per_second': 22.126,
 'epoch': 5.0}

In [20]:
text = "The Golden State Warriors are an American professional basketball team based in San Francisco."

In [21]:
import torch
# Ensure the model and data are on the same device (either GPU or CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the model and tokenizer
# Replace 'model' and 'tokenizer' with the actual model and tokenizer you are using
model = model.to(device)

In [22]:
from transformers import pipeline
pipe = pipeline("token-classification", model, tokenizer=tokenizer,device=0 if device == "cuda" else -1)

In [23]:
pipe(text)

[{'entity': 'B-location',
  'score': 0.29789796,
  'index': 1,
  'word': 'the',
  'start': 0,
  'end': 3},
 {'entity': 'B-location',
  'score': 0.6854044,
  'index': 2,
  'word': 'golden',
  'start': 4,
  'end': 10},
 {'entity': 'I-location',
  'score': 0.68216825,
  'index': 3,
  'word': 'state',
  'start': 11,
  'end': 16},
 {'entity': 'I-group',
  'score': 0.28247955,
  'index': 4,
  'word': 'warriors',
  'start': 17,
  'end': 25},
 {'entity': 'B-location',
  'score': 0.28391644,
  'index': 7,
  'word': 'american',
  'start': 33,
  'end': 41},
 {'entity': 'B-location',
  'score': 0.8134942,
  'index': 13,
  'word': 'san',
  'start': 80,
  'end': 83},
 {'entity': 'I-location',
  'score': 0.53298944,
  'index': 14,
  'word': 'francisco',
  'start': 84,
  'end': 93},
 {'entity': 'I-group',
  'score': 0.1500106,
  'index': 15,
  'word': '.',
  'start': 93,
  'end': 94}]