In [None]:
#TODO: replace run() by invoke() in agent calling (run going to be deprecated in langchain 0.2.0)
from dotenv import load_dotenv
import os

# Load the environment variables from the .env file
load_dotenv()

In [None]:
# #Setting up the LangSmith
# #For now, all runs will be stored in the "KGBot Testing - GPT4"
# #If you want to separate the traces to have a better control of specific traces.
# #Metadata as llm version and temperature can be obtaneid from traces. 
# import os
# #from uuid import uuid4
# #unique_id = uuid4().hex[0:8]

os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = f"KGBot Testing - GPT4" #Please update the name here if you want to create a new project for separating the traces. 
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
from langsmith import Client
client = Client()

# #Check if the client was initialized
print(client)

In [None]:
from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser
from langchain.prompts import BaseChatPromptTemplate
from langchain.utilities import SerpAPIWrapper
from langchain.chains.llm import LLMChain
# from langchain.chat_models import ChatOpenAI
from langchain_openai import ChatOpenAI
from typing import Annotated, List, Tuple, Union

from langchain_core.messages import BaseMessage, HumanMessage
from langchain.schema import AgentAction, AgentFinish, HumanMessage
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain import hub
from langgraph.prebuilt import ToolInvocation
import json
from langchain_core.messages import FunctionMessage
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

import re

from RdfGraphCustom import RdfGraph
from smile_resolver import smiles_to_inchikey
from chemical_resolver import ChemicalResolver
from target_resolver import target_name_to_target_id
from taxon_resolver import TaxonResolver

from sparql import GraphSparqlQAChain

from langchain.agents import  Tool
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool, StructuredTool, tool

from typing import (
    List,
    Tuple
)

import pickle

In [None]:
####################### Instantiate the graph #######################

endpoint_url = 'https://enpkg.commons-lab.org/graphdb/repositories/ENPKG'

# init with the endpoint url
graph = RdfGraph(
    query_endpoint=endpoint_url,
    standard="rdf")

# #OR

# graph = RdfGraph(
#     query_endpoint=endpoint_url,
#     standard="rdf",
#     schema_file='../graphs/schema.ttl')


# save the graph on disk
with open('../graphs/graph.pkl', 'wb') as output_file:
    pickle.dump(graph, output_file)
    

####################### Load the graph from disk #######################
    
# # load the graph from disk
# with open('../graphs/graph.pkl', 'rb') as input_file:
#     graph = pickle.load(input_file)
    
print(graph.get_schema)


In [None]:
temperature = 0
model_id_gpt4 = "gpt-4" 
model_id = "gpt-4-0125-preview"
#defining gpt4 llm for supervisor
llm_gpt4 = ChatOpenAI(temperature=temperature, 
                    model=model_id_gpt4, # default is 'gpt-3.5-turbo'
                    max_retries=3,
                    verbose=True,
                    # model_kwargs={
                    #     "top_p": 0.95,
                    #     }
                    )
# https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.openai.ChatOpenAI.html?highlight=chatopenai#
llm = ChatOpenAI(temperature=temperature, 
                    model=model_id, # default is 'gpt-3.5-turbo'
                    max_retries=3,
                    verbose=True,
                    # model_kwargs={
                    #     "top_p": 0.95,
                    #     }
                    )

#https://api.python.langchain.com/en/latest/_modules/langchain/chains/graph_qa/sparql.html#
sparql_chain = GraphSparqlQAChain.from_llm(
    llm, graph=graph, verbose=True
)

chem_res = ChemicalResolver.from_llm(llm=llm, verbose=True)
taxon_res = TaxonResolver()

class ChemicalInput(BaseModel):
    query: str = Field(description="natural product compound string")



class SparqlInput(BaseModel):
    question: str = Field(description="the original question from the user")
    entities: str = Field(description="strings containing for all entities, entity name and the corresponding entity identifier")


