## Load Training Data and Extract Entities

In [None]:
import json
import os
from collections import defaultdict

# Load training data from three quality levels
train_files = [
    "../data/Annotations/Train/gold_quality/json_format/train_gold.json",
    "../data/Annotations/Train/platinum_quality/json_format/train_platinum.json",
    "../data/Annotations/Train/silver_quality/json_format/train_silver.json"
]

# Dictionary to store entities: {text_span: label}
# If same text has multiple labels, we'll keep all possibilities
entity_dict = defaultdict(set)

for train_file in train_files:
    if os.path.exists(train_file):
        with open(train_file, "r", encoding="utf-8") as f:
            train_data = json.load(f)
        
        for pmid, article in train_data.items():
            for entity in article['entities']:
                text_span = entity['text_span']
                label = entity['label']
                entity_dict[text_span].add(label)
        
        print(f"Loaded {train_file}")
    else:
        print(f"Warning: {train_file} not found")

# Convert to list of (text, labels) tuples for easier processing
entity_list = [(text, list(labels)) for text, labels in entity_dict.items()]

# Sort by length (longest first) to prioritize longer matches
entity_list.sort(key=lambda x: len(x[0]), reverse=True)

print(f"\nTotal unique entity text spans: {len(entity_list)}")
print(f"\nExample entities:")
for text, labels in entity_list[:10]:
    print(f"  '{text}' -> {labels}")

## Load Dev Data

In [None]:
# Load dev data
dev_data_path = "../data/Annotations/Dev/json_format/dev.json"

with open(dev_data_path, "r", encoding="utf-8") as f:
    dev_data = json.load(f)

print(f"Loaded {len(dev_data)} documents from dev set")

## Predict Entities on Dev Set

In [None]:
from tqdm import tqdm

def find_all_occurrences(text, entity_text):
    """
    Find all occurrences of entity_text in text.
    Returns list of (start_idx, end_idx) tuples.
    """
    occurrences = []
    start = 0
    while True:
        pos = text.find(entity_text, start)
        if pos == -1:
            break
        occurrences.append((pos, pos + len(entity_text)))
        start = pos + 1
    return occurrences

def remove_overlapping_entities(entities):
    """
    Remove overlapping entities, keeping only the longest span.
    entities: list of dicts with start_idx, end_idx, text_span, label, location
    """
    if not entities:
        return []
    
    # Sort by start position, then by length (longest first)
    entities = sorted(entities, key=lambda x: (x['start_idx'], -(x['end_idx'] - x['start_idx'])))
    
    kept = []
    for entity in entities:
        # Check if this entity overlaps with any kept entity
        overlaps = False
        for kept_entity in kept:
            # Check if same location and overlapping spans
            if entity['location'] == kept_entity['location']:
                # Check for overlap
                if not (entity['end_idx'] <= kept_entity['start_idx'] or 
                       entity['start_idx'] >= kept_entity['end_idx']):
                    overlaps = True
                    break
        
        if not overlaps:
            kept.append(entity)
    
    return kept

# Process each document in dev set
predictions = {}

for pmid, article in tqdm(dev_data.items(), desc="Processing dev data"):
    title = article['metadata']['title']
    abstract = article['metadata']['abstract']
    
    predicted_entities = []
    
    # Search for entities in title and abstract
    for location, text in [('title', title), ('abstract', abstract)]:
        # Try to match each entity from training data
        for entity_text, labels in entity_list:
            occurrences = find_all_occurrences(text, entity_text)
            
            for start_idx, end_idx in occurrences:
                # For each label associated with this entity text
                for label in labels:
                    predicted_entities.append({
                        "start_idx": start_idx,
                        "end_idx": end_idx-1, # inclusive end index as per organizers specification
                        "location": location,
                        "text_span": entity_text,
                        "label": label
                    })
    
    # Remove overlapping entities (keep longest)
    predicted_entities = remove_overlapping_entities(predicted_entities)
    
    # Store predictions
    predictions[pmid] = {
        "entities": predicted_entities
    }

print(f"\nProcessed {len(predictions)} documents")
print(f"Total entities predicted: {sum(len(p['entities']) for p in predictions.values())}")

## Save Predictions

In [None]:
# Save predictions to file
output_path = "predictions/vanilla_ner_predictions.json"

with open(output_path, "w", encoding="utf-8") as f:
    json.dump(predictions, f, ensure_ascii=False, indent=2)

print(f"Predictions saved to {output_path}")

## Evaluate Performance

In [None]:
# Self-contained evaluation function (adapted from official script)
# Avoids importing evaluate.py which has hardcoded paths

