# M10: KB-NER (Knowledge-Based Named Entity Recognition)

## Overview
Implementation of Alibaba DAMO-NLP's winning system from SemEval 2022 MultiCoNER.

**Winner**: 10 out of 13 tracks (2022)  
**Expected F1**: 83-85%  
**Award**: Best System Paper at SemEval 2022

## Architecture:
1. **Knowledge Base**: Wikipedia-based multilingual knowledge
2. **Retrieval**: Find related Wikipedia contexts for entities
3. **Augmentation**: Append context to input sentences
4. **Model**: XLM-RoBERTa + CRF with augmented context

## GPU Efficiency Strategy:
- **Phase 1 (CPU)**: Build knowledge base, retrieve contexts (1-2 hours)
- **Phase 2 (GPU)**: Train model (1.5 hours)
- **Total GPU**: ~1.5 hours only!

## Reference:
- Paper: https://arxiv.org/abs/2203.00545
- GitHub: https://github.com/Alibaba-NLP/KB-NER

## Installation

In [25]:
# Install required packages
# Let pip resolve to compatible versions with Colab's environment
!pip install --upgrade transformers huggingface-hub
!pip install pytorch-crf
!pip install wikipedia-api

print("‚úì All packages installed")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting huggingface-hub
  Using cached huggingface_hub-1.2.1-py3-none-any.whl.metadata (13 kB)

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
‚úì All packages installed


## Imports

In [26]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import XLMRobertaTokenizerFast, XLMRobertaModel, get_linear_schedule_with_warmup
from torch.optim import AdamW  # Changed: Use PyTorch's AdamW instead of transformers'
from torchcrf import CRF
import json
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import wikipediaapi
from collections import defaultdict
import time
import os

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

Device: cpu


## Configuration

In [None]:
# Model configuration
CONFIG = {
    'model_name': 'xlm-roberta-base',  # 270M params, multilingual
    'max_length': 128,  # Max tokens per example
    'batch_size': 16,
    'learning_rate': 2e-5,
    'num_epochs': 5,  # Changed back to 5 for better training
    'warmup_ratio': 0.1,
    'weight_decay': 0.01,
    'dropout': 0.1,
    'gradient_clip': 1.0,
    
    # KB-NER specific
    'use_knowledge': True,  # Set False for baseline XLM-R only
    'max_context_length': 50,  # Max tokens from Wikipedia context
    'wiki_lang': 'en',  # Wikipedia language
}

# Entity types
ENTITY_TYPES = [
    'O',
    'B-Artist', 'I-Artist',
    'B-Politician', 'I-Politician',
    'B-HumanSettlement', 'I-HumanSettlement',
    'B-PublicCorp', 'I-PublicCorp',
    'B-ORG', 'I-ORG',
    'B-Facility', 'I-Facility',
    'B-OtherPER', 'I-OtherPER'
]

tag2id = {tag: idx for idx, tag in enumerate(ENTITY_TYPES)}
id2tag = {idx: tag for tag, idx in tag2id.items()}
num_tags = len(ENTITY_TYPES)

print(f"Configuration loaded")
print(f"Number of entity types: {num_tags}")
print(f"Knowledge base enabled: {CONFIG['use_knowledge']}")

## Phase 1: Knowledge Base Construction (CPU)

This phase runs on CPU and builds the Wikipedia knowledge base.  
**No GPU credits used here!**

