# Knowledge Graph Embedding

This notebook demonstrates how to create and evaluate knowledge graph embeddings from the knowledge graph constructed in the previous step. It includes:

1. Data augmentation to address limited entity count
2. Knowledge graph embedding training
3. Evaluation and visualization of embeddings
4. Link prediction

## Setup

First, let's set up the environment and import the necessary libraries.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Import standard libraries
import os
import sys
import json
import logging
from pprint import pprint
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

In [None]:
# Add project root to path for importing local modules
# Adjust this path if needed
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

# Import project modules
from src.knowledge_graph.augmentation import KnowledgeGraphAugmenter, augment_knowledge_graph
from src.knowledge_graph.embeddings import KnowledgeGraphEmbedder, create_and_train_embeddings

# Create output directories
os.makedirs('output/data', exist_ok=True)
os.makedirs('output/models', exist_ok=True)
os.makedirs('output/embeddings', exist_ok=True)
os.makedirs('output/visualization', exist_ok=True)

## 1. Load and Examine the Original Knowledge Graph

First, let's load the knowledge graph we constructed in the previous step and examine its properties.

In [None]:
from rdflib import Graph

# Path to the original knowledge graph
# Adjust this path to match your file
original_kg_path = './output/data/knowledge_graph.ttl'

# Load the graph
original_graph = Graph()
original_graph.parse(original_kg_path, format='turtle')

print(f"Loaded knowledge graph with {len(original_graph)} triples")

# Count entities by type
from rdflib.namespace import RDF

entity_types = {}
for s, p, o in original_graph.triples((None, RDF.type, None)):
    entity_type = str(o).split('/')[-1]
    if entity_type not in entity_types:
        entity_types[entity_type] = 0
    entity_types[entity_type] += 1

print("\nEntity types and counts:")
for entity_type, count in entity_types.items():
    print(f"- {entity_type}: {count}")

# Count relations
relations = {}
for s, p, o in original_graph:
    if p != RDF.type:
        relation = str(p).split('/')[-1]
        if relation not in relations:
            relations[relation] = 0
        relations[relation] += 1

print(f"\nNumber of distinct relations: {len(relations)}")
print("Top 10 most common relations:")
for relation, count in sorted(relations.items(), key=lambda x: x[1], reverse=True)[:10]:
    print(f"- {relation}: {count}")

## 2. Data Augmentation

To address the challenge of limited entity count, we'll augment our knowledge graph with external knowledge from DBpedia and Wikidata.

In [None]:
# Path for the augmented knowledge graph
augmented_kg_path = './output/data/knowledge_graph_augmented.ttl'

# Create an augmenter
augmenter = KnowledgeGraphAugmenter(namespace="http://example.org/graphify/")

# Load the original knowledge graph
augmenter.load_graph(original_kg_path)

# Extract entities
entities = augmenter.extract_entities()
print(f"Extracted {len(entities)} entities from the knowledge graph")

# For demonstration, we'll enrich a subset of entities
# In a real scenario, you might want to enrich all entities
import random
random.seed(42)  # For reproducibility

# Select persons and organizations preferentially
persons = [e for e in entities if e[1] == 'PERSON']
organizations = [e for e in entities if e[1] == 'ORG']
locations = [e for e in entities if e[1] == 'LOC']

