In [1]:
import json
import os
from dotenv import load_dotenv

load_dotenv()

True

In [2]:
os.environ['LANGCHAIN_TRACING_V2'] = "true"
os.environ['LANGCHAIN_PROJECT'] = "lg-reflexion-agents"

## Actor
- Tools/tool execution
- Initial responder: generate an initial response (and self-reflection)
- Revisor: re-respond (and reflec) based on previous reflections

#### 1. Contruct tools

In [3]:
from collections import defaultdict
from typing import List


from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation

from langchain.output_parsers.openai_tools import (
    JsonOutputToolsParser,
    PydanticToolsParser
)

In [4]:
# Tools
search = TavilySearchAPIWrapper()
tavily_tool = TavilySearchResults(api_wrapper=search, max_results=5)

# helper class for running tools, it takes in an agent action and calls that tool and returns the result
tool_executor = ToolExecutor([tavily_tool])
parser = JsonOutputToolsParser(return_id=True) # parsing tool messages for the execution / invocation

# helper function to run the tool
def execute_tools(state: List[BaseMessage]) -> List[BaseMessage]:
    tool_invocation: AIMessage = state[-1]
    parsed_tool_calls = parser.invoke(tool_invocation)
    ids = []
    tool_invocations = []
    for parsed_call in parsed_tool_calls:
        for query in parsed_call["args"]["search_queries"]:
            tool_invocations.append(
                ToolInvocation(
                    tool="taivly_search_results_json",
                    tool_input=query
                )
            )
            ids.append(parsed_call["id"])
    
    outputs = tool_executor.batch(tool_invocation)
    outputs_map = defaultdict(dict)
    for id_, output, invocation in zip(ids, outputs, tool_invocations):
        outputs_map[id_][invocation.tool_input] = output
    
    return [
        ToolMessage(content=json.dumps(query_outputs), tool_call_id=id_)
        for id_, query_outputs in outputs_map.items()
    ]

#### 2. Initial Responder

In [5]:
import datetime

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
from langchain_openai import ChatOpenAI
from langsmith import traceable

In [6]:
# the prompt template for the actor - the responder - the expert researcher
actor_prompt_template = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            """You are expert researcher.
Current time: {time}

1. {first_instruction}
2. Reflect and critique your answer. Be severe to maximize improvement.
3. Recommend search queries to research information and improve your answer."""
        ),
        MessagesPlaceholder(variable_name="messages"),
        ("system", "Answer the user's question above using the required format.")
    ]
).partial(
    time=lambda: datetime.datetime.now().isoformat(),
)



class Reflection(BaseModel):
    """Explicitly prompting the criticism to generate both missing and superfluous 
    aspects of its response.
    """

    missing: str = Field(description="Critique of what is missing.")
    superfluous: str = Field(description="Critique of what is superfluous.")

class AnswerQuestion(BaseModel):
    """Answer the question."""

    answer: str = Field(description="~250 word detailed answer to the question.")
    reflection: Reflection = Field(description="Your reflection on the initial answer.")
    search_queries: List[str] = Field(
        description="1-3 search queries for researching improvements to address the critique of your current answer."
    )

In [7]:
llm = ChatOpenAI(model="gpt-4-turbo-preview")

initial_answer_chain = actor_prompt_template.partial(
    first_instruction="Provide a detailed ~250 word answer."
) | llm.bind_tools(tools=[AnswerQuestion], tool_choice="AnswerQuestion")

validator = PydanticToolsParser(tools=[AnswerQuestion])


class ResponderWithRetries:

    def __init__(self, runnable, validator):
        self.runnable = runnable
        self.validator = validator

    @traceable
    def respond(self, state: List[BaseMessage]):
        response = []
        for attempt in range(3):
            try:
                response = self.runnable.invoke({"messages": state})
                self.validator.invoke(response)
                return response
            except ValidationError as e:
                print(f"Validation error: {e}")
                print(f"Retrying attempt {attempt}")
                state = state + [HumanMessage(content=repr(e))]
        return response

In [8]:
first_responder = ResponderWithRetries(
    runnable=initial_answer_chain,
    validator=validator
)

In [9]:
example_question = "What is the importance of continuous batching in LLMs serving?"
initial = first_responder.respond([HumanMessage(content=example_question)])

In [10]:
parsed = parser.invoke(initial)
parsed