In [28]:
class WikipediaKB:
    """
    Wikipedia Knowledge Base for entity context retrieval.
    Simplified version of KB-NER's approach.
    """
    
    def __init__(self, lang='en', cache_file='wiki_cache.json'):
        self.wiki = wikipediaapi.Wikipedia(
            language=lang,
            user_agent='MultiCoNER-KB-NER/1.0'
        )
        self.cache_file = cache_file
        self.cache = self.load_cache()
    
    def load_cache(self):
        """Load cached Wikipedia queries"""
        if os.path.exists(self.cache_file):
            with open(self.cache_file, 'r') as f:
                return json.load(f)
        return {}
    
    def save_cache(self):
        """Save cache to disk"""
        with open(self.cache_file, 'w') as f:
            json.dump(self.cache, f)
    
    def get_entity_context(self, entity_text, entity_type=None, max_length=100):
        """
        Retrieve Wikipedia context for an entity.
        
        Args:
            entity_text: Entity surface form (e.g., "Barack Obama")
            entity_type: Entity type hint (e.g., "Politician")
            max_length: Max characters from Wikipedia
        
        Returns:
            Context string or empty string if not found
        """
        # Check cache first
        cache_key = entity_text.lower()
        if cache_key in self.cache:
            return self.cache[cache_key]
        
        # Query Wikipedia
        try:
            page = self.wiki.page(entity_text)
            
            if page.exists():
                # Get first paragraph as context
                summary = page.summary[:max_length]
                self.cache[cache_key] = summary
                return summary
            else:
                self.cache[cache_key] = ""
                return ""
        except Exception as e:
            print(f"Error fetching {entity_text}: {e}")
            self.cache[cache_key] = ""
            return ""
    
    def extract_entities_from_text(self, tokens, tags):
        """
        Extract entity spans from BIO tags.
        
        Returns:
            List of (entity_text, entity_type) tuples
        """
        entities = []
        current_entity = []
        current_type = None
        
        for token, tag in zip(tokens, tags):
            if tag.startswith('B-'):
                # Save previous entity
                if current_entity:
                    entities.append((' '.join(current_entity), current_type))
                # Start new entity
                current_entity = [token]
                current_type = tag[2:]  # Remove 'B-'
            elif tag.startswith('I-'):
                if current_entity:
                    current_entity.append(token)
            else:  # 'O' tag
                if current_entity:
                    entities.append((' '.join(current_entity), current_type))
                    current_entity = []
                    current_type = None
        
        # Don't forget last entity
        if current_entity:
            entities.append((' '.join(current_entity), current_type))
        
        return entities

print("‚úì WikipediaKB class defined")

‚úì WikipediaKB class defined


In [None]:
def augment_data_with_knowledge(data_file, kb, output_file, max_examples=None, checkpoint_every=500):
    """
    Augment dataset with Wikipedia contexts with checkpoint support.
    
    Args:
        data_file: Input .jsonl file
        kb: WikipediaKB instance
        output_file: Output .jsonl with added contexts
        max_examples: Limit for testing (None = all)
        checkpoint_every: Save progress every N examples
    """
    print(f"\nAugmenting {data_file} with Wikipedia knowledge...")
    
    # Check if partial output exists (resume from checkpoint)
    start_idx = 0
    if os.path.exists(output_file):
        with open(output_file, 'r') as f:
            existing_lines = f.readlines()
            start_idx = len(existing_lines)
        print(f"üìÅ Found existing progress: {start_idx} examples already processed")
        print(f"   Resuming from example {start_idx}...")
    
    df = pd.read_json(data_file, lines=True)
    if max_examples:
        df = df.head(max_examples)
    
    total_examples = len(df)  # Store original total before skipping
    
    # Skip already processed examples
    if start_idx > 0:
        df = df.iloc[start_idx:]
    
    if len(df) == 0:
        print("‚úì All examples already processed!")
        return
    
    # Open file in append mode
    mode = 'a' if start_idx > 0 else 'w'
    
    with open(output_file, mode) as f:
        # Fixed: total should be original total, not remaining
        for idx, row in tqdm(df.iterrows(), total=total_examples, desc="Retrieving contexts", initial=start_idx):
            tokens = row['tokens']
            tags = row.get('ner_tags', ['O'] * len(tokens))
            
            # Extract entities
            entities = kb.extract_entities_from_text(tokens, tags)
            
            # Retrieve contexts
            contexts = []
            for entity_text, entity_type in entities:
                context = kb.get_entity_context(entity_text, entity_type)
                if context:
                    contexts.append(context)
                
                # Rate limit
                time.sleep(0.1)
            
            # Combine contexts
            combined_context = " ".join(contexts)[:500]
            
            # Create augmented example
            augmented = {
                'id': row['id'],
                'tokens': tokens,
                'ner_tags': tags,
                'context': combined_context
            }
            
            # Write immediately (don't accumulate in memory)
            f.write(json.dumps(augmented) + '\n')
            
            # Save cache periodically
            if (idx + 1) % checkpoint_every == 0:
                f.flush()  # Force write to disk
                kb.save_cache()
                print(f"\nüíæ Checkpoint saved at {idx + 1} examples")
    
    # Final cache save
    kb.save_cache()
    
    print(f"\n‚úì Augmented data saved to {output_file}")
    print(f"  Total examples processed: {total_examples}")
    print(f"  Cache size: {len(kb.cache)} entities")

