In [1]:
import pandas as pd
import numpy as np
import pickle
from pathlib import Path
from pykeen.pipeline import pipeline
from pykeen.triples import TriplesFactory
import torch

print("Libraries imported successfully")


Libraries imported successfully


<!-- ## How to Extract Skill Subgraph from Raw Data

This cell demonstrates how to extract a subgraph containing skills from raw RDF data. -->


## Loading Data

In [2]:
# Load knowledge graph triples
triples_file = "out_esco/skill_subgraph_from_rdf.tsv"
print(f"Loading triples from {triples_file}...")

# Read triples (tab-separated: subject, predicate, object)
triples_df = pd.read_csv(
    triples_file, 
    sep="\t", 
    header=None, 
    names=["subject", "predicate", "object"],
    encoding="utf-8"
)

print(f"Loaded {len(triples_df):,} triples")
print("\nFirst few triples:")
print(triples_df.head())
print(f"\nUnique subjects: {triples_df['subject'].nunique():,}")
print(f"Unique predicates: {triples_df['predicate'].nunique():,}")
print(f"Unique objects: {triples_df['object'].nunique():,}")

# Extract skill URIs from triples (URIs that start with skill namespace)
skill_uri_prefix = "http://data.europa.eu/esco/skill/"
print(f"\nIdentifying skill URIs using prefix: {skill_uri_prefix}")
print("Method: Check if URI string starts with this prefix")

all_skill_uris = set()

# Get skill URIs from subjects
skill_uris_from_subjects = set(triples_df[triples_df['subject'].str.startswith(skill_uri_prefix)]['subject'].unique())
all_skill_uris.update(skill_uris_from_subjects)

# Get skill URIs from objects (in case some skills appear as objects)
skill_uris_from_objects = set(triples_df[triples_df['object'].str.startswith(skill_uri_prefix)]['object'].unique())
all_skill_uris.update(skill_uris_from_objects)

skill_uris_list = sorted(list(all_skill_uris))

print(f"\nURI Identification Summary:")
print(f"  - Prefix pattern: '{skill_uri_prefix}'")
print(f"  - Identification method: String prefix matching (str.startswith())")
print(f"  - Why this works: ESCO uses consistent namespace for all skill entities")
print(f"\n✓ Extracted {len(skill_uris_list):,} unique skill URIs from triples")
print(f"  From subjects: {len(skill_uris_from_subjects):,}")
print(f"  From objects: {len(skill_uris_from_objects):,}")
print(f"\nFirst few skill URIs:")
for uri in skill_uris_list[:5]:
    print(f"  {uri}")


Loading triples from out_esco/skill_subgraph_from_rdf.tsv...
Loaded 1,816,462 triples

First few triples:
                                             subject  \
0  http://data.europa.eu/esco/skill/0d168770-4d9c...   
1  http://data.europa.eu/esco/skill/918459f1-147d...   
2  http://data.europa.eu/esco/relation/56091B30-1...   
3  http://data.europa.eu/esco/skill/a549dcdf-3771...   
4  http://data.europa.eu/esco/skill/f66cbeb5-2a8e...   

                                           predicate  \
0  http://data.europa.eu/esco/model#isOptionalSki...   
1  http://www.w3.org/2004/02/skos/core#broaderTra...   
2  http://data.europa.eu/esco/model#isAssociationFor   
3   http://www.w3.org/2004/02/skos/core#topConceptOf   
4               http://purl.org/dc/terms/description   

                                              object  
0  http://data.europa.eu/esco/occupation/f2b15a0e...  
1  http://data.europa.eu/esco/skill/496932f1-0b6b...  
2  http://data.europa.eu/esco/skill/c241f5a5-e23f...  


In [3]:
## List All Predicates

print("=" * 70)
print("All Predicates in the Dataset")
print("=" * 70)

# Count all predicates
predicate_counts = triples_df['predicate'].value_counts()

print(f"\nTotal unique predicates: {len(predicate_counts)}")
print(f"Total triples: {len(triples_df):,}")

print("\n" + "-" * 70)
print("Predicate List (sorted by frequency):")
print("-" * 70)
for idx, (predicate, count) in enumerate(predicate_counts.items(), 1):
    print(f"{idx:2d}. {predicate}")
    print(f"    Count: {count:,} triples ({count/len(triples_df)*100:.2f}%)")

print("\n" + "-" * 70)
print("Predicate List (for copy-paste, sorted alphabetically):")
print("-" * 70)
predicates_sorted = sorted(predicate_counts.index)
for idx, predicate in enumerate(predicates_sorted, 1):
    print(f"{idx:2d}. {predicate}")

print("\n" + "=" * 70)
print("Next Step:")
print("=" * 70)
print("Please provide a blacklist of predicates to remove.")
print("Then run the next cell to filter the triples.")


All Predicates in the Dataset

Total unique predicates: 25
Total triples: 1,816,462

----------------------------------------------------------------------
Predicate List (sorted by frequency):
----------------------------------------------------------------------
 1. http://www.w3.org/2008/05/skos-xl#prefLabel
    Count: 423,633 triples (23.32%)
 2. http://purl.org/dc/terms/description
    Count: 396,455 triples (21.83%)
 3. http://www.w3.org/2008/05/skos-xl#altLabel
    Count: 203,535 triples (11.21%)
 4. http://data.europa.eu/esco/model#hasAssociation
    Count: 125,504 triples (6.91%)
 5. http://data.europa.eu/esco/model#target
    Count: 120,374 triples (6.63%)
 6. http://www.w3.org/2004/02/skos/core#broaderTransitive
    Count: 83,586 triples (4.60%)
 7. http://data.europa.eu/esco/model#relatedEssentialSkill
    Count: 67,811 triples (3.73%)
 8. http://data.europa.eu/esco/model#isEssentialSkillFor
    Count: 67,811 triples (3.73%)
 9. http://data.europa.eu/esco/model#isOptionalSk

In [4]:
## Filter Triples by Predicate Blacklist

print("=" * 70)
print("Filter Triples by Predicate Blacklist")
print("=" * 70)

# ============================================
# Enter the predicate blacklist to remove here
# ============================================
# Add the predicates you want to remove to the list below
predicate_blacklist = [
    "http://data.europa.eu/esco/model#isEssentialSkillFor",
    "http://data.europa.eu/esco/model#isOptionalSkillFor",
    "http://www.w3.org/2004/02/skos/core#broader",
    "http://www.w3.org/2004/02/skos/core#topConceptOf",
    "http://purl.org/dc/terms/isReplacedBy"
]

# ============================================

print(f"\nBlacklist contains {len(predicate_blacklist)} predicates:")
if predicate_blacklist:
    for idx, pred in enumerate(predicate_blacklist, 1):
        count = len(triples_df[triples_df['predicate'] == pred])
        print(f"  {idx}. {pred}")
        print(f"     Will remove {count:,} triples")
else:
    print("  (Blacklist is empty - no predicates will be removed)")

