# Model 6: BERT + CRF Hybrid for Named Entity Recognition

This notebook implements a BERT model with a CRF layer on top for structured prediction. The CRF layer ensures valid BIO sequences and typically improves performance by 1-2% F1 over BERT alone.

**Key Improvements over BERT alone:**
- CRF learns transition constraints between tags
- Viterbi decoding ensures globally optimal tag sequences
- Prevents invalid BIO sequences (e.g., O ‚Üí I-Person)
- Better boundary detection

**Expected Performance:** 90-92% F1 score

In [None]:
# Import libraries
import json
import numpy as np
import pandas as pd
from collections import Counter, defaultdict
import warnings
warnings.filterwarnings('ignore')

# PyTorch and Transformers
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    BertModel,
    BertConfig,
    AdamW,
    get_linear_schedule_with_warmup
)
from datasets import Dataset as HFDataset

# CRF layer
# Install if not already installed: pip install pytorch-crf
from torchcrf import CRF

# Import utils for evaluation
from utils import extract_entities, evaluate_entity_spans, print_evaluation_report

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"CUDA version: {torch.version.cuda}")

print("Libraries imported successfully!")

## 1. Load and Prepare Data

In [None]:
def load_jsonl(file_path):
    """Load JSONL file into a list of dictionaries"""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data

def clean_data(data):
    """Remove samples with invalid BIO sequences"""
    cleaned_data = []
    invalid_count = 0

    for sample in data:
        prev_tag = 'O'
        has_issue = False

        for tag in sample['ner_tags']:
            if tag.startswith('I-'):
                entity_type = tag[2:]
                # Check if I- follows B- or I- of same type
                if not (prev_tag == f'B-{entity_type}' or prev_tag == f'I-{entity_type}'):
                    has_issue = True
                    break
            prev_tag = tag

        if not has_issue:
            cleaned_data.append(sample)
        else:
            invalid_count += 1

    print(f"Removed {invalid_count} samples with invalid BIO sequences")
    return cleaned_data

# Load all training data and test data
train_data_all = load_jsonl('train_data.jsonl')
test_data = load_jsonl('test_data.jsonl')

print(f"Total training samples: {len(train_data_all):,}")
print(f"Test samples: {len(test_data):,}")

# Clean training data
train_data_cleaned = clean_data(train_data_all)
print(f"Training samples after cleaning: {len(train_data_cleaned):,}")

# Create stratification labels based on presence of entities
from sklearn.model_selection import train_test_split

stratify_labels = []
for sample in train_data_cleaned:
    has_entities = any(tag != 'O' for tag in sample['ner_tags'])
    stratify_labels.append(int(has_entities))

# Split into train and validation (90/10 split, same as other models)
train_data, val_data = train_test_split(
    train_data_cleaned,
    test_size=0.1,
    random_state=42,
    stratify=stratify_labels
)

print(f"\nTraining samples: {len(train_data):,}")
print(f"Validation samples: {len(val_data):,}")
print(f"Test samples: {len(test_data):,}")

# Show example
print("\nExample training sample:")
print(json.dumps(train_data[0], indent=2))

## 2. Analyze Tag Set and Initialize BERT+CRF Model

In [None]:
# Get all unique tags
all_tags = set()
tag_counts = Counter()

for sample in train_data:
    all_tags.update(sample['ner_tags'])
    tag_counts.update(sample['ner_tags'])

print(f"Unique tags: {len(all_tags)}")
print("\nTag distribution:")
for tag, count in sorted(tag_counts.items(), key=lambda x: x[1], reverse=True):
    percentage = count / sum(tag_counts.values()) * 100
    print(f"  {tag:20s}: {count:8,} ({percentage:5.2f}%)")

# Create tag lists
label_list = sorted(list(all_tags))
num_labels = len(label_list)
label2id = {label: i for i, label in enumerate(label_list)}
id2label = {i: label for label, i in label2id.items()}

print(f"\nNumber of labels: {num_labels}")
print(f"Labels: {label_list}")

