#### LangGraph Retrieval Agent

Retrieval Agents are useful when we want to make decisions about whether to retrieve from an index. 
To implement a retrieval agent, we simply need to give an LLM access to a retreiver tool.
We will be incoperating this into langgraph

In [1]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores.chroma import Chroma
from langchain.indexes.vectorstore import VectorStoreIndexWrapper
from langchain_openai.embeddings import OpenAIEmbeddings
f

In [2]:
from dotenv import load_dotenv
load_dotenv()

True

In [11]:
urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    # "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    # "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

In [12]:
from langchain_core.embeddings import Embeddings
from langchain.pydantic_v1 import BaseModel
from typing import Any

In [13]:
from openai import OpenAI

client = OpenAI()

def get_embeddings(texts, model="togethercomputer/m2-bert-80M-32k-retrieval"):
   texts = [text.replace("\n", " ") for text in texts]
   outputs = client.embeddings.create(input = texts, model=model)
   return [outputs.data[i].embedding for i in range(len(texts))]

texts=["hello"]
len(get_embeddings(texts))

1

In [14]:
from typing import Coroutine, List


class NewOpenAIEmbeddings(BaseModel, Embeddings):
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        return get_embeddings(texts)
    def aembed_documents(self, texts: List[str]) -> Coroutine[Any, Any, List[List[float]]]:
        return get_embeddings(texts)
    def aembed_query(self, text: str) -> Coroutine[Any, Any, List[float]]:
        return get_embeddings([text])[0]
    def embed_query(self, text: str) -> List[float]:
        return get_embeddings([text])[0]

In [17]:
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]


text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size =100, chunk_overlap = 50
)
doc_splits = text_splitter.split_documents(docs_list)
doc_splits = doc_splits[:50]

vectorstore = Chroma.from_documents(
    documents=doc_splits,
    collection_name="rag-chroma",
    embedding=NewOpenAIEmbeddings()
)
retriever = vectorstore.as_retriever()

Then we create a retriever tool.

In [18]:
from langchain.tools.retriever import create_retriever_tool

tool = create_retriever_tool(
    retriever, 
    "retrieve_blog_posts",
    "Search and return information about Lilian Weng blog posts on LLM agents"
)

tools = [tool]

from langgraph.prebuilt import ToolExecutor

tool_executor = ToolExecutor(tools)

#### Agent State
We will define the graph.

A state object that it passes around to each node.
Our state will be a list of `messages`.
Each node in our graph will append to it

In [19]:
import operator
from typing import Annotated, Sequence, TypedDict
from langchain_core.messages import BaseMessage

class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]


##### Nodes and Edges
We can lay out an agentic RAG graph like this:
- The state is a set of messages
- Each node will update (append to ) sate
- Conditional edges decide which node to visit next

In [21]:
import json
from langchain import hub
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from langchain.tools.render import format_tool_to_openai_function
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.messages import BaseMessage, FunctionMessage, HumanMessage
from langchain.output_parsers.openai_tools import PydanticToolsParser
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolInvocation
from langchain_core.output_parsers import StrOutputParser

In [39]:
def should_retrieve(state):
    """
    Decides whether the agent should retrieve more information or end the process. 
    This function checks the last message in the state for a function call. If a function call is present, the process continues to retrieve information.
    Otherwise, it ends the process
    """
    messages = state["messages"]
    
    last_message: BaseMessage = messages[-1]
    
    if "tool_calls" not in last_message.additional_kwargs:
        print("---DECISION: DO NOT RETRIEVE / DONE ---")
        return "end"
    
    else:
        print("---DECISION: RETRIEVE ---")
        return "continue"
        
        
def grade_documents(state: AgentState):
    """
    Determines whether the retrieved documents are relevant to the quesiton.
    """
    print("---CHECK RELEVANCE---")
    
    class grade(BaseModel):
        """ Binary score for relevance check."""
        binary_score:str = Field(description="Relevance score 'yes' or 'no'" )
    
    model = ChatOpenAI(temperature=0, model="mistralai/Mixtral-8x7B-Instruct-v0.1")
    
    grade_tool_oai = convert_to_openai_tool(grade)
    
    llm_with_tool = model.bind(
        tools = [convert_to_openai_tool(grade_tool_oai)],
        tool_choice={"type": "function", "function": {"name": "grade"}}
    )
    
    parser_tool = PydanticToolsParser(tools=[grade])
    
    prompt = PromptTemplate(
        template="""You are a grader assessing relevance of a retrieved document to a user question. \n
                    Here is the retrieved document: \n\n {context} \n\n
                    Here is the user question: {question} \n
                    If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
                    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \n""",
        input_variables=["context", "question"],
    )
    
    chain = prompt | llm_with_tool | parser_tool 
    messages = state["messages"]
    last_message = messages[-1]
    
    question = messages[0].content
    docs = last_message.content
    
    score = chain.invoke(
        {"question": question, "context": docs}
    )
    
    llm_grade = score[0].binary_score 
    if llm_grade == "yes":
        print("---DECISION: DOCS RELEVANT ---")
        return "yes"
    else:
        print("---DECISION: DOCS NOT RELVANT ---")
        print(score.binary_score)
        return "no"
    