print("‚úì Augmentation function defined (with checkpointing)")

### Run Knowledge Base Augmentation (CPU Phase)

**IMPORTANT**: This cell runs on CPU. Do this BEFORE enabling GPU to save credits!

In [36]:
# Initialize Knowledge Base
kb = WikipediaKB(lang='en', cache_file='wiki_cache_multiconer.json')

# Augment training data
# NOTE: Start with small subset for testing, then run on full data
TEST_MODE = False  # Set False for full dataset

if TEST_MODE:
    print("‚ö†Ô∏è  Running in TEST MODE (100 examples)")
    print("   Set TEST_MODE=False for full dataset")
    augment_data_with_knowledge(
        'train_split.jsonl',
        kb,
        'train_split_kb.jsonl',
        max_examples=100
    )
    augment_data_with_knowledge(
        'val_split.jsonl',
        kb,
        'val_split_kb.jsonl',
        max_examples=100
    )
else:
    print("‚úì Running on FULL DATASET")
    print("  This will take 1-2 hours on CPU")
    augment_data_with_knowledge('train_split.jsonl', kb, 'train_split_kb.jsonl')
    augment_data_with_knowledge('val_split.jsonl', kb, 'val_split_kb.jsonl')

print("\n‚úì Knowledge base augmentation complete!")
print("  Next: Move to GPU phase for training")

‚úì Running on FULL DATASET
  This will take 1-2 hours on CPU

Augmenting train_split.jsonl with Wikipedia knowledge...
üìÅ Found existing progress: 26 examples already processed
   Resuming from example 26...


Retrieving contexts:   0%|          | 26/90294 [00:00<?, ?it/s]

KeyboardInterrupt: 

## Phase 2: Model Definition

XLM-RoBERTa + CRF with knowledge-augmented inputs