# Determine how many of each to select
max_entities = 800
person_count = min(len(persons), max_entities // 3)
location_count = min(len(locations), max_entities // 3)
org_count = min(len(organizations), max_entities - person_count - location_count)

# Select the entities
selected_persons = random.sample(persons, person_count) if person_count > 0 else []
selected_locs = random.sample(locations, location_count) if location_count> 0 else []
selected_orgs = random.sample(organizations, org_count) if org_count > 0 else []

# Combine them
selected_entities = selected_persons + selected_locs + selected_orgs

print(f"Selected {len(selected_entities)} entities for enrichment:")
for entity_uri, entity_type in selected_entities[:5]:  # Show first 5
    # Get entity label
    entity_text, _ = augmenter.get_entity_metadata(entity_uri)
    print(f"- {entity_text} ({entity_type})")
if len(selected_entities) > 5:
    print(f"  ... and {len(selected_entities) - 5} more")

In [None]:
# Enrich selected entities
print("Enriching entities with DBpedia and Wikidata knowledge...")
for i, (entity_uri, entity_type) in enumerate(selected_entities):
    print(f"\nProcessing entity {i+1}/{len(selected_entities)}: {entity_uri}")
    added = augmenter.enrich_entity(
        entity_uri,
        use_dbpedia=True,
        use_wikidata=False, # Doesn't work for some reason...
        connection_depth=1
    )
    print(f"Added {added} triples for this entity")

# Save the augmented knowledge graph
augmenter.save_graph(augmented_kg_path)

# Get statistics about the enrichment
stats = augmenter.compute_enrichment_stats()

print(f"\nAugmentation complete!")
print(f"Original triples: {len(original_graph)}")
print(f"Added triples: {stats['total_triples'] - len(original_graph)}")
print(f"Total triples in augmented graph: {stats['total_triples']}")

## Visualizing

In [None]:
from rdflib import Graph, URIRef, RDF, RDFS
from src.knowledge_graph.builder import KnowledgeGraphBuilder
import networkx as nx
from pyvis.network import Network

# Path to your augmented knowledge graph file
augmented_kg_path = './output/data/knowledge_graph_augmented.ttl'

# Load the graph
g = Graph()
g.parse(augmented_kg_path, format='turtle')

# Extract entity types and labels
entity_types = {}
entity_labels = {}

for s, p, o in g.triples((None, RDF.type, None)):
    entity_types[str(s)] = str(o).split('/')[-1]

for s, p, o in g.triples((None, RDFS.label, None)):
    entity_labels[str(s)] = str(o)

# Create a fresh NetworkX graph
nx_graph = nx.DiGraph()

# Add nodes with proper attributes
for entity, entity_type in entity_types.items():
    label = entity_labels.get(entity, entity.split('/')[-1])
    nx_graph.add_node(entity, label=label, title=f"{label} ({entity_type})", type=entity_type)

# Add edges without RDF.type predicates
for s, p, o in g:
    if p != RDF.type and p != RDFS.label:  # Skip type and label triples
        s_str, p_str, o_str = str(s), str(p), str(o)
        
        # Only add edges between nodes that exist
        if s_str in nx_graph and o_str in nx_graph:
            # Get a shorter predicate name for display
            p_label = p_str.split('/')[-1]
            # Add edge with label and title
            nx_graph.add_edge(s_str, o_str, label=p_label, title=p_label)

# Create visualization with custom settings
output_path = './output/visualization/knowledge_graph_augmented.html'

# Create a PyVis network
net = Network(notebook=True, directed=True, height="750px", width="100%")
net.from_nx(nx_graph)

# Map entity types to colors
color_map = {
    'PERSON': '#a8e6cf',
    'ORG': '#ff8b94',
    'GPE': '#ffd3b6',
    'LOC': '#dcedc1',
    'DATE': '#f9f9f9',
    'MISC': '#d4a5a5'
}

# Update node colors and sizes
for node in net.nodes:
    # Set color based on type
    node_type = node.get('type', 'Unknown')
    node['color'] = color_map.get(node_type, '#b3b3cc')
    
    # Set size based on node importance
    node['size'] = 15
    
    # Use shortened labels for display
    if 'label' in node:
        if len(node['label']) > 20:
            node['label'] = node['label'][:17] + '...'

# Save the visualization
net.save_graph(output_path)
print(f"Visualization created at: {output_path}")
print(f"Nodes: {len(nx_graph.nodes)}, Edges: {len(nx_graph.edges)}")

## 3. Knowledge Graph Embedding

Now, let's create embeddings for our augmented knowledge graph using different models.

In [None]:
print("Preprocessing: Removing Isolated Nodes")

from rdflib import Graph, RDF, RDFS, URIRef
import networkx as nx

# Load the augmented knowledge graph
kg_graph = Graph()
kg_graph.parse(augmented_kg_path, format='turtle')

# Count original triples and entities
original_triple_count = len(kg_graph)
original_entity_count = len({s for s, p, o in kg_graph.triples((None, RDF.type, None))})

print(f"Original knowledge graph: {original_triple_count} triples, {original_entity_count} entities")

# Create a NetworkX graph to find connected components
nx_graph = nx.Graph()

# Add edges for all non-RDF.type and non-RDFS.label predicates
# These predicates represent actual relationships between entities
for s, p, o in kg_graph:
    if p != RDF.type and p != RDFS.label and isinstance(s, URIRef) and isinstance(o, URIRef):
        nx_graph.add_edge(str(s), str(o))

# Get degree of each node
node_degrees = dict(nx_graph.degree())
isolated_nodes = [node for node, degree in node_degrees.items() if degree == 0]
connected_nodes = [node for node, degree in node_degrees.items() if degree > 0]

print(f"Found {len(isolated_nodes)} isolated nodes and {len(connected_nodes)} connected nodes")

# Create a new graph with only connected nodes
filtered_graph = Graph()

# Copy namespace bindings
for prefix, namespace in kg_graph.namespaces():
    filtered_graph.bind(prefix, namespace)

# Add triples where both subject and object are in connected nodes
# or one is in connected nodes and the other is a literal
for s, p, o in kg_graph:
    s_str = str(s)
    o_str = str(o)
    
    # Always include type and label information for connected nodes
    if (p == RDF.type or p == RDFS.label) and s_str in connected_nodes:
        filtered_graph.add((s, p, o))
    # Include relationship triples between connected nodes
    elif isinstance(s, URIRef) and isinstance(o, URIRef) and s_str in connected_nodes and o_str in connected_nodes:
        filtered_graph.add((s, p, o))
    # Include literal properties of connected nodes
    elif s_str in connected_nodes and not isinstance(o, URIRef):
        filtered_graph.add((s, p, o))

# Save the filtered graph
filtered_kg_path = './output/data/knowledge_graph_filtered.ttl'
filtered_graph.serialize(destination=filtered_kg_path, format='turtle')

# Count filtered triples and entities
filtered_triple_count = len(filtered_graph)
filtered_entity_count = len({s for s, p, o in filtered_graph.triples((None, RDF.type, None))})

print(f"Filtered knowledge graph: {filtered_triple_count} triples, {filtered_entity_count} entities")
print(f"Removed {original_triple_count - filtered_triple_count} triples and {original_entity_count - filtered_entity_count} entities")
print(f"Saved filtered knowledge graph to {filtered_kg_path}")

# Update the path to use the filtered graph for embeddings
augmented_kg_path = filtered_kg_path


# Create visualization with custom settings
output_path = './output/visualization/knowledge_graph_filtered.html'

# Extract entity types and labels from the filtered graph
entity_types = {}
for s, p, o in filtered_graph.triples((None, RDF.type, None)):
   entity_types[str(s)] = str(o).split('/')[-1]

entity_labels = {}
for s, p, o in filtered_graph.triples((None, RDFS.label, None)):
   entity_labels[str(s)] = str(o)

# Create a fresh NetworkX graph with proper node attributes
vis_graph = nx.DiGraph()

# Add nodes with type and label information
for entity, entity_type in entity_types.items():
   label = entity_labels.get(entity, entity.split('/')[-1])
   vis_graph.add_node(entity, label=label, title=f"{label} ({entity_type})", type=entity_type)

# Add edges with proper labels
for s, p, o in filtered_graph:
   if p != RDF.type and p != RDFS.label and isinstance(s, URIRef) and isinstance(o, URIRef):
       s_str = str(s)
       o_str = str(o)
       
       # Get a readable predicate name
       p_label = str(p).split('/')[-1]
       
       # Only add if both nodes exist
       if s_str in vis_graph and o_str in vis_graph:
           vis_graph.add_edge(s_str, o_str, label=p_label, title=p_label)

# Create a PyVis network
net = Network(notebook=True, directed=True, height="750px", width="100%")
net.from_nx(vis_graph)

# Map entity types to colors
color_map = {
   'PERSON': '#a8e6cf',
   'ORG': '#ff8b94',
   'GPE': '#ffd3b6',
   'LOC': '#dcedc1',
   'DATE': '#f9f9f9',
   'MISC': '#d4a5a5',
   'DBPEDIA': '#b19cd9',
   'WIKIDATA': '#ffd700'
}

# Update node colors and sizes
for node in net.nodes:
   # Set color based on type
   node_type = node.get('type', 'Unknown')
   node['color'] = color_map.get(node_type, '#b3b3cc')
   
   # Set size based on node importance
   node['size'] = 15
   
   # Use shortened labels for display
   if 'label' in node:
       if len(node['label']) > 20:
           node['label'] = node['label'][:17] + '...'

# Configure physics for better layout
net.set_options("""
{
 "physics": {
   "barnesHut": {
     "gravitationalConstant": -2000,
     "centralGravity": 0.1,
     "springLength": 95,
     "springConstant": 0.04
   },
   "maxVelocity": 50,
   "minVelocity": 0.75,
   "solver": "barnesHut",
   "timestep": 0.5
 },
 "edges": {
   "color": {
     "inherit": true
   },
   "smooth": {
     "enabled": true,
     "type": "dynamic"
   },
   "arrows": {
     "to": {
       "enabled": true,
       "scaleFactor": 0.5
     }
   },
   "font": {
     "size": 8
   }
 },
 "nodes": {
   "font": {
     "size": 12
   }
 }
}
""")

# Save the visualization
net.save_graph(output_path)
print(f"Visualization created at: {output_path}")
print(f"Nodes: {len(vis_graph.nodes)}, Edges: {len(vis_graph.edges)}")

In [None]:
# Create an embedder
embedder = KnowledgeGraphEmbedder(namespace="http://example.org/graphify/")

# Load the augmented knowledge graph
embedder.load_from_rdf(augmented_kg_path)

# Split the dataset into training, validation, and test sets
training, validation, testing = embedder.split_dataset(train_ratio=0.8, validation_ratio=0.1, test_ratio=0.1)

print(f"Split dataset:")
print(f"- Training: {training.num_triples} triples")
print(f"- Validation: {validation.num_triples} triples")
print(f"- Testing: {testing.num_triples} triples")

In [None]:
# Train embeddings with TransE model
print("Training TransE model...")
transe_result = embedder.train_embedding_model(
    model_name='TransE',
    training=training,
    validation=validation,
    testing=testing,
    epochs=100,
    embedding_dim=50,
    batch_size=32,
    learning_rate=0.01,
    num_negs_per_pos=10,
    early_stopping=True,
    early_stopping_patience=5
)

In [None]:
# Train embeddings with DistMult model
print("Training DistMult model...")
distmult_result = embedder.train_embedding_model(
    model_name='DistMult',
    training=training,
    validation=validation,
    testing=testing,
    epochs=100,
    embedding_dim=50,
    batch_size=32,
    learning_rate=0.01,
    num_negs_per_pos=10,
    early_stopping=True,
    early_stopping_patience=5
)

## 4. Evaluation

Now, let's evaluate the performance of our embedding models.

In [None]:
# Compare model performance
comparison_df = embedder.compare_models(['TransE', 'DistMult'])
comparison_df

In [None]:
import matplotlib.pyplot as plt

# Create figure and subplots
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(14, 6))

# Plot Rank-Based Metrics
comparison_df[['mean_rank']].plot(kind='bar', ax=axes[0], legend=False, color=['#1f77b4'])
axes[0].set_title('Mean Rank Comparison')
axes[0].set_ylabel('Mean Rank')
axes[0].set_xlabel('Model')

# Plot Hits@K Metrics
comparison_df[['hits_at_1', 'hits_at_3', 'hits_at_10']].plot(kind='bar', ax=axes[1])
axes[1].set_title('Hits@K Comparison')
axes[1].set_ylabel('Score')
axes[1].set_xlabel('Model')

# Adjust layout
plt.tight_layout()

# Save and show the figure
plt.savefig('output/visualization/embedding_model_comparison.png')
plt.show()


## 5. Entity Similarity Analysis

Let's find similar entities based on our learned embeddings.

In [None]:
# Get a few sample entities
sample_entity_uris = []
sample_entity_texts = []

# First, get all entities with their labels
entity_labels = {}
from rdflib.namespace import RDFS
for s, p, o in augmenter.graph.triples((None, RDFS.label, None)):
    entity_labels[str(s)] = str(o)

# Choose a few sample entities of different types
person_entities = []
org_entities = []
loc_entities = []

for entity_uri, entity_type in entities:
    if entity_uri in entity_labels:
        if entity_type == 'PERSON' and len(person_entities) < 3:
            person_entities.append((entity_uri, entity_labels[entity_uri]))
        elif entity_type == 'ORG' and len(org_entities) < 3:
            org_entities.append((entity_uri, entity_labels[entity_uri]))
        elif entity_type in ('GPE', 'LOC') and len(loc_entities) < 3:
            loc_entities.append((entity_uri, entity_labels[entity_uri]))
    
    if len(person_entities) >= 3 and len(org_entities) >= 3 and len(loc_entities) >= 3:
        break

# Combine all sample entities
sample_entities = person_entities + org_entities + loc_entities

# Print sample entities
print("Sample entities for similarity analysis:")
for entity_uri, entity_text in sample_entities:
    print(f"- {entity_text} ({entity_uri})")

In [None]:
# Find similar entities using TransE embeddings
print("Similar entities using TransE embeddings:")
for entity_uri, entity_text in sample_entities:
    print(f"\nEntities similar to '{entity_text}':")
    similar_entities = embedder.find_similar_entities(entity_uri, 'TransE', top_k=5)
    for similar_uri, similarity in similar_entities:
        print(f"- {similar_uri} (Similarity: {similarity:.4f})")

In [None]:
# Find similar entities using DistMult embeddings
print("Similar entities using DistMult embeddings:")
for entity_uri, entity_text in sample_entities[:3]:  # Just show a few for brevity
    print(f"\nEntities similar to '{entity_text}':")
    similar_entities = embedder.find_similar_entities(entity_uri, 'DistMult', top_k=5)
    for similar_uri, similarity in similar_entities:
        print(f"- {similar_uri} (Similarity: {similarity:.4f})")

## 6. Link Prediction

Now, let's test the link prediction capabilities of our embedding models.

In [None]:
# Get some relations from the knowledge graph
relations = set()
for s, p, o in augmenter.graph:
    if p != RDF.type and p != RDFS.label:
        relations.add(str(p))

# Choose a few sample relations
sample_relations = list(relations)[:5] if len(relations) >= 5 else list(relations)

print("Sample relations for link prediction:")
for relation in sample_relations:
    print(f"- {relation}")

In [None]:
# Find entity pairs for chosen relations
relation_triples = []

for relation in sample_relations:
    # Find triples with this relation
    relation_uri = relation
    for s, p, o in augmenter.graph.triples((None, URIRef(relation_uri), None)):
        # Check if subject and object entities exist in embedder
        if str(s) in embedder.entity_to_id and str(o) in embedder.entity_to_id:
            # Get entity labels
            s_label = entity_labels.get(str(s), str(s).split('/')[-1])
            o_label = entity_labels.get(str(o), str(o).split('/')[-1])
            
            relation_triples.append((str(s), s_label, relation_uri, str(o), o_label))
            break  # Just take one example per relation

print("Sample triples for link prediction:")
for s, s_label, p, o, o_label in relation_triples:
    p_short = p.split('/')[-1]
    print(f"- {s_label} --[{p_short}]--> {o_label}")

In [None]:
# Predict tail entities
print("Predicting tail entities with TransE:")
for s, s_label, p, o, o_label in relation_triples:
    p_short = p.split('/')[-1]
    print(f"\nGiven head '{s_label}' and relation '{p_short}', predicted tails:")
    predicted_tails = embedder.predict_tail_entities(s, p, 'TransE', top_k=5)
    for tail_uri, score in predicted_tails:
        print(f"- {tail_uri} (Score: {score:.4f})")
    
    # Check if the actual tail is in the predictions
    predicted_uris = [uri for uri, _ in predicted_tails]
    if o in predicted_uris:
        print(f"✓ Actual tail '{o_label}' is in the predictions at rank {predicted_uris.index(o) + 1}")
    else:
        print(f"✗ Actual tail '{o_label}' is not in the top 5 predictions")

In [None]:
# Predict head entities
print("Predicting head entities with DistMult:")
for s, s_label, p, o, o_label in relation_triples:
    p_short = p.split('/')[-1]
    print(f"\nGiven relation '{p_short}' and tail '{o_label}', predicted heads:")
    predicted_heads = embedder.predict_head_entities(p, o, 'DistMult', top_k=5)
    for head_uri, score in predicted_heads:
        print(f"- {head_uri} (Score: {score:.4f})")
    
    # Check if the actual head is in the predictions
    predicted_uris = [uri for uri, _ in predicted_heads]
    if s in predicted_uris:
        print(f"✓ Actual head '{s_label}' is in the predictions at rank {predicted_uris.index(s) + 1}")
    else:
        print(f"✗ Actual head '{s_label}' is not in the top 5 predictions")

## 7. Visualization of Embeddings

Let's visualize the entity embeddings to see how they cluster.

In [None]:
# Group entities by type for visualization
entity_type_triples = []
for s, p, o in augmenter.graph.triples((None, RDF.type, None)):
    entity_type_triples.append((str(s), 'rdf:type', str(o)))

# Create entity type mappings for visualization
entity_types = embedder.group_entities_by_type(entity_type_triples)

# Visualize TransE embeddings
embedder.visualize_embeddings(
    model_name='TransE',
    entity_types=entity_types,
    sample_size=1000,  # Limit for better visualization
    figsize=(12, 10),
    output_path='output/visualization/transe_embeddings.png',
    title='TransE Entity Embeddings'
)

In [None]:
# Visualize DistMult embeddings
embedder.visualize_embeddings(
    model_name='DistMult',
    entity_types=entity_types,
    sample_size=1000,  # Limit for better visualization
    figsize=(12, 10),
    output_path='output/visualization/distmult_embeddings.png',
    title='DistMult Entity Embeddings'
)

## 8. Save Models and Embeddings

Finally, let's save our models and embeddings for future use.

In [None]:
# Save models
models_dir = 'output/models'
os.makedirs(models_dir, exist_ok=True)

for model_name in ['TransE', 'DistMult']:
    embedder.save_model(model_name, models_dir)

# Save entity and relation embeddings as NumPy arrays
embeddings_dir = 'output/embeddings'
os.makedirs(embeddings_dir, exist_ok=True)

for model_name in ['TransE', 'DistMult']:
    # Get embeddings
    entity_embeddings = embedder.get_entity_embeddings(model_name)
    relation_embeddings = embedder.get_relation_embeddings(model_name)
    
    # Save embeddings
    np.save(os.path.join(embeddings_dir, f"{model_name}_entity_embeddings.npy"), entity_embeddings)
    np.save(os.path.join(embeddings_dir, f"{model_name}_relation_embeddings.npy"), relation_embeddings)
    
    print(f"Saved {model_name} embeddings:")
    print(f"- Entity embeddings: {entity_embeddings.shape}")
    print(f"- Relation embeddings: {relation_embeddings.shape}")

# Save entity and relation mappings as JSON
mappings = {
    'entity_to_id': {k: v for k, v in embedder.entity_to_id.items()},
    'relation_to_id': {k: v for k, v in embedder.relation_to_id.items()},
    'id_to_entity': {str(k): v for k, v in embedder.id_to_entity.items()},
    'id_to_relation': {str(k): v for k, v in embedder.id_to_relation.items()}
}

with open(os.path.join(embeddings_dir, 'mappings.json'), 'w') as f:
    json.dump(mappings, f, indent=2)

print("Saved entity and relation mappings")