# BMKG - Assignment 3
## Group 6

In [1]:
from operator import itemgetter
import getpass
import os

from typing import Any

from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain.memory import ConversationBufferMemory
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import get_buffer_string
from langchain_openai import OpenAI

from rdflib import Graph, URIRef, Literal, Namespace
from rdflib.namespace import RDF

from SPARQLWrapper import SPARQLWrapper, JSON



In [2]:
if "OPENAI_API_KEY" not in os.environ:
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Provide your OpenAI API Key")

1) Get the RDF graph

In [3]:
WD = Namespace("http://www.wikidata.org/entity/")
RDFS = Namespace("http://www.w3.org/2000/01/rdf-schema#")
ex = Namespace("https://www.example.org/")
wdt = Namespace("https://www.wikidata.org/prop/direct/")

In [4]:
sparql_query = """
SELECT DISTINCT ?drug ?drugLabel ?disease ?diseaseLabel ?cause ?causeLabel ?symptom ?symptomLabel ?sideEffect ?sideEffectLabel ?gene ?geneLabel
WHERE {
  ?drug wdt:P31 wd:Q12140 .  # medicine
  ?drug wdt:P2175 ?disease .    # treats disease
  OPTIONAL { ?disease wdt:P828 ?cause } .  # cause of disease
  OPTIONAL { ?disease wdt:P780 ?symptom } .  # symptom/complaint of disease
  OPTIONAL { ?drug wdt:P780 ?sideEffect } .  # potential side effect
  OPTIONAL { ?disease wdt:P1057 ?gene } .  # gene associated with disease
  SERVICE wikibase:label {
    bd:serviceParam wikibase:language "en" .
  }
}
"""

In [5]:
sparql = SPARQLWrapper("https://query.wikidata.org/sparql", agent="Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.11 (KHTML, like Gecko) Chrome/23.0.1271.64 Safari/537.11")
sparql.setQuery(sparql_query)
sparql.setReturnFormat(JSON)
results = sparql.query().convert()

# Create an RDF graph
g = Graph()

# Bind namespaces
g.bind("wd", WD)
g.bind("wdt", wdt)
g.bind("rdfs", RDFS)

# Process query results
for result in results["results"]["bindings"]:
    drug = URIRef(result["drug"]["value"])
    disease = URIRef(result["disease"]["value"])

    # Add RDF types
    g.add((drug, RDF.type, ex.Drug))
    g.add((disease, RDF.type, ex.Disease))
    if "cause" in result:
        cause = URIRef(result["cause"]["value"])
        g.add((cause, RDF.type, ex.Cause))
    if "symptom" in result:
        symptom = URIRef(result["symptom"]["value"])
        g.add((symptom, RDF.type, ex.Symptom))
    if "sideEffect" in result:
        side_effect = URIRef(result["sideEffect"]["value"])
        g.add((side_effect, RDF.type, ex.SideEffect))
    if "gene" in result:
        gene = URIRef(result["gene"]["value"])
        g.add((gene, RDF.type, ex.Gene))

    # Add labels
    g.add((drug, RDFS.label, Literal(result["drugLabel"]["value"])))
    g.add((disease, RDFS.label, Literal(result["diseaseLabel"]["value"])))
    if "cause" in result:
        g.add((cause, RDFS.label, Literal(result["causeLabel"]["value"])))
    if "symptom" in result:
        g.add((symptom, RDFS.label, Literal(result["symptomLabel"]["value"])))
    if "sideEffect" in result:
        g.add((side_effect, RDFS.label, Literal(result["sideEffectLabel"]["value"])))
    if "gene" in result:
        g.add((gene, RDFS.label, Literal(result["geneLabel"]["value"])))

    # Add relationships
    g.add((drug, ex.treats, disease))
    if "cause" in result:
        g.add((disease, ex.hasCause, cause))
    if "symptom" in result:
        g.add((disease, ex.hasSymptom, symptom))
    if "sideEffect" in result:
        g.add((drug, ex.hasSideEffect, side_effect))
    if "gene" in result:
        g.add((disease, ex.associatedGene, gene))