In [None]:
class KB_NER_Dataset(Dataset):
    """
    Dataset with knowledge-augmented inputs.
    Format: [CLS] tokens [SEP] context [SEP]
    """
    
    def __init__(self, data_file, tokenizer, max_length=128, tag2id=None):
        self.data = []
        with open(data_file, 'r') as f:
            for line in f:
                self.data.append(json.loads(line))
        
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.tag2id = tag2id
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        example = self.data[idx]
        tokens = example['tokens']
        tags = example['ner_tags']
        context = example.get('context', '')
        
        # Tokenize sentence first (with is_split_into_words=True)
        sentence_encoding = self.tokenizer(
            tokens,
            is_split_into_words=True,
            add_special_tokens=False,  # We'll add manually with context
            truncation=False,
            return_tensors=None
        )
        
        # If we have context, tokenize it separately
        if context:
            context_encoding = self.tokenizer(
                context,
                is_split_into_words=False,  # Context is a string, not pre-tokenized
                add_special_tokens=False,
                truncation=True,
                max_length=50,  # Limit context length
                return_tensors=None
            )
            
            # Combine: [CLS] sentence_tokens [SEP] context_tokens [SEP]
            input_ids = (
                [self.tokenizer.cls_token_id] + 
                sentence_encoding['input_ids'] + 
                [self.tokenizer.sep_token_id] + 
                context_encoding['input_ids'] + 
                [self.tokenizer.sep_token_id]
            )
        else:
            # No context: [CLS] sentence_tokens [SEP]
            input_ids = (
                [self.tokenizer.cls_token_id] + 
                sentence_encoding['input_ids'] + 
                [self.tokenizer.sep_token_id]
            )
        
        # Truncate if too long
        if len(input_ids) > self.max_length:
            input_ids = input_ids[:self.max_length]
        
        # Pad to max_length
        attention_mask = [1] * len(input_ids)
        padding_length = self.max_length - len(input_ids)
        input_ids = input_ids + [self.tokenizer.pad_token_id] * padding_length
        attention_mask = attention_mask + [0] * padding_length
        
        # Align labels with tokenized input
        # Get word_ids from sentence encoding only
        word_ids_sentence = sentence_encoding.word_ids()
        
        # Build full word_ids: [None(CLS)] + word_ids_sentence + [None(SEP)] + [None...](context) + [None(SEP)] + [None...](padding)
        word_ids = [None]  # CLS
        word_ids.extend(word_ids_sentence)
        word_ids.append(None)  # SEP after sentence
        
        if context:
            # Add None for context tokens and final SEP
            context_length = len(context_encoding['input_ids']) + 1  # +1 for SEP
            word_ids.extend([None] * context_length)
        
        # Add None for padding
        word_ids.extend([None] * padding_length)
        
        # Create labels aligned with input_ids
        labels = []
        previous_word_idx = None
        
        for word_idx in word_ids[:self.max_length]:
            if word_idx is None:
                labels.append(-100)  # Ignore special tokens, context, padding
            elif word_idx != previous_word_idx:
                # First subword of a word - use the tag
                labels.append(self.tag2id[tags[word_idx]])
            else:
                # Continuation subword - use same tag
                labels.append(self.tag2id[tags[word_idx]])
            previous_word_idx = word_idx
        
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long),
            'word_ids': word_ids,
            'original_tokens': tokens,
            'original_tags': tags
        }

print("‚úì KB_NER_Dataset class defined")

‚úì KB_NER_Dataset class defined


In [None]:
class XLMRobertaCRF(nn.Module):
    """
    XLM-RoBERTa + CRF for NER.
    Based on KB-NER architecture.
    
    IMPORTANT: During inference, pass word_mask to only decode word positions!
    """
    
    def __init__(self, model_name, num_tags, dropout=0.1):
        super().__init__()
        self.xlmr = XLMRobertaModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.xlmr.config.hidden_size, num_tags)
        self.crf = CRF(num_tags, batch_first=True)
    
    def forward(self, input_ids, attention_mask, labels=None, word_mask=None):
        # Get XLM-R embeddings
        outputs = self.xlmr(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        sequence_output = outputs.last_hidden_state
        sequence_output = self.dropout(sequence_output)
        
        # Project to tag space
        emissions = self.classifier(sequence_output)
        
        if labels is not None:
            # TRAINING: Create mask from labels (only real word tokens)
            mask = (labels != -100).byte()
            
            # IMPORTANT: CRF requires first position to be valid
            mask[:, 0] = 1
            
            # Replace -100 with 0 for CRF
            labels_for_crf = labels.clone()
            labels_for_crf[labels_for_crf == -100] = 0
            
            # CRF loss
            log_likelihood = self.crf(emissions, labels_for_crf, mask=mask, reduction='mean')
            loss = -log_likelihood
            
            return loss, emissions
        else:
            # INFERENCE: Use word_mask if provided, otherwise fall back to attention_mask
            if word_mask is not None:
                mask = word_mask.byte()
            else:
                mask = attention_mask.byte()
            
            predictions = self.crf.decode(emissions, mask=mask)
            return predictions

print("‚úì XLMRobertaCRF model defined")

## Phase 3: Training (GPU Phase)

**‚ö†Ô∏è START GPU HERE** - This is where GPU credits are used!

In [None]:
# Custom collate function to handle word_ids (which contains None)
def collate_fn(batch):
    """Custom collate function to handle word_ids (which are lists with None values)"""
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    
    # Keep word_ids, original_tokens, and original_tags as lists (not tensors)
    word_ids = [item['word_ids'] for item in batch]
    original_tokens = [item['original_tokens'] for item in batch]
    original_tags = [item['original_tags'] for item in batch]
    
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels,
        'word_ids': word_ids,
        'original_tokens': original_tokens,
        'original_tags': original_tags
    }

