# Named Entity Linking (NEL) Pipeline

This notebook implements a comprehensive Named Entity Linking system for the GutBrainIE dataset. The pipeline uses a two-stage approach to link predicted entities to their corresponding URIs:

## **Linking Strategy**

1. **Exact Matching**: Direct text span matching against manually annotated training data
2. **Similarity Matching**: Semantic similarity using PubMedBERT embeddings for entities without exact matches

## **Pipeline Workflow**

1. **Extract Training Knowledge**: Build entity-to-URI mappings from gold standard annotations
2. **Analyze Coverage**: Evaluate exact match coverage on predicted entities  
3. **Semantic Similarity**: Use embedding-based similarity for missed entities
4. **Final Linking**: Combine exact and similarity matches to assign URIs

## **Output**
- Linked entities with URI assignments and confidence sources
- Statistics on exact matches, similarity matches, and unlinked entities
- Final predictions in evaluation format

## **Requirements**
- GPU recommended for similarity matching (txtai + PubMedBERT)
- Pre-generated URI definitions from `generate_definitions.ipynb`

## Step 1: Extract Entity-to-URI Mappings from Training Data

Build exact matching dictionaries from manually annotated data (dev, train_platinum, train_gold sets). We create mappings from normalized text spans to their corresponding URIs.

**Strategy**: For entities with identical text spans but different URIs, we'll use the first URI found (as noted, better handling could be implemented as future work).

### Text Span to URI Mapping (Case-Insensitive)

Extract all entity linkages from training data for exact matching. If a predicted entity's text span (lowercase) matches exactly with training data, we assign the corresponding URI.

**Note**: Some text spans may have multiple URIs in the training data. In such cases, we use the first URI found.

In [None]:
import json

# Define paths to training annotation files  
# Here we are commenting out the dev set to avoid data leakage when running inference on the dev set
# When evaluating on the test set, uncomment this line to include the dev annotations in the knowledge base
#dev_annotations_path = '../../Annotations/Dev/json_format/dev.json'
train_platinum_annotations_path = '../../Annotations/Train/platinum_quality/json_format/train_platinum.json'
train_gold_annotations_path = '../../Annotations/Train/gold_quality/json_format/train_gold.json'

training_annotation_files = [
    #dev_annotations_path, 
    train_platinum_annotations_path,
    train_gold_annotations_path
]

# Build mapping from text spans to URIs (case-insensitive)
text_span_to_uris = {}

print("Processing training annotations to build text span -> URI mappings...")

for annotation_file_path in training_annotation_files:
    with open(annotation_file_path, 'r', encoding='utf-8') as input_file:
        annotation_data = json.load(input_file)

    # Process each document's entities
    for document_id, document_content in annotation_data.items():
        for entity in document_content['entities']:
            normalized_text_span = entity['text_span'].lower()
            entity_label = entity['label']
            entity_uri = entity['uri']
            
            # Initialize set for this text span if not exists
            if normalized_text_span not in text_span_to_uris:
                text_span_to_uris[normalized_text_span] = set()
            
            # Add URI to the set for this text span
            text_span_to_uris[normalized_text_span].add(entity_uri)

# Analyze text spans with multiple URIs
ambiguous_text_spans_count = 0
for text_span, uri_set in text_span_to_uris.items():
    if len(uri_set) > 1:
        print(f'Ambiguous text span: "{text_span}" has multiple URIs: {uri_set}')
        ambiguous_text_spans_count += 1

print(f'\n=== Text Span Analysis ===')
print(f'Total unique text spans: {len(text_span_to_uris)}')
print(f'Text spans with multiple URIs: {ambiguous_text_spans_count}')

### Text Span + Label to URI Mapping

Create a more precise mapping that considers both text span AND entity label. This provides additional specificity for disambiguation when the same text can represent different entity types.

In [None]:
# Build mapping from (text_span, label) tuples to URIs
text_span_label_to_uris = {}

print("Building (text_span, label) -> URI mappings...")

