In [4]:
# Cell 1: Setup and Imports
import json
import sys
from pathlib import Path
from collections import defaultdict
import pandas as pd
from typing import List, Dict, Any
# Add src to path
sys.path.append('src')

from seqeval.metrics import (
    precision_score, recall_score, f1_score,
    classification_report, accuracy_score
)

from inference import BusNERInference
from constants.entity_labels import ENTITY_LABELS

print("‚úÖ Imports successful!")

‚úÖ Imports successful!


In [2]:
# Cell 2: Configuration (CORRECTED)
from datasets import Dataset

MODEL_PATH = "models/bus_ner_transformer_v5"
USE_ONNX = False
TEST_DATA_PATH = "data/training_data_bio.json"
TEST_SPLIT = 0.1  # Same as training
SEED = 42  # Same seed as training

# Load all data
with open(TEST_DATA_PATH, 'r', encoding='utf-8') as f:
    all_data = json.load(f)

# Use SAME split as training (important!)
# Convert to HF Dataset format
dataset_dict = {
    "tokens": [sample["tokens"] for sample in all_data],
    "ner_tags": [sample["ner_tags"] for sample in all_data],
}

dataset = Dataset.from_dict(dataset_dict)
split = dataset.train_test_split(test_size=TEST_SPLIT, seed=SEED)

# Convert back to our format for evaluation
test_indices = split['test']['__index__'] if hasattr(split['test'], '__index__') else None

# Get test samples using indices
if test_indices:
    test_data = [all_data[i] for i in test_indices]
else:
    # Fallback: use the split directly
    test_data = [all_data[i] for i in range(len(split['test']))]

print(f"üìä Total samples: {len(all_data)}")
print(f"üß™ Test samples (10%): {len(test_data)}")
print(f"‚úÖ Using same split as training (seed=42)")

üìä Total samples: 200000
üß™ Test samples (10%): 20000
‚úÖ Using same split as training (seed=42)


In [5]:
# Cell 3: Helper Functions
def extract_entities_from_spans(text: str, entities: List[List[int]]) -> Dict[str, List[str]]:
    """Convert span-based entities to label-based dictionary."""
    result = {label: [] for label in ENTITY_LABELS}
    
    for start, end, label in entities:
        entity_text = text[start:end].strip()
        if entity_text and label in result:
            result[label].append(entity_text)
    
    return result

def convert_to_bio_format(text: str, entities: Dict[str, List[str]]) -> List[str]:
    """Convert entity dictionary to BIO tag sequence."""
    words = text.split()
    tags = ["O"] * len(words)
    
    # Build character-to-word mapping
    char_to_word = {}
    char_pos = 0
    for word_idx, word in enumerate(words):
        for i in range(len(word)):
            char_to_word[char_pos + i] = word_idx
        char_pos += len(word) + 1
    
    # Mark entities
    for label, values in entities.items():
        for value in values:
            start_idx = text.find(value)
            if start_idx != -1:
                end_idx = start_idx + len(value)
                word_indices = set()
                for char_idx in range(start_idx, end_idx):
                    if char_idx in char_to_word:
                        word_indices.add(char_to_word[char_idx])
                
                if word_indices:
                    sorted_indices = sorted(word_indices)
                    for i, word_idx in enumerate(sorted_indices):
                        if i == 0:
                            tags[word_idx] = f"B-{label}"
                        else:
                            tags[word_idx] = f"I-{label}"
    
    return tags

print("‚úÖ Helper functions defined!")

‚úÖ Helper functions defined!


In [6]:
# Cell 4: Load Model
print("üîÑ Loading model...")
ner = BusNERInference(MODEL_PATH, use_onnx=USE_ONNX)
print("‚úÖ Model loaded!")

üîÑ Loading model...
Loaded PyTorch model from: models/bus_ner_transformer_v5
Device: cpu
Number of labels: 45
‚úÖ Model loaded!


In [8]:
# Cell 5: Run Evaluation (SIMPLER - uses BIO tags directly)
print("üîÑ Running evaluation...")
print("=" * 60)

# Load label mappings
with open('data/id2label.json', 'r') as f:
    id2label_dict = json.load(f)
    id2label = {int(k): v for k, v in id2label_dict.items()}

true_labels_list = []
pred_labels_list = []
entity_metrics = defaultdict(lambda: {"tp": 0, "fp": 0, "fn": 0})

