In [1]:
from dotenv import load_dotenv
from pprint import pprint

load_dotenv()

True

# Helper utilities

1. Create a worker agent.
2. Create a supervisor for the sub-graph.

In [17]:
from typing import List, Optional
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

from langchain_groq import ChatGroq

from langgraph.graph import END, StateGraph, START
from langchain_core.messages import HumanMessage, trim_messages

In [18]:
llm = ChatGroq(model="llama-3.2-90b-vision-preview")

In [19]:
# trimmer = trim_messages(
#     max_tokens=100000,
#     strategy="last",
#     token_counter=llm,
#     include_system=True,
# )

In [20]:
def agent_node(state, agent, name):
    result = agent.invoke(state)
    return {
        "messages": [HumanMessage(content=result["messages"][-1].content, name=name)]
    }


def create_team_supervisor(llm, system_prompt, members) -> str:
    """An LLM-based router."""
    options = ["FINISH"] + members
    function_def = {
        "name": "route",
        "description": "Select the next role.",
        "parameters": {
            "title": "routeSchema",
            "type": "object",
            "properties": {
                "next": {
                    "title": "Next",
                    "anyOf": [
                        {"enum": options},
                    ],
                },
            },
            "required": ["next"],
        },
    }
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            MessagesPlaceholder(variable_name="messages"),
            (
                "system",
                "Given the conversation above, who should act next?"
                " Or should we FINISH? Select one of: {options}",
            ),
        ]
    ).partial(options=str(options), team_members=", ".join(members))
    return (
        prompt
        # | trimmer
        | llm.bind_functions(functions=[function_def], function_call="route")
        | JsonOutputFunctionsParser()
    )

# Research team

In [35]:
import functools
from langchain_core.messages import BaseMessage, HumanMessage
from team_tools import research_supervisor_prompt, tavily_search_tool, retriever_tool, arxiv_search_tool, web_scraper_tool, repl_tool
from langgraph.prebuilt import create_react_agent
from langgraph.graph.message import add_messages
from typing import TypedDict, Annotated, List


# class ResearchTeamState(TypedDict):
#     messages: Annotated[List[BaseMessage], add_messages]
#     team_members: List[str]
#     next: str

#     plan_string: str
#     steps: List
#     results: dict
#     current_task: str

# relm_sup = ChatGroq(model="llama-3.2-90b-vision-preview")

# relm_worker = ChatGroq(model="llama-3.1-70b-versatile")

# relm_sup_prompt = """You are a supervisor tasked with managing a conversation between the following workers: 
# - "TavilyInternetSearch": for general online searches when information is not available in other tools.
# - "QdrantVectorDBRetriever": prioritize this tool for questions about content already stored in the vector database, especially for queries that mention specific papers, abstracts, summaries, or details of documents likely indexed in the vector DB.
# - "ArxivSearch": use this to locate recent academic papers, research articles, or papers that may not be indexed in the vector database.
# - "WebScraper": for extracting detailed information from specific web pages.
# - "PythonREPL": for implementing computations, data processing, or summarizing retrieved results.

# **Guidance:**
# 1. **Prioritize QdrantVectorDBRetriever** for questions mentioning:
#    - Known paper titles (e.g., “Attention is All You Need”) or document references.
#    - Specific document sections (e.g., “abstract,” “conclusion,” “summary”) when the content is likely stored in the vector database.
#    - Scientific or technical concepts that may have been pre-indexed.

# 2. Use **ArxivSearch** for general academic paper searches when the document isn’t referenced by name and is unlikely to be in the vector DB.

# 3. Select **TavilyInternetSearch** for broad or general knowledge searches that extend beyond academic content, or when initial attempts to retrieve data from the Qdrant DB or Arxiv fail.

# 4. Choose **WebScraper** only when the query requires extracting data from specific URLs or known online resources not covered by other tools.

# 5. Use **PythonREPL** for summarization, data analysis, or any custom processing of results retrieved from other tools.

# For each user request, evaluate the task against these criteria, select the most suitable worker to handle it, and respond with **FINISH** once the task is complete.

# **Example Scenarios:**

# - **Query**: “What is the abstract of ‘Attention is All You Need’?”
#   - **Action**: First attempt to retrieve this from QdrantVectorDBRetriever as the abstract may already be stored.
  