print("\n" + "=" * 70)
print("Filtering Triples")
print("=" * 70)

# Filter out predicates in the blacklist
triples_df_filtered = triples_df[~triples_df['predicate'].isin(predicate_blacklist)].copy()

print(f"\nOriginal triples: {len(triples_df):,}")
print(f"Filtered triples: {len(triples_df_filtered):,}")
print(f"Removed triples: {len(triples_df) - len(triples_df_filtered):,}")

if len(triples_df) - len(triples_df_filtered) > 0:
    print(f"Removed percentage: {(len(triples_df) - len(triples_df_filtered))/len(triples_df)*100:.2f}%")

print(f"\nOriginal unique predicates: {triples_df['predicate'].nunique()}")
print(f"Remaining unique predicates: {triples_df_filtered['predicate'].nunique()}")

print("\n" + "=" * 70)
print("Remaining Predicates:")
print("=" * 70)
remaining_predicates = triples_df_filtered['predicate'].value_counts()
for idx, (predicate, count) in enumerate(remaining_predicates.items(), 1):
    print(f"{idx:2d}. {predicate}: {count:,} triples")

print("\n" + "=" * 70)
print("Ready for Training")
print("=" * 70)
print("✓ Filtered triples are ready in 'triples_df_filtered'")
print("  Use 'triples_df_filtered' instead of 'triples_df' for training")


Filter Triples by Predicate Blacklist

Blacklist contains 5 predicates:
  1. http://data.europa.eu/esco/model#isEssentialSkillFor
     Will remove 67,811 triples
  2. http://data.europa.eu/esco/model#isOptionalSkillFor
     Will remove 67,011 triples
  3. http://www.w3.org/2004/02/skos/core#broader
     Will remove 20,614 triples
  4. http://www.w3.org/2004/02/skos/core#topConceptOf
     Will remove 11,131 triples
  5. http://purl.org/dc/terms/isReplacedBy
     Will remove 439 triples

Filtering Triples

Original triples: 1,816,462
Filtered triples: 1,649,456
Removed triples: 167,006
Removed percentage: 9.19%

Original unique predicates: 25
Remaining unique predicates: 20

Remaining Predicates:
 1. http://www.w3.org/2008/05/skos-xl#prefLabel: 423,633 triples
 2. http://purl.org/dc/terms/description: 396,455 triples
 3. http://www.w3.org/2008/05/skos-xl#altLabel: 203,535 triples
 4. http://data.europa.eu/esco/model#hasAssociation: 125,504 triples
 5. http://data.europa.eu/esco/model#tar

## Modifying data format

In [5]:
# Create TriplesFactory for PyKEEN
print("Creating TriplesFactory...")
print("\nHow TriplesFactory works:")
print("1. Input: triples with subject, predicate, object (all are URI strings)")
print("2. TriplesFactory automatically:")
print("   - Extracts all unique entities from subject and object columns")
print("   - Assigns each unique entity a unique integer ID (0, 1, 2, ...)")
print("   - Preserves the original URI string as the 'label'")
print("   - Creates mapping: entity_id_to_label = {entity_id: original_uri_string}")
print("3. The 'label' IS the original URI string from the triples!")

triples_factory = TriplesFactory.from_labeled_triples(
    triples=triples_df_filtered[["subject", "predicate", "object"]].values,
    create_inverse_triples=True  # Set to True if you want inverse relations
)

print(f"\n✓ TriplesFactory created")
print(f"  Number of entities: {triples_factory.num_entities:,}")
print(f"  Number of relations: {triples_factory.num_relations:,}")
print(f"  Number of triples: {triples_factory.num_triples:,}")

# Show how entity_id_to_label works
print("\n" + "=" * 70)
print("Understanding entity_id_to_label:")
print("=" * 70)
entity_id_to_uri = triples_factory.entity_id_to_label
print(f"Type: {type(entity_id_to_uri)}")
print(f"Size: {len(entity_id_to_uri):,} entities")
print("\nExample mappings (first 3):")
for i, (entity_id, uri) in enumerate(list(entity_id_to_uri.items())[:3], 1):
    print(f"  {i}. Entity ID {entity_id} → URI: {uri[:70]}...")
print("\nKey insight: The 'label' in entity_id_to_label IS the original URI string!")
print("The URI comes from the triples we passed in (subject/object columns)")
print("=" * 70)


Creating TriplesFactory...

How TriplesFactory works:
1. Input: triples with subject, predicate, object (all are URI strings)
2. TriplesFactory automatically:
   - Extracts all unique entities from subject and object columns
   - Assigns each unique entity a unique integer ID (0, 1, 2, ...)
   - Preserves the original URI string as the 'label'
   - Creates mapping: entity_id_to_label = {entity_id: original_uri_string}
3. The 'label' IS the original URI string from the triples!

✓ TriplesFactory created
  Number of entities: 1,179,402
  Number of relations: 40
  Number of triples: 1,649,456

Understanding entity_id_to_label:
Type: <class 'dict'>
Size: 1,179,402 entities

Example mappings (first 3):
  1. Entity ID 0 → URI: http://data.europa.eu/esco/concept-scheme/6c930acd-c104-4ece-acf7-f44f...
  2. Entity ID 1 → URI: http://data.europa.eu/esco/concept-scheme/digcomp...
  3. Entity ID 2 → URI: http://data.europa.eu/esco/concept-scheme/green...

Key insight: The 'label' in entity_id_to_l

In [6]:
# # Split triples into training and testing sets
# # In newer PyKEEN versions (1.11+), use ratios instead of ratio
# # try:
# #     # New API: use ratios parameter (tuple of ratios for train, test, validation)
# #     training, testing = triples_factory.split(ratios=(0.8, 0.2), random_state=42)
# # except TypeError:
# #     # Fallback: try train_ratio parameter
# #     try:
# #         training, testing = triples_factory.split(train_ratio=0.8, random_state=42)
# #     except TypeError:
# #         # Old API: use ratio parameter
# #         training, testing = triples_factory.split(ratio=0.8, random_state=42)
# training, validation, testing = triples_factory.split(ratios=(0.8, 0.1, 0.1))
training, validation = triples_factory.split(ratios=(0.8, 0.2))

# print(f"Training triples: {training.num_triples:,}")
# print(f"Validation triples:{validation.num_triples}")
# print(f"Testing triples: {testing.num_triples:,}")


using automatically assigned random_state=1044129335


## Training Model

In [7]:
import torch
from tqdm import tqdm
from pykeen.training.callbacks import TrainingCallback
from pykeen.training.callbacks import GradientNormClippingTrainingCallback

# Custom callback to print loss after each epoch (using tqdm.write() so it doesn't get overwritten)
# class PrintLossCallback(TrainingCallback):
#     """Callback to print loss after each epoch using tqdm.write()"""
    
#     def __init__(self):
#         super().__init__()
#         self.epoch_losses = []
#         self.metrics_history = []  # List of dicts: [{'epoch': 1, 'loss': 0.123, 'loss_change': 0.0}, ...]
        
