# Chef Agent

In [30]:
from langchain_openai import OpenAIEmbeddings
from dotenv import load_dotenv
from langsmith import utils, traceable
import networkx as nx
import matplotlib.pyplot as plt

load_dotenv(dotenv_path='../.env')

utils.tracing_is_enabled()

True

## State

In [31]:
"""State management.

This module defines the state structures used by the conversational agent.

sources: https://github.com/langchain-ai/chat-langchain/blob/master/backend/retrieval_graph/state.py
"""

from dataclasses import dataclass, field
from typing import Annotated, Literal, Optional

from langchain_core.documents import Document
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
from typing_extensions import TypedDict

@dataclass(kw_only=True)
class InputState:
    """Represent the structure of the input state.

    This class defines the structure of accepted inputs that can be passed in externally from the user or upstream services.

    messages: A list of messages from current conversation.
    ingredients (optional): A list of ingredients tags to query on.
    recipe (optional): A recipe to be used in scope of the query.
    """

    messages: Annotated[list[AnyMessage], add_messages]
    ingredients: list[str] = field(default_factory=list)
    recipe: Optional[Document] = None

@dataclass(kw_only=True)
class OutputResponse(TypedDict):
    answer: str    
    sources: list[Document]

class QueryRouter(TypedDict):
    """Classify user query"""

    type: Literal["search", "ingredient_check", "ask_user_info", "review_and_reflect"]
    search_scope_type: Optional[Literal["recipe", "all"]]

class QueryRouterOverride(QueryRouter):
    """Additional Route Parameters"""

    # Optional, for nodes that call 
    route_override: str

class ChefState(InputState):
    """State of the chef agent."""
    
    query_router: QueryRouter = field(default_factory=lambda: QueryRouter(type="prepare_search_query", search_scope_type="all"))
    """The router's classification for the query."""

    documents: dict[Document] = field(default_factory=dict)
    """Populated documents from retrieval nodes."""

    answer: Optional[OutputResponse] = None




## Prompts

In [None]:
SEARCHER_PROMPT = """You are a helpful assistant that helps to gather information about recipes and cooking related topics.

<instructions>
<instruction> Answer the question based only on the context provided. </instruction>
</instructions>

<context> 
{context} 
</context>

<question> {question} </question>
"""

SOURCE_EXPLAINATION_PROMPT = """You are a helpful assistant that helps to gather information about recipes and cooking related topics.

<instructions>
<instruction> For the provided source, provide an short sentence explaination on why the source is relevent to the question. </instruction>
<instruction> Use the provided explaination template to generate the sentences</instruction>
</instructions>

<explaination_template>
    Found relevent source [source_id] : [explaination]
</explaination_template>

<context> 
    <question> {question} </question>
    <source> {context} </source>
</context>
"""

GENERATE_SEARCH_QUERY_PROMPT = """You are a helpful assistant that helps to descern what the user is looking to search for.

<instructions>
<instruction> Generate a search query based on the conversation content provided. </instruction>
<instruction> Not all messages may be relevent to the current question the user is trying to find. Try to descern the most recent query request the user made </instruction>
<instruction> re-write the users query to be more specific and concise </instruction>
<instruction> return the exact re-written query as your response and nothing else </instruction>
</instructions>
"""

## Search Node

In [33]:
from langchain_chroma.vectorstores import Chroma
from langchain_graph_retriever.transformers import ShreddingTransformer
from langchain_graph_retriever.adapters.chroma import ChromaAdapter
from graph_retriever.strategies import Eager
from langchain_graph_retriever import GraphRetriever

### Load GraphRag Retriever

In [34]:
def load_graph_traversal_retriever(embedding_model="text-embedding-3-large", collection_name="recipe_qa_combined", persist_directory="./data/recipe_qa_combined_chroma_db"):
    embeddings = OpenAIEmbeddings(model=embedding_model)
    shredder = ShreddingTransformer()
    vector_store = ChromaAdapter(
        Chroma(
            embedding_function=embeddings,
            collection_name=collection_name,
            persist_directory=persist_directory,
        ),
        shredder,
        {"keywords"}
    )
    return GraphRetriever(
    store = vector_store,
    edges = [("keywords", "keywords"), ("source_id", "source_id")],
    strategy = Eager(k=5, start_k=5, max_depth=3),
    )

In [None]:
from typing import Any, Literal, TypedDict, cast

from langchain_core.messages import BaseMessage
from langgraph.graph import END, START, StateGraph

