# Token Classification

- Finetune DistilBERT on the WNUT 17 dataset to detect new entities.
- Use your finetuned model for inference.

Token classification assigns a label to individual tokens in a sentence. One of the most common token classification tasks is Named Entity Recognition (NER). NER attempts to find a label for each entity in a sentence, such as a person, location, or organization.

!pip install transformers datasets evaluate seqeval

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [44]:
from datasets import load_dataset

dataset = load_dataset("wnut_17")

dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 3394
    })
    validation: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 1009
    })
    test: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 1287
    })
})

In [45]:
print(dataset["train"][0])

{'id': '0', 'tokens': ['@paulwalk', 'It', "'s", 'the', 'view', 'from', 'where', 'I', "'m", 'living', 'for', 'two', 'weeks', '.', 'Empire', 'State', 'Building', '=', 'ESB', '.', 'Pretty', 'bad', 'storm', 'here', 'last', 'evening', '.'], 'ner_tags': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 8, 8, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0]}


In [46]:
label_list = dataset["train"].features["ner_tags"].feature.names
label_list

['O',
 'B-corporation',
 'I-corporation',
 'B-creative-work',
 'I-creative-work',
 'B-group',
 'I-group',
 'B-location',
 'I-location',
 'B-person',
 'I-person',
 'B-product',
 'I-product']

The letter that prefixes each ner_tag indicates the token position of the entity:

- B- indicates the beginning of an entity.
- I- indicates a token is contained inside the same entity (for example, the State token is a part of an entity like Empire State Building).
- 0 indicates the token doesn’t correspond to any entity.

In [47]:
from transformers import AutoTokenizer

checkpoint = "distilbert/distilbert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [48]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
    labels = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [50]:
tokenized_wnut = dataset.map(tokenize_and_align_labels, batched=True)

In [52]:
for k, v in tokenized_wnut["train"][0].items():
    print(k, v)