def remove_duplicated_entities(predictions):
    """Remove duplicated entities from predictions."""
    removed_count = 0
    for pmid in list(predictions.keys()):
        seen = set()
        deduped = []
        for ent in predictions[pmid]["entities"]:
            key = (ent["start_idx"], ent["end_idx"], ent["location"])
            if key not in seen:
                seen.add(key)
                deduped.append(ent)
            else:
                removed_count += 1
        predictions[pmid]["entities"] = deduped
    
    if removed_count > 0:
        print(f"Removed {removed_count} duplicated entities from predictions")

def remove_overlapping_entities_eval(predictions):
    """Remove overlapping entities, keeping longest spans."""
    removed_count = 0

    for pmid in list(predictions.keys()):
        original_len = len(predictions[pmid]['entities'])
        
        # Group entities by location
        groups = {'title': [], 'abstract': []}
        for ent in predictions[pmid]['entities']:
            loc = ent["location"]
            groups[loc].append(ent)

        # For each location, build overlap clusters and select the longest
        keepers = set()
        for loc in groups:
            group = groups[loc]
            group = sorted(group, key=lambda e: e["start_idx"])

            clusters = []
            cluster = []
            current_end = None

            for ent in group:
                if not cluster:
                    cluster = [ent]
                    current_end = ent["end_idx"]
                else:
                    if ent["start_idx"] < current_end:
                        cluster.append(ent)
                        if ent["end_idx"] > current_end:
                            current_end = ent["end_idx"]
                    else:
                        clusters.append(cluster)
                        cluster = [ent]
                        current_end = ent["end_idx"]
            if cluster:
                clusters.append(cluster)

            # Pick the longest entity in each cluster
            for clust in clusters:
                longest = clust[0]
                max_len = longest["end_idx"] - longest["start_idx"]
                for ent in clust[1:]:
                    length = ent["end_idx"] - ent["start_idx"]
                    if length > max_len:
                        longest = ent
                        max_len = length
                keepers.add((longest["start_idx"],
                             longest["end_idx"],
                             longest["location"]))

        # Rebuild the entity list
        deduped = []
        for ent in predictions[pmid]['entities']:
            key = (ent["start_idx"], ent["end_idx"], ent["location"])
            if key in keepers:
                deduped.append(ent)
                keepers.remove(key)

        predictions[pmid]["entities"] = deduped
        removed_count += (original_len - len(deduped))

    if removed_count > 0:
        print(f"Removed {removed_count} overlapping entities")

def evaluate_ner(predictions, ground_truth):
    """
    Evaluate NER predictions against ground truth.
    Based on the official evaluation script.
    """
    # Remove duplicated and overlapping entities
    remove_duplicated_entities(predictions)
    remove_overlapping_entities_eval(predictions)
    
    LEGAL_ENTITY_LABELS = [
        "anatomical location",
        "animal",
        "bacteria",
        "biomedical technique",
        "chemical",
        "DDF",
        "dietary supplement",
        "drug",
        "food",
        "gene",
        "human",
        "microbiome",
        "statistical technique"
    ]
    
    ground_truth_NER = dict()
    count_annotated_entities_per_label = {}
    
    for pmid, article in ground_truth.items():
        if pmid not in ground_truth_NER:
            ground_truth_NER[pmid] = []
        for entity in article['entities']:
            start_idx = int(entity["start_idx"])
            end_idx = int(entity["end_idx"])
            location = str(entity["location"])
            text_span = str(entity["text_span"])
            label = str(entity["label"]) 
            
            entry = (start_idx, end_idx, location, text_span, label)
            ground_truth_NER[pmid].append(entry)
            
            if label not in count_annotated_entities_per_label:
                count_annotated_entities_per_label[label] = 0
            count_annotated_entities_per_label[label] += 1

    count_predicted_entities_per_label = {label: 0 for label in list(count_annotated_entities_per_label.keys())}
    count_true_positives_per_label = {label: 0 for label in list(count_annotated_entities_per_label.keys())}

    for pmid in predictions.keys():
        entities = predictions[pmid]['entities']
        
        for entity in entities:
            start_idx = int(entity["start_idx"])
            end_idx = int(entity["end_idx"])
            location = str(entity["location"])
            text_span = str(entity["text_span"])
            label = str(entity["label"]) 
            
            if label not in LEGAL_ENTITY_LABELS:
                print(f'Warning: Illegal label {label} for entity: {entity}')
                continue

            if label in count_predicted_entities_per_label:
                count_predicted_entities_per_label[label] += 1

            entry = (start_idx, end_idx, location, text_span, label)
            if entry in ground_truth_NER[pmid]:
                count_true_positives_per_label[label] += 1

    count_annotated_entities = sum(count_annotated_entities_per_label[label] for label in list(count_annotated_entities_per_label.keys()))
    count_predicted_entities = sum(count_predicted_entities_per_label[label] for label in list(count_annotated_entities_per_label.keys()))
    count_true_positives = sum(count_true_positives_per_label[label] for label in list(count_annotated_entities_per_label.keys()))

    micro_precision = count_true_positives / (count_predicted_entities + 1e-10)
    micro_recall = count_true_positives / (count_annotated_entities + 1e-10)
    micro_f1 = 2 * ((micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-10))

    precision, recall, f1 = 0, 0, 0
    n = 0
    for label in list(count_annotated_entities_per_label.keys()):
        n += 1
        current_precision = count_true_positives_per_label[label] / (count_predicted_entities_per_label[label] + 1e-10) 
        current_recall = count_true_positives_per_label[label] / (count_annotated_entities_per_label[label] + 1e-10) 
        
        precision += current_precision
        recall += current_recall
        f1 += 2 * ((current_precision * current_recall) / (current_precision + current_recall + 1e-10))
    
    precision = precision / n
    recall = recall / n
    f1 = f1 / n

    return precision, recall, f1, micro_precision, micro_recall, micro_f1

