# Query Rewriting

This notebook demonstrates how to use the `ner_functions` module to extract node labels, instances, and rewrite NL queries. 

## Overview

With NL input : Identify the gene names where the gene is either bound by Niclosamide (which resembles "tolcapone") or upregulated by Mycophenolic acid.
The method can:
1. extract node labels, ['Gene'], and instances, ['Compound(name:"Mycophenolic acid")', 'Compound(name:"Niclosamide")', 'Compound(name:"Tolcapone")'] 
2. rewrite instances into correct format as defined in KG schemas, outputing a rewrite NL, Identify the gene names where the gene is either bound by 'Niclosamide' (which resembles 'Tolcapone') or upregulated by 'Mycophenolic acid'.

In [None]:
from utils.utilities import *
from utils.ner_functions import *

db_entities = read_json('datas/de_entities.json')
# from sentence_transformers import SentenceTransformer

all_graph_node_entities = ['Pathway', 'Anatomy', 'SideEffect', 'Gene', 'MolecularFunction', 'Symptom', 'BiologicalProcess', 'PharmacologicClass', 'Disease', 'Compound', 'CellularComponent']
all_graph_rela_entities = ['name', 'identifier', 'url', 'license', 'source', 'mesh_id', 'bto_id', 'class_type', 'chromosome', 'inchikey', 'inchi', 'severity', 'description']
all_graph_node_instance =  list(set(db_entities["node_name"]))
label_dict = read_json('datas/graph_node_name_instances.json')

nl ="""Identify the gene names where the gene is either bound by Niclosamide (which resembles "tolcapone") or upregulated by Mycophenolic acid."""
 
nodes_score = detect_terms(all_graph_node_entities, nl)
nodes = [label for label, score in nodes_score]
print(nodes)
instances = align_instances(all_graph_node_instance, nl)
print("instances", instances)
if "Disease" in nodes:
    instances = [inst for inst in instances if inst != "Disease"]
print("instances", instances)
matchlist, rewrite_nl = instance_value_rewrite(instances, nl)
raw_pair = nodelabelextract(label_dict, matchlist)
node_instance_pair = format_instance_list(raw_pair)

print("matchlist:", matchlist)
print("rewrite_nl:", rewrite_nl)
print("node_instance_pair:", node_instance_pair)