# Token Classification with Transformers

First we will load the *WNUT 17 dataset* and then work with that for our Named Entity Recognition Token Classification challenge. 

## Loading the data

In [None]:
from datasets import load_dataset
wnut = load_dataset("wnut_17")
#https://huggingface.co/docs/transformers/tasks/token_classification#load-wnut-17-dataset

In [None]:
wnut["train"][0]

In [None]:
label_list = wnut['train'].features[f"ner_tags"].feature.names
label_list

## Preprocess step

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [None]:
example = wnut["train"][0]
tokenized_input = tokenizer(example["tokens"], is_split_into_words=True)
tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])
tokens

However, this adds some special tokens [CLS] and [SEP] and the subword tokenization creates a mismatch between the input and labels. A single word corresponding to a single label may now be split into two subwords. You’ll need to realign the tokens and labels by:

Mapping all tokens to their corresponding word with the word_ids method.
Assigning the label -100 to the special tokens [CLS] and [SEP] so they’re ignored by the PyTorch loss function.
Only labeling the first token of a given word. Assign -100 to other subtokens from the same word.
Here is how you can create a function to realign the tokens and labels, and truncate sequences to be no longer than DistilBERT’s maximum input length:

In [None]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [None]:
tokenized_wnut = wnut.map(tokenize_and_align_labels, batched=True)

## Data collator creation

In [None]:
from transformers import DataCollatorForTokenClassification
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

## Evaluate

In [None]:
import evaluate
seqeval = evaluate.load("seqeval")

In [None]:
import numpy as np
labels = [label_list[i] for i in example[f"ner_tags"]]
print(labels)

In [None]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_preds = [
        [label_list[p] for (p,l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
        
        ]

    true_labels = [
        [label_list[l] for (p,l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = seqeval.compute(predictions=true_preds, references=true_labels)
    # Return dictionary of results

    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"], 
        "f1": results["overall_f1"], 
        "accuracy": results["overall_accuracy"]
    }

## View training labels

In [None]:
id2label = {
    0: "O",
    1: "B-corporation",
    2: "I-corporation",
    3: "B-creative-work",
    4: "I-creative-work",
    5: "B-group",
    6: "I-group",
    7: "B-location",
    8: "I-location",
    9: "B-person",
    10: "I-person",
    11: "B-product",
    12: "I-product"
}

print(id2label)

label2id = {
    "O": 0,
    "B-corporation": 1,
    "I-corporation": 2,
    "B-creative-work": 3,
    "I-creative-work": 4,
    "B-group": 5,
    "I-group": 6,
    "B-location": 7,
    "I-location": 8,
    "B-person": 9,
    "I-person": 10,
    "B-product": 11,
    "I-product": 12,
}

print(label2id)
print(len(label2id))

## Train our Distilbert classifier

In [None]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
model_name = "distilbert-base-uncased"
large_model_name = "bert-large-uncased"

model = AutoModelForTokenClassification.from_pretrained(
    large_model_name, num_labels = int(len(label2id)), id2label=id2label, label2id=label2id
)

Create the training arguments

In [None]:
batch_size = 16
learning_rate = 5e-05
epochs = 2
weight_dec = 0.01
model_name = "bert-large-ner-wnut-17"

In [None]:
train_args = TrainingArguments(
    output_dir=model_name,
    learning_rate = learning_rate, 
    per_device_eval_batch_size=batch_size, 
    per_device_train_batch_size=batch_size, 
    num_train_epochs=epochs, 
    weight_decay=weight_dec, 
    evaluation_strategy="epoch", 
    save_strategy="epoch", 
    load_best_model_at_end=True, 
    push_to_hub=True
)

In [None]:
trainer = Trainer(
    model=model, 
    args=train_args, 
    train_dataset=tokenized_wnut['train'],
    eval_dataset=tokenized_wnut['test'],
    tokenizer=tokenizer, 
    data_collator=data_collator, 
    compute_metrics=compute_metrics
)

Train the model. 

In [None]:
trainer.train()

In [None]:
trainer.push_to_hub()

## Inference

In [None]:
text = "The Golden State Warriors are an American professional basketball team based in San Francisco."

In [None]:
from transformers import pipeline
classifier = pipeline("ner", model=f"StatsGary/{model_name}")

In [None]:
classifier(text)

## Inference with PyTorch

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(f"StatsGary/{model_name}")
inputs = tokenizer(text, return_tensors="pt")

In [None]:
from transformers import AutoModelForTokenClassification
import torch
model = AutoModelForTokenClassification.from_pretrained(f"StatsGary/{model_name}")
with torch.no_grad():
    logits = model(**inputs).logits

In [None]:
predictions = torch.argmax(logits, dim=2)
predicted_token_class = [model.config.id2label[t.item()] for t in predictions[0]]
predicted_token_class