# Lab 07: Named Entity Recognition (NER) with BiLSTM-CRF

**Task:** Named Entity Recognition (NER)  
**Dataset:** CoNLL-2003 (or custom NER dataset)  
**Tags:** PER (Person), ORG (Organization), LOC (Location), MISC  

## Part 0: Setup và Import Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter, defaultdict
import re
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
from seqeval.metrics import f1_score, precision_score, recall_score, classification_report as seq_report

# Set random seed
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

# Install seqeval if needed: pip install seqeval

## Part 1: BiLSTM-CRF

### 1.1. What is CRF (Conditional Random Field)?

- Enforces valid tag transitions
- Considers the entire sequence globally
- Finds the best tag sequence, not individual tags

### 1.2. BiLSTM-CRF Architecture

```
Input:      [Apple, is, a, company]
            ↓
Embedding:  [e₁, e₂, e₃, e₄]
            ↓
BiLSTM:     [h₁, h₂, h₃, h₄] ← contextual features
            ↓
Emission:   [scores for each tag at each position]
            ↓
CRF Layer:  Transition matrix + Viterbi decoding
            ↓
Output:     [B-ORG, O, O, O] ← best valid sequence
```

### 1.3. Components

1. **Emission Scores**: P(tag | word) from BiLSTM
2. **Transition Scores**: P(tagᵢ₊₁ | tagᵢ) learned by CRF
3. **Viterbi Decoding**: Find best path considering both scores

## Part 2: Data Preparation

In [None]:
# Sample NER dataset (BIO format)
# B-X: Beginning of entity type X
# I-X: Inside entity type X  
# O: Outside any entity

sample_data = [
    (["Apple", "is", "looking", "at", "buying", "U.K.", "startup", "for", "$1", "billion"],
     ["B-ORG", "O", "O", "O", "O", "B-LOC", "O", "O", "O", "O"]),
    
    (["Tim", "Cook", "is", "the", "CEO", "of", "Apple", "Inc."],
     ["B-PER", "I-PER", "O", "O", "O", "O", "B-ORG", "I-ORG"]),
    
    (["Google", "was", "founded", "in", "California"],
     ["B-ORG", "O", "O", "O", "B-LOC"]),
    
    (["John", "Smith", "works", "at", "Microsoft", "in", "Seattle"],
     ["B-PER", "I-PER", "O", "O", "B-ORG", "O", "B-LOC"]),
    
    (["The", "meeting", "will", "be", "in", "New", "York", "City"],
     ["O", "O", "O", "O", "O", "B-LOC", "I-LOC", "I-LOC"]),
]

# Create more synthetic data
additional_data = [
    (["Barack", "Obama", "was", "born", "in", "Hawaii"],
     ["B-PER", "I-PER", "O", "O", "O", "B-LOC"]),
    
    (["Amazon", "delivers", "packages", "worldwide"],
     ["B-ORG", "O", "O", "O"]),
    
    (["Paris", "is", "the", "capital", "of", "France"],
     ["B-LOC", "O", "O", "O", "O", "B-LOC"]),
    
    (["Elon", "Musk", "founded", "Tesla", "and", "SpaceX"],
     ["B-PER", "I-PER", "O", "B-ORG", "O", "B-ORG"]),
    
    (["The", "United", "Nations", "is", "based", "in", "Geneva"],
     ["O", "B-ORG", "I-ORG", "O", "O", "O", "B-LOC"]),
]

# Combine and split
all_data = sample_data + additional_data * 10  # Duplicate for more training data
train_size = int(0.8 * len(all_data))
train_data = all_data[:train_size]
test_data = all_data[train_size:]

print(f"Train samples: {len(train_data)}")
print(f"Test samples: {len(test_data)}")
print(f"\\nSample:")
for i in range(2):
    words, tags = train_data[i]
    print(f"  Words: {' '.join(words)}")
    print(f"  Tags:  {' '.join(tags)}")

In [None]:
# Build vocabularies
class Vocab:
    def __init__(self):
        self.word2idx = {"<PAD>": 0, "<UNK>": 1}
        self.idx2word = {0: "<PAD>", 1: "<UNK>"}
        self.tag2idx = {"<PAD>": 0}
        self.idx2tag = {0: "<PAD>"}
        
    def build(self, data):
        # Build word vocab
        words = set()
        tags = set()
        for sent_words, sent_tags in data:
            words.update(sent_words)
            tags.update(sent_tags)
        
        for idx, word in enumerate(sorted(words), start=2):
            self.word2idx[word] = idx
            self.idx2word[idx] = word
            
        for idx, tag in enumerate(sorted(tags), start=1):
            self.tag2idx[tag] = idx
            self.idx2tag[idx] = tag
            
        print(f"Vocabulary size: {len(self.word2idx)}")
        print(f"Tag set size: {len(self.tag2idx)}")
        print(f"Tags: {list(self.tag2idx.keys())}")