### 2.1 BERT+CRF Model Architecture

In [None]:
class BertCRF(nn.Module):
    """BERT model with CRF layer for sequence labeling"""
    
    def __init__(self, bert_model_name, num_tags):
        super(BertCRF, self).__init__()
        
        # BERT encoder
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = nn.Dropout(0.1)
        
        # Emission layer (maps BERT hidden states to tag scores)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_tags)
        
        # CRF layer
        self.crf = CRF(num_tags, batch_first=True)
        
    def forward(self, input_ids, attention_mask, labels=None):
        """Forward pass
        
        Args:
            input_ids: Token IDs (batch_size, seq_len)
            attention_mask: Attention mask (batch_size, seq_len)
            labels: Optional label IDs (batch_size, seq_len)
            
        Returns:
            If labels provided: loss (negative log-likelihood)
            If labels not provided: list of predicted tag sequences
        """
        # BERT forward pass
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        sequence_output = outputs[0]  # (batch, seq_len, hidden_size)
        sequence_output = self.dropout(sequence_output)
        
        # Emission scores
        emissions = self.classifier(sequence_output)  # (batch, seq_len, num_tags)
        
        if labels is not None:
            # Training: compute CRF loss
            # Convert attention_mask to byte mask for CRF
            mask = attention_mask.type(torch.uint8)
            
            # Filter out -100 labels (ignored tokens)
            labels_masked = labels.clone()
            labels_masked[labels == -100] = 0  # CRF requires valid indices
            
            # Compute negative log-likelihood
            loss = -self.crf(emissions, labels_masked, mask=mask, reduction='mean')
            return loss
        else:
            # Inference: CRF decoding
            mask = attention_mask.type(torch.uint8)
            predictions = self.crf.decode(emissions, mask=mask)
            return predictions

# Initialize tokenizer
model_name = "bert-base-cased"  # Use cased for NER (preserves capitalization)
tokenizer = AutoTokenizer.from_pretrained(model_name)

print(f"Loading model: {model_name}")

# Initialize BERT+CRF model
model = BertCRF(model_name, num_tags=num_labels)
model.to(device)

print(f"Model loaded successfully!")
print(f"Model parameters: {model.num_parameters():,}")
print(f"Tokenizer vocab size: {tokenizer.vocab_size:,}")
print(f"Number of tags: {num_tags}")

## 3. Create Dataset and Data Loader