for i, sample in enumerate(test_data):
    if (i + 1) % 50 == 0:
        print(f"  Processed {i+1}/{len(test_data)} samples...")
    
    # Reconstruct text from tokens
    tokens = sample["tokens"]
    text = " ".join(tokens)
    
    # Get ground truth BIO tags (convert IDs to labels)
    true_ner_tag_ids = sample["ner_tags"]
    true_bio = [id2label.get(tag_id, "O") for tag_id in true_ner_tag_ids]
    
    # Get predictions from model
    pred_entities = ner.extract(text)
    
    # Convert predictions to BIO format
    # Simple approach: match tokens to predicted entities
    pred_bio = ["O"] * len(tokens)
    
    for label, values in pred_entities.items():
        for value in values:
            value_tokens = value.split()
            # Try to find this entity in the token sequence
            for j in range(len(tokens) - len(value_tokens) + 1):
                if tokens[j:j+len(value_tokens)] == value_tokens:
                    # Mark as entity
                    for k, token in enumerate(value_tokens):
                        if k == 0:
                            pred_bio[j + k] = f"B-{label}"
                        else:
                            pred_bio[j + k] = f"I-{label}"
                    break
    
    true_labels_list.append(true_bio)
    pred_labels_list.append(pred_bio)
    
    # Extract entities from BIO for per-entity metrics
    def extract_from_bio(tokens, bio_tags):
        entities = {label: [] for label in ENTITY_LABELS}
        current_entity = []
        current_label = None
        
        for token, tag in zip(tokens, bio_tags):
            if tag.startswith("B-"):
                if current_label and current_entity:
                    entities[current_label].append(" ".join(current_entity))
                current_label = tag[2:]
                current_entity = [token]
            elif tag.startswith("I-"):
                label = tag[2:]
                if label == current_label:
                    current_entity.append(token)
                else:
                    if current_label and current_entity:
                        entities[current_label].append(" ".join(current_entity))
                    current_label = label
                    current_entity = [token]
            else:
                if current_label and current_entity:
                    entities[current_label].append(" ".join(current_entity))
                current_label = None
                current_entity = []
        
        if current_label and current_entity:
            entities[current_label].append(" ".join(current_entity))
        
        return entities
    
    true_entities = extract_from_bio(tokens, true_bio)
    pred_entities_dict = extract_from_bio(tokens, pred_bio)
    
    # Per-entity metrics
    for label in ENTITY_LABELS:
        true_set = set(true_entities.get(label, []))
        pred_set = set(pred_entities_dict.get(label, []))
        
        entity_metrics[label]["tp"] += len(true_set & pred_set)
        entity_metrics[label]["fp"] += len(pred_set - true_set)
        entity_metrics[label]["fn"] += len(true_set - pred_set)

print(f"‚úÖ Processed all {len(test_data)} samples!")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


üîÑ Running evaluation...
  Processed 50/20000 samples...
  Processed 100/20000 samples...
  Processed 150/20000 samples...
  Processed 200/20000 samples...
  Processed 250/20000 samples...
  Processed 300/20000 samples...
  Processed 350/20000 samples...
  Processed 400/20000 samples...
  Processed 450/20000 samples...
  Processed 500/20000 samples...
  Processed 550/20000 samples...
  Processed 600/20000 samples...
  Processed 650/20000 samples...
  Processed 700/20000 samples...
  Processed 750/20000 samples...
  Processed 800/20000 samples...
  Processed 850/20000 samples...
  Processed 900/20000 samples...
  Processed 950/20000 samples...
  Processed 1000/20000 samples...
  Processed 1050/20000 samples...
  Processed 1100/20000 samples...
  Processed 1150/20000 samples...
  Processed 1200/20000 samples...
  Processed 1250/20000 samples...
  Processed 1300/20000 samples...
  Processed 1350/20000 samples...
  Processed 1400/20000 samples...
  Processed 1450/20000 samples...
  Proce

In [9]:
# Cell 6: Calculate Overall Metrics
precision = precision_score(true_labels_list, pred_labels_list)
recall = recall_score(true_labels_list, pred_labels_list)
f1 = f1_score(true_labels_list, pred_labels_list)
accuracy = accuracy_score(true_labels_list, pred_labels_list)

print("=" * 60)
print("üìä OVERALL METRICS")
print("=" * 60)
print(f"Precision: {precision:.4f} ({precision*100:.2f}%)")
print(f"Recall:    {recall:.4f} ({recall*100:.2f}%)")
print(f"F1 Score:  {f1:.4f} ({f1*100:.2f}%)")
print(f"Accuracy:  {accuracy:.4f} ({accuracy*100:.2f}%)")