for annotation_file_path in training_annotation_files:
    with open(annotation_file_path, 'r', encoding='utf-8') as input_file:
        annotation_data = json.load(input_file)

    # Process each document's entities
    for document_id, document_content in annotation_data.items():
        for entity in document_content['entities']:
            normalized_text_span = entity['text_span'].lower()
            entity_label = entity['label']
            entity_uri = entity['uri']
            
            # Create composite key from text span and label
            composite_key = (normalized_text_span, entity_label)
            
            # Initialize set for this composite key if not exists
            if composite_key not in text_span_label_to_uris:
                text_span_label_to_uris[composite_key] = set()
            
            # Add URI to the set for this composite key
            text_span_label_to_uris[composite_key].add(entity_uri)

# Analyze composite keys with multiple URIs
ambiguous_composite_keys_count = 0
for (text_span, label), uri_set in text_span_label_to_uris.items():
    if len(uri_set) > 1:
        print(f'Ambiguous: "{text_span}" (label: {label}) has multiple URIs: {uri_set}')
        ambiguous_composite_keys_count += 1

print(f'\n=== Text Span + Label Analysis ===')
print(f'Total unique (text_span, label) pairs: {len(text_span_label_to_uris)}')
print(f'Pairs with multiple URIs: {ambiguous_composite_keys_count}')

## Step 2: Analyze Exact Match Coverage on Predictions

Evaluate how many predicted entities can be linked using exact matches from our training data mappings. We test both text-only and text+label matching strategies.

### Coverage Analysis: Text + Label Matching

Test exact match coverage using the more precise (text_span, label) mapping approach.

In [None]:
# Load predicted entities for coverage analysis
predictions_file_path = '../../Predictions/NER/baseline_predicted_entities_eval_format.json'
with open(predictions_file_path, 'r', encoding='utf-8') as input_file:
    predicted_entities_data = json.load(input_file)

# Analyze coverage using (text_span, label) matching
exact_matches_with_label = 0
missed_entities_with_label = 0
missed_entities_for_similarity = {}
missed_entity_index = 0

print("Analyzing exact match coverage with (text_span, label) approach...")

for document_id, document_content in predicted_entities_data.items():
    for predicted_entity in document_content['entities']:
        normalized_text_span = predicted_entity['text_span'].lower()
        entity_label = predicted_entity['label']
        composite_key = (normalized_text_span, entity_label)
        
        if composite_key in text_span_label_to_uris:
            exact_matches_with_label += 1
        else:
            missed_entities_with_label += 1
            # Store missed entity for similarity matching
            missed_entities_for_similarity[missed_entity_index] = (
                predicted_entity['text_span'], 
                entity_label
            )
            missed_entity_index += 1

print(f'\n=== Coverage Analysis (Text + Label) ===')
print(f'Exact matches: {exact_matches_with_label}')
print(f'Missed entities: {missed_entities_with_label}')
print(f'Coverage rate: {(exact_matches_with_label / (exact_matches_with_label + missed_entities_with_label) * 100):.1f}%')

# Save missed entities for similarity matching
MISSED_ENTITIES_FILE = 'missed_ents_text_and_label.json'
with open(MISSED_ENTITIES_FILE, 'w', encoding='utf-8') as output_file:
    json.dump(missed_entities_for_similarity, output_file, indent=4)

print(f'\nSaved {len(missed_entities_for_similarity)} missed entities to {MISSED_ENTITIES_FILE}')

### Coverage Analysis: Text-Only Matching + Missed Entity Collection

Test coverage using text-only matching (more permissive) and collect entities that cannot be matched for similarity-based linking.

In [None]:
# Analyze coverage using text-only matching and collect missed entities
exact_matches_text_only = 0
missed_entities_text_only = 0
missed_entities_for_similarity = {}
missed_entity_index = 0

print("Analyzing exact match coverage with text-only approach...")

for document_id, document_content in predicted_entities_data.items():
    for predicted_entity in document_content['entities']:
        normalized_text_span = predicted_entity['text_span'].lower()
        entity_label = predicted_entity['label']
        
        if normalized_text_span in text_span_to_uris:
            exact_matches_text_only += 1
        else:
            missed_entities_text_only += 1
            # Store missed entity for similarity matching
            missed_entities_for_similarity[missed_entity_index] = (
                predicted_entity['text_span'], 
                entity_label
            )
            missed_entity_index += 1

print(f'\n=== Coverage Analysis (Text Only) ===')
print(f'Exact matches: {exact_matches_text_only}')
print(f'Missed entities: {missed_entities_text_only}')
print(f'Coverage rate: {(exact_matches_text_only / (exact_matches_text_only + missed_entities_text_only) * 100):.1f}%')

