# Imports and Environment Setup

In [1]:
import json
import os
import torch
from torch.utils.data import DataLoader, default_collate
from transformers import MBartConfig, MBartModel
from sklearn.model_selection import train_test_split
from seqeval.metrics import precision_score, recall_score, f1_score
import glob
from tqdm import tqdm
from khmernltk import word_tokenize

# Suppress tokenizer parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Set seed for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Verify CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == "cuda":
    print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
    torch.cuda.empty_cache()  # Clear CUDA cache

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
CUDA Device: NVIDIA GeForce RTX 4070 Ti SUPER


# Custom Collate Class

In [2]:
def custom_collate_fn(batch):
    batch = [{k: v for k, v in sample.items() if k in ['input_ids', 'attention_mask', 'labels']} for sample in batch]
    return default_collate(batch)

# Data Loading Function

In [3]:
def load_json_files(input_dir, input_files):
    all_tokens = []
    all_tags = []
    unique_tags = set()
    
    print(f"Processing {len(input_files)} files...")
    for input_file in input_files:
        if not os.path.exists(input_file):
            print(f"Warning: File {input_file} not found.")
            continue
        try:
            with open(input_file, 'r', encoding='utf-8') as f:
                obj = json.load(f)
        except json.JSONDecodeError:
            print(f"Error: Invalid JSON in {input_file}")
            continue
        processed_content = obj.get('processed_content', [])
        if not processed_content:
            print(f"No processed_content in {input_file}")
            continue
        for sentence_idx, sentence_data in enumerate(processed_content):
            tokens = sentence_data.get('tokens', [])
            bio_tags = sentence_data.get('bio_tags', [])
            if not tokens or not bio_tags:
                print(f"Warning: Empty tokens or bio_tags in {input_file}, sentence {sentence_idx}")
                continue
            flattened_tags = [tags if isinstance(tags, str) else tags[0] if tags else "O" for tags in bio_tags]
            if len(tokens) != len(flattened_tags):
                print(f"Warning: Mismatch in {input_file}, sentence {sentence_idx}: "
                      f"{len(tokens)} tokens, {len(flattened_tags)} tags")
                continue
            all_tokens.append(tokens)
            all_tags.append(flattened_tags)
            unique_tags.update(flattened_tags)
    
    print(f"Loaded {len(all_tokens)} sentences")
    print(f"Unique BIO tags: {sorted(unique_tags)}")
    if not all_tokens:
        raise ValueError("No valid data loaded. Check file names, JSON structure, or data content.")
    return all_tokens, all_tags

# Define path and Output Directory

In [4]:
# Define paths for Ubuntu
input_dir = "/home/guest/Public/KHEED/KHEED_Data_Collection/Final/bio_tagged"
output_dir = "/home/guest/Public/KHEED/KHEED_Data_Collection/Evaluation/prahokbart_ner_model"
os.makedirs(output_dir, exist_ok=True)

# Verify input directory
print(f"Input directory exists: {os.path.exists(input_dir)}")
print(f"Files in directory: {os.listdir(input_dir)[:5]}")

# Load input files
input_files = sorted(glob.glob(os.path.join(input_dir, "object_*.json")))
if not input_files:
    raise FileNotFoundError(f"No files matching 'object_*.json' found in {input_dir}")
print(f"Found {len(input_files)} files: {input_files[:5]}")

# Save input_files to JSON
input_files_path = os.path.join(output_dir, "input_files.json")
try:
    with open(input_files_path, 'w', encoding='utf-8') as f:
        json.dump(input_files, f)
    print(f"Updated {input_files_path} with {len(input_files)} files")
except PermissionError:
    raise PermissionError(f"Cannot write to {input_files_path}. Check permissions.")

