# BMKG - Assignment 3

This notebook intends to automatically generate a SCHEMA for any given KG.

In [2]:
from operator import itemgetter
import getpass
import os

from typing import Any

from rdflib import Graph

from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain.memory import ConversationBufferMemory
from langchain_core.prompts import PromptTemplate, format_document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import get_buffer_string
from langchain_openai import OpenAI



In [211]:
if "OPENAI_API_KEY" not in os.environ:
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Provide your OpenAI API Key")

1) Get the RDF graph

In [212]:
with open('wikidata_drug_disease_schema.ttl', 'r') as file:
    schema = file.read()
    
graph = Graph()
graph = graph.parse("medical_graph.ttl")

print("Number of triples: ", len(graph))

Number of triples:  1873


In [360]:
llm = OpenAI(temperature=0)

In [361]:
from langchain_core.prompts import ChatPromptTemplate

# Create the memory object that is used to add messages
memory = ConversationBufferMemory(
    return_messages=True, output_key="answer", input_key="question"
)
# Add a "memory" key to the input object
loaded_memory = RunnablePassthrough.assign(
    chat_history=RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
)

# Prompt to reformulate the question using the chat history
reform_template = """Given the following chat history and a follow up question,
rephrase the follow up question to be a standalone straightforward question, in its original language.
Do not answer the question! Just rephrase reusing information from the chat history.

Chat History:
{chat_history}
Follow up input:
{question}

Standalone question:
"""
REFORM_QUESTION_PROMPT = PromptTemplate.from_template(reform_template)

breakdown_template = """ Given the following question, break it down to identify the known variables. Format these variables as follows:
Name --> Variable Type

Lastly, identify what the question is looking for (including the name, if provided). Only use the information presented in the question, nothing else.

Question:
{question}

Breakdown:
"""
BREAKDOWN_PROMPT = PromptTemplate.from_template(breakdown_template)

# Prompt to ask to answer the reformulated question
answer_template = """Construct a valid SPARQL query based on the provided breakdown, ensuring that labels are compared in lowercase using LCASE and not directly added as triples.

Based on the breakdown, the query should only retrieve the information we are looking for. Use the provided schema, although the query should not include everything!

Schema:
{schema}

Breakdown:
{breakdown}

Query:
"""
ANSWER_PROMPT = ChatPromptTemplate.from_template(answer_template)

In [362]:
# Reformulate the question using chat history
reformulated_question = {
    "reformulated_question": {
        "question": lambda x: x["question"],
        "chat_history": lambda x: get_buffer_string(x["chat_history"]),
    }
    | REFORM_QUESTION_PROMPT
    | llm
    | StrOutputParser(),
}
# Breakdown the question to identify what we are looking for and what we know
question_breakdown = {
    "breakdown" : {
        "question": lambda x: print("💭 Reformulated question:", x["reformulated_question"]) or x["reformulated_question"],
    }
    | BREAKDOWN_PROMPT
    | llm
    | StrOutputParser(),
}
final_inputs = {
    "schema": lambda x: schema,
    "breakdown": lambda x: print("💭 Question breakdown:\n", x["breakdown"]) or x["breakdown"],
}
answer = {
    "answer": final_inputs | ANSWER_PROMPT | llm,
}
# Put the chain together
final_chain = loaded_memory | reformulated_question | question_breakdown | answer

def stream_chain(final_chain, memory: ConversationBufferMemory, inputs_list: list[dict[str, str]]) -> dict[str, Any]:
    """Ask questions, stream the answer output, and return the answers."""
    output = {"answer": []}
    for inputs in inputs_list:
        answer_output = ""
        for chunk in final_chain.stream(inputs):
            if "answer" in chunk:
                answer_output += chunk["answer"]
                print(chunk["answer"], end="", flush=True)
        
        output["answer"].append(answer_output)
        # Add messages to chat history
        memory.save_context(inputs, {"answer": answer_output})
    
    return output

In [363]:
questions = [
    {"question": "What drug should I use to treat the disease 'vomiting'?"},
]
output = stream_chain(final_chain, memory, questions)

💭 Reformulated question: Which drug is recommended for treating the disease 'vomiting'?
💭 Question breakdown:
 
Known Variables:
- Disease: vomiting

Question is looking for:
- Recommended drug for treating vomiting
PREFIX ns1: <https://www.example.org/>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
PREFIX wd: <http://www.wikidata.org/entity/>

SELECT ?drug
WHERE {
    ?drug a ns1:Drug ;
        rdfs:label ?drugLabel ;
        ns1:treats ?disease .
    ?disease a ns1:Disease ;
        rdfs:label ?diseaseLabel .
    FILTER(LCASE(?diseaseLabel) = "vomiting")
}

In [364]:
import pandas as pd
from IPython.display import display, HTML
from pygments import highlight
from pygments.lexers import SparqlLexer
from pygments.formatters import HtmlFormatter

def run_query(graph, query):
    # Execute the SPARQL query
    results = graph.query(query)
    
    # Display the SPARQL query
    formatted_query = highlight(query, SparqlLexer(), HtmlFormatter(style='solarized-dark', full=True, nobackground=True))
    display(HTML(formatted_query))
    
    # Convert results to a Pandas DataFrame
    res_list = []
    for row in results:
        res_list.append([str(item) for item in row])
    df = pd.DataFrame(res_list, columns=[str(var) for var in results.vars]) if len(res_list) > 0 else pd.DataFrame()

    # Display the DataFrame as a table in Jupyter Notebook
    display(HTML(df.to_html()))

In [365]:
for answer_output in output["answer"]:
    print(answer_output)
    run_query(graph, answer_output)

PREFIX ns1: <https://www.example.org/>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
PREFIX wd: <http://www.wikidata.org/entity/>

SELECT ?drug
WHERE {
    ?drug a ns1:Drug ;
        rdfs:label ?drugLabel ;
        ns1:treats ?disease .
    ?disease a ns1:Disease ;
        rdfs:label ?diseaseLabel .
    FILTER(LCASE(?diseaseLabel) = "vomiting")
}


Unnamed: 0,drug
0,http://www.wikidata.org/entity/Q419079
1,http://www.wikidata.org/entity/Q7263592
