# Import Libraries

In [1]:
import json
import os
from glob import glob
import random
import numpy as np
from sklearn.model_selection import train_test_split
from transformers import BertForTokenClassification, BertTokenizerFast
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report
from sklearn.preprocessing import MultiLabelBinarizer
from collections import Counter
from tabulate import tabulate
import pandas as pd

import os
import json
import random
import numpy as np
from glob import glob
from collections import defaultdict, Counter

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from tqdm import tqdm

from transformers import Trainer, TrainingArguments, EarlyStoppingCallback, DataCollatorForTokenClassification

from sklearn.metrics import precision_score, recall_score, f1_score, classification_report
from sklearn.model_selection import train_test_split

from transformers import (
    XLMRobertaForTokenClassification, 
    XLMRobertaTokenizerFast, 
    Trainer, 
    TrainingArguments, 
    DataCollatorForTokenClassification
)

from seqeval.metrics import (
    precision_score as seq_precision,
    recall_score as seq_recall,
    f1_score as seq_f1,
    accuracy_score,
    classification_report as seq_classification_report
)

# Suppress warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import warnings
warnings.filterwarnings("ignore", message="Some weights of.*were not initialized from the model checkpoint.*")

print("✓ Libraries imported successfully")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

  from .autonotebook import tqdm as notebook_tqdm


✓ Libraries imported successfully
PyTorch version: 2.6.0+cu124
CUDA available: True


# Define Path

In [2]:
# Define paths
input_dir = "/home/guest/Public/KHEED/KHEED_Data_Collection/Final/bio_tagged"
output_dir = "/home/guest/Public/KHEED/KHEED_Data_Collection/Evaluation/Models/bert_khmer_ner_model"

# Load Files

In [3]:
def load_json_files_robust(input_dir, verbose=True):
    """
    Load JSON files and extract tokens and BIO tags with robust error handling.
    """
    all_tokens = []
    all_tags = []
    input_files = glob.glob(os.path.join(input_dir, "*.json"))
    
    if not input_files:
        print(f"❌ No JSON files found in '{input_dir}'. Please check the directory.")
        return [], [], []
    
    print(f"📂 Processing {len(input_files)} files...")
    
    errors = []
    skipped_sentences = 0
    processed_sentences = 0
    
    for file_idx, input_file in enumerate(input_files):
        if verbose and file_idx < 5:
            print(f"  📄 Processing: {os.path.basename(input_file)}")
        
        try:
            with open(input_file, 'r', encoding='utf-8') as f:
                obj = json.load(f)
            
            processed_content = obj.get('processed_content', [])
            
            for sent_idx, sentence_data in enumerate(processed_content):
                try:
                    tokens = sentence_data.get('tokens', [])
                    bio_tags = sentence_data.get('bio_tags', [])
                    
                    if not tokens or not bio_tags:
                        skipped_sentences += 1
                        continue
                    
                    flattened_tags = [tag[0] if isinstance(tag, list) and tag else tag for tag in bio_tags]
                    if len(tokens) != len(flattened_tags):
                        error_msg = f"Length mismatch in {os.path.basename(input_file)}, sentence {sent_idx}: {len(tokens)} tokens vs {len(flattened_tags)} tags"
                        errors.append(error_msg)
                        if verbose and len(errors) <= 3:
                            print(f"    ⚠️  {error_msg}")
                        skipped_sentences += 1
                        continue
                    
                    validated_tags = [tag if tag and isinstance(tag, str) else "O" for tag in flattened_tags]
                    all_tokens.append(tokens)
                    all_tags.append(validated_tags)
                    processed_sentences += 1
                    
                except Exception as e:
                    error_msg = f"Error processing sentence {sent_idx} in {os.path.basename(input_file)}: {e}"
                    errors.append(error_msg)
                    if verbose and len(errors) <= 3:
                        print(f"    ❌ {error_msg}")
                    skipped_sentences += 1
                    continue
        
        except Exception as e:
            error_msg = f"Error reading file {input_file}: {e}"
            errors.append(error_msg)
            if verbose:
                print(f"  ❌ {error_msg}")
            continue
    
    print(f"\n📊 Data loading summary:")
    print(f"  ✅ Processed sentences: {processed_sentences}")
    print(f"  ⚠️  Skipped sentences: {skipped_sentences}")
    print(f"  ❌ Total errors: {len(errors)}")
    
    return all_tokens, all_tags, input_files

In [4]:
import glob
import os
import json