# - **Query**: “Find the latest papers on reinforcement learning.”
#   - **Action**: Use ArxivSearch as this is a broad academic query with unspecified document references.

# - **Query**: “Summarize the conclusion of ‘Deep Learning’ by LeCun et al.”
#   - **Action**: Prioritize QdrantVectorDBRetriever if this document is indexed in the vector database.

# Begin handling tasks based on these instructions.


# Based on the given user request, decide which worker should act next. Each worker will complete a specific task and respond with their results and status. Once all tasks are completed and if you feel that the answer is sufficient, then respond with 'FINISH'."""



# # Creating agents for each tool
# internet_search_agent = create_react_agent(relm_worker, tools=[tavily_search_tool])
# internet_search_node = functools.partial(agent_node, agent=internet_search_agent, name="TavilyInternetSearch")

# retrieval_agent = create_react_agent(relm_worker, tools=[retriever_tool])
# retrieval_node = functools.partial(agent_node, agent=retrieval_agent, name="QdrantVectorDBRetriever")

# arxiv_agent = create_react_agent(relm_worker, tools=[arxiv_search_tool])
# arxiv_node = functools.partial(agent_node, agent=arxiv_agent, name="ArxivSearch")

# web_scraper_agent = create_react_agent(relm_worker, tools=[web_scraper_tool])
# web_scraper_node = functools.partial(agent_node, agent=web_scraper_agent, name="WebScraper")

# repl_agent = create_react_agent(relm_worker, tools=[repl_tool])
# repl_node = functools.partial(agent_node, agent=repl_agent, name="PythonREPL")




# supervisor_agent = create_team_supervisor(
#     llm,
#     relm_sup_prompt,
#     ["TavilyInternetSearch", "QdrantVectorDBRetriever", "ArxivSearch", "WebScraper", "PythonREPL"],
# )

# from langgraph.graph import StateGraph, END, START

# # Initialize the research graph
# research_graph = StateGraph(ResearchTeamState)

# # Add nodes for each agent in the research graph
# research_graph.add_node("TavilyInternetSearch", internet_search_node)
# research_graph.add_node("QdrantVectorDBRetriever", retrieval_node)
# research_graph.add_node("ArxivSearch", arxiv_node)
# research_graph.add_node("WebScraper", web_scraper_node)
# research_graph.add_node("PythonREPL", repl_node)
# research_graph.add_node("supervisor", supervisor_agent)

# # Define the control flow from each worker node back to the supervisor node
# research_graph.add_edge("TavilyInternetSearch", "supervisor")
# research_graph.add_edge("QdrantVectorDBRetriever", "supervisor")
# research_graph.add_edge("ArxivSearch", "supervisor")
# research_graph.add_edge("WebScraper", "supervisor")
# research_graph.add_edge("PythonREPL", "supervisor")

# # Define the supervisor's conditional edges based on the response
# research_graph.add_conditional_edges(
#     "supervisor",
#     lambda state: state["next"],
#     {
#         "TavilyInternetSearch": "TavilyInternetSearch",
#         "QdrantVectorDBRetriever": "QdrantVectorDBRetriever",
#         "ArxivSearch": "ArxivSearch",
#         "WebScraper": "WebScraper",
#         "PythonREPL": "PythonREPL",
#         "FINISH": END,
#     },
# )

# # Start the graph from the supervisor node
# research_graph.add_edge(START, "supervisor")
# chain = research_graph.compile()

# # Functions to initialize and handle research chain input/output
# def enter_chain(message: str):
#     results = {
#         "messages": [HumanMessage(content=message)],
#     }
#     return results

# # Define the research chain process
# research_chain = enter_chain | chain

# for s in research_chain.stream(
#     "what is MAGVIT?", {'recursion_limit': 100}
# ):
#     if "__end__" not in s:
#         print(s)
#         print("---")

