<a href="https://colab.research.google.com/github/ClockisTicking/My_Work_at_HF/blob/main/QA_on_knowledge_graph.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ! pip install gradio
# ! pip install --upgrade pip
# ! pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,inmemorygraph]

In [None]:
import logging
import gradio as gr
from haystack.utils import fetch_archive_from_http
from pathlib import Path
from haystack.document_stores import InMemoryKnowledgeGraph
from haystack.nodes import Text2SparqlRetriever


logging.basicConfig(format="%(levelname)s - %(name)s -  %(message)s", level=logging.WARNING)
logging.getLogger("haystack").setLevel(logging.INFO)

In [None]:
# Let's first fetch some triples that we want to store in our knowledge graph
# Here: exemplary triples from the wizarding world
graph_dir = "data/tutorial10"
s3_url = "https://fandom-qa.s3-eu-west-1.amazonaws.com/triples_and_config.zip"
fetch_archive_from_http(url = s3_url, output_dir = graph_dir)

# Fetch a pre-trained BART model that translates text queries to SPARQL queries
model_dir = "../saved_models/tutorial10_knowledge_graph/"
s3_url = "https://fandom-qa.s3-eu-west-1.amazonaws.com/saved_models/hp_v3.4.zip"
fetch_archive_from_http(url = s3_url, output_dir = model_dir)

INFO:haystack.utils.import_utils:Found data stored in 'data/tutorial10'. Delete this first if you really want to fetch new data.
INFO:haystack.utils.import_utils:Found data stored in '../saved_models/tutorial10_knowledge_graph/'. Delete this first if you really want to fetch new data.


False

In [None]:
kg = InMemoryKnowledgeGraph(index = "tutorial_10_index")

# Delete the index as it might have been already created in previous runs
kg.delete_index()

# Create the index
kg.create_index()

# Import triples of subject, predicate, and object statements from a ttl file
kg.import_from_ttl_file(index = "tutorial_10_index", path = Path(graph_dir) / "triples.ttl")
print(f"The last triple stored in the knowledge graph is: {kg.get_all_triples()[-1]}")
print(f"There are {len(kg.get_all_triples())} triples stored in the knowledge graph.")
kgqa_retriever = Text2SparqlRetriever(knowledge_graph = kg, model_name_or_path = Path(model_dir) / "hp_v3.4")

The last triple stored in the knowledge graph is: {'s': {'type': 'uri', 'value': 'https://deepset.ai/harry_potter/Wizards_chess'}, 'p': {'type': 'uri', 'value': 'https://deepset.ai/harry_potter/owners'}, 'o': {'type': 'uri', 'value': 'https://deepset.ai/harry_potter/Harry_potter'}}
There are 118543 triples stored in the knowledge graph.


In [None]:
kg

<haystack.document_stores.memory_knowledgegraph.InMemoryKnowledgeGraph at 0x7f734a773af0>

In [None]:
kgqa_retriever.retrieve

<bound method Text2SparqlRetriever.retrieve of <haystack.nodes.retriever.text2sparql.Text2SparqlRetriever object at 0x7f736c0a9310>>

In [None]:
from transformers import BartForConditionalGeneration, BartTokenizer


def get_sparql(query, top_k):
  model = BartForConditionalGeneration.from_pretrained(
            Path(model_dir) / "hp_v3.4", forced_bos_token_id = 0)
  tok = BartTokenizer.from_pretrained(Path(model_dir) / "hp_v3.4")
  inputs = tok([query], max_length=100, truncation=True, return_tensors="pt")
        # generate top_k+2 SPARQL queries so that we can dismiss some queries with wrong syntax
  temp = model.generate(
            inputs["input_ids"], num_beams = 5, max_length = 100, num_return_sequences = top_k + 2, early_stopping = True
        )
  sparql_queries = [
            tok.decode(g, skip_special_tokens = True, clean_up_tokenization_spaces = False) for g in temp
        ]
  return sparql_queries
  


In [None]:
from numpy import result_type
def get_res(query, top_k = 3):
  returned_res = []
  results = kgqa_retriever.retrieve(query = query, top_k = top_k)

  for res in results:
    returned_res.append("https://harrypotter.fandom.com/wiki/" + res['answer'][0].split("/")[-1].replace("potter", "Potter").replace("r_i", "r_I"))
  returned_res = set(returned_res)
  returned_res = list(returned_res)
  return returned_res

def get_query(query, top_k = 3):
  returned_res = []
  results = kgqa_retriever.retrieve(query = query, top_k = top_k)
  for res in results:
    returned_res.append(res["prediction_meta"]['sparql_query'])
  return "\n".join(returned_res)
  

def get_file():
  file_loc = "/content/data/tutorial10/triples.ttl"
  f = open(file_loc, "r")
  lines = f.readlines()
  return " ".join(lines[10:])

In [None]:
get_res("In which house is Harry Potter?", 3)

['https://harrypotter.fandom.com/wiki/Slytherin',
 'https://harrypotter.fandom.com/wiki/Gryffindor']

In [None]:
demo = gr.Blocks()

with demo:
    gr.Markdown("""# Q/A with Knowledge graph
                Querying knowledge graphs with the help of pre-trained models that translate text queries to SPARQL queries""")
    gr.HTML("<img src = 'file=/content/graph_sanple.png'/>")
    with gr.Tabs():
        with gr.TabItem("The Demo"):
            with gr.Row():
                text_input = gr.Textbox(label = "Input question")
                with gr.Column():
                  Answer = gr.Textbox(label = "Answers")
                  Sparql = gr.Textbox(label = "Sparql")
            
            answer_button = gr.Button("Get Answer")
            sparql_button = gr.Button("Get Sparql")
            gr.Examples(examples = [["in which house is Harry Potter"], ["who's harry potter's grandfather"], ["What is the patronus of Hermione"]], inputs = text_input, outputs = Answer)
        
        with gr.TabItem("The knowledge Triplet"):  
          triplet = gr.Textbox(label = "The knowledge Graph Triplets", lines = 10)
          KG_button = gr.Button("Get Triplets")

    answer_button.click(get_res, inputs = text_input, outputs = Answer)
    sparql_button.click(get_query, inputs = text_input, outputs = Sparql)
    KG_button.click(get_file, outputs = triplet)
demo.launch()

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Note: opening Chrome Inspector may crash demo inside Colab notebooks.

To create a public link, set `share=True` in `launch()`.


<IPython.core.display.Javascript object>

