# KG Classes Extractor for RAG context creation

This notebook would guide you through the extraction of the classes and their metadata of the [IDSM KG](https://idsm.elixir-czech.cz/chemweb) that would be used in the notebook `Similar_Query_IDSM_v1.ipynb`.

IDSM is too large to query over all triples to get all unique predicates. That's why we will use the ontologies from the [void.ttl](https://ftp.ncbi.nlm.nih.gov/pubchem/RDF/void.ttl) file from PubChem RDF

## Prerequisits 
- An endpoint with all the used ontologies (Note that all the ontologies are available at [this MyBox URL](https://mybox.inria.fr/d/24d9423c67d64f8284fa/) you can download them to searching / preprocessing them. The password is: `Kc8(-8aE`)

In the case of using Corese server, an instance can be lunched with similar command

`%docker run --name my-corese -p 8080:8080 -v /path/to/IDSM/ontologies:/usr/local/corese/data  -d wimmics/corese`

## We import the required modules

In [2]:
# general python libs
import os
from SPARQLWrapper import SPARQLWrapper, JSON, TURTLE
import glob
import rdflib
from rdflib import Graph
import tiktoken
from tqdm import tqdm
from typing import Tuple, List
from pathlib import Path
import pickle
from rdflib import RDFS, BNode, Namespace, URIRef

## We prepare the variables and helper functions

In [None]:
# This is the remote ENPKG SPARQL endpoint
endpoint_url_corese = 'http://localhost:8080/sparql'
endpoint_url_idsm = 'https://idsm.elixir-czech.cz/sparql/endpoint/idsm'

directory = Path(os.getcwd()).parent.parent / 'data' / 'saved_pkls' / 'idsm'

if not os.path.exists(directory):
    os.makedirs(directory)

if not os.path.exists(directory / 'schema_ttl'):
  os.mkdir(directory / 'schema_ttl')

In [None]:
query_cls_rel = """
SELECT ?property (SAMPLE(COALESCE(?type, STR(DATATYPE(?value)), "Untyped")) AS ?valueType) WHERE {{
        {{
        SELECT ?instance WHERE {{
            ?instance a <{class_uri}> .
        }} LIMIT 100
        }}
        {
          {?instance ?property ?value .}
        }
        OPTIONAL {{
        ?value a ?type .
        }}
    }}
    GROUP BY ?property ?type
    LIMIT 300
"""

In [None]:
# HELPER FUNCTIONS

def run_sparql(query, 
               url=endpoint_url_corese):
    sparql = SPARQLWrapper(url)
    sparql.setQuery(query)
    sparql.setReturnFormat(JSON)
    sparql.setTimeout(600)
    results = sparql.query().convert()
    results = nested_value(results, ['results', 'bindings'])
    return results

def run_sparql_construct(query, filename, url=endpoint_url_corese):
    sparql = SPARQLWrapper(url)
    sparql.setQuery(query)
    sparql.setReturnFormat(TURTLE)
    sparql.setTimeout(600)
    results = sparql.queryAndConvert()
    graph = rdflib.Graph()
    graph.parse(data=results, format='turtle')
    graph.serialize(destination=filename, format='turtle')
    return results
    
def nested_value(data: dict, path: list):
    current = data
    for key in path:
        try:
            current = current[key]
        except:
            return None
    return current

def get_prop_and_val_types(cls: str) -> List[Tuple[str, str]]:
    query = query_cls_rel.replace("{class_uri}",cls)

    values = [(nested_value(x, ['property','value']),(nested_value(x, ['valueType','value']) )) for x in run_sparql(query,endpoint_url_idsm)]

    return [] if values == [(None,None)] else values

def format_class_graph_file(class_uri:str) -> str: 
    class_name = class_uri.split('/')[-1]
    return f"{directory}/schema_ttl/{class_name}.ttl"


## Get all classes

In [None]:
query_class = """  
SELECT DISTINCT ?cls ?comment ?label
WHERE {
?cls a owl:Class .
OPTIONAL { ?cls rdfs:comment ?comment }
OPTIONAL { ?cls rdfs:label ?label }
FILTER (isIri(?cls))
}
"""

In [None]:
classes = [(nested_value(x, ['cls','value']),nested_value(x, ['label','value']),nested_value(x, ['comment','value'])) for x in run_sparql(query_class,url=endpoint_url_corese)]
classes = list(set(classes))
len(classes)

In [None]:

def ilc_tuple2str(res: Tuple[str, str, str]) -> str:
    return f"(<{res[0]}> , {res[1]}, {res[2]})"

In [None]:
class_str:str =(f"In the following, each IRI is followed by the local name and optionally its description in parentheses.\n" 
+ f"The RDF graph supports the following node types:\n"
+ f'{"\n".join([ilc_tuple2str(c) for c in classes])}')

In [None]:
print(class_str)

## Save the classes in a pickle for later use

In [None]:
with open(f"{directory}/classes.pkl", 'wb') as handle:
    pickle.dump(classes, handle, protocol=pickle.HIGHEST_PROTOCOL)

## Load the classes from the pickle

In [None]:
with open(f"{directory}/classes.pkl", 'rb') as handle:
    classes = pickle.load(handle)
    print(len(classes))

## Get the context of the classes

### Single thread call

it would take almost 1400 hours

In [None]:

def get_context_class(cl:str):
    graph = rdflib.Graph()
    graph.bind('obo', Namespace('http://purl.obolibrary.org/obo/'),override=True)
    graph.bind('cito', Namespace('http://purl.org/spar/cito/'),override=True)
    graph.bind('pubchem', Namespace('http://rdf.ncbi.nlm.nih.gov/pubchem/vocabulary#'),override=True)
    graph.bind('sio', Namespace('http://semanticscience.org/resource/'),override=True)
    graph.bind('bao', Namespace('http://www.bioassayontology.org/bao#'),override=True)

    class_ref = URIRef(cl[0])
    properties_and_values = get_prop_and_val_types(cl[0])
    # print(properties_and_values)
    for property_uri, prop_type in properties_and_values:
        value_ref = (
            BNode() if (prop_type == "Untyped" or prop_type == None) else URIRef(prop_type)
        )
        if (cl[1]): graph.add((class_ref, RDFS.label, rdflib.term.Literal(cl[1])))
        if(cl[2]): graph.add((class_ref, RDFS.comment, rdflib.term.Literal(cl[2])))
        graph.add((class_ref, URIRef(property_uri), value_ref))
    
    # save the graph
    class_file_path = format_class_graph_file(cl[0])

    graph.serialize(destination=class_file_path)

In [None]:

for cl in tqdm(classes, desc="Adding classes to graph"):
    get_context_class(cl)

### Multi-threading solution

If we use more that 4 threads the IDSM server could crash

In [None]:
from concurrent.futures import ThreadPoolExecutor, as_completed


with tqdm(total=len(classes)) as pbar:
    with ThreadPoolExecutor(max_workers=4) as ex:
        futures = [ex.submit(get_context_class, url) for url in classes]
        for future in as_completed(futures):
            result = future.result()
            pbar.update(1)

### Merge les different context dans un meme fichier

Inforce the use of some prefixes

In [None]:
prefix_map = {'http://schema.org/':'schema',
              'https://enpkg.commons-lab.org/module/':'enpkg_module',
              'http://purl.org/pav/':'pav',
              'http://example.org/':'example',
              'https://enpkg.commons-lab.org/kg/':'enpkg',
              'http://purl.obolibrary.org/obo/':'obo',
              'http://purl.org/spar/cito/':'cito',
              'http://rdf.ncbi.nlm.nih.gov/pubchem/vocabulary#':'pubchem',
              'http://semanticscience.org/resource/':'sio',
              'http://www.bioassayontology.org/bao#':'bao',
              'http://purl.obolibrary.org/obo/CHEBI_':'chebi',
              'http://semanticscience.org/resource/CHEMINF_':'cheminf',
              'http://rdf.ebi.ac.uk/terms/chembl#':'chembl'
              }

# Inforce prefixes for each ttl
for filename in glob.glob(str(directory)+'/schema_ttl/*.ttl'):
    g = Graph()
    g.parse(filename, format='turtle')

    # Update prefix definitions
    for namespace, prefix in prefix_map.items():
        g.bind(prefix, namespace,override=True)

    # Save the graph
    g.serialize(destination=filename, format='turtle')

Merge the ttl files

In [None]:
g = Graph()

# Load all ttl files in the folder
for filename in glob.glob(str(directory)+'/schema_ttl/*.ttl'):
    g.parse(filename, format='turtle')

# Save the merged graph
g.serialize(destination=str(directory)+'/merged.ttl', format='turtle')

### Counting the token size of the context

In [None]:
# Load the txt file
schema_file = str(directory)+'/merged.ttl'
with open(schema_file, 'r') as file:
    content = file.read()

# Initialize Tiktoken with the desired encoding model
encoding = tiktoken.encoding_for_model("gpt-4o")

# Count the number of tokens in the TTL file
token_count = len(encoding.encode(content))

print(f"The Schema file '{schema_file}' contains {token_count} tokens.")
