# Decoding Methods

Determines how entity labels are predicted over sequences.

### 🧾Introduction

After assigning labels to tokens, decoding determines how to interpret those predictions. Different methods offer varying levels of structural consistency and accuracy.

In [None]:
!pip install -q transformers torch accelerate scikit-learn

In [None]:
!pip install -U -q datasets 

In [4]:
!pip install -q seqeval

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone


In [57]:
from datasets import load_dataset
from seqeval.metrics.sequence_labeling import get_entities

# Load CoNLL-2003 dataset
dataset = load_dataset("conll2003")

# Sample a sentence
samples = dataset["train"][5:7]

In [58]:
print(samples['id'])
print(samples['tokens'])
print(samples['pos_tags'])
print(samples['chunk_tags'])
print(samples['ner_tags'])

['5', '6']
[['"', 'We', 'do', "n't", 'support', 'any', 'such', 'recommendation', 'because', 'we', 'do', "n't", 'see', 'any', 'grounds', 'for', 'it', ',', '"', 'the', 'Commission', "'s", 'chief', 'spokesman', 'Nikolaus', 'van', 'der', 'Pas', 'told', 'a', 'news', 'briefing', '.'], ['He', 'said', 'further', 'scientific', 'study', 'was', 'required', 'and', 'if', 'it', 'was', 'found', 'that', 'action', 'was', 'needed', 'it', 'should', 'be', 'taken', 'by', 'the', 'European', 'Union', '.']]
[[0, 28, 41, 30, 37, 12, 16, 21, 15, 28, 41, 30, 37, 12, 24, 15, 28, 6, 0, 12, 22, 27, 16, 21, 22, 22, 14, 22, 38, 12, 21, 21, 7], [28, 38, 16, 16, 21, 38, 40, 10, 15, 28, 38, 40, 15, 21, 38, 40, 28, 20, 37, 40, 15, 12, 22, 22, 7]]
[[0, 11, 21, 22, 22, 11, 12, 12, 17, 11, 21, 22, 22, 11, 12, 13, 11, 0, 0, 11, 12, 11, 12, 12, 12, 12, 12, 12, 21, 11, 12, 12, 0], [11, 21, 11, 12, 12, 21, 22, 0, 17, 11, 21, 22, 17, 11, 21, 22, 11, 21, 22, 22, 13, 11, 12, 12, 0]]
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [None]:
from transformers import BertTokenizerFast

example = dataset["train"][0]
tokens = example["tokens"]
labels = example["ner_tags"]

# Load tokenizer
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
inputs = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", truncation=True, padding=True)

label_list = dataset["train"].features["ner_tags"].feature.names
true_tags = [label_list[l] for l in labels]

print("🔹 Tokens:", tokens)
print("🔹 True NER Tags (IOB2):", true_tags)

🔹 Tokens: ['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.']
🔹 True NER Tags (IOB2): ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O']


In [65]:
print(inputs)