# Initialize tokenizer and datasets
tokenizer = XLMRobertaTokenizerFast.from_pretrained(CONFIG['model_name'])

train_dataset = KB_NER_Dataset(
    'train_split_kb.jsonl',
    tokenizer,
    max_length=CONFIG['max_length'],
    tag2id=tag2id
)

val_dataset = KB_NER_Dataset(
    'val_split_kb.jsonl',
    tokenizer,
    max_length=CONFIG['max_length'],
    tag2id=tag2id
)

train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    collate_fn=collate_fn  # Use custom collate function
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    collate_fn=collate_fn  # Use custom collate function
)

print(f"‚úì Datasets loaded")
print(f"  Training examples: {len(train_dataset)}")
print(f"  Validation examples: {len(val_dataset)}")

‚úì Datasets loaded
  Training examples: 100
  Validation examples: 100


In [None]:
# Initialize model
model = XLMRobertaCRF(
    CONFIG['model_name'],
    num_tags,
    dropout=CONFIG['dropout']
).to(device)

# Optimizer and scheduler
optimizer = AdamW(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay']
)

total_steps = len(train_loader) * CONFIG['num_epochs']
warmup_steps = int(total_steps * CONFIG['warmup_ratio'])

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

print(f"‚úì Model initialized")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
print(f"  Device: {device}")

‚úì Model initialized
  Parameters: 278.1M
  Device: cpu




In [None]:
def train_epoch(model, dataloader, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    
    progress_bar = tqdm(dataloader, desc="Training")
    for batch in progress_bar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        
        loss, _ = model(input_ids, attention_mask, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['gradient_clip'])
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})
    
    return total_loss / len(dataloader)

def validate_epoch(model, dataloader, device, id2tag):
    """
    Validate model and return predictions for F1 calculation.
    
    CRITICAL FIX: Create word_mask from word_ids to only decode word positions!
    """
    model.eval()
    
    all_predictions = []
    all_ground_truth = []
    all_tokens = []
    total_loss = 0
    
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Validating")
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            word_ids_batch = batch['word_ids']  # List of lists with None values
            
            # Get loss
            loss, _ = model(input_ids, attention_mask, labels)
            total_loss += loss.item()
            
            # CRITICAL FIX: Create word_mask for CRF decoding
            # Only mark positions that correspond to actual words (not [CLS], [SEP], context, padding)
            word_mask = torch.zeros_like(attention_mask)
            for i, word_ids in enumerate(word_ids_batch):
                for j, word_id in enumerate(word_ids):
                    if word_id is not None:  # This is a real word position
                        word_mask[i, j] = 1
            
            # CRF requires first position to be valid - force it
            word_mask[:, 0] = 1
            
            word_mask = word_mask.to(device)
            
            # Get predictions with word_mask
            predictions = model(input_ids, attention_mask, word_mask=word_mask)
            
            # Convert to tags
            for pred, word_ids, orig_tokens, orig_tags in zip(
                predictions,
                batch['word_ids'],
                batch['original_tokens'],
                batch['original_tags']
            ):
                # Align predictions with original tokens
                token_preds = []
                prev_word_idx = None
                
                for pred_id, word_idx in zip(pred, word_ids):
                    if word_idx is not None and word_idx != prev_word_idx:
                        if word_idx < len(orig_tokens):
                            token_preds.append(id2tag[pred_id])
                    prev_word_idx = word_idx
                
                # Ensure same length
                if len(token_preds) != len(orig_tokens):
                    token_preds = token_preds[:len(orig_tokens)]
                    if len(token_preds) < len(orig_tokens):
                        token_preds.extend(['O'] * (len(orig_tokens) - len(token_preds)))
                
                all_predictions.append(token_preds)
                all_ground_truth.append(list(orig_tags))
                all_tokens.append(list(orig_tokens))
    
    avg_loss = total_loss / len(dataloader)
    return all_predictions, all_ground_truth, all_tokens, avg_loss