Input directory exists: True
Files in directory: ['object_d2279a49-8b25-4e4c-b936-62f88487a895.json', 'object_e1a96cb5-7830-4846-a601-0a0357497b28.json', 'object_792495e7-39e6-4902-90b7-5edecd04877b.json', 'object_1fe657cc-df5b-4309-aaf2-e25e485b8ff4.json', 'object_04d28d27-96e9-4d95-b946-c1475b2207d2.json']
Found 525 files: ['/home/guest/Public/KHEED/KHEED_Data_Collection/Final/bio_tagged/object_014f8a42-78bf-4e17-970b-c74bbb61a812.json', '/home/guest/Public/KHEED/KHEED_Data_Collection/Final/bio_tagged/object_017d1530-4fa2-4a7a-a1cb-0621640d579d.json', '/home/guest/Public/KHEED/KHEED_Data_Collection/Final/bio_tagged/object_01c58b61-821c-43e0-a7c6-8a858837fb0f.json', '/home/guest/Public/KHEED/KHEED_Data_Collection/Final/bio_tagged/object_0269da05-9c7b-447c-84c8-e3060217538d.json', '/home/guest/Public/KHEED/KHEED_Data_Collection/Final/bio_tagged/object_0274d277-8a68-4dae-8911-89da4097ba9c.json']
Updated /home/guest/Public/KHEED/KHEED_Data_Collection/Evaluation/prahokbart_ner_model/input

# Load Tag Mapping

In [5]:
# Load tag2idx
tag2idx_path = os.path.join(output_dir, "tag2idx.json")
if not os.path.exists(tag2idx_path):
    raise FileNotFoundError(f"Missing tag2idx.json in {output_dir}")
try:
    with open(tag2idx_path, 'r', encoding='utf-8') as f:
        tag2idx = json.load(f)
except PermissionError:
    raise PermissionError(f"Cannot read {tag2idx_path}. Check permissions.")

idx2tag = {int(idx): tag for tag, idx in tag2idx.items()}
print(f"Loaded {len(tag2idx)} tags: {list(tag2idx.keys())}")

Loaded 17 tags: ['B-Date', 'B-Disease', 'B-HumanCount', 'B-Location', 'B-Medication', 'B-Organization', 'B-Pathogen', 'B-Symptom', 'I-Date', 'I-Disease', 'I-HumanCount', 'I-Location', 'I-Medication', 'I-Organization', 'I-Pathogen', 'I-Symptom', 'O']


# Load and Split Data

In [6]:
# Load data
all_tokens, all_tags = load_json_files(input_dir, input_files)
if not all_tokens:
    raise ValueError("No data loaded. Check input files or directory.")

# Verify tag2idx
unique_tags = set(tag for tags in all_tags for tag in tags)
missing_tags = unique_tags - set(tag2idx.keys())
if missing_tags:
    print(f"Warning: Tags in data not in tag2idx: {missing_tags}")
    tag2idx = {tag: idx for idx, tag in enumerate(sorted(unique_tags))}
    try:
        with open(tag2idx_path, 'w', encoding='utf-8') as f:
            json.dump(tag2idx, f)
        print(f"Updated {tag2idx_path} with {len(tag2idx)} tags")
    except PermissionError:
        raise PermissionError(f"Cannot write to {tag2idx_path}. Check permissions.")

# Split data
train_tokens, temp_tokens, train_tags, temp_tags = train_test_split(
        all_tokens, all_tags, test_size=0.2, random_state=42, stratify=None
    )
    
val_tokens, test_tokens, val_tags, test_tags = train_test_split(
    temp_tokens, temp_tags, test_size=0.5, random_state=42, stratify=None
)

print(f"Train: {len(train_tokens)}, Val: {len(val_tokens)}, Test: {len(test_tokens)}")

Processing 525 files...
Loaded 6221 sentences
Unique BIO tags: ['B-Date', 'B-Disease', 'B-HumanCount', 'B-Location', 'B-Medication', 'B-Organization', 'B-Pathogen', 'B-Symptom', 'I-Date', 'I-Disease', 'I-HumanCount', 'I-Location', 'I-Medication', 'I-Organization', 'I-Pathogen', 'I-Symptom', 'O']
Train: 4976, Val: 622, Test: 623


# Define Dataset Class