def analyze_data_structure(input_dir, max_files_to_check=5):
    """
    Analyze the structure of JSON files to understand the data format.
    """
    input_files = glob.glob(os.path.join(input_dir, "*.json"))
    
    if not input_files:
        print(f"❌ No JSON files found in '{input_dir}'")
        return None
    
    print(f"📁 Found {len(input_files)} JSON files")
    print(f"📋 Analyzing first {min(max_files_to_check, len(input_files))} files...\n")
    
    all_tags = set()
    sample_data = []
    
    for i, file_path in enumerate(input_files[:max_files_to_check]):
        print(f"File {i+1}: {os.path.basename(file_path)}")
        
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            
            # Show file structure
            print(f"  📝 Keys: {list(data.keys())}")
            
            if 'processed_content' in data:
                processed_content = data['processed_content']
                print(f"  📊 Sentences in file: {len(processed_content)}")
                
                # Check first few sentences
                for j, sentence in enumerate(processed_content[:2]):
                    tokens = sentence.get('tokens', [])
                    bio_tags = sentence.get('bio_tags', [])
                    
                    print(f"    Sentence {j+1}:")
                    print(f"      🔤 Tokens ({len(tokens)}): {tokens[:20]}{'...' if len(tokens) > 5 else ''}")
                    print(f"      🏷️  Tags ({len(bio_tags)}): {bio_tags[:20]}{'...' if len(bio_tags) > 5 else ''}")
                    
                    # Flatten tags and collect unique ones
                    flat_tags = [tag[0] if isinstance(tag, list) and tag else tag for tag in bio_tags]
                    all_tags.update(flat_tags)
                    
                    if j == 0:  # Save first sentence as sample
                        sample_data.append({
                            'file': os.path.basename(file_path),
                            'tokens': tokens,
                            'bio_tags': flat_tags
                        })
            
            print()
        
        except Exception as e:
            print(f"  ❌ Error reading file: {e}\n")
    
    print(f"🏷️  Unique tags found: {sorted(all_tags)}")
    print(f"📈 Total unique tags: {len(all_tags)}")
    
    return all_tags, sample_data

# %%

print("🔍 Analyzing data structure...")
unique_tags, sample_data = analyze_data_structure(input_dir)



🔍 Analyzing data structure...
📁 Found 525 JSON files
📋 Analyzing first 5 files...