#     def on_epoch_end(self, epoch: int, epoch_loss: float, **kwargs) -> None:
#         """Called at the end of each epoch"""
#         self.epoch_losses.append(epoch_loss)
        
#         # Calculate loss change
#         if epoch == 1:
#             change = 0.0
#         else:
#             change = epoch_loss - self.epoch_losses[-2]
#         change_str = f"{change:+.4f}" if epoch > 1 else "N/A"
        
#         # Initialize metrics dict for this epoch
#         epoch_metrics = {
#             'epoch': epoch,
#             'loss': float(epoch_loss),
#             'loss_change': float(change)
#         }
        
#         # Store metrics
#         self.metrics_history.append(epoch_metrics)
        
#         # Key: Use tqdm.write() instead of print() so the message doesn't get overwritten by progress bar
#         # This ensures the loss information persists and doesn't disappear when tqdm refreshes
#         tqdm.write(f"Epoch {epoch:3d} | Loss: {epoch_loss:.4f} | Change: {change_str}")
    
#     def get_metrics_dataframe(self):
#         """Convert metrics history to pandas DataFrame"""
#         import pandas as pd
#         return pd.DataFrame(self.metrics_history)
    
#     def save_metrics(self, filepath):
#         """Save metrics history to CSV file"""
#         import pandas as pd
#         df = pd.DataFrame(self.metrics_history)
#         df.to_csv(filepath, index=False, encoding='utf-8')
#         print(f"✓ Metrics saved to: {filepath}")
#         return df

# Check device (support Apple Silicon GPU, CUDA, or CPU)
if torch.backends.mps.is_available():
    device = torch.device("mps")  # Apple Silicon GPU
    print("Using Apple Silicon GPU (MPS)")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA GPU")
else:
    device = torch.device("cpu")
    print("Using CPU")


Using Apple Silicon GPU (MPS)


In [8]:
import torch
import random
from pykeen.models import TransE
from pykeen.training import SLCWATrainingLoop
from pykeen.evaluation import SampledRankBasedEvaluator
from pykeen.stoppers import EarlyStopper
# Use tqdm.auto to automatically detect environment and avoid ContextVar errors
# If LookupError: shell_parent is encountered, tqdm.auto will automatically fall back to standard version
try:
    from tqdm.auto import tqdm
    # tqdm.auto automatically detects environment, uses notebook version in Jupyter, standard version in terminal
except ImportError:
    from tqdm import tqdm

# Ensure tqdm displays correctly in Jupyter notebook
tqdm.pandas()  # If using pandas

# Assume these three are already available
# train_factory, valid_factory, test_factory = ...

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set random seeds for reproducibility
# random.seed(42)           # Python random module
# torch.manual_seed(42)      # PyTorch random seed
# np.random.seed(42)         # NumPy random seed
# if torch.cuda.is_available():
#     torch.cuda.manual_seed_all(42)  # CUDA random seed (if using GPU)

# 1️⃣ Define model
print("Creating TransE model...")
model = TransE(
    triples_factory=triples_factory,
    embedding_dim=200,
    scoring_fct_norm=1,
    random_seed=42,  # Set random seed to avoid warning
).to(device)
print("✓ Model creation completed")

# 2️⃣ Define trainer
# print("Creating trainer...")
# trainer = SLCWATrainingLoop(
#     model=model,
#     triples_factory=training,
#     optimizer=torch.optim.Adam(model.parameters(), lr=1e-3),
# )
# print("✓ Trainer creation completed")

# 3️⃣ Define evaluator and early stopping
# Note: Do not set additional_filter_triples during creation, as processing large amounts of data will be slow
# Will pass it during evaluate() calls instead, which is more flexible and won't block

# print("Creating evaluator...")
# small_train_idx = torch.randperm(training.num_triples)[:80000]
# train_small = training.clone_and_exchange_triples(
#     mapped_triples=triples_factory.mapped_triples[small_train_idx]
# )

valid_idx = torch.randperm(validation.num_triples)[:10000]
valid_small = validation.clone_and_exchange_triples(
    mapped_triples=validation.mapped_triples[valid_idx]
)


print(f"✓ Created test subset: {valid_small.num_triples:,} triples")

evaluator_small = SampledRankBasedEvaluator(
    evaluation_factory=valid_small,
    num_negatives=100,
    filtered=True,
    additional_filter_triples=[training.mapped_triples],
)
print("✓ Evaluator creation completed")


# # Create subset for test set as well to speed up evaluation
# test_idx = torch.randperm(testing.num_triples)[:10000]
# test_small = testing.clone_and_exchange_triples(
#     mapped_triples=testing.mapped_triples[test_idx]
# )

# # Create new evaluator for test set, because evaluator_small was created for valid_small
# # SampledRankBasedEvaluator requires evaluation_factory to contain the triples to be evaluated
# print("Creating test set evaluator...")
# evaluator_test = SampledRankBasedEvaluator(
#     evaluation_factory=test_small,  # Use test set subset as evaluation_factory
#     num_negatives=100,  # Keep consistent with evaluator_small
#     filtered=True,
#     additional_filter_triples=[training.mapped_triples,validation.mapped_triples],
#     # Do not set additional_filter_triples during creation to avoid blocking when processing large amounts of data
# )
# print("✓ Test set evaluator creation completed")







# evaluator = SampledRankBasedEvaluator(
#     evaluation_factory=validation,
#     num_negatives=50,
#     filtered=False,
#     # additional_filter_triples=[
#     #     training.mapped_triples,
#     #     validation.mapped_triples,
#     # ]
# )

# print("Creating early stopper...")
# stopper = EarlyStopper(
#     model=model,
#     evaluator=evaluator_small,
#     training_triples_factory=training,
#     evaluation_triples_factory=valid_small,
#     frequency=1,           # Check every epoch
#     patience=5,
#     relative_delta=0.001,
#     metric='mean_reciprocal_rank',
# )
# print("✓ Early stopper creation completed")
print("\n" + "=" * 70)
print("All components created successfully, ready to start training!")
print("=" * 70)




Creating TransE model...
✓ Model creation completed
✓ Created test subset: 10,000 triples
✓ Evaluator creation completed

All components created successfully, ready to start training!


In [9]:
best_mrr = -1.0
best_state = None          # Option A: Save in memory
best_ckpt = "best_transE.pt"  # Option B: Save to disk (can use either or both)
patience, waited = 3, 0
rel_delta = 1e-3          # Relative improvement threshold 0.1%
EVAL_EVERY = 1 
import copy

# Create a custom callback for evaluation and early stopping
# from pykeen.training.callbacks import TrainingCallback

# class EvaluationCallback(TrainingCallback):
#     """Callback to evaluate model and handle early stopping after each epoch"""
    