# Save missed entities for similarity matching
MISSED_ENTITIES_FILE = 'missed_ents_text_only.json'
with open(MISSED_ENTITIES_FILE, 'w', encoding='utf-8') as output_file:
    json.dump(missed_entities_for_similarity, output_file, indent=4)

print(f'\nSaved {len(missed_entities_for_similarity)} missed entities to {MISSED_ENTITIES_FILE}')

## Step 3: Semantic Similarity Matching

For entities that couldn't be matched exactly, use semantic similarity with PubMedBERT embeddings to find the most similar URI definitions from our knowledge base.

**Process**:
1. Load pre-generated URI definitions 
2. Build embedding index using txtai + PubMedBERT
3. Query missed entities against the definition corpus
4. Return top-10 most similar definitions for each missed entity

### Build Embedding Index and Perform Similarity Search

⚠️ **GPU Recommended**: This step uses PubMedBERT embeddings which benefit significantly from GPU acceleration.

In [None]:
import txtai
from tqdm import tqdm
import os

print("Initializing PubMedBERT embeddings model...")
# Initialize txtai embeddings with PubMedBERT (biomedical domain-specific)
embeddings_model = txtai.Embeddings(path="neuml/pubmedbert-base-embeddings", content=True)

# Load URI definitions for embedding indexing
URI_DEFINITIONS_FILE = 'definitions/split_uri_definitions.json'
with open(URI_DEFINITIONS_FILE, 'r', encoding='utf-8') as input_file:
    uri_definitions_data = json.load(input_file)

print(f"Loaded {len(uri_definitions_data)} URI definitions")

if os.path.exists("embeddings_index"):
    print("Embeddings index already exists. Loading existing index...")
    embeddings_model.load("embeddings_index")
else:
    # Prepare data for indexing: (id, definition_text) tuples
    definitions_for_indexing = []
    for definition_id, definition_text in uri_definitions_data.items():
        definitions_for_indexing.append((definition_id, definition_text))

    print("Building embedding index... (this may take several minutes)")
    # Build the embedding index
    embeddings_model.index(definitions_for_indexing)

    # Save the index for future use
    print("Saving embedding index...")
    embeddings_model.save("embeddings_index") # Save in directory format
    embeddings_model.save("embeddings_index.tar.gz") # Also save as tar.gz for compatibility

# Load missed entities for similarity matching
MISSED_ENTITIES_FILE = 'missed_ents_text_only.json'
with open(MISSED_ENTITIES_FILE, 'r', encoding='utf-8') as input_file:
    missed_entities_data = json.load(input_file)

print(f"Performing similarity search for {len(missed_entities_data)} missed entities...")

# Perform similarity search for each missed entity
similarity_search_results = {}
for entity_id, (text_span, entity_label) in tqdm(missed_entities_data.items(), desc="Similarity matching"):
    # Search for top 10 most similar definitions
    search_results = embeddings_model.search(text_span, 10)
    
    similarity_search_results[entity_id] = {
        'text_span': text_span,
        'label': entity_label,
        'similarity_results': search_results
    }

# Save similarity results
SIMILARITY_RESULTS_FILE = 'similarity_matching_results.json'
with open(SIMILARITY_RESULTS_FILE, 'w', encoding='utf-8') as output_file:
    json.dump(similarity_search_results, output_file, ensure_ascii=False, indent=4)

print(f"Similarity matching completed. Results saved to {SIMILARITY_RESULTS_FILE}")

## Step 4: Analyze Similarity Matching Results

Process and analyze the similarity matching results to understand the quality and distribution of matches.

### Data Format Conversion and Statistical Analysis

Convert similarity results to different formats and analyze the distribution of similarity scores.

In [None]:
import json
import pickle

def convert_json_to_pickle(json_file_path, pickle_file_path):
    """Convert JSON file to pickle format for faster loading."""
    with open(json_file_path, 'r', encoding='utf-8') as input_file:
        data = json.load(input_file)
    
    with open(pickle_file_path, 'wb') as output_file:
        pickle.dump(data, output_file)
    
    print(f"Converted {json_file_path} to {pickle_file_path}")

def load_pickle_data(pickle_file_path):
    """Load data from pickle file."""
    with open(pickle_file_path, 'rb') as input_file:
        data = pickle.load(input_file)
    return data