# Serialize the graph to a TTL file
g.serialize(destination='../data/medical_graph.ttl', format='ttl')

<Graph identifier=N803db38b79c24fb196dbf0583adea123 (<class 'rdflib.graph.Graph'>)>

In [6]:
#opening the files to use with the llm
with open('../data/wikidata_drug_disease_schema.ttl', 'r') as file:
    schema = file.read()
    
graph = Graph()
graph = graph.parse("../data/medical_graph.ttl")

print("Number of triples: ", len(graph))

Number of triples:  1873


In [7]:
llm = OpenAI(temperature=0)

In [8]:
from langchain_core.prompts import ChatPromptTemplate

# Create the memory object that is used to add messages
memory = ConversationBufferMemory(
    return_messages=True, output_key="answer", input_key="question"
)
# Add a "memory" key to the input object
loaded_memory = RunnablePassthrough.assign(
    chat_history=RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
)

# Prompt to reformulate the question using the chat history
reform_template = """Given the following chat history and a follow up question,
rephrase the follow up question to be a standalone straightforward question, in its original language.
Do not answer the question! Just rephrase reusing information from the chat history.

Chat History:
{chat_history}
Follow up input:
{question}

Standalone question:
"""
REFORM_QUESTION_PROMPT = PromptTemplate.from_template(reform_template)

breakdown_template = """ Given the following question, break it down to identify the known variables. Format it as follows:
Variable Type --> Variable Name

Lastly, identify what the question is looking for (including the name, if provided). Only use the information presented in the question, nothing else.

Question:
{question}

Breakdown:
"""
BREAKDOWN_PROMPT = PromptTemplate.from_template(breakdown_template)

# Prompt to ask to answer the reformulated question
answer_template = """Construct a valid SPARQL query based on the provided breakdown, ensuring that labels are compared in lowercase using LCASE and not directly added as triples.

Based on the breakdown, the query should only retrieve the information we are looking for. Use the provided schema, although the query should not include everything!

Schema:
{schema}

Breakdown:
{breakdown}

Query:
"""
ANSWER_PROMPT = ChatPromptTemplate.from_template(answer_template)

In [9]:
# Reformulate the question using chat history
reformulated_question = {
    "reformulated_question": {
        "question": lambda x: x["question"],
        "chat_history": lambda x: get_buffer_string(x["chat_history"]),
    }
    | REFORM_QUESTION_PROMPT
    | llm
    | StrOutputParser(),
}
# Breakdown the question to identify what we are looking for and what we know
question_breakdown = {
    "breakdown" : {
        "question": lambda x: print("💭 Reformulated question:", x["reformulated_question"]) or x["reformulated_question"],
    }
    | BREAKDOWN_PROMPT
    | llm
    | StrOutputParser(),
}
final_inputs = {
    "schema": lambda x: schema,
    "breakdown": lambda x: print("💭 Question breakdown:\n", x["breakdown"]) or x["breakdown"],
}
answer = {
    "answer": final_inputs | ANSWER_PROMPT | llm,
}
# Put the chain together
final_chain = loaded_memory | reformulated_question | question_breakdown | answer

def stream_chain(final_chain, memory: ConversationBufferMemory, inputs_list: list[dict[str, str]]) -> dict[str, Any]:
    """Ask questions, stream the answer output, and return the answers."""
    output = {"answer": []}
    for inputs in inputs_list:
        answer_output = ""
        for chunk in final_chain.stream(inputs):
            if "answer" in chunk:
                answer_output += chunk["answer"]
                print(chunk["answer"], end="", flush=True)
        
        output["answer"].append(answer_output)
        # Add messages to chat history
        memory.save_context(inputs, {"answer": answer_output})
    
    return output

In [10]:
import pandas as pd

from IPython.display import display, HTML
from pygments import highlight
from pygments.lexers import SparqlLexer
from pygments.formatters import HtmlFormatter