#     def __init__(self, evaluator, valid_small, training, best_mrr_ref, best_state_ref, 
#                  best_ckpt_ref, waited_ref, patience, rel_delta, eval_every):
#         super().__init__()
#         self.evaluator = evaluator
#         self.valid_small = valid_small
#         self.training = training
#         self.best_mrr_ref = best_mrr_ref  # Reference to global variable
#         self.best_state_ref = best_state_ref
#         self.best_ckpt_ref = best_ckpt_ref
#         self.waited_ref = waited_ref
#         self.patience = patience
#         self.rel_delta = rel_delta
#         self.eval_every = eval_every
#         self.should_stop = False
    
#     def post_epoch(self, epoch: int, epoch_loss: float, **kwargs) -> None:
#         """Called after each epoch"""
#         model = self.model
        
#         # Print epoch info
#         loss_value = epoch_loss if isinstance(epoch_loss, (int, float)) else sum(epoch_loss) / len(epoch_loss)
#         print(f"\n[Epoch {epoch}] Training completed | Loss: {loss_value:.4f}")
        
#         # Evaluate if needed
#         if epoch % self.eval_every == 0:
#             print(f"\n[Evaluation] Starting validation evaluation (this may take a few minutes)...")
            
#             # Ensure model is in eval mode for evaluation
#             model.eval()
#             torch.set_grad_enabled(False)
            
#             val_result = self.evaluator.evaluate(
#                 model=model,
#                 mapped_triples=self.valid_small.mapped_triples,
#                 batch_size=512, 
#                 use_tqdm=True,
#                 additional_filter_triples=[self.training.mapped_triples],
#             )
            
#             val_mrr = val_result.get_metric('mean_reciprocal_rank')
#             print(f"[Evaluation] ✓ Evaluation completed | Validation MRR: {val_mrr:.4f}")
            
#             # Restore training mode
#             model.train()
#             torch.set_grad_enabled(True)
            
#             # Early stopping logic
#             if val_mrr > self.best_mrr_ref[0] * (1 + self.rel_delta):
#                 self.best_mrr_ref[0] = val_mrr
#                 self.best_state_ref[0] = copy.deepcopy(model.state_dict())
#                 torch.save(self.best_state_ref[0], self.best_ckpt_ref)
#                 self.waited_ref[0] = 0
#                 print(f"[Early Stopping] ✓ New best MRR! (waited: {self.waited_ref[0]}/{self.patience})")
#             else:
#                 self.waited_ref[0] += 1
#                 print(f"[Early Stopping] No improvement (waited: {self.waited_ref[0]}/{self.patience})")
#                 if self.waited_ref[0] >= self.patience:
#                     print(f"\n[Early Stopping] ⚠️  Early stopping triggered at epoch {epoch}")
#                     print(f"  Best MRR: {self.best_mrr_ref[0]:.4f}")
#                     self.should_stop = True
#         else:
#             print(f"[Evaluation] Skipped (evaluates every {self.eval_every} epochs)")
    
#     def should_stop_training(self) -> bool:
#         """Check if training should stop"""
#         return self.should_stop

In [10]:
# === Manual Evaluation Function ===
# Custom evaluation function for overfitting scenario (no filtering needed)
# def manual_evaluate(model, test_triples, batch_size=512, k_values=[1, 3, 10], device=None, use_tqdm=True):
#     """
#     Manually evaluate model on test triples, computing MRR and Hits@K.
#     For overfitting scenario: no filtering of training triples.
    
#     Args:
#         model: PyKEEN model (e.g., TransE)
#         test_triples: torch.Tensor of shape (n, 3) with (head, relation, tail) triples
#         batch_size: Batch size for evaluation
#         k_values: List of K values for Hits@K metric
#         device: Device to run evaluation on (default: model's device)
#         use_tqdm: Whether to show progress bar
    
#     Returns:
#         EvaluationResult object with get_metric() method for compatibility
#     """
#     if device is None:
#         device = next(model.parameters()).device
    
#     model.eval()
#     torch.set_grad_enabled(False)
    
#     num_entities = model.num_entities
#     num_triples = test_triples.shape[0]
    
#     # Initialize metrics
#     head_ranks = []
#     tail_ranks = []
#     head_hits = {k: 0 for k in k_values}
#     tail_hits = {k: 0 for k in k_values}
    
#     # Process in batches
#     if use_tqdm:
#         from tqdm.auto import tqdm
#         pbar = tqdm(range(0, num_triples, batch_size), desc="Evaluating")
#     else:
#         pbar = range(0, num_triples, batch_size)
    
#     for i in pbar:
#         batch = test_triples[i:i+batch_size].to(device)
#         batch_size_actual = batch.shape[0]
        
#         h_batch = batch[:, 0]  # heads
#         r_batch = batch[:, 1]  # relations
#         t_batch = batch[:, 2]  # tails
        
#         # === Head prediction: (?, r, t) ===
#         for h_idx in range(batch_size_actual):
#             h = h_batch[h_idx]
#             r = r_batch[h_idx]
#             t = t_batch[h_idx]
            
#             # Get scores for all possible heads
#             h_candidates = torch.arange(num_entities, device=device)
#             r_expanded = r.unsqueeze(0).expand(num_entities)
#             t_expanded = t.unsqueeze(0).expand(num_entities)
            
#             # Use position arguments, not keyword arguments
#             scores = model.score_hrt(h_candidates, r_expanded, t_expanded)
            
#             # Get rank of true head (higher score = better)
#             true_h = h.item()
#             sorted_indices = torch.argsort(scores, descending=True)
#             rank = (sorted_indices == true_h).nonzero(as_tuple=True)[0].item() + 1
#             head_ranks.append(rank)
            
#             # Update Hits@K
#             for k in k_values:
#                 if rank <= k:
#                     head_hits[k] += 1
        
#         # === Tail prediction: (h, r, ?) ===
#         for t_idx in range(batch_size_actual):
#             h = h_batch[t_idx]
#             r = r_batch[t_idx]
#             t = t_batch[t_idx]
            
#             # Get scores for all possible tails
#             t_candidates = torch.arange(num_entities, device=device)
#             h_expanded = h.unsqueeze(0).expand(num_entities)
#             r_expanded = r.unsqueeze(0).expand(num_entities)
            
#             # Use position arguments, not keyword arguments
#             scores = model.score_hrt(h_expanded, r_expanded, t_candidates)
            
#             # Get rank of true tail
#             true_t = t.item()
#             sorted_indices = torch.argsort(scores, descending=True)
#             rank = (sorted_indices == true_t).nonzero(as_tuple=True)[0].item() + 1
#             tail_ranks.append(rank)
            
#             # Update Hits@K
#             for k in k_values:
#                 if rank <= k:
#                     tail_hits[k] += 1
    
#     # Calculate final metrics
#     all_ranks = head_ranks + tail_ranks
#     mrr = sum(1.0 / rank for rank in all_ranks) / len(all_ranks) if all_ranks else 0.0
    
#     results = {'mean_reciprocal_rank': mrr}
#     for k in k_values:
#         total_hits = head_hits[k] + tail_hits[k]
#         results[f'hits_at_{k}'] = total_hits / (2 * num_triples) if num_triples > 0 else 0.0
    