In [7]:
class PrahokBARTNERDataset(torch.utils.data.Dataset):
    def __init__(self, tokens, tags, word2idx, tag2idx, max_len=128):
        self.tokens = tokens
        self.tags = tags
        self.word2idx = word2idx  # Custom word2idx dictionary
        self.tag2idx = tag2idx
        self.max_len = max_len

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        tokens = self.tokens[idx]
        tags = self.tags[idx]

        # Tokenize with khmernltk
        if isinstance(tokens, str):
            khmer_tokens = word_tokenize(tokens)
        else:
            khmer_tokens = tokens  # Assume pre-tokenized

        # Convert tokens to input IDs using word2idx
        input_ids = []
        for token in khmer_tokens:
            # Use word2idx ID or <unk> (ID 3) if token not found
            input_ids.append(self.word2idx.get(token, self.word2idx.get('<unk>', 3)))

        # Truncate or pad to max_len
        if len(input_ids) > self.max_len:
            input_ids = input_ids[:self.max_len]
            aligned_tags = tags[:self.max_len]
        else:
            input_ids += [self.word2idx.get('<pad>', 1)] * (self.max_len - len(input_ids))
            aligned_tags = tags + ['O'] * (self.max_len - len(tags)) if len(tags) < self.max_len else tags[:self.max_len]

        # Create attention mask
        attention_mask = [1 if idx != self.word2idx.get('<pad>', 1) else 0 for idx in input_ids]

        # Align tags with tokens
        labels = [-100] * self.max_len
        for i in range(min(len(khmer_tokens), self.max_len)):
            if i < len(aligned_tags):
                labels[i] = self.tag2idx.get(aligned_tags[i], self.tag2idx.get('O', 0))

        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long)
        }

# Define Model Class

In [8]:
class PrahokBARTForNER(torch.nn.Module):
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.mbart = MBartModel.from_pretrained(model_name)
        self.dropout = torch.nn.Dropout(0.1)
        self.classifier = torch.nn.Linear(self.mbart.config.d_model, num_labels)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.mbart(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))

        return loss, logits

# Initialize Tokenizer, Dataset, Dataloaders

In [9]:
# Initialize word2idx
word2idx_path = os.path.join(output_dir, "word2idx.json")
if not os.path.exists(word2idx_path):
    raise FileNotFoundError(f"Missing word2idx.json in {output_dir}")
try:
    with open(word2idx_path, 'r', encoding='utf-8') as f:
        word2idx = json.load(f)
except Exception as e:
    raise Exception(f"Failed to load word2idx from {word2idx_path}: {str(e)}")
print(f"Loaded word2idx with {len(word2idx)} tokens")

# Create datasets
train_dataset = PrahokBARTNERDataset(train_tokens, train_tags, word2idx, tag2idx)
val_dataset = PrahokBARTNERDataset(val_tokens, val_tags, word2idx, tag2idx)
test_dataset = PrahokBARTNERDataset(test_tokens, test_tags, word2idx, tag2idx)

# Create data loaders
batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=custom_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=custom_collate_fn)

print(f"Data loaders created: {len(train_loader)} train batches, {len(val_loader)} val batches, {len(test_loader)} test batches")

Loaded word2idx with 9705 tokens
Data loaders created: 1244 train batches, 156 val batches, 156 test batches


# Training Loop

In [10]:
model_name = "nict-astrec-att/prahokbart_base"

# Initialize model and optimizer
model = PrahokBARTForNER(model_name, num_labels=len(tag2idx)).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