üìä OVERALL METRICS
Precision: 0.9788 (97.88%)
Recall:    0.9867 (98.67%)
F1 Score:  0.9827 (98.27%)
Accuracy:  0.9971 (99.71%)


In [11]:
# Cell 7: Per-Entity Metrics Table (NO MATPLOTLIB VERSION)
per_entity_data = []

for label in ENTITY_LABELS:
    tp = entity_metrics[label]["tp"]
    fp = entity_metrics[label]["fp"]
    fn = entity_metrics[label]["fn"]
    
    precision_entity = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall_entity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1_entity = 2 * (precision_entity * recall_entity) / (precision_entity + recall_entity) if (precision_entity + recall_entity) > 0 else 0.0
    
    per_entity_data.append({
        "Entity": label,
        "Precision": precision_entity,
        "Recall": recall_entity,
        "F1": f1_entity,
        "Support": tp + fn,
        "TP": tp,
        "FP": fp,
        "FN": fn
    })

# Create DataFrame
df = pd.DataFrame(per_entity_data)
df = df.sort_values("F1", ascending=False)

# Format numbers
df_formatted = df.copy()
for col in ['Precision', 'Recall', 'F1']:
    df_formatted[col] = df_formatted[col].apply(lambda x: f"{x:.4f}")
for col in ['Support', 'TP', 'FP', 'FN']:
    df_formatted[col] = df_formatted[col].apply(lambda x: f"{int(x)}")

print("üìà PER-ENTITY METRICS")
print("=" * 80)
display(df_formatted)

# Also show the numeric version for sorting/filtering
print("\nüìä Numeric version (for analysis):")
display(df)

üìà PER-ENTITY METRICS


Unnamed: 0,Entity,Precision,Recall,F1,Support,TP,FP,FN
21,TRAVELER,1.0,1.0,1.0,347,347,0,0
13,AMENITIES,1.0,1.0,1.0,349,349,0,0
19,PRICE,1.0,1.0,1.0,1147,1147,0,0
17,DEALS,1.0,1.0,1.0,448,448,0,0
4,DEPARTURE_DATE,1.0,1.0,1.0,4850,4850,0,0
5,ARRIVAL_DATE,1.0,1.0,1.0,11,11,0,0
16,COUPON_CODE,1.0,1.0,1.0,537,537,0,0
7,ARRIVAL_TIME,1.0,1.0,1.0,209,209,0,0
11,BUS_TYPE,0.9993,0.9996,0.9995,2802,2801,2,1
6,DEPARTURE_TIME,0.9993,0.9993,0.9993,1499,1498,1,1



üìä Numeric version (for analysis):


Unnamed: 0,Entity,Precision,Recall,F1,Support,TP,FP,FN
21,TRAVELER,1.0,1.0,1.0,347,347,0,0
13,AMENITIES,1.0,1.0,1.0,349,349,0,0
19,PRICE,1.0,1.0,1.0,1147,1147,0,0
17,DEALS,1.0,1.0,1.0,448,448,0,0
4,DEPARTURE_DATE,1.0,1.0,1.0,4850,4850,0,0
5,ARRIVAL_DATE,1.0,1.0,1.0,11,11,0,0
16,COUPON_CODE,1.0,1.0,1.0,537,537,0,0
7,ARRIVAL_TIME,1.0,1.0,1.0,209,209,0,0
11,BUS_TYPE,0.999286,0.999643,0.999465,2802,2801,2,1
6,DEPARTURE_TIME,0.999333,0.999333,0.999333,1499,1498,1,1


In [12]:
# Cell 8: Detailed Classification Report
print("=" * 60)
print("üìã DETAILED CLASSIFICATION REPORT")
print("=" * 60)
print(classification_report(true_labels_list, pred_labels_list))

üìã DETAILED CLASSIFICATION REPORT
                       precision    recall  f1-score   support

              AC_TYPE       0.96      0.96      0.96       698
              ADD_ONS       0.99      0.98      0.98      3967
            AMENITIES       1.00      1.00      1.00       349
         ARRIVAL_DATE       1.00      1.00      1.00        11
         ARRIVAL_TIME       1.00      1.00      1.00       210
         BUS_FEATURES       0.99      0.99      0.99       462
             BUS_TYPE       1.00      1.00      1.00      2803
          COUPON_CODE       1.00      1.00      1.00       538
                DEALS       1.00      1.00      1.00       449
       DEPARTURE_DATE       1.00      1.00      1.00      4850
       DEPARTURE_TIME       1.00      1.00      1.00      1505