{'supervisor': {'next': 'TavilyInternetSearch'}}
---
{'TavilyInternetSearch': {'messages': [HumanMessage(content='MAGVIT is a video generation model that uses masked token modeling and multi-task learning. It outperforms existing methods in quality, efficiency, and flexibility for various tasks such as frame interpolation, outpainting, and class-conditional generation.', additional_kwargs={}, response_metadata={}, name='TavilyInternetSearch', id='da458fa2-bcc7-4c28-9475-b869b20bb291')]}}
---
{'supervisor': {'next': 'QdrantVectorDBRetriever'}}
---
{'QdrantVectorDBRetriever': {'messages': [HumanMessage(content='<function=qdrant_retriever{"query": "MAGVIT"}</function>', additional_kwargs={}, response_metadata={}, name='QdrantVectorDBRetriever', id='0bf4a9ed-c3f2-4b3d-97a7-b10e9e4fcb8e')]}}
---
{'supervisor': {'next': 'TavilyInternetSearch'}}
---
{'TavilyInternetSearch': {'messages': [HumanMessage(content='<function=tavily_search_results_json{"query": "MAGVIT video generation model explain

# New teams

## research team

In [1]:
from typing import List, Optional
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import END, StateGraph, START
from langchain_core.messages import HumanMessage, trim_messages

from langchain_groq import ChatGroq

In [2]:
llm = ChatGroq(model='llama-3.2-90b-vision-preview')

In [3]:
trimmer = trim_messages(
    max_tokens=120000,
    strategy="last",
    token_counter=llm,
    include_system=True,
)

In [4]:
def agent_node(state, agent, name):
    result = agent.invoke(state)
    return {
        "messages": [HumanMessage(content=result["messages"][-1].content, name=name)]
    }

def create_team_supervisor(llm: ChatGroq, system_prompt, members) -> str:
    """An LLM-based router."""
    options = ["FINISH"] + members
    function_def = {
        "name": "route",
        "description": "Select the next role.",
        "parameters": {
            "title": "routeSchema",
            "type": "object",
            "properties": {
                "next": {
                    "title": "Next",
                    "anyOf": [
                        {"enum": options},
                    ],
                },
            },
            "required": ["next"],
        },
    }
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            MessagesPlaceholder(variable_name="messages"),
            (
                "system",
                "Given the conversation above, who should act next?"
                " Or should we FINISH? Select one of: {options}",
            ),
        ]
    ).partial(options=str(options), team_members=", ".join(members))
    return (
        prompt
        | trimmer
        | llm.bind_functions(functions=[function_def], function_call="route")
        | JsonOutputFunctionsParser()
    )

In [5]:
import functools
import operator

from langchain_core.messages import BaseMessage, HumanMessage
from langchain_openai.chat_models import ChatOpenAI
from langgraph.prebuilt import create_react_agent
from typing import TypedDict, Annotated, Literal, Union, List

# ResearchTeam graph state
class ResearchTeamState(TypedDict):
    messages: Annotated[List[BaseMessage], operator.add]
    team_members: List[str]
    next: str

In [20]:
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from team_tools import Qretriever

retriever = Qretriever


# Prompt
prompt = hub.pull("rlm/rag-prompt")
# Post-processing
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

from langchain_core.runnables import RunnableParallel
from langchain_core.runnables import RunnablePassthrough

rag_chain_from_docs = (
    RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
    | prompt
    | llm
    | StrOutputParser()
)

rag_chain_with_source = RunnableParallel(
    {"context": retriever, "question": RunnablePassthrough()}
).assign(answer=rag_chain_from_docs)

rag_chain_with_source.invoke("What is attention")

{'context': [Document(metadata={'pdf_id': 'a.pdf', 'score': 0.61511683}, page_content='The decoder is also composed of a stack of N = 6 identical layers. In addition to the two\nsub-layers in each encoder layer, the decoder inserts a third sub-layer, which performs multi-head\nattention over the output of the encoder stack. Similar to the encoder, we employ residual connections\naround each of the sub-layers, followed by layer normalization. We also modify the self-attention\nsub-layer in the decoder stack to prevent positions from attending to subsequent positions. This\nmasking, combined with fact that the output embeddings are offset by one position, ensures that the\npredictions for position i can depend only on the known outputs at positions less than i.\n3.2\nAttention\nAn attention function can be described as mapping a query and a set of key-value pairs to an output,\nwhere the query, keys, values, and output are all vectors. The output is computed as a weighted sum\n3\nScaled 

In [21]:
response = rag_chain_with_source.invoke("What is attention, and how is it realted to SAN pdf")

# Access the answer and sources
answer = response['answer']
sources = [doc.metadata['pdf_id'] for doc in response['context']]

In [22]:
answer, sources[0]

('Attention is a mapping of a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key. In the context of SAN (Scaled Dot-Product Attention Network), attention is used in the decoder to perform multi-head attention over the output of the encoder stack.',
 'SAN.pdf')