# Convert similarity results to pickle format
SIMILARITY_RESULTS_JSON = 'similarity_matching_results.json'
SIMILARITY_RESULTS_PICKLE = 'similarity_matching_results.pkl'

convert_json_to_pickle(SIMILARITY_RESULTS_JSON, SIMILARITY_RESULTS_PICKLE)
similarity_results_dict = load_pickle_data(SIMILARITY_RESULTS_PICKLE)

In [None]:
from pprint import pprint

# Analyze similarity score distributions
total_similarity_score = 0
min_similarity_score = float('inf')
max_similarity_score = float('-inf')
total_entities_count = 0
all_top_scores = []

print("Analyzing similarity score distributions...")

# Collect top similarity scores for each entity
for entity_id, entity_data in similarity_results_dict.items():
    # Get the top similarity result (index 0)
    top_similarity_score = entity_data['similarity_results'][0]['score']
    
    # Update statistics
    total_similarity_score += top_similarity_score
    min_similarity_score = min(min_similarity_score, top_similarity_score)
    max_similarity_score = max(max_similarity_score, top_similarity_score)
    all_top_scores.append(top_similarity_score)
    total_entities_count += 1

# Sort scores for percentile analysis
all_top_scores.sort()

# Calculate average
average_similarity_score = total_similarity_score / total_entities_count if total_entities_count > 0 else 0

print(f"\n=== Similarity Score Analysis ===")
print(f"Total entities analyzed: {total_entities_count}")

# Show distribution of scores at different thresholds
print(f"\n=== Score Distribution Analysis ===")
for threshold_percent in range(10, 100, 10):
    threshold_value = threshold_percent / 100
    entities_below_threshold = sum(1 for score in all_top_scores if score < threshold_value)
    percentage_below = (entities_below_threshold / total_entities_count) * 100 if total_entities_count > 0 else 0
    
    print(f"Entities with similarity < {threshold_value:.2f}: {entities_below_threshold} ({percentage_below:.2f}%)")