{'input_ids': tensor([[  101,  7270, 22961,  1528,  1840,  1106, 21423,  1418,  2495, 12913,
           119,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


### 🔹 Softmax Decoder

* Simplest approach: predicts label for each token independently
* Can lead to illegal sequences (e.g., `I-ORG` without `B-ORG`)

In [66]:
from transformers import BertForTokenClassification
import torch

model = BertForTokenClassification.from_pretrained("bert-base-cased", num_labels=len(label_list))

with torch.no_grad():
    outputs = model(**inputs).logits
predictions = torch.argmax(outputs, dim=-1)

# Convert predicted IDs to tags
pred_tags = [label_list[p.item()] for p in predictions[0]]

print("\n🔸 Softmax Decoder Prediction:")
for tok, gold, pred in zip(tokens, true_tags, pred_tags[1:len(tokens)+1]):  # skip [CLS], [SEP]
    print(f"{tok:15} | True: {gold:10} | Pred: {pred}")


model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



🔸 Softmax Decoder Prediction:
EU              | True: B-ORG      | Pred: B-LOC
rejects         | True: O          | Pred: B-PER
German          | True: B-MISC     | Pred: B-LOC
call            | True: O          | Pred: B-PER
to              | True: O          | Pred: I-LOC
boycott         | True: O          | Pred: I-MISC
British         | True: B-MISC     | Pred: B-LOC
lamb            | True: O          | Pred: I-LOC
.               | True: O          | Pred: I-LOC


### 🔹 CRF (Conditional Random Field)

* Adds structural constraints during decoding
* Learns legal transitions between tags (e.g., `B-LOC` → `I-LOC`)
* Significantly improves label consistency and overall performance


In [68]:
from torch import nn
from transformers import BertModel
from torchcrf import CRF  # install with `pip install pytorch-crf`

class BERT_CRF(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-cased")
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
        self.crf = CRF(num_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        emissions = self.classifier(self.dropout(outputs))
        if labels is not None:
            loss = -self.crf(emissions, labels, mask=attention_mask.bool(), reduction='mean')
            return loss
        else:
            pred = self.crf.decode(emissions, mask=attention_mask.bool())
            return pred

In [69]:
crf_model = BERT_CRF(num_labels=len(label_list))
crf_model.eval()

with torch.no_grad():
    pred_crf = crf_model(inputs["input_ids"], inputs["attention_mask"])  # list of predicted sequences

pred_crf_tags = [label_list[i] for i in pred_crf[0]]

print("\n🔸 CRF Decoder Prediction:")
for tok, gold, pred in zip(tokens, true_tags, pred_crf_tags[:len(tokens)]):
    print(f"{tok:15} | True: {gold:10} | Pred: {pred}")



🔸 CRF Decoder Prediction:
EU              | True: B-ORG      | Pred: I-PER
rejects         | True: O          | Pred: I-PER
German          | True: B-MISC     | Pred: B-ORG
call            | True: O          | Pred: B-LOC
to              | True: O          | Pred: I-MISC
boycott         | True: O          | Pred: I-MISC
British         | True: B-MISC     | Pred: B-LOC
lamb            | True: O          | Pred: I-PER
.               | True: O          | Pred: B-ORG


### 🔹 Span-based Classification

* Predicts start and end positions of entities
* Suitable for overlapping or nested entities
* Common in QA-style or span extraction tasks

In [None]:
class SpanNER(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-cased")
        self.start_fc = nn.Linear(self.bert.config.hidden_size, num_labels)
        self.end_fc = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        start_logits = self.start_fc(output)
        end_logits = self.end_fc(output)
        return start_logits, end_logits


In [70]:
# Dummy example span indices (pretend model predicted them)
pred_spans = [(1, 2, "PER"), (4, 5, "LOC")]

print("\n🔸 Span-based Decoder Prediction:")
for start, end, label in pred_spans:
    span_tokens = tokens[start:end+1]
    print(f"Span: {' '.join(span_tokens):20} | Label: {label}")



🔸 Span-based Decoder Prediction:
Span: rejects German       | Label: PER
Span: to boycott           | Label: LOC


### 🔹 Pointer Networks 

* Used for identifying arbitrary spans using attention mechanisms
* Complex but powerful for flexible span boundaries

In [71]:
# Simulated pointer outputs
pred_pointer_spans = [(0, 0, "ORG"), (6, 7, "MISC")]

print("\n🔸 Pointer / Multi-head Span Decoder Prediction:")
for start, end, label in pred_pointer_spans:
    print(f"Span: {' '.join(tokens[start:end+1]):20} | Label: {label}")



🔸 Pointer / Multi-head Span Decoder Prediction:
Span: EU                   | Label: ORG
Span: British lamb         | Label: MISC


### Summary

Softmax is easy to implement but limited in sequence integrity. CRF and span-based methods are preferred in production-grade NER systems for better accuracy and coherence.