def run_query(graph, query, entity=None):
    # Execute the SPARQL query
    results = graph.query(query, initBindings={'entity': entity})
    
    # Display the SPARQL query
    formatted_query = highlight(query, SparqlLexer(), HtmlFormatter(style='solarized-dark', full=True, nobackground=True))
    display(HTML(formatted_query))
    
    # Convert results to a Pandas DataFrame
    res_list = []
    for row in results:
        res_list.append([str(item) for item in row])
    df = pd.DataFrame(res_list, columns=[str(var) for var in results.vars]) if len(res_list) > 0 else pd.DataFrame()

    # Display the DataFrame as a table in Jupyter Notebook
    display(HTML(df.to_html()))
    
    if len(df) > 1:
        return df.iloc[0, 0]

In [11]:
def get_uri_from_label(label, graph):
    """
    Given a human-readable label and a graph, return the URI corresponding to the label.
    If multiple URIs exist for a given label, this function returns the first match.
    """
    sparql_query = """
    SELECT ?entity WHERE {
        ?entity rdfs:label ?label .
        FILTER (LCASE(STR(?label)) = LCASE(?input_label))
    }
    LIMIT 1
    """
    query_result = graph.query(sparql_query, initBindings={'input_label': rdflib.Literal(label)})
    for row in query_result:
        return row[0]  # Return the first (and hopefully only) URI that matches the label
    return None  # Return None if no matching URI is found

In [14]:
#To combine our multi-hop reasoning with our llm, we ask the llm two questions, and we try to find the link between them.

import rdflib
from rdflib.extras.external_graph_libs import rdflib_to_networkx_digraph

def multi_hop_questioning(question1, question2, graph, final_chain, memory):
    # now we have our questions ready, we find our answers
    question1 = [
            {"question" : question1},
    ]
    output1 = stream_chain(final_chain, memory, question1)
    question2 = [
            {"question": question2},
    ]
    output2 = stream_chain(final_chain, memory, question2)
    
    for answer_output in output1["answer"]:
        print(answer_output)
        answer1 = run_query(graph, answer_output)
        
    for answer_output in output2["answer"]:
        print(answer_output)
        answer2 = run_query(graph, answer_output)

    dg = rdflib_to_networkx_digraph(graph)
    
    query_entity = get_uri_from_label(answer1, graph)
    target_relation = get_uri_from_label(answer2, graph)
    
    visited = set()
    queue = [(query_entity, [])]  # Start with the query entity and an empty relation path
    while queue:
        entity, relation_path = queue.pop(0)
        if entity == target_relation:
            return relation_path
        if entity not in visited:
            visited.add(entity)
            for neighbor in dg.neighbors(entity):
                new_path = relation_path + [(entity, neighbor)]
                queue.append((neighbor, new_path))
    return None

In [15]:
q1 = "What disease has as a symptom 'vomiting'?"
q2 = "What disease is caused by 'alcoholism'?"

x = multi_hop_questioning(q1, q2, graph, final_chain, memory)

💭 Reformulated question: What disease is associated with the symptom 'vomiting'?
💭 Question breakdown:
 
Known Variables:
- Symptom --> 'vomiting'

Question Looking For:
- Disease
PREFIX ns1: <https://www.example.org/>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
PREFIX wd: <http://www.wikidata.org/entity/>

SELECT ?disease_label
WHERE {
    ?drug_uri a ns1:Drug ;
        rdfs:label ?drug_label ;
        ns1:treats ?disease_uri .
    ?disease_uri a ns1:Disease ;
        rdfs:label ?disease_label ;
        ns1:hasSymptom ?symptom_uri ;
        ns1:hasCause ?cause_uri .
    ?symptom_uri a ns1:Symptom ;
        rdfs:label ?symptom_label ;
        ns1:hasCause ?cause_uri .
    ?cause_uri a ns1:Cause ;
        rdfs:label ?cause_label .
    FILTER (LCASE(?symptom_label) = "vomiting")
}💭 Reformulated question: What disease is caused by alcoholism?
💭 Question breakdown:
 
Known Variables:
- Variable Type: Disease
- Variable Name: Unknown

Question is looking for:
- Name of the disease 

KeyboardInterrupt: 

In [None]:
print(x)