from agent.core.state import QueryRouter, ChefState
from agent.core.prompts import ROUTER_SYSTEM_PROMPT
from langchain.chat_models import init_chat_model

from langchain.chat_models.base import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig, RunnablePassthrough, RunnableLambda, RunnableParallel
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from agent.utils import format_docs



In [None]:
# Agent class

class ChefAgent:
    """The chef agent class."""

    def __init__(self, llm: BaseChatModel, graph_retriever: GraphRetriever, tools = []):
        self.llm = llm
        self.graph_retriever = graph_retriever
        self.graph_builder = StateGraph(ChefState)
    
    @traceable(run_type="llm")
    async def _source_explaination(self, question, docs: list[Document], config: RunnableConfig):
        se_prompt = ChatPromptTemplate.from_template(SOURCE_EXPLAINATION_PROMPT)
        explaination_chain = se_prompt | self.llm | StrOutputParser()

        se_chain = RunnableParallel(
            question=RunnableLambda(lambda x: x["question"]),
            docs=RunnableLambda(lambda x: x["source"]),
            explainations=explaination_chain
        )

        formatted_docs = [{"question":question, "source": doc, **format_docs(doc)} for doc in docs]
        se_response = se_chain.batch(formatted_docs)
        return se_response 
    
    async def search(self, state: ChefState, config: RunnableConfig) -> dict[str, Any]:
        """Search for recipe and cooking related documents."""

        @traceable(run_type="chain")
        def format_docs(docs:list[Document], config=None):
            print(docs)
            # formatted_docs = []
            # for doc in docs:
            #     doc_type = doc.metadata.get("type")
            #     f_doc = f"text: {doc.page_content} metadata: {doc.metadata}"
            #     if doc_type == "recipe":
            return {"context":"\n\n".join(
                f"text: {doc.page_content} metadata: {doc.metadata}" for doc in docs
            )}
        ### Generate Search Query ###
        generate_search_query = await self.llm.ainvoke([
            {"role": "system", "content": GENERATE_SEARCH_QUERY_PROMPT}
        ]+state.messages)

        generated_query = generate_search_query["content"]

        ### Search On Graph ###
        search_prompt = ChatPromptTemplate(SEARCHER_PROMPT)

        ss_seperator = f"\n{'#'*20}\n"
        search_chain = (
            RunnableParallel(sources=self.graph_retriever, question=RunnablePassthrough())
            | RunnablePassthrough.assign(explainations=RunnableLambda(lambda x: self._source_explaination(x["question"], x["sources"])))
            | RunnablePassthrough.assign(
                search_summary=RunnableLambda(
                    lambda x: f"{ss_seperator}Search Results:\n" + "\n".join([f"> {e['explaination']}" for e in x['explainations']]) + ss_seperator))
        )

        search_response = await search_chain.ainvoke(generated_query)

        
        # load sources into state
        sources_added = 0
        for doc in search_response["sources"]:
            # check if document is already in state
            if doc.id not in state.documents.keys():
                sources_added += 1
                # add document to state
                state.documents[doc.id] = doc
        
        # TODO: if no new sources were added, return a message to the user

        


        

In [28]:
from langchain.chat_models import init_chat_model
from langchain_core.output_parsers import StrOutputParser

llm = init_chat_model("gpt-4o-mini", model_provider="openai")

test_messages = [
    {"role": "system", "content": GENERATE_SEARCH_QUERY_PROMPT},
    {"role": "user", "content": "What football games are on today?"},
    {"role": "assistant", "content": "The Cowboys play the Giants today at 7:pm."},
    {
        "role": "user", "content": "I am watching a football game and I don't know who what city the Cowboys play for.",
    },
]
# generate_search_query = await llm.ainvoke([
#             {"role": "system", "content": GENERATE_SEARCH_QUERY_PROMPT},
#             {"role": "user", "content": "What football games are on today?"},
#             {"role": "assistant", "content": "The Cowboys play the Giants today at 7:pm."},
#             {"role": "user", "content": "I am watching a football game and I don't know who what city the Cowboys play for."}
#         ])

gsq_prompt = ChatPromptTemplate.from_template(GENERATE_SEARCH_QUERY_PROMPT)

gsq_chain = llm | StrOutputParser()

generate_search_query = await gsq_chain.ainvoke(test_messages)

In [29]:
generate_search_query

'What city do the Dallas Cowboys play in?'