# Evaluate
precision, recall, f1, micro_precision, micro_recall, micro_f1 = evaluate_ner(predictions, dev_data)

print("="*60)
print("VANILLA NER BASELINE RESULTS")
print("="*60)
print("\nMacro-averaged Metrics:")
print(f"  Macro-Precision: {precision:.4f}")
print(f"  Macro-Recall:    {recall:.4f}")
print(f"  Macro-F1 Score:  {f1:.4f}")

print("\nMicro-averaged Metrics:")
print(f"  Micro-Precision: {micro_precision:.4f}")
print(f"  Micro-Recall:    {micro_recall:.4f}")
print(f"  Micro-F1 Score:  {micro_f1:.4f}")
print("="*60)

## Example Predictions

In [None]:
# Show example predictions
print("Example Predictions:\n")

# Get first few documents
sample_pmids = list(dev_data.keys())[:5]

for pmid in sample_pmids:
    article = dev_data[pmid]
    pred = predictions[pmid]
    
    print(f"Document PMID: {pmid}")
    print(f"Title: {article['metadata']['title'][:100]}...")
    print(f"\nGold entities: {len(article['entities'])}")
    print(f"Predicted entities: {len(pred['entities'])}")
    
    # Show first few predicted entities
    print("\nSample predictions:")
    for entity in pred['entities'][:5]:
        print(f"  - '{entity['text_span']}' [{entity['label']}] in {entity['location']}")
    
    # Calculate match statistics for this document
    gold_set = set()
    for entity in article['entities']:
        gold_set.add((
            entity['start_idx'],
            entity['end_idx'],
            entity['location'],
            entity['text_span'],
            entity['label']
        ))
    
    pred_set = set()
    for entity in pred['entities']:
        pred_set.add((
            entity['start_idx'],
            entity['end_idx'],
            entity['location'],
            entity['text_span'],
            entity['label']
        ))
    
    correct = len(gold_set & pred_set)
    missed = len(gold_set - pred_set)
    wrong = len(pred_set - gold_set)
    
    print(f"\n✓ Correct: {correct}")
    print(f"✗ Missed: {missed}")
    print(f"✗ Wrong: {wrong}")
    print("-" * 80)
    print()

## Analysis: Entity Distribution by Label

In [None]:
from collections import Counter

# Count entities by label in predictions
pred_label_counts = Counter()
for pmid, pred in predictions.items():
    for entity in pred['entities']:
        pred_label_counts[entity['label']] += 1

# Count entities by label in gold standard
gold_label_counts = Counter()
for pmid, article in dev_data.items():
    for entity in article['entities']:
        gold_label_counts[entity['label']] += 1

print("Entity Distribution by Label:")
print("="*60)
print(f"{'Label':<25} {'Gold':<10} {'Predicted':<10}")
print("-"*60)

all_labels = set(gold_label_counts.keys()) | set(pred_label_counts.keys())
for label in sorted(all_labels):
    print(f"{label:<25} {gold_label_counts[label]:<10} {pred_label_counts[label]:<10}")

print("-"*60)
print(f"{'TOTAL':<25} {sum(gold_label_counts.values()):<10} {sum(pred_label_counts.values()):<10}")