File 1: object_d2279a49-8b25-4e4c-b936-62f88487a895.json
  📝 Keys: ['id', 'annotations', 'file_upload', 'drafts', 'predictions', 'data', 'meta', 'created_at', 'updated_at', 'inner_id', 'total_annotations', 'cancelled_annotations', 'total_predictions', 'comment_count', 'unresolved_comment_count', 'last_comment_updated_at', 'project', 'updated_by', 'comment_authors', 'processed_content']
  📊 Sentences in file: 9
    Sentence 1:
      🔤 Tokens (81): ['ឧប', 'នាយក', 'រដ្ឋមន្ត្រី', ' ', 'នេត', ' ', 'សាវឿន', ' ', 'តំណាង', 'ដ៏', 'ខ្ពង់ខ្ពស់', 'របស់', 'សម្តេច', 'មហា', 'បវរធិបតី', ' ', 'ហ៊ុន', ' ', 'ម៉ាណែត', ' ']...
      🏷️  Tags (81): [['O'], ['O'], ['O'], ['O'], ['O'], ['O'], ['O'], ['O'], ['O'], ['O'], ['O'], ['O'], ['O'], ['O'], ['O'], ['O'], ['O'], ['O'], ['O'], ['O']]...
    Sentence 2:
      🔤 Tokens (83): ['ថ្លែង', 'ក្នុង', 'ពិធី', 'ប្រកាស', 'ជា', 'ផ្លូវការដាក់', 'ឱ្យ', 'អនុវត្ត', 'ការ', ' ', 'កែទម្រង់', 

# Prepare Dataset

In [5]:
def prepare_dataset_robust(tokens, tags, tokenizer, label2id, max_length=512):
    """
    Prepare dataset for BERT with robust error handling.
    """
    encodings = []
    skipped = 0
    
    print(f"🔄 Preparing dataset...")
    
    for i, (token_list, tag_list) in enumerate(zip(tokens, tags)):
        try:
            if not token_list or not tag_list or len(token_list) != len(tag_list):
                skipped += 1
                continue
            
            missing_tags = [tag for tag in tag_list if tag not in label2id]
            if missing_tags:
                print(f"⚠️  Sentence {i}: Missing tags {set(missing_tags)} - replacing with 'O'")
                tag_list = [tag if tag in label2id else 'O' for tag in tag_list]
            
            encoding = tokenizer(
                token_list,
                is_split_into_words=True,
                return_offsets_mapping=True,
                truncation=True,
                max_length=max_length,
                padding=False
            )
            
            aligned_labels = []
            word_ids = encoding.word_ids()
            prev_word_id = None
            
            for word_id in word_ids:
                if word_id is None:
                    aligned_labels.append(-100)
                elif word_id != prev_word_id:
                    if word_id < len(tag_list):
                        aligned_labels.append(label2id[tag_list[word_id]])
                    else:
                        aligned_labels.append(label2id['O'])
                else:
                    if word_id < len(tag_list):
                        tag = tag_list[word_id]
                        if tag.startswith("B-"):
                            i_tag = f"I-{tag[2:]}"
                            aligned_labels.append(label2id.get(i_tag, label2id[tag]))
                        else:
                            aligned_labels.append(label2id[tag])
                    else:
                        aligned_labels.append(label2id['O'])
                
                prev_word_id = word_id
            
            encoding["labels"] = aligned_labels
            del encoding["offset_mapping"]
            encodings.append(encoding)
            
        except Exception as e:
            print(f"❌ Error processing sentence {i}: {e}")
            skipped += 1
            continue
    
    print(f"✅ Prepared {len(encodings)} examples, skipped {skipped}")
    return encodings

# Custom Dataset Class

In [6]:
class NERDataset(torch.utils.data.Dataset):
    """Custom Dataset for NER."""
    
    def __init__(self, encodings):
        self.encodings = encodings
    
    def __getitem__(self, idx):
        item = {key: torch.tensor(val) for key, val in self.encodings[idx].items()}
        return item
    
    def __len__(self):
        return len(self.encodings)

# Training

In [7]:
def train_model_trainer(model, train_dataset, val_dataset, output_dir, tokenizer,
                       num_epochs=10, batch_size=4, learning_rate=2e-5, 
                       seed=42, patience=3):
    """
    Train a Khmer NER model using Hugging Face Trainer with early stopping and dynamic padding.
    
    Args:
        model: The NER model to train (e.g., BERT-based model for Khmer NER)
        train_dataset: Training dataset with input_ids, attention_mask, and labels
        val_dataset: Validation dataset with input_ids, attention_mask, and labels
        output_dir: Directory to save model checkpoints and logs
        tokenizer: Tokenizer used for the model (required for data collator)
        num_epochs: Maximum number of epochs to train
        batch_size: Per-device batch size
        learning_rate: Initial learning rate
        seed: Random seed for reproducibility
        patience: Number of epochs to wait for improvement before stopping
    """
    # Set random seed for reproducibility
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    
    # Check device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"🖥️  Using device: {device}")
    if device.type == "cuda":
        print(f"   GPU: {torch.cuda.get_device_name()}")
        print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
        torch.cuda.empty_cache()
    
    model.to(device)
    
    # Define data collator for dynamic padding
    data_collator = DataCollatorForTokenClassification(
        tokenizer=tokenizer,
        padding=True,
        label_pad_token_id=-100  # Standard for ignoring padding tokens in NER loss
    )
    
    # Define training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        learning_rate=learning_rate,
        weight_decay=0.01,
        warmup_steps=min(500, len(train_dataset) // batch_size // 10),
        eval_strategy="epoch",  
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        save_total_limit=2,  # Keep only the last 2 checkpoints
        seed=seed,
        logging_dir=os.path.join(output_dir, "logs"),
        logging_strategy="epoch",
        report_to="none",  # Disable TensorBoard to avoid dependency error
        fp16=torch.cuda.is_available(),  # Enable mixed precision on GPU
        gradient_accumulation_steps=2,  # Add gradient accumulation for stability
        max_grad_norm=1.0,  # Gradient clipping to prevent exploding gradients
    )
    
    # Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,  # Add collator for dynamic padding
        callbacks=[EarlyStoppingCallback(early_stopping_patience=patience, early_stopping_threshold=0.001)],
    )
    
    # Start training
    print(f"🚀 Starting training for Khmer NER...")
    print(f"   Epochs: {num_epochs}")
    print(f"   Per-device batch size: {batch_size}")
    print(f"   Gradient accumulation steps: {training_args.gradient_accumulation_steps}")
    print(f"   Effective batch size: {batch_size * training_args.gradient_accumulation_steps}")
    print(f"   Learning rate: {learning_rate}")
    print(f"   Early stopping patience: {patience} epochs")
    
    try:
        trainer.train()
        print("✅ Training completed!")
    except RuntimeError as e:
        print(f"❌ Training failed: {e}")
        if "out of memory" in str(e).lower() and device.type == "cuda":
            print("🔄 GPU memory error. Try reducing batch_size or increasing gradient_accumulation_steps.")
            print(f"   Current batch_size: {batch_size}")
            print(f"   Current gradient_accumulation_steps: {training_args.gradient_accumulation_steps}")
            print("   Suggested: batch_size=2, gradient_accumulation_steps=4")
        raise e
    
    # Save final model
    final_model_path = os.path.join(output_dir, "final_model")
    model.save_pretrained(final_model_path)
    tokenizer.save_pretrained(final_model_path)
    print(f"💾 Final model saved to {final_model_path}")
    
    return model