In [None]:
class NERDataset(Dataset):
    """Dataset for NER with BERT tokenization"""
    
    def __init__(self, data, tokenizer, label2id, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.label2id = label2id
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        tokens = item['tokens']
        labels = item['ner_tags']
        
        # Tokenize
        tokenized = self.tokenizer(
            tokens,
            is_split_into_words=True,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        # Align labels with tokenized input
        word_ids = tokenized.word_ids()
        aligned_labels = []
        previous_word_idx = None
        
        for word_idx in word_ids:
            if word_idx is None:
                # Special tokens
                aligned_labels.append(-100)
            elif word_idx != previous_word_idx:
                # First token of word
                aligned_labels.append(self.label2id[labels[word_idx]])
            else:
                # Subsequent tokens of same word
                aligned_labels.append(-100)
            previous_word_idx = word_idx
        
        # Convert to tensors
        input_ids = tokenized['input_ids'].squeeze()
        attention_mask = tokenized['attention_mask'].squeeze()
        labels = torch.tensor(aligned_labels, dtype=torch.long)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

# Create datasets
train_dataset = NERDataset(train_data, tokenizer, label2id)
val_dataset = NERDataset(val_data, tokenizer, label2id)

print(f"Training dataset size: {len(train_dataset):,}")
print(f"Validation dataset size: {len(val_dataset):,}")

# Check an example
example = train_dataset[0]
print("\nExample from dataset:")
print(f"  Input IDs shape: {example['input_ids'].shape}")
print(f"  Attention mask shape: {example['attention_mask'].shape}")
print(f"  Labels shape: {example['labels'].shape}")
print(f"  Number of -100 labels: {(example['labels'] == -100).sum().item()}")

## 4. Training Functions

In [None]:
def train_epoch(model, dataloader, optimizer, scheduler, device):
    """Train model for one epoch"""
    model.train()
    total_loss = 0
    num_batches = len(dataloader)
    
    for batch_idx, batch in enumerate(dataloader):
        # Move to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # Clear gradients
        optimizer.zero_grad()
        
        # Forward pass
        loss = model(input_ids, attention_mask, labels)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Update weights
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        
        # Log progress
        if (batch_idx + 1) % 100 == 0:
            print(f"  Batch {batch_idx + 1}/{num_batches}, Loss: {loss.item():.4f}")
    
    avg_loss = total_loss / num_batches
    return avg_loss


def evaluate(model, dataloader, device, id2label):
    """Evaluate model and compute metrics"""
    model.eval()
    total_loss = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for batch in dataloader:
            # Move to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            # Get loss
            loss = model(input_ids, attention_mask, labels)
            total_loss += loss.item()
            
            # Get predictions
            predictions = model(input_ids, attention_mask)
            
            # Convert to tag sequences
            for i, (pred_seq, label_seq) in enumerate(zip(predictions, labels)):
                pred_tags = []
                true_tags = []
                
                for j, (pred, label) in enumerate(zip(pred_seq, label_seq.tolist())):
                    if label != -100 and j < len(pred_seq):  # Valid token
                        pred_tags.append(id2label[pred])
                        true_tags.append(id2label[label])
                
                if len(true_tags) > 0:  # Only add non-empty sequences
                    all_predictions.append(pred_tags)
                    all_labels.append(true_tags)
    
    # Compute metrics
    results = evaluate_entity_spans(all_labels, all_predictions)
    avg_loss = total_loss / len(dataloader)
    
    return {
        'loss': avg_loss,
        'precision': results['precision'],
        'recall': results['recall'],
        'f1': results['f1']
    }


print("Training functions defined successfully!")

## 5. Train BERT+CRF Model

In [None]:
# Create data loaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"Data loaders created with batch size {batch_size}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

# Initialize optimizer and scheduler
num_epochs = 4
learning_rate = 5e-5
weight_decay = 0.01
warmup_steps = 500

optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

print(f"\nTraining configuration:")
print(f"  Learning rate: {learning_rate}")
print(f"  Batch size: {batch_size}")
print(f"  Epochs: {num_epochs}")
print(f"  Warmup steps: {warmup_steps}")
print(f"  Weight decay: {weight_decay}")
print(f"  Total training steps: {total_steps:,}")

In [None]:
# Training loop
import time

best_f1 = 0
patience = 2
patience_counter = 0
training_start_time = time.time()

print("Starting BERT+CRF training...")
print("=" * 60)

for epoch in range(num_epochs):
    epoch_start_time = time.time()
    
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    print("-" * 40)
    
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, scheduler, device)
    
    # Evaluate
    val_metrics = evaluate(model, val_loader, device, id2label)
    
    epoch_time = time.time() - epoch_start_time
    
    print(f"\nEpoch {epoch + 1} Results:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss:   {val_metrics['loss']:.4f}")
    print(f"  Precision:  {val_metrics['precision']:.4f}")
    print(f"  Recall:     {val_metrics['recall']:.4f}")
    print(f"  F1 Score:   {val_metrics['f1']:.4f}")
    print(f"  Time:       {epoch_time:.1f}s")
    
    # Check if this is the best model
    if val_metrics['f1'] > best_f1:
        best_f1 = val_metrics['f1']
        patience_counter = 0
        
        # Save best model
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'f1_score': best_f1,
            'label2id': label2id,
            'id2label': id2label
        }, 'bert_crf_best_model.pt')
        print(f"  ‚úÖ New best model saved! (F1: {best_f1:.4f})")
    else:
        patience_counter += 1
        print(f"  ‚è≥ No improvement. Patience: {patience_counter}/{patience}")
    
    # Early stopping
    if patience_counter >= patience:
        print(f"\nEarly stopping triggered after {epoch + 1} epochs")
        break

