In [None]:
import sys, os, rdflib, pickle
from pathlib import Path
from sentence_transformers import SentenceTransformer
from pykeen.triples import TriplesFactory
from pykeen.pipeline import pipeline

# Custom EHR Tools 
from EHRPipeline.entity_alignment.invokers import Invoker
from EHRPipeline.entity_alignment.entity_alignement import CrossOntologyAligner
from EHRPipeline.entity_alignment.embedder import SimpleDataEmbedder, ClusterGenerator
from EHRPipeline.entity_linking.linking_validation import LinkingValidator
from EHRPipeline.fact_validation.factValidation import Validator


# General setup of environment and files

In [None]:
basePath = "merged_ontology.ttl"
tokenizer = SentenceTransformer("all-MiniLM-L6-v2")

baseKG = rdflib.Graph()
baseKG.parse(basePath, format="ttl") 

snomed_embeddings = Path("../data/snomed_embedded.pkl")
clusters = Path("../sven/data/cluster.pkl")

if snomed_embeddings.exists():
    with open("../data/snomed_embedded.pkl", "rb") as file:
        data_embedding = pickle.load(file)
else:
    snomed = rdflib.Graph()
    snomed.parse("../data/snomed-ct-20221231-mini.ttl", format="ttl")
    embedder = SimpleDataEmbedder(embeddingModel=tokenizer)
    data_embedding = embedder.encode(data=snomed)

if clusters.exists():
    with open("../sven/data/cluster.pkl", "rb") as file1:
        segmentation = pickle.load(file1)
else:
    cluster = ClusterGenerator(data_embedding, n_clusters=50)
    segmentation = cluster.generate_clusters()

# Schema Mapping

In [None]:
# Insert Fluvio's part

# Cross-Ontology Entity Alignment

In [None]:
ontologyaligner = CrossOntologyAligner(dataGraph=data_embedding, clusters=segmentation, embeddingModel=tokenizer)
CrossOntologyAlignedKG = ontologyaligner.merge(query=baseKG, Invoker="icd9tosnomed", Namespace=rdflib.URIRef("https://biomedit.ch/rdf/sphn-schema/sphn#hasCode"))

# Entity Linking Validation Step

In [None]:
LinkingValidator(CrossOntologyAlignedKG)

# TransE Embedding

In [None]:
# Convert rdf graph to triples:
triples = []
for vertex, edge, label in alignedGraph:
    triples.append((str(vertex), str(edge), str(label)))

# Andy's code
training, validation, testing = triples.split([0.8, 0.1, 0.1])

result = pipeline(
    training=training,
    validation=validation,
    testing=testing,
    model='TransE',
    model_kwargs={
        'embedding_dim': 20,
    },
    optimizer='Adam',
    optimizer_kwargs={
        'lr': 1e-3,
        'weight_decay': 1e-5
    },
    negative_sampler='basic',
    loss='SoftplusLoss',
    training_loop='sLCWA',
    training_kwargs={
        'num_epochs': 100,
        'batch_size': 32,
        'label_smoothing': 0.0
    },
    evaluator_kwargs=  {
        "filtered": True
    },
    filter_validation_when_testing = True,
)

# Fact Validation

In [None]:
def main():
    sparql_endpoint = "http://localhost:7200/repositories/finalrepohealthcare" # This is a localhost so has to be configured per machine
    validator = Validator(sparql_endpoint)

    predictions_file = "predictions.txt"
    output_file = "validated_facts.txt"
    
    with open(predictions_file, "r", encoding="utf-8") as f_in, open(output_file, "w", encoding="utf-8") as f_out:
        for line in f_in:
            line = line.strip()
            if not line:
                continue

            # Expect exactly 3 parts: subject, predicate, object
            parts = line.split()
            if len(parts) != 3:
                print(f"Skipping malformed line: {line}")
                continue
            
            subj = parts[0].strip()
            pred = parts[1].strip()
            obj  = parts[2].strip()

            subj_uri = subj.strip("<>")
            pred_uri = pred.strip("<>")
            obj_uri  = obj.strip("<>")

            # Validate
            score = validator.validate_fact(subj_uri, pred_uri, obj_uri, max_length=3)

            print(f"Fact: {subj} {pred} {obj} => Score: {score}")

            # threshold for writing the facts validated
            if score >= 0.5:
                f_out.write(f"{subj} {pred} {obj}\n")

    print(f"Validation complete. Facts with score >= 0.5 are in '{output_file}'.")

if __name__ == "__main__":
    main()