# Evaluate Model 

In [8]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import json
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report

def evaluate_model_standard(model, tokenizer, test_dataset, id2label, output_dir, batch_size=8):
    """
    Standard evaluation for NER models following best practices.
    Uses both token-level and entity-level (seqeval) metrics.
    """
    print("🔍 Evaluating model on test set...")
    
    device = next(model.parameters()).device
    model.eval()
    
    # Create data loader with proper collate function
    def collate_fn_eval(batch):
        """Custom collate function for evaluation DataLoader with padding"""
        max_length = max(len(item['input_ids']) for item in batch)
        
        input_ids_list = []
        attention_mask_list = []
        labels_list = []
        
        for item in batch:
            seq_len = len(item['input_ids'])
            padding_length = max_length - seq_len
            
            # Use tokenizer's pad_token_id, fallback to 0
            pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
            
            # Pad sequences
            padded_input_ids = torch.cat([
                item['input_ids'], 
                torch.full((padding_length,), pad_token_id, dtype=item['input_ids'].dtype)
            ])
            
            padded_attention_mask = torch.cat([
                item['attention_mask'],
                torch.zeros(padding_length, dtype=item['attention_mask'].dtype)
            ])
            
            # Pad labels with -100 (ignore index for loss calculation)
            padded_labels = torch.cat([
                item['labels'],
                torch.full((padding_length,), -100, dtype=item['labels'].dtype)
            ])
            
            input_ids_list.append(padded_input_ids)
            attention_mask_list.append(padded_attention_mask)
            labels_list.append(padded_labels)
        
        return {
            'input_ids': torch.stack(input_ids_list),
            'attention_mask': torch.stack(attention_mask_list),
            'labels': torch.stack(labels_list)
        }
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        collate_fn=collate_fn_eval,
        num_workers=0
    )
    
    # Collections for different evaluation approaches
    all_predictions = []
    seqeval_true_labels = []  # For entity-level evaluation
    seqeval_pred_labels = []  # For entity-level evaluation
    token_true_labels = []    # For token-level evaluation  
    token_pred_labels = []    # For token-level evaluation
    
    print("Running inference...")
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(test_loader, desc="Evaluating")):
            # Move to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels']
            
            # Get model predictions
            outputs = model(input_ids, attention_mask=attention_mask)
            predictions = outputs.logits.argmax(dim=-1).cpu().numpy()
            true_labels = labels.cpu().numpy()
            
            # Process each sample in the batch
            for sample_idx in range(input_ids.size(0)):
                tokens = tokenizer.convert_ids_to_tokens(input_ids[sample_idx].cpu())
                sample_predictions = predictions[sample_idx]
                sample_labels = true_labels[sample_idx]
                
                # Extract valid predictions (ignore padding and special tokens)
                seq_pred_tags = []
                seq_true_tags = []
                valid_tokens = []
                
                for pred_id, true_id, token in zip(sample_predictions, sample_labels, tokens):
                    # Skip padded tokens and special tokens
                    if true_id != -100 and not token.startswith('[') and token not in ['<s>', '</s>', '<pad>', '<unk>']:
                        pred_tag = id2label[pred_id]
                        true_tag = id2label[true_id]
                        
                        seq_pred_tags.append(pred_tag)
                        seq_true_tags.append(true_tag)
                        valid_tokens.append(token)
                        
                        # For token-level metrics
                        token_pred_labels.append(pred_tag)
                        token_true_labels.append(true_tag)
                
                # Add to seqeval collections (sequence-level)
                if seq_pred_tags and seq_true_tags:
                    seqeval_pred_labels.append(seq_pred_tags)
                    seqeval_true_labels.append(seq_true_tags)
                    
                    # Store detailed predictions
                    all_predictions.append({
                        'tokens': valid_tokens,
                        'true_tags': seq_true_tags,
                        'pred_tags': seq_pred_tags
                    })
    
    print(f"✅ Evaluation completed on {len(all_predictions)} samples")
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Save detailed predictions
    save_predictions(all_predictions, output_dir)
    
    # Calculate and display metrics
    metrics = calculate_comprehensive_metrics(
        seqeval_true_labels, seqeval_pred_labels,
        token_true_labels, token_pred_labels,
        output_dir
    )
    
    return all_predictions, seqeval_true_labels, seqeval_pred_labels, metrics

