# BMKG - Assignment 3

This notebook intends to automatically generate a SCHEMA for any given KG.

In [2]:
from operator import itemgetter
import getpass
import os

from typing import Any

from rdflib import Graph

from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain.memory import ConversationBufferMemory
from langchain_core.prompts import PromptTemplate, format_document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import get_buffer_string
from langchain_openai import OpenAI



In [3]:
if "OPENAI_API_KEY" not in os.environ:
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Provide your OpenAI API Key")

1) Get the RDF graph

In [170]:
with open('wikidata_drug_disease_schema.ttl', 'r') as file:
    schema = file.read()
    
graph = Graph()
graph = graph.parse("medical_graph.ttl")

print("Number of triples: ", len(graph))

Number of triples:  1873


In [176]:
llm = OpenAI(temperature=0)

In [192]:
from langchain_core.prompts import ChatPromptTemplate

# Create the memory object that is used to add messages
memory = ConversationBufferMemory(
    return_messages=True, output_key="answer", input_key="question"
)
# Add a "memory" key to the input object
loaded_memory = RunnablePassthrough.assign(
    chat_history=RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
)

# Prompt to reformulate the question using the chat history
reform_template = """Given the following chat history and a follow up question,
rephrase the follow up question to be a standalone straightforward question, in its original language.
Do not answer the question! Just rephrase reusing information from the chat history.
Make it short and straight to the point.

Chat History:
{chat_history}
Follow up input:
{question}

Standalone question:
"""
REFORM_QUESTION_PROMPT = PromptTemplate.from_template(reform_template)

# Prompt to ask to answer the reformulated question
answer_template = """Write a valid SPARQL query to answer the question based only on the following schema, do not use any information outside this schema:
{schema}

When comparing labels, always convert them to lower case using LCASE.

Question: {question}
"""
ANSWER_PROMPT = ChatPromptTemplate.from_template(answer_template)

In [193]:
# TODO: ask it to generate a schema as well?

# Reformulate the question using chat history
reformulated_question = {
    "reformulated_question": {
        "question": lambda x: x["question"],
        "chat_history": lambda x: get_buffer_string(x["chat_history"]),
    }
    | REFORM_QUESTION_PROMPT
    | llm
    | StrOutputParser(),
}
final_inputs = {
    "schema": lambda x: schema,
    "question": lambda x: print("💭 Reformulated question:", x["reformulated_question"]) or x["reformulated_question"],
}
answer = {
    "answer": final_inputs | ANSWER_PROMPT | llm,
}
# Put the chain together
final_chain = loaded_memory | reformulated_question | answer

def stream_chain(final_chain, memory: ConversationBufferMemory, inputs_list: list[dict[str, str]]) -> dict[str, Any]:
    """Ask questions, stream the answer output, and return the answers."""
    output = {"answer": []}
    for inputs in inputs_list:
        answer_output = ""
        for chunk in final_chain.stream(inputs):
            if "answer" in chunk:
                answer_output += chunk["answer"]
                print(chunk["answer"], end="", flush=True)
        
        output["answer"].append(answer_output)
        # Add messages to chat history
        memory.save_context(inputs, {"answer": answer_output})
    
    return output

In [194]:
questions = [
    {"question": "What drug treats lymphosarcoma"},
]
output = stream_chain(final_chain, memory, questions)

💭 Reformulated question: What is the drug used to treat lymphosarcoma?

PREFIX ns1: <https://www.example.org/>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
PREFIX wd: <http://www.wikidata.org/entity/>
SELECT ?drug
WHERE {
    ?drug a ns1:Drug ;
        rdfs:label ?drugLabel ;
        ns1:treats ?disease .
    ?disease a ns1:Disease ;
        rdfs:label ?diseaseLabel .
    FILTER(LCASE(?diseaseLabel) = "lymphosarcoma")
}

In [187]:
import pandas as pd
from IPython.display import display, HTML
from pygments import highlight
from pygments.lexers import SparqlLexer
from pygments.formatters import HtmlFormatter

def run_query(graph, query):
    # Execute the SPARQL query
    results = graph.query(query)
    
    # Display the SPARQL query
    formatted_query = highlight(query, SparqlLexer(), HtmlFormatter(style='solarized-dark', full=True, nobackground=True))
    display(HTML(formatted_query))
    
    # Convert results to a Pandas DataFrame
    res_list = []
    for row in results:
        res_list.append([str(item) for item in row])
    df = pd.DataFrame(res_list, columns=[str(var) for var in results.vars]) if len(res_list) > 0 else pd.DataFrame()

    # Display the DataFrame as a table in Jupyter Notebook
    display(HTML(df.to_html()))

In [195]:
for answer_output in output["answer"]:
    print(answer_output)
    run_query(graph, answer_output)


PREFIX ns1: <https://www.example.org/>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
PREFIX wd: <http://www.wikidata.org/entity/>

SELECT ?drug
WHERE {
    ?drug a ns1:Drug ;
        rdfs:label ?drugLabel ;
        ns1:treats ?disease .
    ?disease a ns1:Disease ;
        rdfs:label ?diseaseLabel .
    FILTER(LCASE(?diseaseLabel) = "lymphosarcoma")
}


Unnamed: 0,drug
0,http://www.wikidata.org/entity/Q415571
