In [1]:
import gradio as gr
import torch
from transformers import XLMRobertaForTokenClassification, XLMRobertaTokenizerFast, pipeline
import json
import logging
import pandas as pd
from typing import List, Dict, Tuple
import re

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('ner_interface.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Define paths
MODEL_PATH = "/home/guest/Public/KHEED/KHEED_Data_Collection/xlmr_ner_model/final_model"
LABEL2ID_PATH = "/home/guest/Public/KHEED/KHEED_Data_Collection/xlmr_ner_model/label2id.json"

# Define colors for entity types
ENTITY_COLORS = {
    "Disease": "red",
    "Organization": "blue",
    "Location": "green",
    "Date": "yellow",
    "AffectedCount": "purple",
    "Medication": "orange",
    "Symptom": "pink",
    "Pathogen": "cyan"
}

def load_model_and_tokenizer() -> Tuple[XLMRobertaForTokenClassification, XLMRobertaTokenizerFast, Dict, Dict]:
    """
    Load the fine-tuned model, tokenizer, and label mappings.
    """
    try:
        logger.info("Loading tokenizer from %s", MODEL_PATH)
        tokenizer = XLMRobertaTokenizerFast.from_pretrained(MODEL_PATH)
        logger.info("Tokenizer loaded successfully")

        logger.info("Loading model from %s", MODEL_PATH)
        model = XLMRobertaForTokenClassification.from_pretrained(MODEL_PATH)
        logger.info("Model loaded successfully")

        logger.info("Loading label mappings from %s", LABEL2ID_PATH)
        with open(LABEL2ID_PATH, 'r', encoding='utf-8') as f:
            label2id = json.load(f)
        id2label = {int(k): v for k, v in model.config.id2label.items()}
        logger.info("Label mappings loaded successfully")

        return model, tokenizer, label2id, id2label
    except Exception as e:
        logger.error("Error loading model/tokenizer/label mappings: %s", str(e))
        raise

def process_entities(entities: List[Dict], input_text: str) -> Tuple[List[Tuple], List[Dict]]:
    """
    Process raw NER pipeline output to prepare data for Gradio.
    """
    processed_entities = []
    highlighted_text = []
    last_end = 0

    for entity in entities:
        entity_type = entity['entity_group']
        score = float(entity['score'])
        word = entity['word']
        start = entity['start']
        end = entity['end']

        # Add non-entity text before the current entity
        if start > last_end:
            highlighted_text.append((input_text[last_end:start], None))

        # Add the entity
        highlighted_text.append((input_text[start:end], entity_type))

        processed_entities.append({
            "text": word,
            "type": entity_type,
            "score": score
        })

        last_end = end

    # Add any remaining non-entity text
    if last_end < len(input_text):
        highlighted_text.append((input_text[last_end:], None))

    return highlighted_text, processed_entities

def perform_ner(text: str) -> Tuple[gr.HighlightedText, gr.DataFrame]:
    """
    Perform NER on input text and return highlighted text and entity table.
    """
    try:
        if not text.strip():
            logger.warning("Empty input text provided")
            return (
                [("", None)],
                pd.DataFrame({"Entity": [], "Type": [], "Confidence": []})
            )

        # Load model, tokenizer, and label mappings
        model, tokenizer, label2id, id2label = load_model_and_tokenizer()

        # Initialize NER pipeline
        logger.info("Initializing NER pipeline")
        device = 0 if torch.cuda.is_available() else -1
        if device == 0:
            logger.info("Device set to use cuda:0")
        else:
            logger.info("Device set to use CPU")
        ner_pipeline = pipeline(
            "ner",
            model=model,
            tokenizer=tokenizer,
            aggregation_strategy="simple",
            device=device
        )

        # Run NER pipeline
        logger.info("Running NER pipeline on input text")
        entities = ner_pipeline(text)
        logger.info("Raw pipeline output: %s", entities)

        if not entities:
            logger.warning("No entities detected in the input text")
            return (
                [(text, None)],
                pd.DataFrame({"Entity": [], "Type": [], "Confidence": []})
            )

        # Process entities
        highlighted_text, processed_entities = process_entities(entities, text)

        # Prepare DataFrame
        df_data = [{
            "Entity": entity['text'],
            "Type": entity['type'],
            "Confidence": f"{entity['score']:.4f}"
        } for entity in processed_entities]

        df = pd.DataFrame(df_data)

        logger.info("Processed %d entities", len(processed_entities))
        return (
            highlighted_text,
            df
        )

    except Exception as e:
        logger.error("Error during NER processing: %s", str(e))
        return (
            [(text, None)],
            pd.DataFrame({"Entity": [], "Type": [], "Confidence": []})
        )

# Define default example text
DEFAULT_TEXT = """
នៅខេត្តសៀមរាប៖​ ជំងឺរលាកសួតបានរកឃើញនៅថ្ងៃទី៥ ខែមករា ឆ្នាំ២០២៥ ដោយមានអ្នកជំងឺ៥៦នាក់ ប្រើថ្នាំអាស្ពីរីន និងមានរោគសញ្ញាក្អក និងគ្រុនក្តៅ បណ្តាលមកពីមេរោគអេសអិន១១។
"""

# Create Gradio interface
with gr.Blocks(title="Khmer NER Interface") as demo:
    gr.Markdown("# Khmer Named Entity Recognition (NER)")
    gr.Markdown("Enter Khmer text to identify entities such as Disease, Location, Date, etc.")
    
    with gr.Row():
        input_text = gr.Textbox(
            label="Input Text",
            value=DEFAULT_TEXT,
            lines=5,
            placeholder="Enter Khmer text here..."
        )
    
    submit_button = gr.Button("Perform NER")
    
    with gr.Row():
        highlighted_output = gr.HighlightedText(
            label="Highlighted Entities",
            show_legend=True,
            color_map=ENTITY_COLORS
        )
        entity_table = gr.DataFrame(
            label="Detected Entities",
            headers=["Entity", "Type", "Confidence"]
        )
    
    submit_button.click(
        fn=perform_ner,
        inputs=input_text,
        outputs=[highlighted_output, entity_table]
    )

if __name__ == "__main__":
    try:
        logger.info("Launching Gradio interface")
        demo.launch(share=True)
    except Exception as e:
        logger.error("Failed to launch Gradio interface: %s", str(e))
        print(f"Error launching interface: {str(e)}")

  from .autonotebook import tqdm as notebook_tqdm
2025-08-07 10:56:40,780 - INFO - Launching Gradio interface
2025-08-07 10:56:40,830 - INFO - HTTP Request: GET http://127.0.0.1:7860/gradio_api/startup-events "HTTP/1.1 200 OK"
2025-08-07 10:56:40,839 - INFO - HTTP Request: HEAD http://127.0.0.1:7860/ "HTTP/1.1 200 OK"


* Running on local URL:  http://127.0.0.1:7860


2025-08-07 10:56:41,662 - INFO - HTTP Request: GET https://api.gradio.app/v3/tunnel-request "HTTP/1.1 200 OK"
2025-08-07 10:56:42,153 - INFO - HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"


* Running on public URL: https://d0dcf751121b884bdb.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


2025-08-07 10:56:44,089 - INFO - HTTP Request: HEAD https://d0dcf751121b884bdb.gradio.live "HTTP/1.1 200 OK"


In [20]:
import os
import gradio as gr
import torch
from transformers import BertForTokenClassification, BertTokenizerFast, XLMRobertaForTokenClassification, XLMRobertaTokenizerFast, pipeline, MBartTokenizerFast, MBartModel
import json
import logging
import pandas as pd
import pickle
import numpy as np
import re
from typing import List, Dict, Tuple, Union
from khmernltk import word_tokenize

import torch.nn as nn

try:
    from transformers import MBartForTokenClassification
except ImportError:
    from transformers import MBartModel
    import torch.nn as nn

    class MBartForTokenClassificationCustom(MBartModel):
        def __init__(self, config):
            super().__init__(config)
            self.num_labels = config.num_labels
            self.classifier = nn.Linear(config.d_model, config.num_labels)
        
        def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
            outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
            sequence_output = outputs[0]
            logits = self.classifier(sequence_output)
            
            loss = None
            if labels is not None:
                loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            
            return {"logits": logits, "loss": loss} if loss is not None else {"logits": logits}

# Add the PrahokBARTForNER class from your training code
class PrahokBARTForNER(torch.nn.Module):
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.mbart = MBartModel.from_pretrained(model_name)
        self.dropout = torch.nn.Dropout(0.1)
        self.classifier = torch.nn.Linear(self.mbart.config.d_model, num_labels)
        self.num_labels = num_labels

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.mbart(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))

        return loss, logits

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('ner_interface.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Define paths
XLM_FINAL_MODEL_PATH = "/home/guest/Public/KHEED/KHEED_Data_Collection/Evaluation/xlmr_kh_ner_model/final_model"
XLM_MODEL_PATH = "/home/guest/Public/KHEED/KHEED_Data_Collection/xlmr_ner_model/final_model"
BERT_MODEL_PATH = "/home/guest/Public/KHEED/KHEED_Data_Collection/Evaluation/Models/bert_khmer_ner_model/final_model"
PRAHOKBART_MODEL_PATH = "/home/guest/Public/KHEED/KHEED_Data_Collection/Evaluation/prahokbart_ner_model"
LABEL2ID_PATH = "/home/guest/Public/KHEED/KHEED_Data_Collection/xlmr_ner_model/label2id.json"
BILSTM_MODEL_PATH = "/home/guest/Public/KHEED/KHEED_Data_Collection/Evaluation/bilstm_crf_ner_model/khmer_ner_best.pt"
BILSTM_VOCAB_PATH = "/home/guest/Public/KHEED/KHEED_Data_Collection/Evaluation/bilstm_crf_ner_model/vocabularies.pkl"

# Define colors for entity types (unchanged)
ENTITY_COLORS = {
    "Disease": "red",
    "Organization": "blue",
    "Location": "green",
    "Date": "yellow",
    "HumanCount": "purple",
    "Medication": "orange",
    "Symptom": "pink",
    "Pathogen": "cyan",
}


# Map BiLSTM-CRF BIO tags to XLM-RoBERTa entity types
BIO_TO_ENTITY = {
    "B-Disease": "Disease",
    "I-Disease": "Disease",
    "B-Organization": "Organization",
    "I-Organization": "Organization",
    "B-Location": "Location",
    "I-Location": "Location",
    "B-Date": "Date",
    "I-Date": "Date",
    "B-HumanCount": "HumanCount",
    "I-HumanCount": "HumanCount",
    "B-Medication": "Medication",
    "I-Medication": "Medication",
    "B-Symptom": "Symptom",
    "I-Symptom": "Symptom",
    "B-Pathogen": "Pathogen",
    "I-Pathogen": "Pathogen",
    "O": None
}

# BiLSTM-CRF Model Definition
class BiLSTM_CRF(torch.nn.Module):
    def __init__(self, vocab_size, tagset_size, embedding_dim, hidden_dim, dropout=0.5, 
                 tag_to_idx=None, pretrained_embeddings=None):
        super(BiLSTM_CRF, self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        if pretrained_embeddings is not None:
            self.embedding.weight = torch.nn.Parameter(torch.tensor(pretrained_embeddings, dtype=torch.float32))
            self.embedding.weight.requires_grad = False
        self.lstm = torch.nn.LSTM(embedding_dim, hidden_dim // 2, num_layers=1, bidirectional=True, batch_first=True)
        self.dropout = torch.nn.Dropout(dropout)
        self.hidden2tag = torch.nn.Linear(hidden_dim, tagset_size)
        self.transitions = torch.nn.Parameter(torch.randn(tagset_size, tagset_size))
        self.tagset_size = tagset_size
        self.tag_to_idx = tag_to_idx

    def forward(self, sentences, mask):
        emissions = self._get_lstm_features(sentences, mask)
        return self._viterbi_decode(emissions, mask)

    def _get_lstm_features(self, sentences, mask):
        embeds = self.embedding(sentences)
        packed_embeds = torch.nn.utils.rnn.pack_padded_sequence(embeds, mask.sum(1).cpu(), batch_first=True, enforce_sorted=False)
        packed_lstm_out, _ = self.lstm(packed_embeds)
        lstm_out, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_lstm_out, batch_first=True, total_length=embeds.size(1))
        lstm_out = self.dropout(lstm_out)
        emissions = self.hidden2tag(lstm_out)
        return emissions

    def _viterbi_decode(self, emissions, mask):
        tag_seq = emissions.argmax(-1).cpu().numpy()
        return None, tag_seq

def load_bilstm_model_and_vocab() -> Tuple[BiLSTM_CRF, Dict[str, int], Dict[str, int], Dict[int, str]]:
    """Load the BiLSTM-CRF model and vocabularies."""
    try:
        logger.info("Loading vocabularies from %s", BILSTM_VOCAB_PATH)
        with open(BILSTM_VOCAB_PATH, 'rb') as f:
            vocab_data = pickle.load(f)
            word_to_idx = vocab_data['word_to_idx']
            tag_to_idx = vocab_data['tag_to_idx']
            idx_to_tag = vocab_data['idx_to_tag']
        
        model = BiLSTM_CRF(
            vocab_size=len(word_to_idx),
            tagset_size=len(tag_to_idx),
            embedding_dim=300,
            hidden_dim=256,
            dropout=0.5,
            tag_to_idx=tag_to_idx,
            pretrained_embeddings=None
        )
        
        logger.info("Loading BiLSTM-CRF model from %s", BILSTM_MODEL_PATH)
        checkpoint = torch.load(BILSTM_MODEL_PATH, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        
        logger.info("BiLSTM-CRF model and vocabularies loaded successfully")
        return model, word_to_idx, tag_to_idx, idx_to_tag
    except Exception as e:
        logger.error("Error loading BiLSTM-CRF model/vocab: %s", str(e))
        raise

def load_prahokbart_model(model_path: str) -> Tuple[PrahokBARTForNER, Dict[str, int], Dict[int, str], Dict[str, int]]:
    """Load PrahokBART model and vocabularies."""
    try:
        # Load vocabularies
        with open(os.path.join(model_path, "tag2idx.json"), 'r', encoding='utf-8') as f:
            tag2idx = json.load(f)
        
        with open(os.path.join(model_path, "word2idx.json"), 'r', encoding='utf-8') as f:
            word2idx = json.load(f)
        
        idx2tag = {int(idx): tag for tag, idx in tag2idx.items()}
        
        # Initialize model
        model = PrahokBARTForNER(
            model_name="nict-astrec-att/prahokbart_base",
            num_labels=len(tag2idx)
        )
        
        # Load model weights
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model_weights_path = os.path.join(model_path, "prahokbart_ner.pt")
        model.load_state_dict(torch.load(model_weights_path, map_location=device))
        model.eval()
        
        logger.info("PrahokBART model loaded successfully")
        return model, tag2idx, idx2tag, word2idx
    except Exception as e:
        logger.error("Error loading PrahokBART model: %s", str(e))
        raise

def tokenize_khmer_text(text: str) -> List[str]:
    """Simple whitespace-based tokenizer for Khmer text."""
    text = re.sub(r'\s+', ' ', text.strip())
    tokens = text.split()
    return tokens

def preprocess_bilstm_input(text: str, word_to_idx: Dict[str, int], max_len: int = 128) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
    """Preprocess input text for BiLSTM-CRF model."""
    tokens = tokenize_khmer_text(text)
    word_indices = [word_to_idx.get(token, word_to_idx['<UNK>']) for token in tokens]
    
    seq_len = min(len(word_indices), max_len)
    attention_mask = [True] * seq_len + [False] * (max_len - seq_len)
    if len(word_indices) < max_len:
        word_indices += [word_to_idx['<PAD>']] * (max_len - len(word_indices))
    else:
        word_indices = word_indices[:max_len]
    
    word_tensor = torch.tensor([word_indices], dtype=torch.long)
    mask_tensor = torch.tensor([attention_mask], dtype=torch.bool)
    
    return word_tensor, mask_tensor, tokens[:seq_len]

def preprocess_prahokbart_input(text: str, word2idx: Dict[str, int], max_len: int = 128) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
    """Preprocess input text for PrahokBART model."""
    # Tokenize with khmernltk to match training
    khmer_tokens = word_tokenize(text)
    
    # Convert tokens to input IDs using word2idx
    input_ids = []
    for token in khmer_tokens:
        input_ids.append(word2idx.get(token, word2idx.get('<unk>', 3)))

    # Truncate or pad to max_len
    if len(input_ids) > max_len:
        input_ids = input_ids[:max_len]
        tokens = khmer_tokens[:max_len]
    else:
        input_ids += [word2idx.get('<pad>', 1)] * (max_len - len(input_ids))
        tokens = khmer_tokens

    # Create attention mask
    attention_mask = [1 if idx != word2idx.get('<pad>', 1) else 0 for idx in input_ids]

    input_tensor = torch.tensor([input_ids], dtype=torch.long)
    mask_tensor = torch.tensor([attention_mask], dtype=torch.long)
    
    return input_tensor, mask_tensor, tokens

def fix_bilstm_tags(pred_tags: List[str]) -> List[str]:
    """Fix invalid BIO sequences by ensuring I- tags follow corresponding B- tags."""
    fixed_tags = []
    for i, tag in enumerate(pred_tags):
        if tag.startswith('I-'):
            # Check if previous tag is B- or I- of the same entity type
            entity_type = tag[2:]  # e.g., 'Disease' from 'I-Disease'
            if i == 0 or not (pred_tags[i-1] == f'B-{entity_type}' or pred_tags[i-1] == f'I-{entity_type}'):
                logger.warning(f"Invalid I- tag found at position {i}: {tag}. Converting to B-{entity_type}.")
                fixed_tags.append(f'B-{entity_type}')
            else:
                fixed_tags.append(tag)
        else:
            fixed_tags.append(tag)
    return fixed_tags

def process_bilstm_entities(tokens: List[str], pred_tags: List[str], input_text: str) -> Tuple[List[Tuple], List[Dict]]:
    """Process BiLSTM-CRF predictions for Gradio output, mapping BIO tags to entity types."""
    # Fix invalid BIO tags
    pred_tags = fix_bilstm_tags(pred_tags)
    
    highlighted_text = []
    processed_entities = []
    current_pos = 0
    current_entity = []
    current_entity_type = None

    for token, tag in zip(tokens, pred_tags):
        # Map BIO tag to entity type
        entity_type = BIO_TO_ENTITY.get(tag, None)
        
        # Find token in input_text
        token_start = input_text.find(token, current_pos)
        if token_start == -1:
            logger.warning(f"Token '{token}' not found in input text from position {current_pos}")
            continue
        token_end = token_start + len(token)
        
        # Add non-entity text before token
        if token_start > current_pos:
            highlighted_text.append((input_text[current_pos:token_start], None))
        
        # Handle entity aggregation
        if tag.startswith('B-'):
            if current_entity:  # Save previous entity
                entity_text = ''.join(current_entity)
                processed_entities.append({
                    "text": entity_text,
                    "type": current_entity_type,
                    "score": 1.0
                })
                highlighted_text.append((entity_text, current_entity_type))
            current_entity = [token]
            current_entity_type = entity_type
        elif tag.startswith('I-') and entity_type == current_entity_type:
            current_entity.append(token)
        else:  # O or new entity
            if current_entity:
                entity_text = ''.join(current_entity)
                processed_entities.append({
                    "text": entity_text,
                    "type": current_entity_type,
                    "score": 1.0
                })
                highlighted_text.append((entity_text, current_entity_type))
                current_entity = []
                current_entity_type = None
            if tag == 'O':
                highlighted_text.append((token, None))
        
        current_pos = token_end
    
    # Save last entity if exists
    if current_entity:
        entity_text = ''.join(current_entity)
        processed_entities.append({
            "text": entity_text,
            "type": current_entity_type,
            "score": 1.0
        })
        highlighted_text.append((entity_text, current_entity_type))
    
    # Add remaining text
    if current_pos < len(input_text):
        highlighted_text.append((input_text[current_pos:], None))
    
    return highlighted_text, processed_entities

def process_prahokbart_entities(tokens: List[str], pred_tags: List[str], scores: List[float], input_text: str) -> Tuple[List[Tuple], List[Dict]]:
    """Process PrahokBART predictions for Gradio output."""
    # Fix invalid BIO tags
    pred_tags = fix_bilstm_tags(pred_tags)
    
    highlighted_text = []
    processed_entities = []
    current_pos = 0
    current_entity = []
    current_entity_type = None
    current_scores = []

    for token, tag, score in zip(tokens, pred_tags, scores):
        # Map BIO tag to entity type
        entity_type = BIO_TO_ENTITY.get(tag, None)
        
        # Find token in input_text
        token_start = input_text.find(token, current_pos)
        if token_start == -1:
            # Try to find token with spaces
            spaced_token = f" {token} "
            token_start = input_text.find(spaced_token, current_pos)
            if token_start != -1:
                token_start += 1  # Skip the leading space
                token_end = token_start + len(token)
            else:
                logger.warning(f"Token '{token}' not found in input text from position {current_pos}")
                continue
        else:
            token_end = token_start + len(token)
        
        # Add non-entity text before token
        if token_start > current_pos:
            highlighted_text.append((input_text[current_pos:token_start], None))
        
        # Handle entity aggregation
        if tag.startswith('B-'):
            if current_entity:  # Save previous entity
                entity_text = ''.join(current_entity)
                avg_score = np.mean(current_scores) if current_scores else 0.0
                processed_entities.append({
                    "text": entity_text,
                    "type": current_entity_type,
                    "score": avg_score
                })
                highlighted_text.append((entity_text, current_entity_type))
            current_entity = [token]
            current_entity_type = entity_type
            current_scores = [score]
        elif tag.startswith('I-') and entity_type == current_entity_type:
            current_entity.append(token)
            current_scores.append(score)
        else:  # O or new entity
            if current_entity:
                entity_text = ''.join(current_entity)
                avg_score = np.mean(current_scores) if current_scores else 0.0
                processed_entities.append({
                    "text": entity_text,
                    "type": current_entity_type,
                    "score": avg_score
                })
                highlighted_text.append((entity_text, current_entity_type))
                current_entity = []
                current_entity_type = None
                current_scores = []
            if tag == 'O':
                highlighted_text.append((token, None))
        
        current_pos = token_end
    
    # Save last entity if exists
    if current_entity:
        entity_text = ''.join(current_entity)
        avg_score = np.mean(current_scores) if current_scores else 0.0
        processed_entities.append({
            "text": entity_text,
            "type": current_entity_type,
            "score": avg_score
        })
        highlighted_text.append((entity_text, current_entity_type))
    
    # Add remaining text
    if current_pos < len(input_text):
        highlighted_text.append((input_text[current_pos:], None))
    
    return highlighted_text, processed_entities

def perform_ner(text: str, model_type: str) -> Tuple[gr.HighlightedText, gr.DataFrame]:
    """
    Perform NER on input text using the selected model type.
    """
    try:
        if not text.strip():
            logger.warning("Empty input text provided")
            return (
                [("", None)],
                pd.DataFrame({"Entity": [], "Type": [], "Confidence": []})
            )
        
        original_text = text
        
        if model_type == "PrahokBART":
            logger.info("Using PrahokBART model")
            # Load model and vocabularies
            model, tag2idx, idx2tag, word2idx = load_prahokbart_model(PRAHOKBART_MODEL_PATH)
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            model.to(device)
            
            # Preprocess input
            input_tensor, mask_tensor, tokens = preprocess_prahokbart_input(text, word2idx)
            input_tensor, mask_tensor = input_tensor.to(device), mask_tensor.to(device)
            
            # Run inference
            with torch.no_grad():
                _, logits = model(input_tensor, mask_tensor)
            
            # Get predictions
            pred_tag_ids = torch.argmax(logits, dim=-1)
            scores = torch.softmax(logits, dim=-1).max(dim=-1)[0]
            
            # Convert predictions to tags
            pred_tags = []
            confidence_scores = []
            mask = mask_tensor[0].cpu().numpy()
            
            for i, (tag_id, score, m) in enumerate(zip(pred_tag_ids[0], scores[0], mask)):
                if m and i < len(tokens):  # Only process non-padded tokens
                    tag_idx = int(tag_id.item())
                    if 0 <= tag_idx < len(idx2tag):
                        pred_tags.append(idx2tag[tag_idx])
                    else:
                        pred_tags.append('O')
                    confidence_scores.append(float(score.item()))
            
            # Ensure lengths match
            min_len = min(len(tokens), len(pred_tags), len(confidence_scores))
            tokens = tokens[:min_len]
            pred_tags = pred_tags[:min_len]
            confidence_scores = confidence_scores[:min_len]
            
            # Process entities
            highlighted_text, processed_entities = process_prahokbart_entities(tokens, pred_tags, confidence_scores, text)
            
            if not processed_entities:
                logger.warning("No entities detected by PrahokBART")
                return (
                    [(text, None)],
                    pd.DataFrame({"Entity": [], "Type": [], "Confidence": []})
                )

        elif model_type in ["XLM-RoBERTa", "XLM-RoBERTa-Khmer-Small", "BERT-Khmer"]:
            logger.info("Using %s model", model_type)
            if model_type == "PrahokBART":
                # Tokenize with khmernltk to match training
                tokens = word_tokenize(text)
                text = " ".join(tokens)
            
            # Load model and tokenizer
            model_path = {
                "XLM-RoBERTa": XLM_MODEL_PATH,
                "XLM-RoBERTa-Khmer-Small": XLM_FINAL_MODEL_PATH,
                "BERT-Khmer": BERT_MODEL_PATH,
            }[model_type]
            model, tokenizer, label2id, id2label = load_model_and_tokenizer(model_path, model_type)
            
            # Initialize NER pipeline
            device = 0 if torch.cuda.is_available() else -1
            logger.info("Device set to %s", "cuda:0" if device == 0 else "CPU")
            ner_pipeline = pipeline(
                "ner",
                model=model,
                tokenizer=tokenizer,
                aggregation_strategy="simple",
                device=device
            )
            
            # Run NER pipeline
            logger.info("Running %s NER pipeline", model_type)
            entities = ner_pipeline(text)
            
            if not entities:
                logger.warning("No entities detected by %s", model_type)
                return (
                    [(text, None)],
                    pd.DataFrame({"Entity": [], "Type": [], "Confidence": []})
                )
            
            # Process entities
            highlighted_text, processed_entities = process_entities(entities, text)
        
        else:  # BiLSTM-CRF
            logger.info("Using BiLSTM-CRF model")
            # Load BiLSTM-CRF model and vocabularies
            model, word_to_idx, tag_to_idx, idx_to_tag = load_bilstm_model_and_vocab()
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            model.to(device)
            
            # Preprocess input
            word_tensor, mask_tensor, tokens = preprocess_bilstm_input(text, word_to_idx)
            word_tensor, mask_tensor = word_tensor.to(device), mask_tensor.to(device)
            
            # Run inference
            with torch.no_grad():
                _, pred_tags = model(word_tensor, mask_tensor)
            
            # Convert predictions to tags
            pred_tags = pred_tags[0]
            mask = mask_tensor[0].cpu().numpy()
            pred_tags = [idx_to_tag.get(int(tag), 'O') for tag, m in zip(pred_tags, mask) if m]
            
            if len(pred_tags) != len(tokens):
                logger.warning(f"Length mismatch: {len(pred_tags)} tags vs {len(tokens)} tokens")
                prev_len = len(pred_tags)
                pred_tags = pred_tags[:len(tokens)]
                logger.warning(f"Truncating from {prev_len} to {len(tokens)}")
            
            # Process entities
            highlighted_text, processed_entities = process_bilstm_entities(tokens, pred_tags, text)
            
            if not processed_entities:
                logger.warning("No entities detected by BiLSTM-CRF")
                return (
                    [(text, None)],
                    pd.DataFrame({"Entity": [], "Type": [], "Confidence": []})
                )

        # Prepare DataFrame
        df_data = [{
            "Entity": entity['text'],
            "Type": entity['type'],
            "Confidence": f"{entity['score']:.4f}"
        } for entity in processed_entities]
        
        df = pd.DataFrame(df_data)
        
        logger.info("Processed %d entities", len(processed_entities))
        return (
            highlighted_text,
            df
        )

    except Exception as e:
        logger.error("Error during NER processing: %s", str(e))
        return (
            [(text, None)],
            pd.DataFrame({"Entity": [], "Type": [], "Confidence": []})
        )

def load_model_and_tokenizer(model_path: str, model_type: str) -> Tuple[Union[XLMRobertaForTokenClassification, BertForTokenClassification], Union[XLMRobertaTokenizerFast, BertTokenizerFast, MBartTokenizerFast], Dict, Dict]:
    """
    Load the fine-tuned model, tokenizer, and label mappings based on model type.
    """
    try:
        if model_type in ["XLM-RoBERTa", "XLM-RoBERTa-Khmer-Small"]:
            from transformers import XLMRobertaForTokenClassification, XLMRobertaTokenizerFast
            logger.info("Loading XLM-RoBERTa tokenizer from %s", model_path)
            tokenizer = XLMRobertaTokenizerFast.from_pretrained(model_path)
            logger.info("Loading XLM-RoBERTa model from %s", model_path)
            model = XLMRobertaForTokenClassification.from_pretrained(model_path)
        elif model_type == "BERT-Khmer":
            from transformers import BertForTokenClassification, BertTokenizerFast
            logger.info("Loading BERT-Khmer tokenizer from %s", model_path)
            tokenizer = BertTokenizerFast.from_pretrained(model_path)
            logger.info("Loading BERT-Khmer model from %s", model_path)
            model = BertForTokenClassification.from_pretrained(model_path)
        else:
            raise ValueError(f"Unsupported model type: {model_type}")

        logger.info("Tokenizer and model loaded successfully")

        # Load label mappings
        label2id_path = os.path.join(model_path, "label2id.json") if model_type == "BERT-Khmer" else LABEL2ID_PATH
        logger.info("Loading label mappings from %s", label2id_path)
        with open(label2id_path, 'r', encoding='utf-8') as f:
            label2id = json.load(f)
        id2label = {int(k): v for k, v in model.config.id2label.items()}
        logger.info("Label mappings loaded successfully")

        return model, tokenizer, label2id, id2label
    except Exception as e:
        logger.error("Error loading model/tokenizer/label mappings: %s", str(e))
        raise

def process_entities(entities: List[Dict], input_text: str) -> Tuple[List[Tuple], List[Dict]]:
    """
    Process raw XLM-RoBERTa NER pipeline output for Gradio.
    """
    processed_entities = []
    highlighted_text = []
    last_end = 0

    for entity in entities:
        entity_type = entity['entity_group']
        score = float(entity['score'])
        word = entity['word']
        start = entity['start']
        end = entity['end']

        if start > last_end:
            highlighted_text.append((input_text[last_end:start], None))

        highlighted_text.append((input_text[start:end], entity_type))

        processed_entities.append({
            "text": word,
            "type": entity_type,
            "score": score
        })

        last_end = end

    if last_end < len(input_text):
        highlighted_text.append((input_text[last_end:], None))

    return highlighted_text, processed_entities

# Define default example text
DEFAULT_TEXT = """
នៅខេត្តសៀមរាប៖ ជំងឺរលាកសួតបានរកឃើញនៅថ្ងៃទី៥ ខែមករា ឆ្នាំ២០២៥ ដោយមានអ្នកជំងឺ៥៦នាក់ ប្រើថ្នាំអាស្ពីរីន និងមានរោគសញ្ញាក្អក និងគ្រុនក្តៅ បណ្តាលមកពីមេរោគអេសអិន១១។
"""

# Create Gradio interface
with gr.Blocks(title="Khmer NER Interface") as demo:
    gr.Markdown("# Khmer Named Entity Recognition (NER)")
    gr.Markdown("Enter Khmer text and select a model to identify entities such as Disease, Location, Date, etc.")
    
    with gr.Row():
        input_text = gr.Textbox(
            label="Input Text",
            value=DEFAULT_TEXT,
            lines=5,
            placeholder="Enter Khmer text here..."
        )
        model_choice = gr.Dropdown(
            label="Model Type",
            choices=["XLM-RoBERTa", "XLM-RoBERTa-Khmer-Small", "BERT-Khmer", "PrahokBART", "BiLSTM-CRF"],
            value="XLM-RoBERTa"
        )
    
    submit_button = gr.Button("Perform NER")
    
    with gr.Row():
        highlighted_output = gr.HighlightedText(
            label="Highlighted Entities",
            show_legend=True,
            color_map=ENTITY_COLORS
        )
        entity_table = gr.DataFrame(
            label="Detected Entities",
            headers=["Entity", "Type", "Confidence"]
        )
    
    submit_button.click(
        fn=perform_ner,
        inputs=[input_text, model_choice],
        outputs=[highlighted_output, entity_table]
    )

if __name__ == "__main__":
    try:
        logger.info("Launching Gradio interface")
        demo.launch(share=True)  # Set share=False for local testing
    except Exception as e:
        logger.error("Failed to launch Gradio interface: %s", str(e))
        print(f"Error launching interface: {str(e)}")

2025-08-07 13:14:20,843 - INFO - Launching Gradio interface
2025-08-07 13:14:20,884 - INFO - HTTP Request: GET http://127.0.0.1:7875/gradio_api/startup-events "HTTP/1.1 200 OK"
2025-08-07 13:14:20,887 - INFO - HTTP Request: HEAD http://127.0.0.1:7875/ "HTTP/1.1 200 OK"


* Running on local URL:  http://127.0.0.1:7875


2025-08-07 13:14:21,894 - INFO - HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
2025-08-07 13:14:21,907 - INFO - HTTP Request: GET https://api.gradio.app/v3/tunnel-request "HTTP/1.1 200 OK"


* Running on public URL: https://c0eb0ee942141099a9.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


2025-08-07 13:14:25,869 - INFO - HTTP Request: HEAD https://c0eb0ee942141099a9.gradio.live "HTTP/1.1 200 OK"


2025-08-07 13:14:42,521 - INFO - Using PrahokBART model
Some weights of MBartModel were not initialized from the model checkpoint at nict-astrec-att/prahokbart_base and are newly initialized: ['decoder.embed_tokens.weight', 'encoder.embed_tokens.weight', 'shared.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2025-08-07 13:14:44,581 - INFO - PrahokBART model loaded successfully
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
2025-08-07 13:14:44,608 - INFO - Processed 10 entities
2025-08-07 13:15:12,200 - INFO - Using PrahokBART model
Some weights of MBartModel were not initialized from the model checkpoint at nict-astrec-att/prahokbart_base and are newly initialized: ['decoder.embed_tokens.weight', 'encoder.embed_tokens.weight', 