def save_predictions(all_predictions, output_dir):
    """Save predictions in multiple formats"""
    
    # Save as JSON
    json_path = os.path.join(output_dir, "predictions.json")
    with open(json_path, 'w', encoding='utf-8') as f:
        json.dump(all_predictions, f, ensure_ascii=False, indent=2)
    print(f"💾 Detailed predictions saved to {json_path}")
    
    # Save as CoNLL format
    conll_path = os.path.join(output_dir, "predictions.conll")
    with open(conll_path, 'w', encoding='utf-8') as f:
        for pred in all_predictions:
            for token, true_tag, pred_tag in zip(pred['tokens'], pred['true_tags'], pred['pred_tags']):
                f.write(f"{token}\t{true_tag}\t{pred_tag}\n")
            f.write("\n")  # Empty line between sentences
    print(f"💾 CoNLL format saved to {conll_path}")

def serialize_classification_report(report_dict):
    """Convert classification report to JSON-serializable format"""
    serializable = {}
    for key, value in report_dict.items():
        if isinstance(key, tuple):
            # Convert tuple keys to string representation
            key = str(key)
        elif key is None:
            key = "none"
        
        if isinstance(value, dict):
            serializable[key] = serialize_classification_report(value)
        else:
            serializable[key] = value
    return serializable

def analyze_confusion_patterns(true_labels, pred_labels):
    """Analyze confusion patterns between predictions and true labels"""
    confusion_patterns = {}
    
    # Count misclassifications
    misclassifications = Counter()
    correct_predictions = Counter()
    
    for true_label, pred_label in zip(true_labels, pred_labels):
        if true_label == pred_label:
            correct_predictions[true_label] += 1
        else:
            misclassifications[(true_label, pred_label)] += 1
    
    # Convert to serializable format
    confusion_patterns['misclassifications'] = {f"{true_label}→{pred_label}": count 
                                                for (true_label, pred_label), count in misclassifications.most_common(10)}
    confusion_patterns['correct_predictions'] = dict(correct_predictions)
    
    return confusion_patterns