print("‚úì Training and validation functions defined")

In [None]:
# Load utils for F1 calculation during training
try:
    import utils
    print("‚úì utils.py loaded for validation")
except ImportError:
    print("‚ö†Ô∏è utils.py not found - will save based on loss instead of F1")
    utils = None

# Training loop
print("\n" + "="*80)
print("STARTING TRAINING (GPU PHASE)")
print("="*80)

best_f1 = 0.0
best_val_loss = float('inf')
training_start_time = time.time()

# Create models directory if it doesn't exist
os.makedirs('models', exist_ok=True)

for epoch in range(CONFIG['num_epochs']):
    print(f"\n{'='*80}")
    print(f"Epoch {epoch + 1}/{CONFIG['num_epochs']}")
    print('='*80)
    
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, scheduler, device)
    print(f"Train Loss: {train_loss:.4f}")
    
    # Validate
    print("\nValidating...")
    predictions, ground_truth, tokens, val_loss = validate_epoch(model, val_loader, device, id2tag)
    
    # Calculate F1 score
    if utils is not None:
        results = utils.evaluate_entity_spans(ground_truth, predictions, tokens)
        val_f1 = results['f1']
        val_precision = results['precision']
        val_recall = results['recall']
        
        print(f"\nValidation Results:")
        print(f"  Loss:      {val_loss:.4f}")
        print(f"  Precision: {val_precision:.4f} ({val_precision*100:.2f}%)")
        print(f"  Recall:    {val_recall:.4f} ({val_recall*100:.2f}%)")
        print(f"  F1 Score:  {val_f1:.4f} ({val_f1*100:.2f}%)")
        
        # Save best model based on F1 score
        if val_f1 > best_f1:
            best_f1 = val_f1
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'val_f1': val_f1,
                'val_precision': val_precision,
                'val_recall': val_recall,
                'config': CONFIG,
                'tag2id': tag2id,
                'id2tag': id2tag,
            }, 'models/kb_ner_best.pt')
            print(f"  ‚úì New best model saved! (F1: {val_f1:.4f})")
        else:
            print(f"  (Best F1 so far: {best_f1:.4f})")
    else:
        # Fallback: save based on validation loss if utils not available
        print(f"Validation Loss: {val_loss:.4f}")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'config': CONFIG,
                'tag2id': tag2id,
                'id2tag': id2tag,
            }, 'models/kb_ner_best.pt')
            print(f"‚úì Saved best model (loss: {val_loss:.4f})")

training_time = time.time() - training_start_time
print(f"\n{'='*80}")
print("TRAINING COMPLETE!")
print('='*80)
print(f"Total time: {training_time / 60:.1f} minutes")
print(f"Time per epoch: {training_time / CONFIG['num_epochs'] / 60:.1f} minutes")
if utils is not None:
    print(f"Best validation F1: {best_f1:.4f} ({best_f1*100:.2f}%)")
print(f"Best model saved to: models/kb_ner_best.pt")

## Evaluation

Evaluate on validation set using entity-span F1

In [None]:
# Load utils for evaluation
try:
    import utils
    print("‚úì utils.py loaded")
except ImportError:
    print("‚úó utils.py not found. Upload it to evaluate.")