# Early stopping parameters
best_val_loss = float('inf')
patience_counter = 0
patience = 3
num_epochs = 20

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    train_progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]", mininterval=1, leave=True)
    for batch_idx, batch in enumerate(train_progress):
        try:
            # Debug: Print batch info
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            optimizer.zero_grad()
            loss, _ = model(input_ids, attention_mask, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            train_progress.set_postfix({'batch_loss': loss.item()})
            
        
        except RuntimeError as e:
            print(f"Error in batch {batch_idx+1}: {str(e)}")
            if "CUDA out of memory" in str(e):
                print("CUDA out of memory. Try reducing batch_size or max_len.")
                torch.cuda.empty_cache()
                raise
            raise
        
    avg_train_loss = total_loss / len(train_loader)
    
    # Validation loop
    model.eval()
    total_val_loss = 0
    val_progress = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]", mininterval=1, leave=True)
    with torch.no_grad():
        for batch_idx, batch in enumerate(val_progress):
            try:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                
                loss, _ = model(input_ids, attention_mask, labels)
                total_val_loss += loss.item()
                val_progress.set_postfix({'batch_loss': loss.item()})
            
            except RuntimeError as e:
                print(f"Error in validation batch {batch_idx+1}: {str(e)}")
                if "CUDA out of memory" in str(e):
                    print("CUDA out of memory in validation. Try reducing batch_size.")
                    torch.cuda.empty_cache()
                    raise
                raise
    
    avg_val_loss = total_val_loss / len(val_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
    
    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        try:
            torch.save(model.state_dict(), os.path.join(output_dir, "prahokbart_ner.pt"))
            print(f"Saved best model at epoch {epoch+1}")
        except PermissionError:
            raise PermissionError(f"Cannot write to {os.path.join(output_dir, 'prahokbart_ner.pt')}. Check permissions.")
        patience_counter = 0
    else:
        patience_counter += 1
    
    if patience_counter >= patience:
        print(f"Early stopping at epoch {epoch+1}")
        break

Epoch 1/20 [Train]:   0%|          | 0/1244 [00:00<?, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
Epoch 1/20 [Train]: 100%|██████████| 1244/1244 [00:26<00:00, 47.68it/s, batch_loss=1.33]  
Epoch 1/20 [Val]: 100%|██████████| 156/156 [00:00<00:00, 232.93it/s, batch_loss=0.546]


Epoch 1/20, Train Loss: 0.7207, Val Loss: 0.4989
Saved best model at epoch 1


Epoch 2/20 [Train]: 100%|██████████| 1244/1244 [00:26<00:00, 47.84it/s, batch_loss=0.309] 
Epoch 2/20 [Val]: 100%|██████████| 156/156 [00:00<00:00, 235.58it/s, batch_loss=0.45]


Epoch 2/20, Train Loss: 0.4680, Val Loss: 0.3723
Saved best model at epoch 2


Epoch 3/20 [Train]: 100%|██████████| 1244/1244 [00:26<00:00, 47.83it/s, batch_loss=0.444] 
Epoch 3/20 [Val]: 100%|██████████| 156/156 [00:00<00:00, 236.00it/s, batch_loss=0.324]


Epoch 3/20, Train Loss: 0.3683, Val Loss: 0.3043
Saved best model at epoch 3


Epoch 4/20 [Train]: 100%|██████████| 1244/1244 [00:26<00:00, 47.85it/s, batch_loss=0.164] 
Epoch 4/20 [Val]: 100%|██████████| 156/156 [00:00<00:00, 234.87it/s, batch_loss=0.25]


Epoch 4/20, Train Loss: 0.3032, Val Loss: 0.2561
Saved best model at epoch 4


Epoch 5/20 [Train]: 100%|██████████| 1244/1244 [00:26<00:00, 47.78it/s, batch_loss=0.274]  
Epoch 5/20 [Val]: 100%|██████████| 156/156 [00:00<00:00, 235.61it/s, batch_loss=0.204]


Epoch 5/20, Train Loss: 0.2492, Val Loss: 0.2311
Saved best model at epoch 5


Epoch 6/20 [Train]: 100%|██████████| 1244/1244 [00:26<00:00, 47.15it/s, batch_loss=0.156]  
Epoch 6/20 [Val]: 100%|██████████| 156/156 [00:00<00:00, 235.00it/s, batch_loss=0.163]


Epoch 6/20, Train Loss: 0.2189, Val Loss: 0.2213
Saved best model at epoch 6


Epoch 7/20 [Train]: 100%|██████████| 1244/1244 [00:26<00:00, 47.78it/s, batch_loss=0.441]  
Epoch 7/20 [Val]: 100%|██████████| 156/156 [00:00<00:00, 234.76it/s, batch_loss=0.174]


Epoch 7/20, Train Loss: 0.1950, Val Loss: 0.2059
Saved best model at epoch 7


Epoch 8/20 [Train]: 100%|██████████| 1244/1244 [00:26<00:00, 47.58it/s, batch_loss=0.0695] 
Epoch 8/20 [Val]: 100%|██████████| 156/156 [00:00<00:00, 234.65it/s, batch_loss=0.148]


Epoch 8/20, Train Loss: 0.1757, Val Loss: 0.1967
Saved best model at epoch 8


Epoch 9/20 [Train]: 100%|██████████| 1244/1244 [00:26<00:00, 47.44it/s, batch_loss=0.468]  
Epoch 9/20 [Val]: 100%|██████████| 156/156 [00:00<00:00, 233.71it/s, batch_loss=0.151]


Epoch 9/20, Train Loss: 0.1561, Val Loss: 0.1937
Saved best model at epoch 9


Epoch 10/20 [Train]: 100%|██████████| 1244/1244 [00:26<00:00, 47.18it/s, batch_loss=0.121]  
Epoch 10/20 [Val]: 100%|██████████| 156/156 [00:00<00:00, 231.14it/s, batch_loss=0.134]


Epoch 10/20, Train Loss: 0.1445, Val Loss: 0.1871
Saved best model at epoch 10


Epoch 11/20 [Train]: 100%|██████████| 1244/1244 [00:26<00:00, 47.62it/s, batch_loss=0.0263] 
Epoch 11/20 [Val]: 100%|██████████| 156/156 [00:00<00:00, 234.87it/s, batch_loss=0.142]


Epoch 11/20, Train Loss: 0.1308, Val Loss: 0.1901


Epoch 12/20 [Train]: 100%|██████████| 1244/1244 [00:26<00:00, 47.61it/s, batch_loss=0.212]  
Epoch 12/20 [Val]: 100%|██████████| 156/156 [00:00<00:00, 235.43it/s, batch_loss=0.154]


Epoch 12/20, Train Loss: 0.1207, Val Loss: 0.1785
Saved best model at epoch 12


Epoch 13/20 [Train]: 100%|██████████| 1244/1244 [00:26<00:00, 47.58it/s, batch_loss=0.0498]  
Epoch 13/20 [Val]: 100%|██████████| 156/156 [00:00<00:00, 235.06it/s, batch_loss=0.145]


Epoch 13/20, Train Loss: 0.1129, Val Loss: 0.1801


Epoch 14/20 [Train]: 100%|██████████| 1244/1244 [00:26<00:00, 47.55it/s, batch_loss=0.107]  
Epoch 14/20 [Val]: 100%|██████████| 156/156 [00:00<00:00, 233.60it/s, batch_loss=0.163]


Epoch 14/20, Train Loss: 0.1029, Val Loss: 0.1827


Epoch 15/20 [Train]: 100%|██████████| 1244/1244 [00:26<00:00, 47.57it/s, batch_loss=0.141]  
Epoch 15/20 [Val]: 100%|██████████| 156/156 [00:00<00:00, 231.07it/s, batch_loss=0.167]

Epoch 15/20, Train Loss: 0.0941, Val Loss: 0.1838
Early stopping at epoch 15





# Evaluate and Save

In [11]:
# Create DataFrame for tabular formatting
import pandas as pd
import numpy as np
from seqeval.metrics import classification_report, precision_score, recall_score, f1_score
from collections import defaultdict
import torch
import json
import os
from tqdm import tqdm

# Load best model
try:
    model.load_state_dict(torch.load(os.path.join(output_dir, "prahokbart_ner.pt"), map_location=device))
    print("Loaded best model for evaluation")
except PermissionError:
    raise PermissionError(f"Cannot read {os.path.join(output_dir, 'prahokbart_ner.pt')}. Check permissions.")

# Evaluate on test set
model.eval()
all_true_tags = []
all_pred_tags = []
all_predictions = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Evaluating Test Set"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels']
        
        _, logits = model(input_ids, attention_mask)
        pred_tag_ids = torch.argmax(logits, dim=-1)
        
        for pred_ids, true_ids in zip(pred_tag_ids, labels):
            true_tags = []
            for idx in true_ids:
                idx_val = idx.item()
                if idx_val != -100 and 0 <= idx_val < len(idx2tag):
                    true_tags.append(idx2tag[idx_val])
            
            pred_tags = []
            for idx in pred_ids[:len(true_tags)]:
                idx_val = idx.item()
                if 0 <= idx_val < len(idx2tag):
                    pred_tags.append(idx2tag[idx_val])
                else:
                    pred_tags.append('O')  # fallback to 'O' if out of range
            
            all_true_tags.append(true_tags)
            all_pred_tags.append(pred_tags)
            all_predictions.append({
                'tokens': [],
                'true_tags': true_tags,
                'pred_tags': pred_tags
            })

# Save predictions to JSON
json_path = os.path.join(output_dir, "predictions.json")
try:
    with open(json_path, 'w', encoding='utf-8') as f:
        json.dump(all_predictions, f, ensure_ascii=False, indent=2)
    print(f"Saved predictions to {json_path}")
except PermissionError:
    raise PermissionError(f"Cannot write to {json_path}. Check permissions.")

# Save predictions to CoNLL format
conll_path = os.path.join(output_dir, "predictions.conll")
try:
    with open(conll_path, 'w', encoding='utf-8') as f:
        for pred in all_predictions:
            for true_tag, pred_tag in zip(pred['true_tags'], pred['pred_tags']):
                f.write(f"_\t{true_tag}\t{pred_tag}\n")
            f.write("\n")
    print(f"Saved predictions to {conll_path}")
except PermissionError:
    raise PermissionError(f"Cannot write to {conll_path}. Check permissions.")

# Custom function to compute NER metrics
def compute_ner_metrics(true_tags_list, pred_tags_list):
    # Extract unique categories (excluding 'O' and IOB2 prefixes)
    categories = set()
    for tags in true_tags_list + pred_tags_list:
        for tag in tags:
            if tag != 'O':
                category = tag.split('-')[-1]  # Get category (e.g., 'Disease' from 'B-Disease')
                categories.add(category)
    categories = sorted(categories)

    # Initialize counters
    metrics = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0, 'support': 0})

    def extract_entities(tags):
        entities = []
        current_entity = None
        for i, tag in enumerate(tags):
            if tag == 'O':
                if current_entity:
                    entities.append((current_entity[0], current_entity[1], current_entity[2]))
                    current_entity = None
                continue
            
            if '-' in tag:
                prefix, category = tag.split('-', 1)
            else:
                prefix, category = 'B', tag  # Assume B- if no prefix
            
            if prefix == 'B' or (prefix == 'I' and not current_entity):
                if current_entity:
                    entities.append((current_entity[0], current_entity[1], current_entity[2]))
                current_entity = (category, i, i)
            elif prefix == 'I' and current_entity and current_entity[0] == category:
                current_entity = (current_entity[0], current_entity[1], i)
            else:
                if current_entity:
                    entities.append((current_entity[0], current_entity[1], current_entity[2]))
                    current_entity = None
                if prefix == 'B':
                    current_entity = (category, i, i)
        
        if current_entity:
            entities.append((current_entity[0], current_entity[1], current_entity[2]))
        return entities

    # Process each sequence
    for true_tags, pred_tags in zip(true_tags_list, pred_tags_list):
        true_entities = extract_entities(true_tags)
        pred_entities = extract_entities(pred_tags)

        # Group entities by category
        true_by_category = defaultdict(set)
        pred_by_category = defaultdict(set)
        
        for category, start, end in true_entities:
            true_by_category[category].add((start, end))
        
        for category, start, end in pred_entities:
            pred_by_category[category].add((start, end))

        # Calculate metrics for each category
        for category in categories:
            true_cat_entities = true_by_category[category]
            pred_cat_entities = pred_by_category[category]

            metrics[category]['support'] += len(true_cat_entities)
            metrics[category]['tp'] += len(true_cat_entities & pred_cat_entities)
            metrics[category]['fp'] += len(pred_cat_entities - true_cat_entities)
            metrics[category]['fn'] += len(true_cat_entities - pred_cat_entities)

    # Compute precision, recall, F1-score
    results = {}
    total_tp, total_fp, total_fn, total_support = 0, 0, 0, 0
    
    for category in categories:
        tp = metrics[category]['tp']
        fp = metrics[category]['fp']
        fn = metrics[category]['fn']
        support = metrics[category]['support']
        
        total_tp += tp
        total_fp += fp
        total_fn += fn
        total_support += support

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

        results[category] = {
            'precision': precision,
            'recall': recall,
            'f1-score': f1,
            'support': support
        }

    # Compute micro, macro, weighted averages
    micro_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
    micro_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
    micro_f1 = 2 * (micro_precision * micro_recall) / (micro_precision + micro_recall) if (micro_precision + micro_recall) > 0 else 0.0

    macro_precision = np.mean([results[cat]['precision'] for cat in categories]) if categories else 0.0
    macro_recall = np.mean([results[cat]['recall'] for cat in categories]) if categories else 0.0
    macro_f1 = np.mean([results[cat]['f1-score'] for cat in categories]) if categories else 0.0

    weighted_precision = sum(results[cat]['precision'] * results[cat]['support'] for cat in categories) / total_support if total_support > 0 else 0.0
    weighted_recall = sum(results[cat]['recall'] * results[cat]['support'] for cat in categories) / total_support if total_support > 0 else 0.0
    weighted_f1 = sum(results[cat]['f1-score'] * results[cat]['support'] for cat in categories) / total_support if total_support > 0 else 0.0

    results['micro avg'] = {'precision': micro_precision, 'recall': micro_recall, 'f1-score': micro_f1, 'support': total_support}
    results['macro avg'] = {'precision': macro_precision, 'recall': macro_recall, 'f1-score': macro_f1, 'support': total_support}
    results['weighted avg'] = {'precision': weighted_precision, 'recall': weighted_recall, 'f1-score': weighted_f1, 'support': total_support}

    return results, micro_precision, micro_recall, micro_f1