training_time = time.time() - training_start_time
print(f"\nTraining completed in {training_time/60:.2f} minutes")
print(f"Best validation F1: {best_f1:.4f}")

## 6. Load Best Model and Final Evaluation

In [None]:
# Load best model
checkpoint = torch.load('bert_crf_best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])

print(f"Loaded best model from epoch {checkpoint['epoch']} with F1: {checkpoint['f1_score']:.4f}")

# Final evaluation
print("\nFinal evaluation on validation set:")
print("=" * 60)

val_metrics = evaluate(model, val_loader, device, id2label)

print("\nBERT+CRF Final Results:")
print(f"  Precision: {val_metrics['precision']:.4f}")
print(f"  Recall:    {val_metrics['recall']:.4f}")
print(f"  F1 Score:  {val_metrics['f1']:.4f}")
print(f"  Validation loss: {val_metrics['loss']:.4f}")

# Store results for comparison
bert_crf_results = {
    'model': 'BERT+CRF',
    'precision': val_metrics['precision'],
    'recall': val_metrics['recall'],
    'f1': val_metrics['f1'],
    'training_time': training_time,
    'parameters': model.num_parameters()
}

# Compare with expected from implementation plan
expected_f1_min = 0.90
actual_f1 = val_metrics['f1']

print(f"\nPerformance Analysis:")
print(f"  Expected F1 minimum: {expected_f1_min:.2f}")
print(f"  Actual F1 score:     {actual_f1:.4f}")

if actual_f1 >= expected_f1_min:
    print("  ‚úÖ F1 score meets expectations!")
else:
    print("  ‚ö†Ô∏è  F1 score below expected minimum")

## 7. Analyze Model Performance

In [None]:
# Get detailed predictions on validation set for analysis
print("Generating detailed predictions for analysis...")

model.eval()
val_sentences = [sample['tokens'] for sample in val_data]
val_true_tags = [sample['ner_tags'] for sample in val_data]
val_pred_tags = []

with torch.no_grad():
    for batch in val_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # Get predictions
        predictions = model(input_ids, attention_mask)
        
        # Convert to tag sequences
        for i, (pred_seq, label_seq) in enumerate(zip(predictions, labels)):
            pred_tags = []
            
            for j, (pred, label) in enumerate(zip(pred_seq, label_seq.tolist())):
                if label != -100 and j < len(pred_seq):  # Valid token
                    pred_tags.append(id2label[pred])
            
            val_pred_tags.append(pred_tags)

# Show examples
print("\nPrediction examples:")
for i in range(5):
    if i < len(val_sentences):
        print(f"\nExample {i+1}:")
        tokens = val_sentences[i][:15]
        true_tags = val_true_tags[i][:15]
        pred_tags = val_pred_tags[i][:15]
        
        print(f"  Tokens:    {tokens}")
        print(f"  True:      {true_tags}")
        print(f"  Predicted: {pred_tags}")
        
        # Count correct predictions
        correct = sum(1 for t, p in zip(true_tags, pred_tags) if t == p)
        total = len(true_tags)
        print(f"  Accuracy:  {correct}/{total} ({correct/total:.2%})")

# Check for valid BIO sequences
print("\nChecking BIO sequence validity...")
invalid_count = 0
for pred_tags in val_pred_tags:
    prev_tag = 'O'
    for tag in pred_tags:
        if tag.startswith('I-'):
            entity_type = tag[2:]
            if not (prev_tag == f'B-{entity_type}' or prev_tag == f'I-{entity_type}'):
                invalid_count += 1
                break
        prev_tag = tag

print(f"Invalid BIO sequences: {invalid_count}/{len(val_pred_tags)}")
if invalid_count == 0:
    print("‚úÖ All predicted BIO sequences are valid!")

## 8. Error Analysis

In [None]:
def analyze_errors(true_tags_list, pred_tags_list, sentences_list, num_examples=5):
    """Analyze common error patterns"""
    
    false_positives = []  # Predicted entity that's not actually an entity
    false_negatives = []  # Missed entity
    wrong_type = []       # Correct span but wrong entity type
    
    for true_tags, pred_tags, tokens in zip(true_tags_list, pred_tags_list, sentences_list):
        true_entities = extract_entities(tokens, true_tags)
        pred_entities = extract_entities(tokens, pred_tags)
        
        true_spans = {(start, end, entity_type) for _, entity_type, start, end in true_entities}
        pred_spans = {(start, end, entity_type) for _, entity_type, start, end in pred_entities}
        
        # False positives: predicted but not true
        fp_spans = pred_spans - true_spans
        for start, end, entity_type in fp_spans:
            entity_text = ' '.join(tokens[start:end+1])
            false_positives.append((entity_text, entity_type, tokens))
        
        # False negatives: true but not predicted
        fn_spans = true_spans - pred_spans
        for start, end, entity_type in fn_spans:
            entity_text = ' '.join(tokens[start:end+1])
            false_negatives.append((entity_text, entity_type, tokens))
        
        # Wrong type: same span but different type
        true_span_dict = {(start, end): entity_type for start, end, entity_type in true_spans}
        pred_span_dict = {(start, end): entity_type for start, end, entity_type in pred_spans}
        
        common_spans = set(true_span_dict.keys()) & set(pred_span_dict.keys())
        for start, end in common_spans:
            true_type = true_span_dict[(start, end)]
            pred_type = pred_span_dict[(start, end)]
            if true_type != pred_type:
                entity_text = ' '.join(tokens[start:end+1])
                wrong_type.append((entity_text, true_type, pred_type, tokens))
    
    return false_positives, false_negatives, wrong_type

# Analyze errors
fp, fn, wt = analyze_errors(val_true_tags, val_pred_tags, val_sentences)

print("\nError Analysis:")
print(f"False Positives: {len(fp):,} (predicted entities that shouldn't exist)")
print(f"False Negatives: {len(fn):,} (missed entities)")
print(f"Wrong Type:      {len(wt):,} (correct span, wrong entity type)")

# Show examples of each error type
print("\nExample False Positives:")
for i, (text, pred_type, tokens) in enumerate(fp[:3]):
    print(f"  {i+1}. '{text}' ‚Üí predicted as {pred_type}")

print("\nExample False Negatives:")
for i, (text, true_type, tokens) in enumerate(fn[:3]):
    print(f"  {i+1}. '{text}' ‚Üí missed {true_type}")

print("\nExample Wrong Types:")
for i, (text, true_type, pred_type, tokens) in enumerate(wt[:3]):
    print(f"  {i+1}. '{text}' ‚Üí true: {true_type}, predicted: {pred_type}")

## 9. Generate Test Predictions

In [None]:
def predict_bert_crf(tokens_list, model, tokenizer, batch_size=32, max_length=128):
    """Predict NER tags using BERT+CRF model"""
    model.eval()
    predictions = []
    
    # Process in batches
    for i in range(0, len(tokens_list), batch_size):
        if i % 500 == 0:
            print(f"  Processing {i:,}/{len(tokens_list):,}")
        
        batch_tokens = tokens_list[i:i+batch_size]
        
        # Tokenize batch
        tokenized = tokenizer(
            batch_tokens,
            is_split_into_words=True,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        )
        
        # Move to device
        for k, v in tokenized.items():
            tokenized[k] = v.to(device)
        
        # Get predictions
        with torch.no_grad():
            pred_sequences = model(
                tokenized['input_ids'],
                tokenized['attention_mask']
            )
        
        # Convert predictions back to tag sequences
        for j, (tokens, pred_seq) in enumerate(zip(batch_tokens, pred_sequences)):
            word_ids = tokenized.word_ids(batch_index=j)
            pred_tags = []
            
            for k, word_id in enumerate(word_ids):
                if word_id is not None and word_id < len(tokens):
                    # Only keep prediction for first subword token
                    if k == 0 or word_ids[k-1] != word_id:
                        if k < len(pred_seq):
                            pred_tags.append(id2label[pred_seq[k]])
            
            predictions.append(pred_tags)
    
    return predictions

# Prepare test data
test_sentences = [sample['tokens'] for sample in test_data]

print(f"\nGenerating predictions for {len(test_sentences):,} test sentences...")
test_pred_tags = predict_bert_crf(test_sentences, model, tokenizer)

print(f"\nTest predictions complete!")
print(f"Generated predictions for {len(test_pred_tags):,} sentences")

# Verify format
print("\nTest prediction examples:")
for i in range(3):
    sample = test_data[i]
    tokens = sample['tokens']
    pred_tags = test_pred_tags[i]
    print(f"\nExample {i+1} (ID: {sample['id']}):")
    print(f"  Tokens:    {tokens[:10]}{'...' if len(tokens) > 10 else ''}")
    print(f"  Predicted: {pred_tags[:10]}{'...' if len(pred_tags) > 10 else ''}")
    print(f"  Length match: {len(tokens) == len(pred_tags)}")

In [None]:
# Add predictions to test data
test_data_with_predictions = []
for sample, pred_tags in zip(test_data, test_pred_tags):
    sample_copy = sample.copy()
    sample_copy['ner_tags'] = pred_tags
    test_data_with_predictions.append(sample_copy)

# Validate predictions
print("Validating test predictions...")

validation_errors = []
for i, sample in enumerate(test_data_with_predictions):
    # Check length match
    if len(sample['tokens']) != len(sample['ner_tags']):
        validation_errors.append(f"Sample {i}: Length mismatch")
    
    # Check for valid tags
    valid_tags = set(label_list)
    for tag in sample['ner_tags']:
        if tag not in valid_tags:
            validation_errors.append(f"Sample {i}: Invalid tag '{tag}'")
            break

if validation_errors:
    print(f"Found {len(validation_errors)} validation errors:")
    for error in validation_errors[:5]:
        print(f"  - {error}")
else:
    print("‚úì All validations passed!")

# Check BIO sequence validity
bio_errors = 0
for sample in test_data_with_predictions:
    prev_tag = 'O'
    for tag in sample['ner_tags']:
        if tag.startswith('I-'):
            entity_type = tag[2:]
            if not (prev_tag == f'B-{entity_type}' or prev_tag == f'I-{entity_type}'):
                bio_errors += 1
                break
        prev_tag = tag

print(f"\nBIO sequence errors: {bio_errors}")
if bio_errors == 0:
    print("‚úÖ All predicted BIO sequences are valid!")

# Save predictions
output_file = 'test_data_bert_crf_predictions.jsonl'
with open(output_file, 'w', encoding='utf-8') as f:
    for sample in test_data_with_predictions:
        f.write(json.dumps(sample) + '\n')

print(f"\nSaved predictions to: {output_file}")

# Generate statistics
all_test_tags = []
for sample in test_data_with_predictions:
    all_test_tags.extend(sample['ner_tags'])

tag_counts = Counter(all_test_tags)
print(f"\nTest prediction statistics:")
print(f"  Total tokens: {len(all_test_tags):,}")
print(f"  Tag distribution:")
for tag, count in sorted(tag_counts.items(), key=lambda x: x[1], reverse=True):
    percentage = count / len(all_test_tags) * 100
    print(f"    {tag:20s}: {count:8,} ({percentage:5.2f}%)")

# Count predicted entities
test_entities = []
for sample in test_data_with_predictions:
    entities = extract_entities(sample['tokens'], sample['ner_tags'])
    test_entities.extend(entities)

entity_type_counts = Counter(entity_type for _, entity_type, _, _ in test_entities)
print(f"\nPredicted entities: {len(test_entities):,}")
print(f"  Entity type distribution:")
for entity_type, count in sorted(entity_type_counts.items(), key=lambda x: x[1], reverse=True):
    print(f"    {entity_type:20s}: {count:6,}")

## 10. Summary and Results

In [None]:
print("=" * 80)
print("BERT+CRF HYBRID MODEL - SUMMARY")
print("=" * 80)

print(f"\nüìä Model Performance:")
print(f"   Precision: {bert_crf_results['precision']:.4f}")
print(f"   Recall:    {bert_crf_results['recall']:.4f}")
print(f"   F1 Score:  {bert_crf_results['f1']:.4f}")

print(f"\nüîß Model Details:")
print(f"   Base model:           {model_name}")
print(f"   Architecture:         BERT + CRF")
print(f"   Parameters:           {bert_crf_results['parameters']:,}")
print(f"   Number of labels:     {num_labels}")
print(f"   Training time:        {bert_crf_results['training_time']/60:.1f} minutes")
print(f"   Training samples:     {len(train_dataset):,}")
print(f"   Validation samples:   {len(val_dataset):,}")

print(f"\n‚öôÔ∏è Training Configuration:")
print(f"   Learning rate:        {learning_rate}")
print(f"   Batch size:           {batch_size}")
print(f"   Epochs:               {num_epochs} (with early stopping)")
print(f"   Warmup steps:         {warmup_steps}")
print(f"   Weight decay:         {weight_decay}")
print(f"   Max sequence length:  128")
print(f"   Gradient clipping:    1.0")

print(f"\nüìã Test Predictions:")
print(f"   Test sentences:       {len(test_data):,}")
print(f"   Predicted entities:   {len(test_entities):,}")
print(f"   Output file:          {output_file}")
print(f"   Model saved:          bert_crf_best_model.pt")

print(f"\n‚úÖ Implementation Status:")
print(f"   ‚úì BERT + CRF architecture implemented")
print(f"   ‚úì CRF enforces valid BIO sequences")
print(f"   ‚úì Viterbi decoding for optimal tag sequences")
print(f"   ‚úì Evaluated with entity-span level metrics")
print(f"   ‚úì Generated test predictions")
print(f"   ‚úì All BIO sequences are valid (checked: {bio_errors == 0})")

# Performance analysis
expected_range = (0.90, 0.92)
actual_f1 = bert_crf_results['f1']

print(f"\nüéØ Performance Analysis:")
print(f"   Expected F1 range:  {expected_range[0]:.2f} - {expected_range[1]:.2f}")
print(f"   Actual F1 score:    {actual_f1:.4f}")

if actual_f1 >= expected_range[0]:
    if actual_f1 <= expected_range[1]:
        print(f"   ‚úÖ Performance meets expectations!")
    else:
        print(f"   üöÄ Performance exceeds expectations!")
else:
    print(f"   ‚ö†Ô∏è  Performance below expected range")

print(f"\nüí° Key Strengths:")
print(f"   ‚Ä¢ BERT's contextual representations + CRF's structure")
print(f"   ‚Ä¢ Guaranteed valid BIO sequences via CRF")
print(f"   ‚Ä¢ Better boundary detection")
print(f"   ‚Ä¢ Global sequence optimization")
print(f"   ‚Ä¢ Research-proven 1-2% F1 improvement over BERT alone")

print(f"\n‚ö†Ô∏è  Limitations:")
print(f"   ‚Ä¢ Slightly slower training than BERT alone")
print(f"   ‚Ä¢ More complex implementation")
print(f"   ‚Ä¢ Still large memory footprint")
print(f"   ‚Ä¢ Requires careful label alignment")

print(f"\nüîú Possible Improvements:")
print(f"   ‚Ä¢ Use larger BERT model (bert-large-cased)")
print(f"   ‚Ä¢ Add character-level CNN for OOV words")
print(f"   ‚Ä¢ Implement weighted loss for class imbalance")
print(f"   ‚Ä¢ Try different learning rate schedules")
print(f"   ‚Ä¢ Add self-attention layer")

print("\n" + "=" * 80)
print("BERT+CRF HYBRID MODEL COMPLETE!")
print("=" * 80)