In [1]:
import json
import pandas as pd
import re

from collections import defaultdict
from nltk.corpus import stopwords
from rdflib import Graph, URIRef, Literal, RDFS
from tqdm import tqdm

import logging
logging.getLogger("rdflib").setLevel(logging.ERROR)

from hkg.graph_utils import n3_to_RDFLib, save_graph, load_named_graph
from hkg.labels import HKG_PREDICATES, URI2labels, Labels
from hkg.rdf_regex import parse_triples_iter, string_is_URI, parse_triple, split_uri

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
STOPWORDS = stopwords.words('english')

In [3]:
save_pickle = True
save_turtle = True
save_csv = True
serialize_graph = lambda graph, graph_name, folder: save_graph(graph, graph_name, folder, save_pickle=save_pickle, save_turtle=save_turtle, save_csv=save_csv)

In [4]:
ALLOW_INVERSE = True
ALLOW_SIMILAR_PREDICATES = True

# Load graph and dense index

In [5]:
%%time
graph_name = "full_graph"
main_wl = load_named_graph(graph_name)

CPU times: user 31.3 s, sys: 593 ms, total: 31.9 s
Wall time: 31.9 s


In [6]:
%%time
main_labels = Labels(main_wl, dense_index="full_graph_labels")

Processing graph
Loading dense index full_graph_labels
Computing equivalence classes
Collecting entities for which to store labels
Computing labels


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 182452/182452 [00:12<00:00, 14547.22it/s]


Done.
CPU times: user 47.3 s, sys: 525 ms, total: 47.8 s
Wall time: 48.6 s


# Functions

In [7]:
def check_term_in_text(term, text):
    if string_is_URI(term):
        labels = URI2labels(term)
    elif string_is_URI(re.sub(r"\s+", "_", term)):
        labels = URI2labels(re.sub(r"\s+", "_", term))
    else:
        labels = set([term])
    words = set(ls.lower() for label in labels for ls in label.split() if len(ls)>2 and ls.lower() not in STOPWORDS)
    for word in words:
        if word in text:
            return True
    return False

In [8]:
def validate_abstract(uri, term, abstract):
    if string_is_URI(term) and (term == uri or URI2labels(term).intersection(URI2labels(uri))):
        return True
    return check_term_in_text(term, abstract)

# Load abstracts

In [9]:
abstracts = {}
with open('data/main_graph_abstracts.jsonl') as f:
    for line in f:
        d = json.loads(line.strip())
        abstracts[d['uri']] = {"abstract": d['abstract']}

In [10]:
with open('data/main_graph_abstracts_triples.jsonl') as f:
    for line in f:
        d = json.loads(line.strip())
        abstracts[d["uri"]]["triples"] = d["llm_output"]

# Filter out triples with terms that do not originate from the abstract

In [11]:
%%time
good = set()
bad = set()
for uri in tqdm(abstracts.keys()):
    triples = abstracts[uri]['triples'].strip()
    abstract = abstracts[uri]["abstract"].lower()
    for match in parse_triples_iter(triples):
        predicate = match.group("predicate")
        valid_predicate, rdf_predicate = main_labels._validate_predicate(predicate)
        if not valid_predicate:
            continue
        subject = match.group("subject")
        object_ = match.group("object")
        valid_sub = validate_abstract(uri, subject, abstract)
        valid_obj = validate_abstract(uri, object_, abstract)
        if valid_sub and valid_obj:
            good.add((uri, match.group()))
        else:
            bad.add((uri, match.group()))
print(len(good))
print(len(bad))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3430/3430 [00:02<00:00, 1630.10it/s]

26936
7843
CPU times: user 2.1 s, sys: 7.79 ms, total: 2.11 s
Wall time: 2.11 s





# Normalize terms in generated triples (match terms to ones in DBpedia if possible)

In [12]:
def uris(terms):
    for term in terms:
        if isinstance(term, URIRef):
            yield(term)

