In [None]:
import sys, os, rdflib, pickle, urllib.parse, re, random
import pandas as pd
from pathlib import Path
from sentence_transformers import SentenceTransformer
from pykeen.triples import TriplesFactory
from pykeen.pipeline import pipeline

# Custom EHR Tools 
from EHRPipeline.entity_alignment.invokers import Invoker
from EHRPipeline.entity_alignment.entity_alignement import CrossOntologyAligner
from EHRPipeline.entity_alignment.embedder import SimpleDataEmbedder, ClusterGenerator
from EHRPipeline.entity_linking.linking_validation import LinkingValidator
from EHRPipeline.fact_validation.factValidation import Validator


# General setup of environment and files

In [None]:
tokenizer = SentenceTransformer("all-MiniLM-L6-v2")

snomed_embeddings = Path("data/snomed_embedded.pkl")
clusters = Path("data/cluster.pkl")

if snomed_embeddings.exists():
    with open("/data/snomed_embedded.pkl", "rb") as file:
        data_embedding = pickle.load(file)
else:
    snomed = rdflib.Graph()
    snomed.parse("/data/snomed-ct-20221231-mini.ttl", format="ttl")
    embedder = SimpleDataEmbedder(embeddingModel=tokenizer)
    data_embedding = embedder.encode(data=snomed)

if clusters.exists():
    with open("data/cluster.pkl", "rb") as file1:
        segmentation = pickle.load(file1)
else:
    cluster = ClusterGenerator(data_embedding, n_clusters=50)
    segmentation = cluster.generate_clusters()

In [None]:
PATIENTS_CSV = "data/mimic-iii/PATIENTS.csv"
DIAGNOSES_ICD_CSV = "data/mimic-iii/DIAGNOSES_ICD.csv"
LABEVENTS_CSV = "data/mimic-iii/LABEVENTS.csv"
DLABITEMS_CSV = "data/mimic-iii/D_LABITEMS.csv"

OUTPUT_TTL = "data/enhanced_sphn_triples_sample_FINAL.ttl"

# Schema Mapping

In [None]:
FRACTION = 0.05  # 5%
patients_df = pd.read_csv(PATIENTS_CSV)
diagnoses_icd_df = pd.read_csv(DIAGNOSES_ICD_CSV)
lab_events_df = pd.read_csv(LABEVENTS_CSV)
lab_items_df = pd.read_csv(DLABITEMS_CSV)

all_subjects = patients_df['subject_id'].unique().tolist()
random.shuffle(all_subjects)

num_5pct = int(len(all_subjects) * FRACTION)
keep_subjects = set(all_subjects[:num_5pct])

patients_df_small = patients_df[patients_df['subject_id'].isin(keep_subjects)]
diagnoses_icd_df_small = diagnoses_icd_df[diagnoses_icd_df['subject_id'].isin(keep_subjects)]
lab_events_df_small = lab_events_df[lab_events_df['subject_id'].isin(keep_subjects)]

keep_itemids = set(lab_events_df_small['itemid'].dropna().unique())
lab_items_df_small = lab_items_df[lab_items_df['itemid'].isin(keep_itemids)]

itemid_to_loinc = {}
for _, row in lab_items_df_small.iterrows():
    itemid = row['itemid']
    loinc = str(row['loinc_code']).strip()
    if loinc == 'nan' or loinc == '':
        loinc = None
    itemid_to_loinc[itemid] = loinc

def sanitize_value_for_iri(value):
    """Sanitize values for use in IRI format."""
    if pd.isna(value):
        return "NA"
    val_str = str(value)
    return urllib.parse.quote(val_str, safe="-._~")

triples_ttl = []

for idx, row in diagnoses_icd_df_small.iterrows():
    row_id = row['row_id']
    subj_id = row['subject_id']
    icd9_code = str(row['icd9_code']).strip()

    diagnosis_iri = f"<http://example.org/Diagnosis/{subj_id}/PATIENTS/{row_id}>"
    subject_pseudo_iri = f"<http://example.org/PATIENTS/{subj_id}>"
    icd9_iri = f"<http://example.org/Code/icd9#{icd9_code}>"

    triples_ttl.append(f"{diagnosis_iri} a sphn:Diagnosis .")
    triples_ttl.append(f"{subject_pseudo_iri} a sphn:SubjectPseudoIdentifier .")
    triples_ttl.append(f"{diagnosis_iri} sphn:hasSubjectPseudoIdentifier {subject_pseudo_iri} .")
    triples_ttl.append(f"{icd9_iri} a sphn:Code .")
    triples_ttl.append(f"{diagnosis_iri} sphn:hasCode {icd9_iri} .")