tools_resolver = [
    StructuredTool.from_function(
        name = "CHEMICAL_RESOLVER",
        func = chem_res.run,
        description="The function takes a natural product compound string as input and returns a InChIKey, if InChIKey not found, it returns the NPCClass, NPCPathway or NPCSuperClass.",
        args_schema=ChemicalInput,
    ),
    StructuredTool.from_function(
        name = "TAXON_RESOLVER",
        func=taxon_res.query_wikidata,
        description="The function takes a taxon string as input and returns the wikidata ID.",
    ),
    StructuredTool.from_function(
        name = "TARGET_RESOLVER",
        func=target_name_to_target_id,
        description="The function takes a target string as input and returns the ChEMBLTarget IRI.",
    ),
    StructuredTool.from_function(
        name = "SMILE_CONVERTER",
        func=smiles_to_inchikey,
        description="The function takes a SMILES string as input and returns the InChIKey notation of the molecule.",
    )
    ]
    
tool_sparql = [  
     StructuredTool.from_function(
        name = "SPARQL_QUERY_RUNNER",
        func=sparql_chain.run,
        description="The agent resolve the user's question by querying the knowledge graph database. Input should be a question and the resolved entities in the question.",
        args_schema=SparqlInput,
        # return_direct=True,
    )
]


tool_names = [tool.name for tool in tools_resolver]


# chain.run(q3)

#  Helper Utilities to create agent

In [None]:
def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
    # Each worker node will be given a name and some tools.
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                system_prompt,
            ),
            MessagesPlaceholder(variable_name="messages"),
            MessagesPlaceholder(variable_name="agent_scratchpad"),
        ]
    )
    agent = create_openai_tools_agent(llm, tools, prompt)
    executor = AgentExecutor(agent=agent, tools=tools)
    return executor

# Create Agent Supervisor


In [None]:
members = ["Resolver", "Sparql_query_runner"]
system_prompt = ( "You are a supervisor. As the supervisor, your primary role is to coordinate the flow of information between agents and ensure the appropriate processing of the user question based on its content. You have access to a team of specialized agents: {members}"
" Here is a list of steps to help you accomplish your role:"
" - Analyse the user question and delegate functions to the specialized agents below if needed:"
" If the question mentions any of the following entities: natural product compound, chemical name, taxon name, target, SMILES structure, or numerical value delegate the question to the Resolver agent. Resolver would provide resolved entities needed to generate SPARQL query. For example if the question mentions either caffeine, or Desmodium heterophyllum call Resolver agent."
" If you have answers from the agent mentioned above, you provide those answers with the user question to the Sparql_query_runner."
" If the question does not mention chemical name, taxon name, target name, nor SMILES structure, delegate the question to the agent Sparql_query_runner."
" Once the Sparql_query_runner has completed its task and provided the answer, mark the process as FINISH. Do not call the Sparql_query_runner again."
" For example, the user provides the following question: For features from Melochia umbellata in PI mode with SIRIUS annotations, get the ones for which a feature in NI mode with the same retention time  has the same SIRIUS annotation. Since the question mentions Melochia umbellata you should firstly delegate it to the Resolver which would provide wikidata IRI with TAXON_RESOLVER tool, and then, you should delegate the question together with the output generated by Resolver agent to the Sparql_query_runner agent."
" Avoid calling the same agent if this agent has already been called previously and provided the answer. For example, if you have called Resolver and it provided InChIKey for chemical compound do not call this agent again. "
" - Collect the answer from Sparql_query_runner and provide the final assembled response back to the user."
" Always tell the user the SPARQL query that have been returned by the Sparql_query_runner."
" If the agent does not provide the expected output mark the process as FINISH."
" Remember, your efficiency in routing the questions accurately and collecting responses is crucial for the seamless operation of our system. If you don't know the answer to any of the steps, please say explicitly and help the user by providing a query that you think will be better interpreted.")

def create_team_supervisor(llm: ChatOpenAI, system_prompt, members) -> str:
    """An LLM-based router."""
    options = ["FINISH"] + members
    function_def = {
        "name": "route",
        "description": "Select the next role.",
        "parameters": {
            "title": "routeSchema",
            "type": "object",
            "properties": {
                "next": {
                    "title": "Next",
                    "anyOf": [
                        {"enum": options},
                    ],
                },
            },
            "required": ["next"],
        },
    }
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            MessagesPlaceholder(variable_name="messages"),
            (
                "system",
                "Given the conversation above, who should act next?"
                " Or should we FINISH? Select one of: {options}",
            ),
        ]
    ).partial(options=str(options), members=", ".join(members))
    return (
        prompt
        | llm.bind_functions(functions=[function_def], function_call="route")
        | JsonOutputFunctionsParser()
    )