[{'type': 'AnswerQuestion',
  'args': {'answer': 'Continuous batching in the context of serving Large Language Models (LLMs) refers to the process of dynamically grouping incoming requests into batches for parallel processing. This technique is crucial for several reasons:\n\n1. **Efficiency**: By processing requests in batches, LLMs can leverage hardware acceleration (e.g., GPUs or TPUs) more effectively, leading to faster response times and lower per-request processing costs. It optimizes resource utilization by ensuring that the computational power is not idling between single requests.\n\n2. **Scalability**: Continuous batching allows for the system to adjust dynamically to varying loads, making it easier to handle peak traffic periods. This scalability is essential for maintaining performance and availability without over-provisioning resources.\n\n3. **Cost-effectiveness**: Since LLMs can be resource-intensive, optimizing their operation through batching reduces operational costs

#### 3. Revision
- The second part of the Actor

In [11]:
revise_instructions = """Revise your previous answer using the new information.
    - You should use the previous critique to add important information to your answer.
        - You MUST include numerical citations in your revised answer to ensure it can be verified.
        - Add a "References" section to the bottom of your answer (which does not count towards the word limit). In form of:
                - [1] https://www.example.com
                - [2] https://www.example.com
    - You should use the previous critique to remove superfluous information from your answer and make SURE it is not more than 250 words.
"""

# Extend the initial answer schema to include the references
class ReviseAnswer(AnswerQuestion):
    """Revise your original answer to the question."""
    
    references: List[str] = Field(description="Citations motivating your updated answer.")

In [12]:
revision_chain = actor_prompt_template.partial(
    first_instruction=revise_instructions
) | llm.bind_tools(tools=[ReviseAnswer], tool_choice="ReviseAnswer")

revision_validator = PydanticToolsParser(tools=[ReviseAnswer])

revisor = ResponderWithRetries(runnable=revision_chain, validator=revision_validator)

In [13]:
revised = revisor.respond(
    [
        HumanMessage(content=""),
        initial,
        ToolMessage(
            tool_call_id=initial.additional_kwargs["tool_calls"][0]["id"],
            content=json.dumps(
                tavily_tool.invoke(str(parsed[0]["args"]["search_queries"]))
            ),
        ),
    ]
)

In [14]:
parsed = parser.invoke(revised)
parsed

[{'type': 'ReviseAnswer',
  'args': {'answer': "Continuous batching in the context of serving Large Language Models (LLMs) significantly enhances their efficiency, scalability, cost-effectiveness, and Quality of Service (QoS). This process dynamically groups incoming requests into batches for parallel processing, leveraging hardware acceleration more effectively [1]. A real-world example of continuous batching's impact is observed in vLLM, where it was shown to reduce latency by immediately injecting new requests when possible and enabling advanced memory optimizations. These optimizations increased the Queries Per Second (QPS) that the serving system could handle before becoming saturated, significantly improving over static batching [5].\n\nEfficient resource utilization is crucial for maintaining performance and availability without over-provisioning resources, especially during peak traffic periods. By optimizing throughput and reducing latency, organizations can serve more users w

Graph

In [None]:
from langgraph.graph import END, MessageGraph

MAX_ITERATIONS = 5
builder = MessageGraph()
builder.add_node("draft", first_responder.respond)
builder.add_node("execute_tools", execute_tools)
builder.add_node("revise", revisor.respond)
# draft -> execute_tools
builder.add_edge("draft", "execute_tools")
# execute_tools -> revise
builder.add_edge("execute_tools", "revise")

# Define looping logic:


def _get_num_iterations(state: List[BaseMessage]):
    i = 0
    for m in state[::-1]:
        if not isinstance(m, (ToolMessage, AIMessage)):
            break
        i += 1
    return i


def event_loop(state: List[BaseMessage]) -> str:
    # in our case, we'll just stop after N plans
    num_iterations = _get_num_iterations(state)
    if num_iterations > MAX_ITERATIONS:
        return END
    return "execute_tools"


# revise -> execute_tools OR end
builder.add_conditional_edges("revise", event_loop)
builder.set_entry_point("draft")
graph = builder.compile()

In [None]:
events = graph.stream(
    [HumanMessage(content="What is the importance of continuous batching in LLMs serving?")]
)
for i, step in enumerate(events):
    
    node, output = next(iter(step.items()))
    print(f"## {i+1}. {node}")
    print(str(output) + " ...")
    print("---")

In [None]:
print(parser.invoke(step[END][-1])[0]["args"]["answer"])