vocab = Vocab()
vocab.build(train_data)

# Add special tags for CRF
START_TAG = "<START>"
STOP_TAG = "<STOP>"
vocab.tag2idx[START_TAG] = len(vocab.tag2idx)
vocab.tag2idx[STOP_TAG] = len(vocab.tag2idx)
vocab.idx2tag[vocab.tag2idx[START_TAG]] = START_TAG
vocab.idx2tag[vocab.tag2idx[STOP_TAG]] = STOP_TAG

print(f"\\nFinal tag set (with START/STOP): {list(vocab.tag2idx.keys())}")

In [None]:
# Dataset class
class NERDataset(Dataset):
    def __init__(self, data, vocab):
        self.data = data
        self.vocab = vocab
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        words, tags = self.data[idx]
        
        word_ids = [self.vocab.word2idx.get(w, self.vocab.word2idx["<UNK>"]) for w in words]
        tag_ids = [self.vocab.tag2idx[t] for t in tags]
        
        return torch.tensor(word_ids), torch.tensor(tag_ids)

def collate_fn(batch):
    """Collate function for variable length sequences"""
    word_seqs, tag_seqs = zip(*batch)
    lengths = torch.tensor([len(seq) for seq in word_seqs])
    
    # Pad sequences
    word_seqs_padded = pad_sequence(word_seqs, batch_first=True, padding_value=0)
    tag_seqs_padded = pad_sequence(tag_seqs, batch_first=True, padding_value=0)
    
    return word_seqs_padded, tag_seqs_padded, lengths

# Create datasets and dataloaders
train_dataset = NERDataset(train_data, vocab)
test_dataset = NERDataset(test_data, vocab)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

print(f"Train batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")

## Part 3: CRF Layer Implementation