## Prompt for Resolver agent

In [None]:


system_message_resolver = """You are an entity resolution agent for the Sparql_query_runner.
You have access to the following tools:
{tool_names}
You should analyze the question and provide resolved entities to the supervisor. Here is a list of steps to help you accomplish your role:
If the question ask anything about any entities that could be natural product compound, find the relevant IRI to this chemical class using CHEMICAL_RESOLVER. Input is the chemical class name. For example, if salicin is mentioned in the question, provide its IRI using CHEMICAL_RESOLVER, input is salicin. 

If a taxon is mentionned, find what is its wikidata IRI with TAXON_RESOLVER. Input is the taxon name. For example, if the question mentions acer saccharum, you should provide it's wikidata IRI using TAXON_RESOLVER tool.")

If a target is mentionned, find the ChEMBLTarget IRI of the target with TARGET_RESOLVER. Input is the target name.

If a SMILE structure is mentionned, find what is the InChIKey notation of the molecule with SMILE_CONVERTER. Input is the SMILE structure. For example, if there is a string with similar structure to CCC12CCCN3C1C4(CC3) in the question, provide it to SMILE_CONVERTER.
        
Give me units relevant to numerical values in this question. Return nothing if units for value is not provided.
Be sure to say that these are the units of the quantities found in the knowledge graph.
Here is the list of units to find:
    "retention time": "minutes",
    "activity value": null, 
    "feature area": "absolute count or intensity", 
    "relative feature area": "normalized area in percentage", 
    "parent mass": "ppm (parts-per-million) for m/z",
    "mass difference": "delta m/z", 
    "cosine": "score from 0 to 1. 1 = identical spectra. 0 = completely different spectra"


 You are required to submit only the final answer to the supervisor.
        
""".format(tool_names="\n".join(tool_names))




prompt_resolver = ChatPromptTemplate.from_messages(
    [
        ("system", system_message_resolver),
        MessagesPlaceholder("chat_history", optional=True),
        ("human", "{input}"),
        MessagesPlaceholder("agent_scratchpad"),
    ]
)

prompt_resolver.pretty_print()



## Defining separate agents

In [None]:
resolver_agent=create_agent(llm, tools_resolver, system_message_resolver)
sparql_query_agent = create_agent(llm, tool_sparql, "You are sparql query runner, you take as input the user request and resolved entities provided by other agents, generate a SPARQL query, run it on the knowledge graph and answer the question using SPARQL_QUERY_RUNNER tool. You are required to submit only the final answer to the supervisor. If you could not get the answer, provide the SPARQL query generated.")

# Construct Graph

In [None]:
import operator
from typing import Annotated, Any, Dict, List, Optional, Sequence, TypedDict
import functools

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import StateGraph, END

# The agent state is the input to each node in the graph
class AgentState(TypedDict):
    # The annotation tells the graph that new messages will always
    # be added to the current states
    messages: Annotated[Sequence[BaseMessage], operator.add]
    # The 'next' field indicates where to route to next
    next: str

#function to define nodes
def agent_node(state, agent, name):
    result = agent.invoke(state)
    return {"messages": [HumanMessage(content=result["output"], name=name)]}

#creating nodes for each agent
resolver_node= functools.partial(agent_node, agent=resolver_agent, name="Resolver")
sparql_query_node= functools.partial(agent_node, agent=sparql_query_agent, name="Sparql_query_runner")
supervisor_agent=create_team_supervisor(llm_gpt4, 
                                        system_prompt, members)

#creating the workflow and adding nodes to it
workflow = StateGraph(AgentState)
workflow.add_node("Resolver", resolver_node)
workflow.add_node("Sparql_query_runner",sparql_query_node)
workflow.add_node("supervisor", supervisor_agent)
#connect all the edges in the graph
for member in members:
    # We want our workers to ALWAYS "report back" to the supervisor when done
    workflow.add_edge(member, "supervisor")