#     # Return object with get_metric method for compatibility
#     class EvaluationResult:
#         def __init__(self, results):
#             self.results = results
        
#         def get_metric(self, metric):
#             return self.results.get(metric, 0.0)
        
#         def to_dict(self):
#             return self.results
    
#     return EvaluationResult(results)

# print("✓ Manual evaluation function defined")


In [None]:


# 4️⃣ Start training loop
# Recreate trainer each epoch to avoid internal state causing training time to be 0
# But preserve optimizer state to avoid losing training momentum

max_epochs = 1  # Define maximum number of training epochs
train_set = triples_factory

# Ensure model is in training mode
model.train()
torch.set_grad_enabled(True)

# Create optimizer only once to preserve optimizer state (Adam's momentum, etc.)
# This way, even if we recreate the trainer, the optimizer state will be preserved
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

print("=" * 70)
print(f"Starting training for {max_epochs} epochs")
print("=" * 70)
print(f"Evaluation every {EVAL_EVERY} epochs")
print(f"Early stopping patience: {patience}")
print("=" * 70)

for epoch in range(1, max_epochs + 1):
    print("\n" + "=" * 70)
    print(f"Epoch {epoch} of {max_epochs}")
    print("=" * 70)
    
    # === Recreate trainer each epoch, but reuse the same optimizer ===
    # This ensures each training is fresh and won't have training time of 0 due to internal state
    # At the same time, preserve optimizer state to avoid losing training momentum
    trainer = SLCWATrainingLoop(
        model=model,
        triples_factory=train_set,
        optimizer=optimizer,  # Reuse the same optimizer to preserve state
    )
    
    # Ensure model is in training mode
    model.train()
    torch.set_grad_enabled(True)
    
    # === Training phase ===
    print(f"[Training] Starting training for epoch {epoch}...")
    loss = trainer.train(
        triples_factory=train_set,
        num_epochs=1,  # Train only 1 epoch each time
        batch_size=1024,
        # Note: Removed gradient_clipping_max_norm to avoid callback initialization error
        # Gradient clipping can be done manually if needed using torch.nn.utils.clip_grad_norm_
        use_tqdm=True,
        use_tqdm_batch=True,
        tqdm_kwargs={'leave': True, 'ncols': 100},
        continue_training=False,  # Explicitly specify not to continue previous training
        pin_memory=False  # Disable pin_memory to avoid MPS warning
``    )
    
    loss_value = loss if isinstance(loss, (int, float)) else sum(loss) / len(loss) if loss else 0.0
    print(f"[Training] ✓ Training completed | Loss: {loss_value:.4f}")

    # === Evaluate on validation set ===
    if epoch % EVAL_EVERY == 0:
        print(f"\n[Evaluation] Starting validation evaluation (this may take a few minutes)...")
        
        # Ensure model is in eval mode
        model.eval()
        torch.set_grad_enabled(False)
        
        # Use manual evaluation function (no filtering for overfitting scenario)
        val_result = evaluator_small.evaluate(
            model=model,
            mapped_triples=valid_small.mapped_triples,
            batch_size=512,
            # k_values=[1, 3, 10],
            # device=device,
            additional_filter_triples=[training.mapped_triples],
            use_tqdm=True
        )
        
        val_mrr = val_result.get_metric('mean_reciprocal_rank')
        # val_hits1 = val_result.get_metric('hits_at_1')
        # val_hits3 = val_result.get_metric('hits_at_3')
        # val_hits10 = val_result.get_metric('hits_at_10')
        print(f"[Evaluation] ✓ Evaluation completed")
        print(f"  MRR: {val_mrr:.4f} ")
        
        # Restore training mode
        model.train()
        torch.set_grad_enabled(True)

        # Early stopping logic
        # First evaluation: directly update best_mrr, no improvement check, don't reset waited
        # Subsequent evaluations: check for relative improvement, reset waited if improved, otherwise increment waited
        is_first_eval = (best_mrr < 0)
        
        if is_first_eval:
            # First evaluation: directly update, don't reset waited
            best_mrr = val_mrr
            print(best_mrr)
            best_state = copy.deepcopy(model.state_dict())
            torch.save(best_state, best_ckpt)
            # waited keeps original value (usually 0), don't reset
            print(f"[Early Stopping] ✓ First evaluation | MRR: {val_mrr:.4f} (waited: {waited}/{patience})")
        else:
            print(best_mrr)
            # Subsequent evaluations: check for relative improvement
            if val_mrr > best_mrr * (1 + rel_delta):
                # Improvement: update best_mrr, reset waited
                best_mrr = val_mrr
                best_state = copy.deepcopy(model.state_dict())
                torch.save(best_state, best_ckpt)
                waited = 0
                print(f"[Early Stopping] ✓ New best MRR! (waited: {waited}/{patience})")
            else:
                # No improvement: increment waited
                waited += 1
                print(f"[Early Stopping] No improvement (waited: {waited}/{patience})")
                if waited >= patience:
                    print(f"\n[Early Stopping] ⚠️  Early stopping triggered at epoch {epoch}")
                    print(f"  Best MRR: {best_mrr:.4f}")
                    break
    else:
        print(f"[Evaluation] Skipped (evaluates every {EVAL_EVERY} epochs)")
    
    print(f"[Epoch {epoch}] ✓ Completed")

print("\n" + "=" * 70)
print("Training completed!")
print("=" * 70)
    

    

    # === Print model parameters (example: print an embedding) ===
    # ent_emb = model.entity_representations[0]  # First embedding module
    # rel_emb = model.relation_representations[0]
    # print("Sample entity vector:", ent_emb(torch.tensor([0])).detach().cpu().numpy())
    # print("Sample relation vector:", rel_emb(torch.tensor([0])).detach().cpu().numpy())
 
    

# === Load and save best model ===
if best_state is not None:
    model.load_state_dict(best_state)
    model.eval(); torch.set_grad_enabled(False)
    
    # Save the best model
    print("\n" + "=" * 70)
    print("Saving Best Model")
    print("=" * 70)
    
    # 1. Save complete model (includes architecture)
    model_path = "best_transE_model.pt"
    torch.save(model, model_path)
    print(f"✓ Saved complete model: {model_path}")
    print(f"  Usage: model = torch.load('{model_path}')")
    
    # 2. Save model state dict (already saved during training, but save again with metadata)
    # state_dict_path = "best_transE_state_dict.pt"
    # torch.save({
    #     'model_state_dict': best_state,
    #     'best_mrr': best_mrr,
    #     'model_config': {
    #         'model_type': 'TransE',
    #         'embedding_dim': model.embedding_dim,
    #         'scoring_fct_norm': model.scoring_fct_norm,
    #         'num_entities': train_set.num_entities,
    #         'num_relations': train_set.num_relations,
    #     },
    #     # 'training_info': {
    #     #     'num_training_triples': training.num_triples,
    #     #     'num_validation_triples': validation.num_triples,
    #     #     'num_testing_triples': testing.num_triples,
    #     # }
    # }, state_dict_path)
    # print(f"✓ Saved model state dict with metadata: {state_dict_path}")
    # print(f"  Usage: checkpoint = torch.load('{state_dict_path}')")
    # print(f"        model.load_state_dict(checkpoint['model_state_dict'])")
    
    # print(f"\nBest validation MRR: {best_mrr:.4f}")
    # print("=" * 70)