# Compute metrics
report_dict, overall_precision, overall_recall, overall_f1 = compute_ner_metrics(all_true_tags, all_pred_tags)

# Create DataFrame for display
categories = []
precisions = []
recalls = []
f1_scores = []
supports = []

# Add individual categories
for label, scores in report_dict.items():
    if label not in ['micro avg', 'macro avg', 'weighted avg']:
        categories.append(label)
        precisions.append(scores['precision'])
        recalls.append(scores['recall'])
        f1_scores.append(scores['f1-score'])
        supports.append(scores['support'])

# Add averages
for avg_type in ['micro avg', 'macro avg', 'weighted avg']:
    if avg_type in report_dict:
        categories.append(avg_type.title())
        precisions.append(report_dict[avg_type]['precision'])
        recalls.append(report_dict[avg_type]['recall'])
        f1_scores.append(report_dict[avg_type]['f1-score'])
        supports.append(report_dict[avg_type]['support'])

# Create DataFrame
df = pd.DataFrame({
    'Category': categories,
    'Precision': precisions,
    'Recall': recalls,
    'F1-Score': f1_scores,
    'Support': supports
})

# Format numbers to two decimal places
df['Precision'] = df['Precision'].round(4)
df['Recall'] = df['Recall'].round(4)
df['F1-Score'] = df['F1-Score'].round(4)
df['Support'] = df['Support'].astype(int)