In [None]:
class CRF(nn.Module):
    """
    Conditional Random Field layer
    """
    
    def __init__(self, num_tags, start_tag_idx, stop_tag_idx):
        super(CRF, self).__init__()
        self.num_tags = num_tags
        self.start_tag_idx = start_tag_idx
        self.stop_tag_idx = stop_tag_idx
        
        # Transition parameters: transitions[i, j] = score of transitioning from tag i to tag j
        self.transitions = nn.Parameter(torch.randn(num_tags, num_tags))
        
        # Initialize transitions: no transition to START tag, no transition from STOP tag
        self.transitions.data[start_tag_idx, :] = -10000
        self.transitions.data[:, stop_tag_idx] = -10000
    
    def forward(self, emissions, tags, mask):
        """
        Compute negative log likelihood loss
        
        Args:
            emissions: [batch_size, seq_len, num_tags] - emission scores from BiLSTM
            tags: [batch_size, seq_len] - true tag indices
            mask: [batch_size, seq_len] - mask for padding
        
        Returns:
            loss: negative log likelihood
        """
        # Compute score of the true path
        gold_score = self._score_sentence(emissions, tags, mask)
        
        # Compute log partition function (sum over all possible paths)
        forward_score = self._forward_algorithm(emissions, mask)
        
        # Negative log likelihood
        nll = forward_score - gold_score
        return nll.mean()
    
    def _score_sentence(self, emissions, tags, mask):
        """
        Compute score of the gold tag sequence
        """
        batch_size, seq_len = tags.shape
        
        score = torch.zeros(batch_size).to(emissions.device)
        
        # Add START transition
        first_tags = tags[:, 0]
        score += self.transitions[self.start_tag_idx, first_tags]
        
        # Add emission and transition scores
        for i in range(seq_len):
            current_tags = tags[:, i]
            emission_score = emissions[:, i].gather(1, current_tags.unsqueeze(1)).squeeze(1)
            score += emission_score * mask[:, i]
            
            if i < seq_len - 1:
                next_tags = tags[:, i + 1]
                transition_score = self.transitions[current_tags, next_tags]
                score += transition_score * mask[:, i + 1]
        
        # Add STOP transition
        last_tag_indices = mask.sum(1).long() - 1
        last_tags = tags.gather(1, last_tag_indices.unsqueeze(1)).squeeze(1)
        score += self.transitions[last_tags, self.stop_tag_idx]
        
        return score
    
    def _forward_algorithm(self, emissions, mask):
        """
        Forward algorithm to compute log partition function
        """
        batch_size, seq_len, num_tags = emissions.shape
        
        # Initialize forward variables: alpha[batch, tag]
        alpha = torch.full((batch_size, num_tags), -10000.0).to(emissions.device)
        alpha[:, self.start_tag_idx] = 0.0
        
        # Iterate through the sequence
        for i in range(seq_len):
            emit_score = emissions[:, i]  # [batch, num_tags]
            
            # Broadcast for all possible transitions
            # alpha: [batch, from_tag] -> [batch, from_tag, 1]
            # transitions: [from_tag, to_tag] -> [1, from_tag, to_tag]
            # emit_score: [batch, to_tag] -> [batch, 1, to_tag]
            
            alpha_broadcast = alpha.unsqueeze(2)  # [batch, from_tag, 1]
            emit_broadcast = emit_score.unsqueeze(1)  # [batch, 1, to_tag]
            trans_broadcast = self.transitions.unsqueeze(0)  # [1, from_tag, to_tag]
            
            # Compute scores for all transitions
            next_alpha = alpha_broadcast + trans_broadcast + emit_broadcast  # [batch, from_tag, to_tag]
            
            # Log-sum-exp over from_tag dimension
            next_alpha = torch.logsumexp(next_alpha, dim=1)  # [batch, to_tag]
            
            # Apply mask
            alpha = next_alpha * mask[:, i].unsqueeze(1) + alpha * (1 - mask[:, i].unsqueeze(1))
        
        # Add transition to STOP tag
        alpha = alpha + self.transitions[:, self.stop_tag_idx].unsqueeze(0)
        
        # Log-sum-exp over all tags
        return torch.logsumexp(alpha, dim=1)
    
    def decode(self, emissions, mask):
        """
        Viterbi decoding to find the best tag sequence
        
        Args:
            emissions: [batch_size, seq_len, num_tags]
            mask: [batch_size, seq_len]
        
        Returns:
            best_paths: [batch_size, seq_len]
        """
        batch_size, seq_len, num_tags = emissions.shape
        
        # Initialize viterbi variables
        viterbi = torch.full((batch_size, num_tags), -10000.0).to(emissions.device)
        viterbi[:, self.start_tag_idx] = 0.0
        
        # Backpointers
        backpointers = []
        
        # Forward pass
        for i in range(seq_len):
            emit_score = emissions[:, i]
            
            # Broadcast for all transitions
            viterbi_broadcast = viterbi.unsqueeze(2)  # [batch, from_tag, 1]
            trans_broadcast = self.transitions.unsqueeze(0)  # [1, from_tag, to_tag]
            emit_broadcast = emit_score.unsqueeze(1)  # [batch, 1, to_tag]
            
            # Compute scores
            next_scores = viterbi_broadcast + trans_broadcast + emit_broadcast  # [batch, from_tag, to_tag]
            
            # Find best previous tag
            next_viterbi, best_tags = next_scores.max(dim=1)  # [batch, to_tag]
            
            backpointers.append(best_tags)
            
            # Apply mask
            viterbi = next_viterbi * mask[:, i].unsqueeze(1) + viterbi * (1 - mask[:, i].unsqueeze(1))
        
        # Add transition to STOP
        viterbi = viterbi + self.transitions[:, self.stop_tag_idx].unsqueeze(0)
        
        # Find best last tag
        _, best_last_tags = viterbi.max(dim=1)  # [batch]
        
        # Backward pass to reconstruct paths
        best_paths = []
        for batch_idx in range(batch_size):
            path = [best_last_tags[batch_idx].item()]
            
            # Backtrack
            for i in range(seq_len - 1, 0, -1):
                if mask[batch_idx, i] == 0:
                    continue
                prev_tag = backpointers[i][batch_idx, path[-1]].item()
                path.append(prev_tag)
            
            path.reverse()
            
            # Pad to seq_len
            while len(path) < seq_len:
                path.append(0)  # PAD tag
            
            best_paths.append(path)
        
        return torch.tensor(best_paths).to(emissions.device)

## Part 4: BiLSTM-CRF Model