else:
    print("\n⚠️  No best model found. Model may not have improved during training.")
    # Save current model anyway
    model_path = "final_transE_model.pt"
    torch.save(model, model_path)
    print(f"✓ Saved final model: {model_path}")


# test_result = evaluator_test.evaluate(
#     model=model,
#     mapped_triples=test_small.mapped_triples,
#     batch_size=512, use_tqdm=True, 
#     # num_workers=0,
#     additional_filter_triples=[
#         training.mapped_triples,
#         validation.mapped_triples,
#     ],
        
#     )
# print(f"Test MRR: {test_result.get_metric('mean_reciprocal_rank'):.4f}")

print("\n✅ Training completed.")





Starting training for 1 epochs
Evaluation every 1 epochs
Early stopping patience: 3

Epoch 1 of 1
[Training] Starting training for epoch 1...




Training epochs on mps:0:   0%|                                            | 0/1 [00:00<?, ?epoch/s]

Training batches on mps:0:   0%|          | 0.00/3.22k [00:00<?, ?batch/s]

[Training] ✓ Training completed | Loss: 0.4384

[Evaluation] Starting validation evaluation (this may take a few minutes)...


Evaluating on mps:0:   0%|          | 0.00/10.0k [00:00<?, ?triple/s]

Encountered tensors on device_types={'mps'} while only ['cuda'] are considered safe for automatic memory utilization maximization. This may lead to undocumented crashes (but can be safe, too).


## Embedding and mapping URI to embedding
I don't know why the numbers don't totally match for 'skill' set from the graph and from the pykeen-processed dataset, but I suppose using the overlap is ok...

In [None]:
# Get entity embeddings
print("Extracting entity embeddings...")

# Use the manually trained model from Cell 17
# The model variable is available globally after training
if 'model' not in globals():
    raise ValueError("Model not found. Please run the training cell (Cell 17) first.")

print("Using manually trained model...")

entity_embeddings = model.entity_representations[0](indices=None).detach().cpu().numpy()

print(f"✓ Entity embeddings extracted")
print(f"  Shape: {entity_embeddings.shape}")
print(f"  Embedding dimension: {entity_embeddings.shape[1]}")


Extracting entity embeddings...
Using manually trained model...
✓ Entity embeddings extracted
  Shape: (1179402, 200)
  Embedding dimension: 200


In [None]:
# Load Best Model from File and Extract Embeddings
# This cell loads the saved best model and extracts entity embeddings

import os
import torch

print("=" * 70)
print("Loading Best Model from File")
print("=" * 70)

# Load the best model from file
model_path = "best_transE_model.pt"
print(f"Loading model from {model_path}...")

if not os.path.exists(model_path):
    raise FileNotFoundError(f"Model file not found: {model_path}. Please run training first.")

# Load the complete model (includes architecture)
# Note: weights_only=False is required for loading complete models (not just weights)
# This is safe since we trust our own saved model files
model = torch.load(model_path, map_location=device, weights_only=False)
model.eval()
torch.set_grad_enabled(False)
print(f"✓ Model loaded from {model_path}")

# Verify it's the best model by checking metadata
state_dict_path = "best_transE_state_dict.pt"
if os.path.exists(state_dict_path):
    # Loading checkpoint with metadata also requires weights_only=False
    checkpoint = torch.load(state_dict_path, map_location=device, weights_only=False)
    if 'best_mrr' in checkpoint:
        print(f"\nModel Information:")
        print(f"  Best validation MRR: {checkpoint['best_mrr']:.4f}")
        if 'model_config' in checkpoint:
            config = checkpoint['model_config']
            print(f"  Model type: {config.get('model_type', 'Unknown')}")
            print(f"  Embedding dimension: {config.get('embedding_dim', 'Unknown')}")
            print(f"  Number of entities: {config.get('num_entities', 'Unknown'):,}")
            print(f"  Number of relations: {config.get('num_relations', 'Unknown'):,}")
        if 'training_info' in checkpoint:
            info = checkpoint['training_info']
            print(f"\nTraining Information:")
            print(f"  Training triples: {info.get('num_training_triples', 'Unknown'):,}")
            print(f"  Validation triples: {info.get('num_validation_triples', 'Unknown'):,}")
            print(f"  Testing triples: {info.get('num_testing_triples', 'Unknown'):,}")

print("\n" + "=" * 70)
print("Extracting Entity Embeddings")
print("=" * 70)

# Extract all entity embeddings
entity_embeddings = model.entity_representations[0](indices=None).detach().cpu().numpy()

print(f"\n✓ Entity embeddings extracted")
print(f"  Shape: {entity_embeddings.shape}")
print(f"  Embedding dimension: {entity_embeddings.shape[1]}")
print(f"  Total entities: {entity_embeddings.shape[0]:,}")

# Store in global variable for use in subsequent cells
print(f"\n✓ Model and embeddings ready for use")
print("=" * 70)


Loading Best Model from File
Loading model from best_transE_model.pt...


UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL pykeen.models.unimodal.trans_e.TransE was not an allowed global by default. Please use `torch.serialization.add_safe_globals([pykeen.models.unimodal.trans_e.TransE])` or the `torch.serialization.safe_globals([pykeen.models.unimodal.trans_e.TransE])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

In [None]:
# Create mapping from skill URI to embedding
# Extract skill embeddings from entity_embeddings (created in Cell 7)

print("=" * 70)
print("Creating Skill URI to Embedding Mapping")
print("=" * 70)

# Get entity ID to URI mapping
# 
# Why entity_id_to_label contains URIs (explanation):
# ===================================================
# When we created TriplesFactory in Cell 3, we passed:
#   triples_df[["subject", "predicate", "object"]].values
# 
# The subject and object columns contain URI strings like:
#   "http://data.europa.eu/esco/skill/66c88f8d-b73e-4f9d-8cf3-15d9c68b5384"
# 
# TriplesFactory.from_labeled_triples() automatically:
#   1. Extracts all unique entities from subject and object columns
#   2. Assigns each unique entity a unique integer ID (0, 1, 2, ...)
#   3. Preserves the original string (the URI) as the "label"
#   4. Creates: entity_id_to_label = {entity_id: original_string}
# 
# So entity_id_to_label[0] = "http://data.europa.eu/esco/skill/..."
# The "label" IS the URI because that's what was in the triples!
#
# No explicit URI specification needed - the URI comes from the input data itself.

entity_id_to_uri = triples_factory.entity_id_to_label

print(f"\nentity_id_to_label explanation:")
print(f"  Type: {type(entity_id_to_uri)}")
print(f"  Size: {len(entity_id_to_uri):,} entities")
print(f"  Format: {{entity_id: original_uri_string}}")
print(f"  The 'label' is the URI string from the triples we passed to TriplesFactory")

# Print first few items
print(f"\n" + "=" * 70)
print("First 10 items in entity_id_to_uri:")
print("=" * 70)
for i, (entity_id, uri) in enumerate(list(entity_id_to_uri.items())[:10], 1):
    print(f"\n[{i}] Entity ID: {entity_id}")
    print(f"    URI: {uri}")
    # Check if it's a skill URI
    is_skill = uri.startswith("http://data.europa.eu/esco/skill/")
    print(f"    Is skill URI? {is_skill}")
print("=" * 70)

# Create reverse mapping: URI -> entity_id
uri_to_entity_id = {uri: entity_id for entity_id, uri in entity_id_to_uri.items()}

print(f"\nTotal entities in model: {len(entity_id_to_uri):,}")
print(f"Total skills extracted from triples: {len(skill_uris_list):,}")

# Create dictionary: skill_uri -> embedding
skill_embeddings_dict = {}
skill_uris_set = set(skill_uris_list)

print("\nExtracting skill embeddings from entity_embeddings...")
for entity_id, uri in entity_id_to_uri.items():
    if uri in skill_uris_set:
        # entity_id is the index into entity_embeddings array
        skill_embeddings_dict[uri] = entity_embeddings[entity_id]

print(f"\n" + "=" * 70)
print("FINAL STATISTICS:")
print("=" * 70)
print(f"  Total skills extracted from triples: {len(skill_uris_list):,}")
print(f"  Skills with embeddings: {len(skill_embeddings_dict):,}")
print(f"  Skills without embeddings: {len(skill_uris_list) - len(skill_embeddings_dict):,}")
if len(skill_uris_list) > 0:
    coverage = len(skill_embeddings_dict) / len(skill_uris_list) * 100
    print(f"  Coverage: {coverage:.2f}%")
print("=" * 70)

# Check if all skills have embeddings
missing_skills = skill_uris_set - set(skill_embeddings_dict.keys())
if missing_skills:
    print(f"\n⚠️  Warning: {len(missing_skills)} skills don't have embeddings")
    print(f"  First few missing: {list(missing_skills)[:5]}")
    print(f"  (This might happen if a skill URI appears in triples but is not an entity in the model)")
else:
    print(f"\n✓ All skills have embeddings!")

print(f"\n✓ Mapping created: skill_embeddings_dict contains {len(skill_embeddings_dict):,} skill embeddings")


Creating Skill URI to Embedding Mapping

entity_id_to_label explanation:
  Type: <class 'dict'>
  Size: 1,150,092 entities
  Format: {entity_id: original_uri_string}
  The 'label' is the URI string from the triples we passed to TriplesFactory

First 10 items in entity_id_to_uri:

[1] Entity ID: 0
    URI: http://data.europa.eu/esco/concept-scheme/6c930acd-c104-4ece-acf7-f44fd7333036
    Is skill URI? False

[2] Entity ID: 1
    URI: http://data.europa.eu/esco/concept-scheme/digcomp
    Is skill URI? False

[3] Entity ID: 2
    URI: http://data.europa.eu/esco/concept-scheme/green
    Is skill URI? False

[4] Entity ID: 3
    URI: http://data.europa.eu/esco/concept-scheme/member-skills
    Is skill URI? False

[5] Entity ID: 4
    URI: http://data.europa.eu/esco/concept-scheme/research
    Is skill URI? False

[6] Entity ID: 5
    URI: http://data.europa.eu/esco/concept-scheme/skill-language-groups
    Is skill URI? False

[7] Entity ID: 6
    URI: http://data.europa.eu/esco/concept-sche

In [None]:
# Verify URI to Embedding Mapping
# Verify that skill URIs are correctly mapped to embeddings

print("=" * 70)
print("Verification: URI to Embedding Mapping")
print("=" * 70)

# Check if skill_embeddings_dict exists
if 'skill_embeddings_dict' not in globals() or not skill_embeddings_dict:
    print("⚠️  ERROR: skill_embeddings_dict not found!")
    print("  Please run Cell 8 (Create mapping) first.")
    raise ValueError("skill_embeddings_dict not found. Run Cell 8 first.")

print(f"\nVerifying {len(skill_embeddings_dict):,} skill embeddings...")

# Get entity ID to URI mapping (if not already available)
if 'entity_id_to_uri' not in globals():
    entity_id_to_uri = triples_factory.entity_id_to_label
if 'uri_to_entity_id' not in globals():
    uri_to_entity_id = {uri: entity_id for entity_id, uri in entity_id_to_uri.items()}

# Verify consistency
print("\nStep 1: Consistency check")
print("-" * 70)
assert len(entity_id_to_uri) == entity_embeddings.shape[0], \
    "ERROR: Number of entities doesn't match embedding matrix size!"
print("✓ Entity count matches embedding matrix size")

# Verify a few examples
print("\nStep 2: Verifying sample skill URIs")
print("-" * 70)
test_uris = list(skill_embeddings_dict.keys())[:5]
for i, test_uri in enumerate(test_uris, 1):
    if test_uri in uri_to_entity_id:
        entity_id = uri_to_entity_id[test_uri]
        # Get embedding from dictionary
        embedding_from_dict = skill_embeddings_dict[test_uri]
        # Get embedding directly using entity_id
        embedding_direct = entity_embeddings[entity_id]
        # Check if they match
        match = np.allclose(embedding_from_dict, embedding_direct)
        
        print(f"\nExample {i}:")
        print(f"  Skill URI: {test_uri[:60]}...")
        print(f"  → Entity ID: {entity_id}")
        print(f"  → Embedding shape: {embedding_from_dict.shape}")
        print(f"  → Embedding (first 5): {embedding_from_dict[:5]}")
        print(f"  → Verification: {'✓ Match' if match else '✗ Mismatch'}")

print("\n" + "=" * 70)
print("Verification complete!")
print("=" * 70)
print(f"✓ All {len(skill_embeddings_dict):,} skill embeddings are correctly mapped")


Verification: URI to Embedding Mapping

Verifying 14,671 skill embeddings...

Step 1: Consistency check
----------------------------------------------------------------------
✓ Entity count matches embedding matrix size

Step 2: Verifying sample skill URIs
----------------------------------------------------------------------

Example 1:
  Skill URI: http://data.europa.eu/esco/skill/0005c151-5b5a-4a66-8aac-60e...
  → Entity ID: 1135420
  → Embedding shape: (128,)
  → Embedding (first 5): [-0.09477121 -0.02935514 -0.10610551  0.11285554 -0.06000264]
  → Verification: ✓ Match

Example 2:
  Skill URI: http://data.europa.eu/esco/skill/00064735-8fad-454b-90c7-ed8...
  → Entity ID: 1135421
  → Embedding shape: (128,)
  → Embedding (first 5): [ 0.05098764  0.13925022  0.05058711  0.10238589 -0.08057278]
  → Verification: ✓ Match

Example 3:
  Skill URI: http://data.europa.eu/esco/skill/000709ed-2be5-4193-b056-45a...
  → Entity ID: 1135422
  → Embedding shape: (128,)
  → Embedding (first 5): [

In [None]:
# Save embeddings and URI mappings in multiple formats for easy use

print("=" * 70)
print("Saving Skill Embeddings and URI Mappings")
print("=" * 70)

# Check if skill_embeddings_dict exists
if 'skill_embeddings_dict' not in globals() or not skill_embeddings_dict:
    print("⚠️  ERROR: skill_embeddings_dict not found!")
    print("  Please run Cell 9 (Create mapping) first.")
    raise ValueError("skill_embeddings_dict not found. Run Cell 9 first.")

print(f"\nSaving {len(skill_embeddings_dict):,} skill embeddings...")

# 1. Save as pickle (Python native, preserves data types, easiest to use)
print("\n[1] Saving as pickle...")
with open("skill_embeddings_pykeen.pkl", "wb") as f:
    pickle.dump(skill_embeddings_dict, f)
print("  ✓ Saved: skill_embeddings_pykeen.pkl")
print("  Format: {uri: embedding_array}")
print("  Usage: pickle.load(open('skill_embeddings_pykeen.pkl', 'rb'))")

# 2. Save as numpy file (efficient for numerical operations)
print("\n[2] Saving as numpy...")
skill_uris_list_sorted = sorted(list(skill_embeddings_dict.keys()))
skill_embeddings_matrix = np.array([skill_embeddings_dict[uri] for uri in skill_uris_list_sorted])

np.savez(
    "skill_embeddings_pykeen.npz",
    uris=skill_uris_list_sorted,
    embeddings=skill_embeddings_matrix
)
print("  ✓ Saved: skill_embeddings_pykeen.npz")
print("  Format: {'uris': array, 'embeddings': matrix}")
print("  Usage: np.load('skill_embeddings_pykeen.npz')")

# 3. Save as CSV with URI and embedding (for easy inspection)
# Note: This creates a wide CSV, may be large
print("\n[3] Saving as CSV...")
embeddings_df = pd.DataFrame(
    skill_embeddings_matrix,
    index=skill_uris_list_sorted
)
embeddings_df.index.name = "uri"
embeddings_df.to_csv("skill_embeddings_pykeen.csv")
print("  ✓ Saved: skill_embeddings_pykeen.csv")
print("  Format: CSV with URI as index, embedding dimensions as columns")
print("  Usage: pd.read_csv('skill_embeddings_pykeen.csv', index_col='uri')")

# 4. Save as JSON (human-readable, but large)
print("\n[4] Saving as JSON (URI -> embedding as list)...")
# Convert numpy arrays to lists for JSON
skill_embeddings_json = {
    uri: embedding.tolist() 
    for uri, embedding in skill_embeddings_dict.items()
}
import json
with open("skill_embeddings_pykeen.json", "w", encoding="utf-8") as f:
    json.dump(skill_embeddings_json, f, indent=2)
print("  ✓ Saved: skill_embeddings_pykeen.json")
print("  Format: {uri: [embedding_values...]}")
print("  Usage: json.load(open('skill_embeddings_pykeen.json'))")

# 5. Save metadata
print("\n[5] Saving metadata...")
metadata = {
    "model": "TransE",
    "embedding_dim": skill_embeddings_matrix.shape[1],
    "num_skills": len(skill_embeddings_dict),
    "num_triples": triples_factory.num_triples,
    "num_entities": triples_factory.num_entities,
    "num_relations": triples_factory.num_relations,
    "skill_uri_prefix": "http://data.europa.eu/esco/skill/",
    "description": "Skill embeddings from PyKEEN knowledge graph embedding model",
    "created_date": pd.Timestamp.now().isoformat(),
}

with open("skill_embeddings_pykeen_metadata.pkl", "wb") as f:
    pickle.dump(metadata, f)
print("  ✓ Saved: skill_embeddings_pykeen_metadata.pkl")
print("  Contains: model info, dimensions, counts, etc.")

# 6. Save as TSV with URI and embedding (alternative CSV format)
print("\n[6] Saving URI-embedding pairs as TSV...")
# Create a TSV with URI and embedding values (one row per skill)
tsv_data = []
for uri in skill_uris_list_sorted:
    embedding = skill_embeddings_dict[uri]
    row = {"uri": uri, "embedding_dim": len(embedding)}
    # Add first few embedding values as example
    for i, val in enumerate(embedding[:5]):
        row[f"emb_{i}"] = val
    row["embedding_norm"] = np.linalg.norm(embedding)
    tsv_data.append(row)

tsv_df = pd.DataFrame(tsv_data)
tsv_df.to_csv("skill_embeddings_pykeen_summary.tsv", sep="\t", index=False)
print("  ✓ Saved: skill_embeddings_pykeen_summary.tsv")
print("  Format: TSV with URI, embedding_dim, first 5 values, norm")

print("\n" + "=" * 70)
print("All files saved successfully!")
print("=" * 70)
print("\nSaved files:")
print("  1. skill_embeddings_pykeen.pkl - Python pickle (recommended for Python)")
print("  2. skill_embeddings_pykeen.npz - NumPy format (efficient)")
print("  3. skill_embeddings_pykeen.csv - CSV format (human-readable)")
print("  4. skill_embeddings_pykeen.json - JSON format (human-readable)")
print("  5. skill_embeddings_pykeen_metadata.pkl - Metadata")
print("  6. skill_embeddings_pykeen_summary.tsv - Summary with sample values")
print("\n" + "=" * 70)


Saving Skill Embeddings and URI Mappings

Saving 14,671 skill embeddings...

[1] Saving as pickle...
  ✓ Saved: skill_embeddings_pykeen.pkl
  Format: {uri: embedding_array}
  Usage: pickle.load(open('skill_embeddings_pykeen.pkl', 'rb'))

[2] Saving as numpy...
  ✓ Saved: skill_embeddings_pykeen.npz
  Format: {'uris': array, 'embeddings': matrix}
  Usage: np.load('skill_embeddings_pykeen.npz')

[3] Saving as CSV...
  ✓ Saved: skill_embeddings_pykeen.csv
  Format: CSV with URI as index, embedding dimensions as columns
  Usage: pd.read_csv('skill_embeddings_pykeen.csv', index_col='uri')

[4] Saving as JSON (URI -> embedding as list)...
  ✓ Saved: skill_embeddings_pykeen.json
  Format: {uri: [embedding_values...]}
  Usage: json.load(open('skill_embeddings_pykeen.json'))

[5] Saving metadata...
  ✓ Saved: skill_embeddings_pykeen_metadata.pkl
  Contains: model info, dimensions, counts, etc.

[6] Saving URI-embedding pairs as TSV...
  ✓ Saved: skill_embeddings_pykeen_summary.tsv
  Format: TSV