print(f"\n=== Summary Statistics ===")
pprint({
    "average_similarity_score": round(average_similarity_score, 4),
    "min_similarity_score": round(min_similarity_score, 4),
    "max_similarity_score": round(max_similarity_score, 4),
    "median_score": round(all_top_scores[len(all_top_scores)//2], 4) if all_top_scores else 0
})

### Process Similarity Results for Entity Linking

Extract the top similarity match for each entity and prepare the mapping for final linking.

In [None]:
# Load similarity results
SIMILARITY_RESULTS_FILE = 'similarity_matching_results.json'
with open(SIMILARITY_RESULTS_FILE, 'r', encoding='utf-8') as input_file:
    similarity_results = json.load(input_file)

# Create mapping from text spans to top similarity match definition IDs
text_span_to_definition_id = {}

print("Processing similarity results to extract top matches...")

for entity_id, entity_data in similarity_results.items():
    entity_text_span = entity_data['text_span']
    top_similarity_result = entity_data['similarity_results'][0]  # Get top match
    top_definition_id = top_similarity_result['id']
    
    # Map text span to the ID of the most similar definition
    text_span_to_definition_id[entity_text_span] = top_definition_id

# Save processed similarity mappings
PROCESSED_SIMILARITY_FILE = "processed_similarity_res.json"
with open(PROCESSED_SIMILARITY_FILE, "w", encoding="utf-8") as output_file:
    json.dump(text_span_to_definition_id, output_file, indent=2)

print(f"Processed similarity mappings for {len(text_span_to_definition_id)} entities")
print(f"Saved to {PROCESSED_SIMILARITY_FILE}")

## Step 5: Final Entity Linking

Combine exact matching and similarity matching to assign URIs to all predicted entities. The linking priority is:

1. **Exact Match**: Direct text span match from training data
2. **Similarity Match**: Best semantic match from embedding search  
3. **No Match**: Assign 'NA' if neither method finds a suitable URI

Each entity will be tagged with its URI source for transparency.

### Apply Linking Strategy to All Predicted Entities

Execute the complete linking pipeline on all predicted entities and generate the final output with linking statistics.

In [None]:
# Initialize linking statistics counters
exact_match_count = 0
similarity_match_count = 0
no_match_count = 0

# Load processed similarity mappings
PROCESSED_SIMILARITY_FILE = 'processed_similarity_res.json'
with open(PROCESSED_SIMILARITY_FILE, 'r', encoding='utf-8') as input_file:
    text_span_to_similarity_definition_id = json.load(input_file)

# Load definition ID to URI mapping
DEFINITION_ID_TO_URI_FILE = "definitions/id_to_uri.json"
with open(DEFINITION_ID_TO_URI_FILE, "r", encoding="utf-8") as input_file:
    definition_id_to_uri = json.load(input_file)

print("Loading training data for exact matching...")

# Rebuild exact matching dictionary (text span -> URI) from training data
# Here we are commenting out the dev set to avoid data leakage when running inference on the dev set
# When evaluating on the test set, uncomment this line to include the dev annotations in the knowledge base
#dev_annotations_path = '../../Annotations/Dev/json_format/dev.json'
train_platinum_annotations_path = '../../Annotations/Train/platinum_quality/json_format/train_platinum.json'
train_gold_annotations_path = '../../Annotations/Train/gold_quality/json_format/train_gold.json'

training_annotation_files = [
    #dev_annotations_path, 
    train_platinum_annotations_path, 
    train_gold_annotations_path
]

# Rebuild exact matching mappings
exact_text_span_to_uris = {}
for annotation_file_path in training_annotation_files:
    with open(annotation_file_path, 'r', encoding='utf-8') as input_file:
        annotation_data = json.load(input_file)

    for document_id, document_content in annotation_data.items():
        for entity in document_content['entities']:
            normalized_text_span = entity['text_span'].lower()
            entity_uri = entity['uri']
            
            if normalized_text_span not in exact_text_span_to_uris:
                exact_text_span_to_uris[normalized_text_span] = set()
            exact_text_span_to_uris[normalized_text_span].add(entity_uri)

# Load predicted entities for final linking
predictions_file_path = "../../Predictions/NER/baseline_predicted_entities_eval_format.json" 
with open(predictions_file_path, "r", encoding="utf-8") as input_file:
    final_predictions = json.load(input_file)

print("Applying entity linking to all predicted entities...")

# Apply linking strategy to each predicted entity
for document_id, document_content in final_predictions.items():
    for predicted_entity in document_content['entities']:
        entity_text_span = predicted_entity['text_span']
        normalized_text_span = entity_text_span.lower()
        
        # Strategy 1: Try exact matching first
        if normalized_text_span in exact_text_span_to_uris:
            exact_match_count += 1
            # Use first URI from the set (as noted in original approach)
            assigned_uri = list(exact_text_span_to_uris[normalized_text_span])[0]
            predicted_entity['uri'] = assigned_uri
            predicted_entity['uri_source'] = 'exact_match'
            
        # Strategy 2: Try similarity matching
        elif entity_text_span in text_span_to_similarity_definition_id:
            similarity_match_count += 1
            definition_id = text_span_to_similarity_definition_id[entity_text_span]
            assigned_uri = definition_id_to_uri[definition_id]
            predicted_entity['uri'] = assigned_uri
            predicted_entity['uri_source'] = 'similarity_match'
            
        # Strategy 3: No match found
        else:
            no_match_count += 1
            predicted_entity['uri'] = 'NA'
            predicted_entity['uri_source'] = 'no_match'

# Display final linking statistics
total_entities = exact_match_count + similarity_match_count + no_match_count
print(f'\n=== Final Entity Linking Results ===')
print(f'Total entities processed: {total_entities}')
print(f'Exact matches: {exact_match_count} ({(exact_match_count/total_entities*100):.1f}%)')
print(f'Similarity matches: {similarity_match_count} ({(similarity_match_count/total_entities*100):.1f}%)')
print(f'No matches (NA): {no_match_count} ({(no_match_count/total_entities*100):.1f}%)')
print(f'Overall linking rate: {((exact_match_count + similarity_match_count)/total_entities*100):.1f}%')

# Save final linked predictions
FINAL_PREDICTIONS_FILE = '../../Predictions/NEL/baseline_predicted_entities_eval_format.json'
with open(FINAL_PREDICTIONS_FILE, 'w', encoding='utf-8') as output_file:
    json.dump(final_predictions, output_file, indent=4)

print(f'\nFinal linked predictions saved to: {FINAL_PREDICTIONS_FILE}')