def calculate_comprehensive_metrics(seqeval_true_labels, seqeval_pred_labels, 
                                  token_true_labels, token_pred_labels, output_dir):
    """Calculate comprehensive NER evaluation metrics"""
    
    print("\n📊 Comprehensive NER Evaluation Results:")
    print("=" * 80)
    
    metrics = {}
    
    # 1. Entity-level metrics (Standard for NER - using seqeval)
    print("🎯 ENTITY-LEVEL METRICS (Primary - Standard for NER)")
    print("-" * 50)
    
    try:
        entity_precision = seq_precision(seqeval_true_labels, seqeval_pred_labels)
        entity_recall = seq_recall(seqeval_true_labels, seqeval_pred_labels) 
        entity_f1 = seq_f1(seqeval_true_labels, seqeval_pred_labels)
        entity_accuracy = accuracy_score(seqeval_true_labels, seqeval_pred_labels)
        
        print(f"Entity Precision: {entity_precision:.4f}")
        print(f"Entity Recall:    {entity_recall:.4f}")
        print(f"Entity F1-Score:  {entity_f1:.4f}")
        print(f"Entity Accuracy:  {entity_accuracy:.4f}")
        
        metrics.update({
            'entity_precision': entity_precision,
            'entity_recall': entity_recall,
            'entity_f1_score': entity_f1,
            'entity_accuracy': entity_accuracy
        })
        
        # Detailed entity-level classification report
        # Note: seqeval classification_report returns a string, not a dict
        entity_report_str = seq_classification_report(seqeval_true_labels, seqeval_pred_labels)
        print(f"\n📋 Entity-level Classification Report:")
        print(entity_report_str)
        
        # Store the string report (seqeval doesn't provide dict output)
        metrics['entity_classification_report'] = entity_report_str
        
    except Exception as e:
        print(f"❌ Error calculating entity-level metrics: {e}")
        metrics['entity_error'] = str(e)
    
    print("\n" + "="*80)
    
    # 2. Token-level metrics
    print("🔤 TOKEN-LEVEL METRICS (Secondary)")
    print("-" * 50)
    
    try:
        # Overall token accuracy
        token_accuracy = sum(1 for t, p in zip(token_true_labels, token_pred_labels) if t == p) / len(token_true_labels)
        print(f"Token Accuracy: {token_accuracy:.4f}")
        
        # Get unique labels for token-level metrics
        unique_labels = sorted(list(set(token_true_labels + token_pred_labels)))
        
        # Token-level metrics (micro and macro averages)
        token_precision_micro = precision_score(token_true_labels, token_pred_labels, 
                                               labels=unique_labels, average='micro', zero_division=0)
        token_recall_micro = recall_score(token_true_labels, token_pred_labels, 
                                         labels=unique_labels, average='micro', zero_division=0)
        token_f1_micro = f1_score(token_true_labels, token_pred_labels, 
                                 labels=unique_labels, average='micro', zero_division=0)
        
        token_precision_macro = precision_score(token_true_labels, token_pred_labels, 
                                               labels=unique_labels, average='macro', zero_division=0)
        token_recall_macro = recall_score(token_true_labels, token_pred_labels, 
                                         labels=unique_labels, average='macro', zero_division=0)
        token_f1_macro = f1_score(token_true_labels, token_pred_labels, 
                                 labels=unique_labels, average='macro', zero_division=0)
        
        print(f"Token Precision (Micro): {token_precision_micro:.4f}")
        print(f"Token Recall (Micro):    {token_recall_micro:.4f}")
        print(f"Token F1-Score (Micro):  {token_f1_micro:.4f}")
        print(f"Token Precision (Macro): {token_precision_macro:.4f}")
        print(f"Token Recall (Macro):    {token_recall_macro:.4f}")
        print(f"Token F1-Score (Macro):  {token_f1_macro:.4f}")
        
        metrics.update({
            'token_accuracy': token_accuracy,
            'token_precision_micro': token_precision_micro,
            'token_recall_micro': token_recall_micro,
            'token_f1_micro': token_f1_micro,
            'token_precision_macro': token_precision_macro,
            'token_recall_macro': token_recall_macro,
            'token_f1_macro': token_f1_macro
        })
        
        # Token-level classification report (using sklearn)
        token_report = classification_report(token_true_labels, token_pred_labels, 
                                           labels=unique_labels, zero_division=0, output_dict=True)
        print(f"\n📋 Token-level Classification Report:")
        print(classification_report(token_true_labels, token_pred_labels, 
                                   labels=unique_labels, zero_division=0))
        
        # Serialize the classification report for JSON storage
        metrics['token_classification_report'] = serialize_classification_report(token_report)
        
    except Exception as e:
        print(f"❌ Error calculating token-level metrics: {e}")
        metrics['token_error'] = str(e)
    
    # 3. Additional analysis
    print("\n" + "="*80)
    print("📈 ADDITIONAL ANALYSIS")
    print("-" * 50)
    
    # Count statistics
    total_tokens = len(token_true_labels)
    entity_tokens = sum(1 for label in token_true_labels if label != 'O')
    o_tokens = total_tokens - entity_tokens
    
    print(f"Total tokens: {total_tokens:,}")
    print(f"Entity tokens: {entity_tokens:,} ({entity_tokens/total_tokens*100:.1f}%)")
    print(f"O tokens: {o_tokens:,} ({o_tokens/total_tokens*100:.1f}%)")
    print(f"Total sequences: {len(seqeval_true_labels):,}")
    
    # Entity type distribution
    entity_counts = Counter([label for label in token_true_labels if label != 'O'])
    print(f"\n🏷️ Entity Type Distribution:")
    for entity_type, count in entity_counts.most_common():
        print(f"  {entity_type}: {count:,}")
    
    metrics.update({
        'total_tokens': total_tokens,
        'entity_tokens': entity_tokens,
        'o_tokens': o_tokens,
        'total_sequences': len(seqeval_true_labels),
        'entity_distribution': dict(entity_counts)
    })
    
    # 4. Error analysis
    confusion_analysis = analyze_confusion_patterns(token_true_labels, token_pred_labels)
    metrics['confusion_analysis'] = confusion_analysis
    
    # Save comprehensive metrics
    metrics_path = os.path.join(output_dir, "evaluation_metrics.json")
    with open(metrics_path, 'w', encoding='utf-8') as f:
        json.dump(metrics, f, ensure_ascii=False, indent=2, default=str)
    print(f"\n💾 Comprehensive metrics saved to {metrics_path}")
    
    return metrics


# Main Execution

In [9]:
print("📥 Loading data...")
# If load_json_files_robust uses glob, make sure to use glob.glob, not glob()
all_tokens, all_tags, input_files = load_json_files_robust(input_dir)

if not all_tokens:
    print("❌ No data loaded. Please check your input directory and file format.")
    exit()

print(f"✅ Successfully loaded {len(all_tokens)} sentences from {len(input_files)} files")