In [13]:
%%time
good_dbp = set()
llm_graph = Graph()
true_triples_dict = defaultdict(set)
false_triples_dict = defaultdict(set)
true_triples = set()
for uri, trip in tqdm(good):
    match = parse_triple(trip)
    truth = match.group("truth")
    sub = match.group("subject")
    obj = match.group("object")
    pred = match.group("predicate")
    valid_predicate, pred = main_labels._validate_predicate(pred)
    assert valid_predicate
    sub_normalized = uris(main_labels.match_term(sub))
    obj_normalized = main_labels.match_term(obj)
    if truth == "FALSE" and not main_labels.verify_triple_two_steps(trip,
                                                        allow_inverse=ALLOW_INVERSE,
                                                        allow_similar_predicates=ALLOW_SIMILAR_PREDICATES):
        # Avoid processing FALSE llm triples that are inside our main graph
        continue
    for s in sub_normalized:
        if obj_normalized:
            # Both subject and object can be mapped to terms in our main graph
            for o in obj_normalized:
                good_dbp.add((uri, str(truth), str(s), str(pred), str(o), "full"))
                if truth == "TRUE":
                    true_triples_dict[(s, pred, o)].add(trip)
                    true_triples.add(trip)
                    llm_graph.add((s, pred, o))
                else:
                    false_triples_dict[(s, pred, o)].add(trip)
        else:
            # Partial match: only the subject can be mapped to a term in our main graph
            good_dbp.add((uri, str(truth), str(s), str(pred), str(obj), "partial"))
            if truth == "TRUE":
                true_triples_dict[(s, pred, n3_to_RDFLib(obj))].add(trip)
                true_triples.add(trip)
                llm_graph.add((s, pred, n3_to_RDFLib(obj)))
            else:
                false_triples_dict[(s, pred, n3_to_RDFLib(obj))].add(trip)
            if string_is_URI(obj):
                labels = URI2labels(obj)
                for label in labels:
                    good_dbp.add((uri, "TRUE", str(obj), str(RDFS.label), str(label), "partial"))
                    if truth == "TRUE":
                        llm_graph.add((n3_to_RDFLib(obj), RDFS.label, Literal(label)))

len(good_dbp)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26936/26936 [02:50<00:00, 157.79it/s]

CPU times: user 19min 59s, sys: 1.38 s, total: 20min
Wall time: 2min 50s





165443

In [14]:
len(llm_graph)

150103

- main: 149253,
- full: 150085

# Consistency Check

In [15]:
true_triples = set()
for rdf_triples, set_of_str_triples in true_triples_dict.items():
    for str_triple in set_of_str_triples:
        true_triples.add(str_triple)
len(true_triples)

24586

In [16]:
false_triples = set()
for rdf_triples, set_of_str_triples in false_triples_dict.items():
    for str_triple in set_of_str_triples:
        false_triples.add(str_triple)
len(false_triples)

1038

In [17]:
len(true_triples_dict), len(false_triples_dict)

(144689, 4528)

Some LLM-generated triples have both a "TRUE" version and a "FALSE" version:

In [18]:
intersection = set(false_triples_dict.keys()).intersection(set(true_triples_dict.keys()))
len(intersection)

49

Percentage of contradictory generated triples:

In [19]:
f"{len(intersection)/len(set(false_triples_dict.keys()).union(set(true_triples_dict.keys())))*100:.4f}%"

'0.0328%'

In [20]:
example = sorted(list(intersection))[0]

In [21]:
example

(rdflib.term.URIRef('http://dbpedia.org/resource/African_trypanosomiasis'),
 rdflib.term.URIRef('https://example.org/health_kg/diagnosis'),
 rdflib.term.URIRef('http://dbpedia.org/resource/Suramin'))

In [22]:
true_triples_dict[example]

{'TRUE ( dbr:African_trypanosomiasis hkg:diagnosis dbr:Suramin )'}

In [23]:
false_triples_dict[example]

{'FALSE ( dbr:African_sleeping_sickness hkg:diagnosis dbr:Suramin )'}

Let's remove these triples from the graph and from the set of true triples

In [24]:
print(f"Before: {len(llm_graph)}")
for trip in intersection:
    llm_graph.remove(trip)
print(f"After: {len(llm_graph)}")

Before: 150103
After: 150054


In [25]:
print(f"Before: {len(true_triples)}")
for trip in intersection:
    for str_trip in true_triples_dict[trip]:
        true_triples.discard(str_trip)
print(f"After: {len(true_triples)}")

Before: 24586
After: 24577


# Count LLM triples

In [26]:
predicate_to_trip = defaultdict(set)
for trip in true_triples:
    match = parse_triple(trip)
    sub = match.group("subject")
    obj = match.group("object")
    pred = match.group("predicate")
    prefix, pred_name = split_uri(pred)
    predicate_to_trip[pred_name].add((sub, obj))

In [27]:
for k, v in sorted(predicate_to_trip.items()):
    print(f"# generated triples for hkg:{k}: {len(v)}")

# generated triples for hkg:cause: 5587
# generated triples for hkg:complication: 3179
# generated triples for hkg:diagnosis: 2032
# generated triples for hkg:medication: 323
# generated triples for hkg:prevention: 695
# generated triples for hkg:risk: 1722
# generated triples for hkg:symptom: 7406
# generated triples for hkg:treatment: 3632


# Save filtered triples as a graph

In [28]:
%%time
out_filename=f"{graph_name}_abstract_triples_filtered"
serialize_graph(llm_graph, out_filename, "./data/graphs/")

Saving ./data/graphs/pickle/full_graph_abstract_triples_filtered.pickle
Saving ./data/graphs/csv/full_graph_abstract_triples_filtered.csv
Saving ./data/graphs/ttl/full_graph_abstract_triples_filtered.ttl
CPU times: user 31.6 s, sys: 132 ms, total: 31.7 s
Wall time: 31.6 s