In [None]:
def evaluate_model(model, dataloader, device, id2tag):
    """
    Evaluate model and return predictions.
    
    CRITICAL FIX: Create word_mask from word_ids to only decode word positions!
    """
    model.eval()
    
    all_predictions = []
    all_ground_truth = []
    all_tokens = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            word_ids_batch = batch['word_ids']
            
            # CRITICAL FIX: Create word_mask for CRF decoding
            # Only mark positions that correspond to actual words (not [CLS], [SEP], context, padding)
            word_mask = torch.zeros_like(attention_mask)
            for i, word_ids in enumerate(word_ids_batch):
                for j, word_id in enumerate(word_ids):
                    if word_id is not None:  # This is a real word position
                        word_mask[i, j] = 1
            
            # CRF requires first position to be valid - force it
            word_mask[:, 0] = 1
            
            word_mask = word_mask.to(device)
            
            # Get predictions with word_mask
            predictions = model(input_ids, attention_mask, word_mask=word_mask)
            
            # Convert to tags
            for pred, word_ids, orig_tokens, orig_tags in zip(
                predictions,
                batch['word_ids'],
                batch['original_tokens'],
                batch['original_tags']
            ):
                # Align predictions with original tokens
                token_preds = []
                prev_word_idx = None
                
                for pred_id, word_idx in zip(pred, word_ids):
                    if word_idx is not None and word_idx != prev_word_idx:
                        if word_idx < len(orig_tokens):
                            token_preds.append(id2tag[pred_id])
                    prev_word_idx = word_idx
                
                # Ensure same length
                if len(token_preds) != len(orig_tokens):
                    token_preds = token_preds[:len(orig_tokens)]
                    if len(token_preds) < len(orig_tokens):
                        token_preds.extend(['O'] * (len(orig_tokens) - len(token_preds)))
                
                all_predictions.append(token_preds)
                all_ground_truth.append(list(orig_tags))
                all_tokens.append(list(orig_tokens))
    
    return all_predictions, all_ground_truth, all_tokens

print("‚úì Evaluation function defined")

In [None]:
# Evaluate on validation set
print("\nEvaluating on validation set...")

predictions, ground_truth, tokens = evaluate_model(model, val_loader, device, id2tag)

# Calculate metrics using utils.py
results = utils.evaluate_entity_spans(ground_truth, predictions, tokens)

print("\n" + "="*80)
print("M10: KB-NER RESULTS")
print("="*80)
print(f"Precision: {results['precision']:.4f} ({results['precision']*100:.2f}%)")
print(f"Recall:    {results['recall']:.4f} ({results['recall']*100:.2f}%)")
print(f"F1 Score:  {results['f1']:.4f} ({results['f1']*100:.2f}%)")
print(f"\nTrue Positives:  {results['true_positives']}")
print(f"False Positives: {results['false_positives']}")
print(f"False Negatives: {results['false_negatives']}")
print("="*80)

# Detailed report
utils.print_evaluation_report(ground_truth, predictions, tokens, "M10: KB-NER")

# Save results
results_data = {
    'model': 'M10: KB-NER (XLM-RoBERTa + CRF + Knowledge Base)',
    'architecture': 'XLM-RoBERTa-base + CRF',
    'knowledge_source': 'Wikipedia',
    'precision': results['precision'],
    'recall': results['recall'],
    'f1': results['f1'],
    'training_time': training_time,
    'num_epochs': CONFIG['num_epochs'],
    'parameters': sum(p.numel() for p in model.parameters()),
    'config': CONFIG
}

with open('models/kb_ner_results.json', 'w') as f:
    json.dump(results_data, f, indent=2)

print("\n‚úì Results saved to models/kb_ner_results.json")

## Summary

### M10: KB-NER Performance

**Architecture**: XLM-RoBERTa-base (270M params) + CRF + Wikipedia Knowledge Base

**Expected F1**: 83-85%

**Comparison**:
- Gemini Few-Shot: 68% F1
- M4 v2 (BiLSTM-CRF): 75.94% F1
- Friends (BERT): 77-79% F1
- **M10 (KB-NER)**: ~83-85% F1 ‚úÖ
- M8 (RoBERTa): ~85-88% F1 (simpler, no KB needed)

### GPU Usage:
- Phase 1 (KB construction): CPU only (~1-2 hours)
- Phase 2 (Training): GPU (~1.5 hours for 5 epochs)
- **Total GPU**: ~1.5 hours

### Next Steps:
1. If F1 ‚â• 83%, generate test predictions
2. Compare with M8 (simpler but similar F1)
3. Consider ensemble: M4 v2 + M10 + M8

### Notes:
- KB-NER adds 5-8% over baseline BERT
- Most gain comes from XLM-RoBERTa, KB adds 1-2%
- For simplicity, M8 (pure RoBERTa) might be better
- KB-NER shines on ambiguous entities