# Check for duplicate sentences to prevent data leakage
sentence_strings = [" ".join(tokens) for tokens in all_tokens]
duplicate_counts = Counter(sentence_strings)
duplicates = {sent: count for sent, count in duplicate_counts.items() if count > 1}
if duplicates:
    print(f"⚠️ Warning: {len(duplicates)} duplicate sentences found:")
    for sent, count in list(duplicates.items())[:5]:  # Show up to 5 duplicates
        print(f"  Sentence: '{sent[:50]}...' (repeated {count} times)")
else:
    print("✅ No duplicate sentences found.")

# Set random seed
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Create label mappings
unique_labels = set(tag for tags in all_tags for tag in tags)
unique_labels.add('O')
sorted_labels = sorted(unique_labels)
label2id = {label: idx for idx, label in enumerate(sorted_labels)}
id2label = {idx: label for label, idx in label2id.items()}

print(f"🏷️ Label mapping created:")
print(f"  Total labels: {len(label2id)}")
print(f"  Labels: {sorted_labels}")

# Function to print entity distribution
def print_entity_distribution(tags, name):
    entity_counts = Counter(tag for sentence in tags for tag in sentence if tag != 'O')
    total_entities = sum(entity_counts.values())
    print(f"\n📊 {name} entity distribution:")
    table_data = [[tag, count, f"{(count/total_entities)*100:.2f}%"] for tag, count in sorted(entity_counts.items())]
    headers = ['Entity', 'Count', 'Percentage']
    print(tabulate(table_data, headers=headers, tablefmt='grid'))


# Split data
train_tokens, temp_tokens, train_tags, temp_tags = train_test_split(
    all_tokens, all_tags, test_size=0.2, random_state=seed, stratify=None
)

val_tokens, test_tokens, val_tags, test_tags = train_test_split(
    temp_tokens, temp_tags, test_size=0.5, random_state=seed, stratify=None
)

print(f"\n📊 Data split:")
print(f"  Training: {len(train_tokens):,} sentences")
print(f"  Validation: {len(val_tokens):,} sentences")
print(f"  Test: {len(test_tokens):,} sentences")

# Verify entity distribution in each split
print_entity_distribution(train_tags, "Training")
print_entity_distribution(val_tags, "Validation")
print_entity_distribution(test_tags, "Test")

# Save split information
split_info = {
    "train_sentences": len(train_tokens),
    "val_sentences": len(val_tokens),
    "test_sentences": len(test_tokens),
    "train_entity_distribution": dict(Counter(tag for sentence in train_tags for tag in sentence if tag != 'O')),
    "val_entity_distribution": dict(Counter(tag for sentence in val_tags for tag in sentence if tag != 'O')),
    "test_entity_distribution": dict(Counter(tag for sentence in test_tags for tag in sentence if tag != 'O'))
}
split_info_path = os.path.join(output_dir, "split_info.json")
os.makedirs(output_dir, exist_ok=True)
with open(split_info_path, 'w', encoding='utf-8') as f:
    json.dump(split_info, f, ensure_ascii=False, indent=2)
print(f"💾 Split information saved to {split_info_path}")

# Load model and tokenizer
model_name = "GKLMIP/bert-khmer-small-uncased"
print(f"🤖 Loading model: {model_name}")

try:
    tokenizer = BertTokenizerFast.from_pretrained(model_name)
    model = BertForTokenClassification.from_pretrained(
        model_name,
        num_labels=len(label2id),
        id2label=id2label,
        label2id=label2id,
        ignore_mismatched_sizes=True
    )
    print("✅ Model and tokenizer loaded successfully")
except Exception as e:
    print(f"❌ Error loading model: {e}")
    exit()

# Prepare datasets
print("🔄 Preparing datasets...")
train_encodings = prepare_dataset_robust(train_tokens, train_tags, tokenizer, label2id)
val_encodings = prepare_dataset_robust(val_tokens, val_tags, tokenizer, label2id)
test_encodings = prepare_dataset_robust(test_tokens, test_tags, tokenizer, label2id)

train_dataset = NERDataset(train_encodings)
val_dataset = NERDataset(val_encodings)
test_dataset = NERDataset(test_encodings)

# Fine-tune model
print("🎓 Starting model training...")
trained_model = train_model_trainer(
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    output_dir=output_dir,
    tokenizer=tokenizer,
    num_epochs=10,
    batch_size=8,
    learning_rate=2e-5
)

# Evaluate model
print("📊 Starting model evaluation...")
predictions, true_tags, pred_tags, metrics = evaluate_model_standard(
    model=trained_model,
    tokenizer=tokenizer,
    test_dataset=test_dataset,
    id2label=id2label,
    output_dir=output_dir
)