DESTINATION_CITY_CODE       0.12      1.00      0.21        34
     DESTINATION_NAME       0.99      0.99      0.99     18413
           DROP_POINT       1.00      1.00      1.00       478
             OPERA

In [14]:
# Cell 9: Sample Predictions (Visual Inspection) - CORRECTED
print("=" * 60)
print("üîç SAMPLE PREDICTIONS")
print("=" * 60)

# Show first 5 samples
for i, sample in enumerate(test_data[:5]):
    # Reconstruct text from tokens
    tokens = sample["tokens"]
    text = " ".join(tokens)
    
    # Get predictions
    pred = ner.extract(text)
    
    # Get ground truth entities from BIO tags
    true_ner_tag_ids = sample["ner_tags"]
    true_bio = [id2label.get(tag_id, "O") for tag_id in true_ner_tag_ids]
    
    # Extract ground truth entities
    def extract_from_bio(tokens, bio_tags):
        entities = {label: [] for label in ENTITY_LABELS}
        current_entity = []
        current_label = None
        
        for token, tag in zip(tokens, bio_tags):
            if tag.startswith("B-"):
                if current_label and current_entity:
                    entities[current_label].append(" ".join(current_entity))
                current_label = tag[2:]
                current_entity = [token]
            elif tag.startswith("I-"):
                label = tag[2:]
                if label == current_label:
                    current_entity.append(token)
                else:
                    if current_label and current_entity:
                        entities[current_label].append(" ".join(current_entity))
                    current_label = label
                    current_entity = [token]
            else:
                if current_label and current_entity:
                    entities[current_label].append(" ".join(current_entity))
                current_label = None
                current_entity = []
        
        if current_label and current_entity:
            entities[current_label].append(" ".join(current_entity))
        
        return entities
    
    true_entities = extract_from_bio(tokens, true_bio)
    
    print(f"\n[{i+1}] Query: {text}")
    print("   Ground Truth:")
    for label, values in true_entities.items():
        if values:
            print(f"     {label}: {values}")
    
    print("   Predicted:")
    for label, values in pred.items():
        if values:
            print(f"     {label}: {values}")
    
    # Show differences
    print("   Comparison:")
    all_labels = set(list(true_entities.keys()) + list(pred.keys()))
    for label in all_labels:
        true_vals = set(true_entities.get(label, []))
        pred_vals = set(pred.get(label, []))
        if true_vals != pred_vals:
            missing = true_vals - pred_vals
            extra = pred_vals - true_vals
            if missing:
                print(f"     ‚ùå Missing {label}: {list(missing)}")
            if extra:
                print(f"     ‚ö†Ô∏è  Extra {label}: {list(extra)}")
        else:
            if true_vals:
                print(f"     ‚úÖ {label}: Correct")

üîç SAMPLE PREDICTIONS

[1] Query: I want to know about the Multi axle bus from NCR Delhi to Imphal , specifically from Near Silk Board Bus Stop .
   Ground Truth:
     SOURCE_NAME: ['NCR Delhi']
     DESTINATION_NAME: ['Imphal']
     PICKUP_POINT: ['Near Silk Board Bus Stop']
     BUS_TYPE: ['Multi axle']
   Predicted:
     SOURCE_NAME: ['NCR Delhi']
     DESTINATION_NAME: ['Imphal']
     PICKUP_POINT: ['Near Silk Board Bus Stop']
     BUS_TYPE: ['Multi axle']
   Comparison:
     ‚úÖ DESTINATION_NAME: Correct
     ‚úÖ BUS_TYPE: Correct
     ‚úÖ PICKUP_POINT: Correct
     ‚úÖ SOURCE_NAME: Correct

[2] Query: find me the bus options from Kullu to Thiruvananthapuram .
   Ground Truth:
     SOURCE_NAME: ['Kullu']
     DESTINATION_NAME: ['Thiruvananthapuram']
   Predicted:
     SOURCE_NAME: ['Kullu']
     DESTINATION_NAME: ['Thiruvananthapuram']
   Comparison:
     ‚úÖ DESTINATION_NAME: Correct
     ‚úÖ SOURCE_NAME: Correct

[3] Query: can you find me bus services from Gachibowli to Impha