# Print formatted table
print("\nEvaluation Metrics:")
print(df.to_string(index=False))

# Print overall metrics
print("\nOverall Metrics:")
print(f"Precision: {overall_precision:.4f}")
print(f"Recall: {overall_recall:.4f}")
print(f"F1-Score: {overall_f1:.4f}")

# Save metrics to file
metrics_path = os.path.join(output_dir, "evaluation_metrics.json")
try:
    with open(metrics_path, 'w', encoding='utf-8') as f:
        json.dump(report_dict, f, ensure_ascii=False, indent=2)
    print(f"Saved evaluation metrics to {metrics_path}")
except PermissionError:
    raise PermissionError(f"Cannot write to {metrics_path}. Check permissions.")

Loaded best model for evaluation


Evaluating Test Set: 100%|██████████| 156/156 [00:00<00:00, 200.82it/s]

Saved predictions to /home/guest/Public/KHEED/KHEED_Data_Collection/Evaluation/prahokbart_ner_model/predictions.json
Saved predictions to /home/guest/Public/KHEED/KHEED_Data_Collection/Evaluation/prahokbart_ner_model/predictions.conll

Evaluation Metrics:
    Category  Precision  Recall  F1-Score  Support
        Date     0.6458  0.6458    0.6458      144
     Disease     0.6834  0.8163    0.7440      283
  HumanCount     0.4865  0.4932    0.4898       73
    Location     0.6317  0.6934    0.6611      287
  Medication     0.0968  0.1500    0.1176       20
Organization     0.4301  0.5845    0.4955      284
    Pathogen     0.7955  0.7609    0.7778       46
     Symptom     0.5833  0.5185    0.5490       54
   Micro Avg     0.5732  0.6641    0.6153     1191
   Macro Avg     0.5441  0.5828    0.5601     1191
Weighted Avg     0.5839  0.6641    0.6193     1191

Overall Metrics:
Precision: 0.5732
Recall: 0.6641
F1-Score: 0.6153
Saved evaluation metrics to /home/guest/Public/KHEED/KHEED_Data_