def agent(state: AgentState):
    """
    Invoke the agent model to generate a response based on the current state. Given the question, it will decide to 
    retrieve using the retriever tool, or simply end.
    """
    
    print("---CALL AGENT---")
    messages = state["messages"]
    model = ChatOpenAI(model="mistralai/Mixtral-8x7B-Instruct-v0.1")
    functions = [format_tool_to_openai_function(t) for t in tools]
    model = model.bind_functions(functions)
    response = model.invoke(messages)
    
    return {"messages": [response]}

def retrieve(state: AgentState):
    """
    Uses tool to execute retrieval
    """
    print("--EXECUTE RETRIEVAL")
    messages = state["messages"]
    last_message = messages[-1]
    
    action = ToolInvocation(
        tool=last_message.additional_kwargs["tool_calls"][0]["function"]["name"],
        tool_input = json.loads(last_message.additional_kwargs["tool_calls"][0]["function"]["arguments"])
    )
    response = tool_executor.invoke(action)
    function_message = FunctionMessage(content=str(response), name=action.tool)
    
    return {"messages": [function_message]}


def rewrite(state: AgentState): 
    """ 
    Transform the query to produce a better question.
    """
    print("---TRANSFORM QUERY---")
    messages = state["messages"]
    
    question = messages[0].content
    
    msg = HumanMessage(
        content=f"""\n Look at the input and try to reason about the underlying semantic intent/meaning. 
        \n Here is the initial question: \n ---- \n {question} \n ----- \n Formulate an imporoved question:""")
    
    model= ChatOpenAI(temperature=0, model="mistralai/Mixtral-8x7B-Instruct-v0.1"),
    response = model.invoke(msg)
    return {"messages": [response]}


def generate(state: AgentState):
    """ 
    Generate answer
    """
    print("---GENERATE---")
    messages = state["messages"]
    question = messages[0].content
    last_message:BaseMessage = messages[-1]
    
    
    question = messages[0].content
    docs = last_message.content
    
    prompt = hub.pull("rlm/rag-prompt")
    
    llm = ChatOpenAI(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1")
    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)
    
    rag_chain = prompt | llm | StrOutputParser()
    
    response = rag_chain.invoke({"context": docs, "question": question})
    return {"messages": [response]}

#### Graph
- Start with an agent, `call_model`
- Agent make a decision to call a function
- If so, ten `action` to call tool (retriever)
- Then call agent with the tool output added to messages (`state`)


In [40]:
from langgraph.graph import END, StateGraph

workflow = StateGraph(AgentState)

workflow.add_node("agent", agent)
workflow.add_node("retrieve", retrieve)
workflow.add_node("rewrite", rewrite)
workflow.add_node("generate", generate)



In [41]:
workflow.set_entry_point("agent")

workflow.add_conditional_edges("agent", should_retrieve, {
    "continue": "retrieve",
    "end": END
})

workflow.add_conditional_edges(
    "retrieve",
    grade_documents,
    {
        "yes": "generate",
        "no": "rewrite",
    }
)


workflow.add_edge("generate", END)
workflow.add_edge("rewrite", "agent")

app = workflow.compile()

In [43]:
import pprint
from langchain_core.messages import HumanMessage

inputs ={
    "messages": [
        HumanMessage(
            content="What does Lilian Weng say about the types of agent memory?"
            # content="Hello there"
        )
    ]
}

for output in app.stream(inputs):
    for key, value in output.items():
        pprint.pprint(f"output from node {key}")
        pprint.pprint("----")
        pprint.pprint(value, indent=2, width=80, depth=None)
    pprint.pprint("\n----\n")

---CALL AGENT---
'output from node agent'
'----'
{ 'messages': [ AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_gcdgmqtfcmy7lgvi6wqjkgya', 'function': {'arguments': '{"query":"types of agent memory in Lilian Weng\'s blog"}', 'name': 'retrieve_blog_posts'}, 'type': 'function'}]})]}
'\n----\n'
---DECISION: RETRIEVE ---
--EXECUTE RETRIEVAL
'output from node retrieve'
'----'
{ 'messages': [ FunctionMessage(content='[Document(page_content=\'Building agents with LLM (large language model) as its core controller is a cool concept. Several proof-of-concepts demos, such as AutoGPT, GPT-Engineer and BabyAGI, serve as inspiring examples. The potentiality of LLM extends beyond generating well-written copies, stories, essays and programs; it can be framed as a powerful general problem solver.\\nAgent System Overview#\', metadata={\'description\': \'Building agents with LLM (large language model) as its core controller is a cool concept. Several proof-of-concepts demos, such as