In [None]:
class BiLSTM_CRF(nn.Module):
    """
    BiLSTM-CRF model for sequence labeling
    """
    
    def __init__(self, vocab_size, tag_size, embedding_dim=100, hidden_dim=128, 
                 num_layers=1, dropout=0.5, start_tag_idx=None, stop_tag_idx=None):
        super(BiLSTM_CRF, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        
        # BiLSTM
        self.bilstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        self.dropout = nn.Dropout(dropout)
        
        # Emission layer: BiLSTM output -> tag scores
        self.hidden2tag = nn.Linear(hidden_dim * 2, tag_size)
        
        # CRF layer
        self.crf = CRF(tag_size, start_tag_idx, stop_tag_idx)
    
    def forward(self, sentences, tags, lengths):
        """
        Forward pass for training
        
        Args:
            sentences: [batch_size, seq_len]
            tags: [batch_size, seq_len]
            lengths: [batch_size]
        
        Returns:
            loss: negative log likelihood
        """
        # Get emission scores from BiLSTM
        emissions = self._get_emissions(sentences, lengths)
        
        # Create mask
        mask = self._create_mask(sentences, lengths)
        
        # Compute CRF loss
        loss = self.crf(emissions, tags, mask)
        
        return loss
    
    def _get_emissions(self, sentences, lengths):
        """Get emission scores from BiLSTM"""
        # Embedding
        embeds = self.embedding(sentences)  # [batch, seq_len, embed_dim]
        embeds = self.dropout(embeds)
        
        # Pack sequences
        packed_embeds = pack_padded_sequence(
            embeds, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        
        # BiLSTM
        packed_lstm_out, _ = self.bilstm(packed_embeds)
        
        # Unpack
        lstm_out, _ = pad_packed_sequence(packed_lstm_out, batch_first=True)
        
        # Dropout
        lstm_out = self.dropout(lstm_out)
        
        # Emission scores
        emissions = self.hidden2tag(lstm_out)  # [batch, seq_len, num_tags]
        
        return emissions
    
    def _create_mask(self, sentences, lengths):
        """Create mask for padded positions"""
        batch_size, seq_len = sentences.shape
        mask = torch.zeros(batch_size, seq_len).to(sentences.device)
        
        for i, length in enumerate(lengths):
            mask[i, :length] = 1
        
        return mask
    
    def predict(self, sentences, lengths):
        """
        Predict tags using Viterbi decoding
        
        Args:
            sentences: [batch_size, seq_len]
            lengths: [batch_size]
        
        Returns:
            predicted_tags: [batch_size, seq_len]
        """
        # Get emission scores
        emissions = self._get_emissions(sentences, lengths)
        
        # Create mask
        mask = self._create_mask(sentences, lengths)
        
        # Viterbi decoding
        best_paths = self.crf.decode(emissions, mask)
        
        return best_paths

# Create model
VOCAB_SIZE = len(vocab.word2idx)
TAG_SIZE = len(vocab.tag2idx)
EMBEDDING_DIM = 100
HIDDEN_DIM = 128
NUM_LAYERS = 1
DROPOUT = 0.3

START_TAG_IDX = vocab.tag2idx[START_TAG]
STOP_TAG_IDX = vocab.tag2idx[STOP_TAG]

model = BiLSTM_CRF(
    VOCAB_SIZE, TAG_SIZE, EMBEDDING_DIM, HIDDEN_DIM,
    NUM_LAYERS, DROPOUT, START_TAG_IDX, STOP_TAG_IDX
).to(device)

print("BiLSTM-CRF Model")
print(model)
print(f"\\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

## Part 5: Training

In [None]:
# Training functions
def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    
    for sentences, tags, lengths in tqdm(dataloader, desc="Training"):
        sentences = sentences.to(device)
        tags = tags.to(device)
        lengths = lengths.to(device)
        
        optimizer.zero_grad()
        loss = model(sentences, tags, lengths)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

def evaluate(model, dataloader, vocab, device):
    model.eval()
    all_preds = []
    all_true = []
    
    with torch.no_grad():
        for sentences, tags, lengths in tqdm(dataloader, desc="Evaluating"):
            sentences = sentences.to(device)
            lengths = lengths.to(device)
            
            # Predict
            pred_tags = model.predict(sentences, lengths)
            
            # Convert to tag names for seqeval
            for i, length in enumerate(lengths):
                true_tags = [vocab.idx2tag[t.item()] for t in tags[i, :length]]
                pred_tag_ids = [t.item() for t in pred_tags[i, :length]]
                pred_tag_names = [vocab.idx2tag[t] for t in pred_tag_ids]
                
                # Filter out PAD, START, STOP tags
                true_tags_filtered = [t for t in true_tags if t not in ['<PAD>', START_TAG, STOP_TAG]]
                pred_tags_filtered = [t for t in pred_tag_names if t not in ['<PAD>', START_TAG, STOP_TAG]]
                
                if len(true_tags_filtered) > 0:
                    all_true.append(true_tags_filtered)
                    all_preds.append(pred_tags_filtered[:len(true_tags_filtered)])
    
    # Calculate metrics
    precision = precision_score(all_true, all_preds)
    recall = recall_score(all_true, all_preds)
    f1 = f1_score(all_true, all_preds)
    
    return precision, recall, f1, all_true, all_preds

# Training loop
N_EPOCHS = 20
LEARNING_RATE = 0.01

optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)

history = {'train_loss': [], 'precision': [], 'recall': [], 'f1': []}
best_f1 = 0

print("Training BiLSTM-CRF")

for epoch in range(N_EPOCHS):
    print(f"\\nEpoch {epoch+1}/{N_EPOCHS}")
    
    train_loss = train_epoch(model, train_loader, optimizer, device)
    precision, recall, f1, _, _ = evaluate(model, test_loader, vocab, device)
    
    history['train_loss'].append(train_loss)
    history['precision'].append(precision)
    history['recall'].append(recall)
    history['f1'].append(f1)
    
    print(f"\\nTrain Loss: {train_loss:.4f}")
    print(f"Precision: {precision:.4f} | Recall: {recall:.4f} | F1: {f1:.4f}")
    
    if f1 > best_f1:
        best_f1 = f1
        torch.save(model.state_dict(), 'bilstm_crf_best.pt')
        print(f" Saved best model (F1: {best_f1:.4f})")

print(f"Training Complete! Best F1: {best_f1:.4f}")

## Part 6: Visualization và Analysis

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

epochs = range(1, N_EPOCHS + 1)

# Loss plot
ax = axes[0]
ax.plot(epochs, history['train_loss'], 'b-', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('Training Loss', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)

# Metrics plot
ax = axes[1]
ax.plot(epochs, history['precision'], 'r-', label='Precision', linewidth=2)
ax.plot(epochs, history['recall'], 'g-', label='Recall', linewidth=2)
ax.plot(epochs, history['f1'], 'b-', label='F1-Score', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Score', fontsize=12)
ax.set_title('Evaluation Metrics', fontsize=14, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
# plt.savefig('bilstm_crf_history.png', dpi=150)
plt.show()

In [None]:
# Detailed classification report
model.load_state_dict(torch.load('bilstm_crf_best.pt'))
_, _, _, all_true, all_preds = evaluate(model, test_loader, vocab, device)

print("Classification Report")
print(seq_report(all_true, all_preds))

# Sample predictions
def predict_sentence(model, sentence, vocab, device):
    """Predict NER tags for a sentence"""
    model.eval()
    
    # Tokenize (simple space split)
    words = sentence.split()
    
    # Convert to indices
    word_ids = [vocab.word2idx.get(w, vocab.word2idx["<UNK>"]) for w in words]
    
    # Create tensors
    sentence_tensor = torch.tensor([word_ids]).to(device)
    length_tensor = torch.tensor([len(word_ids)]).to(device)
    
    # Predict
    with torch.no_grad():
        pred_tags = model.predict(sentence_tensor, length_tensor)
    
    # Convert to tag names
    pred_tag_names = [vocab.idx2tag[t.item()] for t in pred_tags[0, :len(words)]]
    
    return list(zip(words, pred_tag_names))

# Test on sample sentences
print("Sample Predictions")

test_sentences = [
    "Apple is looking at buying U.K. startup",
    "Tim Cook is the CEO of Apple Inc.",
    "Google was founded in California",
    "John Smith works at Microsoft in Seattle"
]

for sent in test_sentences:
    result = predict_sentence(model, sent, vocab, device)
    print(f"\\nSentence: {sent}")
    print("Predictions:")
    for word, tag in result:
        if tag != 'O':
            print(f"  {word:<15} -> {tag}")
    if not any(tag != 'O' for _, tag in result):
        print("  (No entities detected)")