# Save label mappings
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "label2id.json"), 'w', encoding='utf-8') as f:
    json.dump(label2id, f, ensure_ascii=False, indent=2)
with open(os.path.join(output_dir, "id2label.json"), 'w', encoding='utf-8') as f:
    json.dump(id2label, f, ensure_ascii=False, indent=2)
print(f"💾 Label mappings saved to {output_dir}")

print("🎉 Fine-tuning and evaluation completed!")

📥 Loading data...
📂 Processing 525 files...
  📄 Processing: object_d2279a49-8b25-4e4c-b936-62f88487a895.json
  📄 Processing: object_e1a96cb5-7830-4846-a601-0a0357497b28.json
  📄 Processing: object_792495e7-39e6-4902-90b7-5edecd04877b.json
  📄 Processing: object_1fe657cc-df5b-4309-aaf2-e25e485b8ff4.json
  📄 Processing: object_04d28d27-96e9-4d95-b946-c1475b2207d2.json

📊 Data loading summary:
  ✅ Processed sentences: 6221
  ⚠️  Skipped sentences: 12
  ❌ Total errors: 0
✅ Successfully loaded 6221 sentences from 525 files
  Sentence: 'លោក ថ្លែង ថា៖...' (repeated 8 times)
  Sentence: 'លោក ថ្លែង ដោយ ប្រែ សម្រួល ជា ភាសា ខ្មែរ ថា៖...' (repeated 2 times)
  Sentence: '១៖...' (repeated 11 times)
  Sentence: '( ១) ៖...' (repeated 2 times)
  Sentence: 'ក្រសួង សុខាភិបាល   បាន អះអាង ថា   ការផ្សព្វផ្សាយ ផ...' (repeated 2 times)
🏷️ Label mapping created:
  Total labels: 17
  Labels: ['B-Date', 'B-Disease', 'B-HumanCount', 'B-Location', 'B-Medication', 'B-Organization', 'B-Pathogen', 'B-Symptom', 'I-Dat

Some weights of BertForTokenClassification were not initialized from the model checkpoint at GKLMIP/bert-khmer-small-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✅ Model and tokenizer loaded successfully
🔄 Preparing datasets...
🔄 Preparing dataset...
✅ Prepared 4976 examples, skipped 0
🔄 Preparing dataset...
✅ Prepared 622 examples, skipped 0
🔄 Preparing dataset...
✅ Prepared 623 examples, skipped 0
🎓 Starting model training...
🖥️  Using device: cuda
   GPU: NVIDIA GeForce RTX 4070 Ti SUPER
   Memory: 16.9 GB
🚀 Starting training for Khmer NER...
   Epochs: 10
   Per-device batch size: 8
   Gradient accumulation steps: 2
   Effective batch size: 16
   Learning rate: 2e-05
   Early stopping patience: 3 epochs


Epoch,Training Loss,Validation Loss
1,0.61,0.207286
2,0.1894,0.172495
3,0.1453,0.153441
4,0.1244,0.148791
5,0.1092,0.150308
6,0.0988,0.151319
7,0.0927,0.153856


✅ Training completed!
💾 Final model saved to /home/guest/Public/KHEED/KHEED_Data_Collection/Evaluation/Models/bert_khmer_ner_model/final_model
📊 Starting model evaluation...
🔍 Evaluating model on test set...
Running inference...


Evaluating: 100%|██████████| 78/78 [00:00<00:00, 455.37it/s]


✅ Evaluation completed on 623 samples
💾 Detailed predictions saved to /home/guest/Public/KHEED/KHEED_Data_Collection/Evaluation/Models/bert_khmer_ner_model/predictions.json
💾 CoNLL format saved to /home/guest/Public/KHEED/KHEED_Data_Collection/Evaluation/Models/bert_khmer_ner_model/predictions.conll

📊 Comprehensive NER Evaluation Results:
🎯 ENTITY-LEVEL METRICS (Primary - Standard for NER)
--------------------------------------------------
Entity Precision: 0.6575
Entity Recall:    0.7595
Entity F1-Score:  0.7048
Entity Accuracy:  0.9531

📋 Entity-level Classification Report:
              precision    recall  f1-score   support

     Symptom       0.37      0.59      0.46        32
Organization       0.53      0.66      0.59       270
    Location       0.72      0.82      0.77       260
        Date       0.82      0.84      0.83       154
     Disease       0.73      0.80      0.77       287
  HumanCount       0.71      0.84      0.77        91
  Medication       0.00      0.00    