id 0
tokens ['@paulwalk', 'It', "'s", 'the', 'view', 'from', 'where', 'I', "'m", 'living', 'for', 'two', 'weeks', '.', 'Empire', 'State', 'Building', '=', 'ESB', '.', 'Pretty', 'bad', 'storm', 'here', 'last', 'evening', '.']
ner_tags [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 8, 8, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0]
input_ids [101, 1030, 2703, 17122, 2009, 1005, 1055, 1996, 3193, 2013, 2073, 1045, 1005, 1049, 2542, 2005, 2048, 3134, 1012, 3400, 2110, 2311, 1027, 9686, 2497, 1012, 3492, 2919, 4040, 2182, 2197, 3944, 1012, 102]
attention_mask [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
labels [-100, 0, -100, -100, 0, 0, -100, 0, 0, 0, 0, 0, 0, -100, 0, 0, 0, 0, 0, 7, 8, 8, 0, 7, -100, 0, 0, 0, 0, 0, 0, 0, 0, -100]


In [43]:
# Let's verify the labels are being passed correctly
print("=== LABEL VERIFICATION ===")
print(f"Original dataset sample:")
print(f"Tokens: {dataset['train'][0]['tokens'][:5]}")
print(f"NER tags: {dataset['train'][0]['ner_tags'][:5]}")
print(f"Label names: {[label_list[i] for i in dataset['train'][0]['ner_tags'][:5]]}")

print(f"\nAfter tokenization and alignment:")
print(f"Input IDs: {tokenized_wnut['train'][0]['input_ids'][:10]}")
print(f"Labels: {tokenized_wnut['train'][0]['labels'][:10]}")
print(f"Label names: {[label_list[i] if i != -100 else 'IGNORE' for i in tokenized_wnut['train'][0]['labels'][:10]]}")

print(f"\nDataset features after tokenization:")
print(f"Features: {tokenized_wnut['train'].features}")
print(f"Labels field exists: {'labels' in tokenized_wnut['train'].features}")
print(f"Sample label shape: {len(tokenized_wnut['train'][0]['labels'])}")
print(f"Sample input_ids shape: {len(tokenized_wnut['train'][0]['input_ids'])}")
print("=== END VERIFICATION ===")


=== LABEL VERIFICATION ===
Original dataset sample:
Tokens: ['@paulwalk', 'It', "'s", 'the', 'view']
NER tags: [0, 0, 0, 0, 0]
Label names: ['O', 'O', 'O', 'O', 'O']

After tokenization and alignment:
Input IDs: [101, 1030, 2703, 17122, 2009, 1005, 1055, 1996, 3193, 2013]
Labels: [-100, 0, -100, -100, 0, 0, -100, 0, 0, 0]
Label names: ['IGNORE', 'O', 'IGNORE', 'IGNORE', 'O', 'O', 'IGNORE', 'O', 'O', 'O']

Dataset features after tokenization:
Features: {'id': Value(dtype='string', id=None), 'tokens': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'ner_tags': Sequence(feature=ClassLabel(names=['O', 'B-corporation', 'I-corporation', 'B-creative-work', 'I-creative-work', 'B-group', 'I-group', 'B-location', 'I-location', 'B-person', 'I-person', 'B-product', 'I-product'], id=None), length=-1, id=None), 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), '

In [53]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [54]:
import evaluate

metric = evaluate.load("seqeval")

In [55]:
import numpy as np

# labels = [label_list[i] for i in example[f"ner_tags"]]


def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [56]:
id2label = {
    0: "O",
    1: "B-corporation",
    2: "I-corporation",
    3: "B-creative-work",
    4: "I-creative-work",
    5: "B-group",
    6: "I-group",
    7: "B-location",
    8: "I-location",
    9: "B-person",
    10: "I-person",
    11: "B-product",
    12: "I-product",
}
label2id = {
    "O": 0,
    "B-corporation": 1,
    "I-corporation": 2,
    "B-creative-work": 3,
    "I-creative-work": 4,
    "B-group": 5,
    "I-group": 6,
    "B-location": 7,
    "I-location": 8,
    "B-person": 9,
    "I-person": 10,
    "B-product": 11,
    "I-product": 12,
}

In [57]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer

model = AutoModelForTokenClassification.from_pretrained(
    checkpoint, num_labels=13, id2label=id2label, label2id=label2id
)

Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [59]:
training_args = TrainingArguments(
    output_dir="token_classification_DistilBERT_WNUT17",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,  # Increased from 2 to 5
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=True,
    logging_steps=50,  # Add logging
    metric_for_best_model="f1",  # Use F1 score for best model selection
    greater_is_better=True,
    # save_total_limit=2,  # Save only best 2 checkpoints
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_wnut["train"],
    eval_dataset=tokenized_wnut["validation"],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()



Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,0.0319,0.254779,0.643777,0.538278,0.586319,0.953589
2,0.0235,0.269519,0.652798,0.544258,0.593607,0.954225
3,0.0148,0.294628,0.700326,0.514354,0.593103,0.954352
4,0.0131,0.29311,0.678516,0.52512,0.592043,0.95486
5,0.0134,0.291374,0.666667,0.5311,0.591212,0.954288




TrainOutput(global_step=1065, training_loss=0.020616970767437574, metrics={'train_runtime': 397.9129, 'train_samples_per_second': 42.648, 'train_steps_per_second': 2.676, 'total_flos': 229914027537180.0, 'train_loss': 0.020616970767437574, 'epoch': 5.0})

In [60]:
trainer.push_to_hub()

model.safetensors:   0%|          | 0.00/266M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/Swagam/token_classification_DistilBERT_WNUT17/commit/c60896d541fc48fe2518f0fa91bef692a2e7e809', commit_message='End of training', commit_description='', oid='c60896d541fc48fe2518f0fa91bef692a2e7e809', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Swagam/token_classification_DistilBERT_WNUT17', endpoint='https://huggingface.co', repo_type='model', repo_id='Swagam/token_classification_DistilBERT_WNUT17'), pr_revision=None, pr_num=None)

In [61]:
text = "The Golden State Warriors are an American professional basketball team based in San Francisco."

In [None]:
# IMPROVED INFERENCE WITH BETTER AGGREGATION
from transformers import pipeline

# Try different aggregation strategies
classifier_first = pipeline("ner", 
                           model="Swagam/token_classification_DistilBERT_WNUT17",
                           aggregation_strategy="first")

classifier_max = pipeline("ner", 
                         model="Swagam/token_classification_DistilBERT_WNUT17",
                         aggregation_strategy="max")

# Test with higher confidence threshold
classifier_simple = pipeline("ner", 
                            model="Swagam/token_classification_DistilBERT_WNUT17",
                            aggregation_strategy="simple")

text = "The Golden State Warriors are an American professional basketball team based in San Francisco."

print("=== FIRST AGGREGATION ===")
result_first = classifier_first(text)
for item in result_first:
    print(f"{item['word']:15} | {item['entity']:15} | {item['score']:.3f}")

print("\n=== MAX AGGREGATION ===")
result_max = classifier_max(text)
for item in result_max:
    print(f"{item['word']:15} | {item['entity']:15} | {item['score']:.3f}")

print("\n=== SIMPLE AGGREGATION ===")
result_simple = classifier_simple(text)
for item in result_simple:
    print(f"{item['word']:15} | {item['entity']:15} | {item['score']:.3f}")