for idx, row in lab_events_df_small.iterrows():
    row_id = row['row_id']
    subj_id = row['subject_id']
    itemid = row['itemid']
    val = row['value']

    if pd.isna(row_id) or pd.isna(subj_id) or pd.isna(itemid):
        continue

    lab_event_iri = f"<http://example.org/LabTestEvent/{int(subj_id)}/PATIENTS/{int(row_id)}>"
    subject_pseudo_iri = f"<http://example.org/PATIENTS/{int(subj_id)}>"
    lab_test_iri = f"<http://example.org/LabTest/{int(subj_id)}/PATIENTS/{int(itemid)}>"
    
    value_part = sanitize_value_for_iri(val)
    lab_result_iri = f"<http://example.org/LabResult/{int(subj_id)}/PATIENTS/{int(itemid)}/{value_part}>"

    triples_ttl.append(f"{lab_event_iri} a sphn:LabTestEvent .")
    triples_ttl.append(f"{subject_pseudo_iri} a sphn:SubjectPseudoIdentifier .")
    triples_ttl.append(f"{lab_event_iri} sphn:hasSubjectPseudoIdentifier {subject_pseudo_iri} .")

    triples_ttl.append(f"{lab_test_iri} a sphn:LabTest .")
    triples_ttl.append(f"{lab_event_iri} sphn:hasLabTest {lab_test_iri} .")

    triples_ttl.append(f"{lab_result_iri} a sphn:LabResult .")
    triples_ttl.append(f"{lab_test_iri} sphn:hasResult {lab_result_iri} .")

    loinc_code = itemid_to_loinc.get(itemid, None)
    if loinc_code is not None:
        loinc_iri = f"<http://example.org/Code/loinc#{loinc_code}>"
        triples_ttl.append(f"{loinc_iri} a sphn:Code .")
        triples_ttl.append(f"{lab_test_iri} sphn:hasCode {loinc_iri} .")
        triples_ttl.append(f"{lab_result_iri} sphn:hasCode {loinc_iri} .")

with open(OUTPUT_TTL, "w", encoding="utf-8") as f:
    for line in triples_ttl:
        f.write(line)
        if not line.endswith("\n"):
            f.write("\n")

print(f"Sample of the data has been transformed into '{OUTPUT_TTL}' with {len(triples_ttl)} RDF statements.")

# Cross-Ontology Entity Alignment

In [None]:
# Import RDF file from schema mapping
query = rdflib.Graph()
query.parse(OUTPUT_TTL, format="ttl") 

In [None]:
ontologyaligner = CrossOntologyAligner(dataGraph=data_embedding, clusters=segmentation, embeddingModel=tokenizer)
CrossOntologyAlignedKG = ontologyaligner.merge(query=query, Invoker="icd9tosnomed", Namespace=rdflib.URIRef("https://biomedit.ch/rdf/sphn-schema/sphn#hasCode"))

# Entity Linking Validation Step

In [None]:
LinkingValidator(CrossOntologyAlignedKG)

# TransE Embedding

In [None]:
triples_factory = TriplesFactory.from_path('data/formatted_triples_FINAL_2.txt')

# Andy's code
training, validation, testing = triples_factory.split([0.8, 0.1, 0.1])

result = pipeline(
    training=training,
    validation=validation,
    testing=testing,
    model='TransE',
    model_kwargs={
        'embedding_dim': 20,
    },
    optimizer='Adam',
    optimizer_kwargs={
        'lr': 1e-3,
        'weight_decay': 1e-5
    },
    negative_sampler='basic',
    loss='SoftplusLoss',
    training_loop='sLCWA',
    training_kwargs={
        'num_epochs': 100,
        'batch_size': 32,
        'label_smoothing': 0.0
    },
    evaluator_kwargs=  {
        "filtered": True
    },
    filter_validation_when_testing = True,
)

In [None]:
from pykeen import predict  # or pykeen.models.predict, depending on version

df_predictions = predict.predict_target(
    model=result.model,
    head="Diagnosis/10033/PATIENTS/112578",
    relation="hasCode",
    triples_factory=result.training
).df

# Inspect the top 10
df_predictions.head(10)

### For Fact Validation Part: Generate a txt with the previsions from the cell above

In [None]:
output_file = "predictions.txt"

with open(output_file, "w", encoding="utf-8") as f:
    for idx, row in df_predictions.head(10).iterrows():
        predicted_code = row["tail_label"] 
        subject_uri = "<http://example.org/Diagnosis/10033/PATIENTS/112578>"
        predicate_uri = "<https://biomedit.ch/rdf/sphn-schema/sphn#hasCode>"
        object_uri = f"<http://example.org/Code/{predicted_code}>"

        triple_line = f"{subject_uri}  {predicate_uri}  {object_uri}"
        f.write(triple_line + "\n")

print(f"Wrote top-10 predictions to {output_file}")


# Fact Validation

In [None]:
def main():
    sparql_endpoint = "http://localhost:7200/repositories/integrationhealthcare" # This is a localhost so has to be configured per machine
    validator = Validator(sparql_endpoint)

    predictions_file = "predictions.txt"
    output_file = "validated_facts.txt"
    
    with open(predictions_file, "r", encoding="utf-8") as f_in, open(output_file, "w", encoding="utf-8") as f_out:
        for line in f_in:
            line = line.strip()
            if not line:
                continue

            # Expect exactly 3 parts: subject, predicate, object
            parts = line.split()
            if len(parts) != 3:
                print(f"Skipping malformed line: {line}")
                continue
            
            subj = parts[0].strip()
            pred = parts[1].strip()
            obj  = parts[2].strip()

            subj_uri = subj.strip("<>")
            pred_uri = pred.strip("<>")
            obj_uri  = obj.strip("<>")

            # Validate
            score = validator.validate_fact(subj_uri, pred_uri, obj_uri, max_length=3)

            print(f"Fact: {subj} {pred} {obj} => Score: {score}")

            # threshold for writing the facts validated
            if score >= 0.5:
                f_out.write(f"{subj} {pred} {obj}\n")

    print(f"Validation complete. Facts with score >= 0.5 are in '{output_file}'.")

if __name__ == "__main__":
    main()