workflow.add_conditional_edges("supervisor",
    lambda x: x["next"],
    {"Resolver": "Resolver", "Sparql_query_runner": "Sparql_query_runner",  "FINISH": END}
)

workflow.set_entry_point("supervisor")
app= workflow.compile()

In [None]:
q1 = 'How many features (pos ionization and neg ionization modes) have the same SIRIUS/CSI:FingerID and ISDB annotation by comparing the InCHIKey of the annotations?'
q1_bis = 'How many features (pos ionization and neg ionization modes) have the same SIRIUS/CSI:FingerID and ISDB annotation by comparing the InCHIKey2D of the annotations?'
q2 = 'Which extracts have features (pos ionization mode) annotated as the class, aspidosperma-type alkaloids, by CANOPUS with a probability score above 0.5, ordered by the decresing count of features as aspidosperma-type alkaloids? Group by extract.'
q3 = 'Among the structural annotations from the Tabernaemontana coffeoides (Apocynaceae) seeds extract taxon , which ones contain an aspidospermidine substructure, CCC12CCCN3C1C4(CC3)C(CC2)NC5=CC=CC=C45?'
q4 = 'Among the SIRIUS structural annotations from the Tabernaemontana coffeoides (Apocynaceae) seeds extract taxon, which ones are reported in the Tabernaemontana genus in Wikidata?'
q5 = 'Which compounds have annotations with chembl assay results indicating reported activity against T. cruzi by looking at the cosmic, zodiac and taxo scores?'
q5_bis = 'Which compounds have annotations with chembl assay results indicating reported activity against Trypanosoma cruzi by looking at the cosmic, zodiac and taxo scores?'
q6 = 'Filter the pos ionization mode features of the Melochia umbellata taxon annotated as [M+H]+ by SIRIUS to keep the ones for which a feature in neg ionization mode is detected with the same retention time (+/- 3 seconds) and a mass corresponding to the [M-H]- adduct (+/- 5ppm).'
q7 = 'For features from the Melochia umbellata taxon in pos ionization mode with SIRIUS annotations, get the ones for which a feature in neg ionization mode with the same retention time (+/- 3 seconds) has the same SIRIUS annotation by comparing the InCHIKey 2D. Return the features, retention times, and InChIKey2D'
q8 = "Which features were annotated as 'Tetraketide meroterpenoids' by SIRIUS, and how many such features were found for each species and plant part?"
q9 = "What are all distinct submitted taxons for the extracts in the knowledge graph?"
q10 = "What are the taxons, lab process and label (if one exists) for each sample? Sort by sample and then lab process"
q11 = "Count all the species per family in the collection"

q12 = "Taxons can be found in enpkg:LabExtract. Find the best URI of the Taxon in the context of this question : \n Among the structural annotations from the Tabernaemontana coffeoides (Apocynaceae) seeds extract taxon , which ones contain an aspidospermidine substructure, CCC12CCCN3C1C4(CC3)C(CC2)NC5=CC=CC=C45?"


In [None]:
for s in app.stream(
    {
        "messages": [
            HumanMessage(content=q2)
        ]
    },
    {"recursion_limit": 100},
):
    if "__end__" not in s:
        print(s)
        print("----")

In [None]:

result = agent_executor.invoke({"input": q1_bis})

In [None]:
agent_executor.invoke({"input": q2})

In [None]:
agent_executor.invoke({"input": q3})

In [None]:
agent_executor.invoke({"input": q5})

In [None]:
agent_executor.invoke({"input": q5_bis})

In [None]:
agent_executor.invoke({"input": q6})

In [None]:
agent_executor.invoke({"input": q7})

In [None]:
agent_executor.invoke({"input": q8})

In [None]:
agent_executor.invoke({"input": q9})

In [None]:
agent_executor.invoke({"input": q10})

In [None]:
agent_executor.invoke({"input": q11})

In [None]:
agent_executor.invoke({"input": q12})