In [None]:
# MANUAL FILTERING WITH CONFIDENCE THRESHOLD
def filter_low_confidence(results, threshold=0.7):
    """Filter out predictions below confidence threshold"""
    return [item for item in results if item['score'] >= threshold]

def filter_common_words(results, common_words={'the', 'a', 'an', 'are', 'is', 'was', 'were', 'and', 'or', 'but'}):
    """Filter out common words that shouldn't be entities"""
    return [item for item in results if item['word'].lower() not in common_words]

# Apply filters to your original results
original_results = [
    {'entity': 'B-group', 'score': 0.62402034, 'index': 1, 'word': 'the', 'start': 0, 'end': 3},
    {'entity': 'B-location', 'score': 0.8082413, 'index': 2, 'word': 'golden', 'start': 4, 'end': 10},
    {'entity': 'I-group', 'score': 0.57333845, 'index': 3, 'word': 'state', 'start': 11, 'end': 16},
    {'entity': 'I-group', 'score': 0.90311676, 'index': 4, 'word': 'warriors', 'start': 17, 'end': 25},
    {'entity': 'I-group', 'score': 0.3521599, 'index': 5, 'word': 'are', 'start': 26, 'end': 29},
    {'entity': 'B-group', 'score': 0.44896773, 'index': 7, 'word': 'american', 'start': 33, 'end': 41},
    {'entity': 'I-group', 'score': 0.3984678, 'index': 8, 'word': 'professional', 'start': 42, 'end': 54},
    {'entity': 'I-group', 'score': 0.3239563, 'index': 9, 'word': 'basketball', 'start': 55, 'end': 65},
    {'entity': 'I-group', 'score': 0.48978537, 'index': 10, 'word': 'team', 'start': 66, 'end': 70},
    {'entity': 'B-location', 'score': 0.9849925, 'index': 13, 'word': 'san', 'start': 80, 'end': 83},
    {'entity': 'I-location', 'score': 0.9699109, 'index': 14, 'word': 'francisco', 'start': 84, 'end': 93}
]

print("=== ORIGINAL RESULTS ===")
for item in original_results:
    print(f"{item['word']:15} | {item['entity']:15} | {item['score']:.3f}")

print("\n=== AFTER CONFIDENCE FILTER (>= 0.7) ===")
filtered_confidence = filter_low_confidence(original_results, 0.7)
for item in filtered_confidence:
    print(f"{item['word']:15} | {item['entity']:15} | {item['score']:.3f}")

print("\n=== AFTER COMMON WORDS FILTER ===")
filtered_common = filter_common_words(original_results)
for item in filtered_common:
    print(f"{item['word']:15} | {item['entity']:15} | {item['score']:.3f}")

print("\n=== AFTER BOTH FILTERS ===")
filtered_both = filter_common_words(filter_low_confidence(original_results, 0.7))
for item in filtered_both:
    print(f"{item['word']:15} | {item['entity']:15} | {item['score']:.3f}")


In [62]:
from transformers import pipeline

classifier = pipeline("ner", model="Swagam/token_classification_DistilBERT_WNUT17")
classifier(text)

model.safetensors:   0%|          | 0.00/266M [00:00<?, ?B/s]

Device set to use mps:0


[{'entity': 'B-group',
  'score': 0.62402034,
  'index': 1,
  'word': 'the',
  'start': 0,
  'end': 3},
 {'entity': 'B-location',
  'score': 0.8082413,
  'index': 2,
  'word': 'golden',
  'start': 4,
  'end': 10},
 {'entity': 'I-group',
  'score': 0.57333845,
  'index': 3,
  'word': 'state',
  'start': 11,
  'end': 16},
 {'entity': 'I-group',
  'score': 0.90311676,
  'index': 4,
  'word': 'warriors',
  'start': 17,
  'end': 25},
 {'entity': 'I-group',
  'score': 0.3521599,
  'index': 5,
  'word': 'are',
  'start': 26,
  'end': 29},
 {'entity': 'B-group',
  'score': 0.44896773,
  'index': 7,
  'word': 'american',
  'start': 33,
  'end': 41},
 {'entity': 'I-group',
  'score': 0.3984678,
  'index': 8,
  'word': 'professional',
  'start': 42,
  'end': 54},
 {'entity': 'I-group',
  'score': 0.3239563,
  'index': 9,
  'word': 'basketball',
  'start': 55,
  'end': 65},
 {'entity': 'I-group',
  'score': 0.48978537,
  'index': 10,
  'word': 'team',
  'start': 66,
  'end': 70},
 {'entity': 'B-lo