# Reflexion
Is an architecture designed to learn through verbal feedback and self-reflection. The agent explicitly critiques its reponses for tasks to generate a higher quality final response, at the expense of longer execution time.

In [1]:
from dotenv import load_dotenv
load_dotenv()

import re
import os
import json
import base64
import asyncio
import datetime
import platform
import requests
import operator
import playwright
import numpy as np
import pandas as pd
import datetime as dt

from enum import Enum
from typing import List
from typing import Dict
from typing import Tuple
from typing import Union
from typing import Literal
from typing import Optional
from typing import Sequence
from typing import Annotated
from typing import TypedDict
from operator import itemgetter
from collections import defaultdict

from IPython import display
from IPython.display import HTML
from IPython.display import Image

from langsmith import traceable

from langgraph.graph import END
from langgraph.graph import StateGraph
from langgraph.graph import MessageGraph
from langgraph.prebuilt import create_react_agent
from langgraph.prebuilt.tool_executor import ToolExecutor
from langgraph.prebuilt.tool_executor import ToolInvocation

from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings

from langchain_core.messages import BaseMessage
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.chat import ChatMessage
from langchain_core.messages.tool import ToolMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.system import SystemMessage
from langchain_core.messages.function import FunctionMessage
from langchain_core.prompts.image import ImagePromptTemplate

from langchain_core.pydantic_v1 import Field
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables import RunnableParallel
from langchain_core.pydantic_v1 import ValidationError
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers import JsonOutputParser

from langchain_core.runnables.graph import CurveStyle
from langchain_core.runnables.graph import NodeColors
from langchain_core.runnables.graph import MermaidDrawMethod

from langchain import hub
from langchain.schema import Document
from langchain.prompts import PromptTemplate
from langchain.prompts import ChatPromptTemplate
from langchain.prompts import MessagesPlaceholder
from langchain.prompts import HumanMessagePromptTemplate
from langchain.prompts import SystemMessagePromptTemplate
from langchain.agents import create_openai_functions_agent
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.output_parsers.openai_tools import PydanticToolsParser
from langchain.output_parsers.openai_tools import JsonOutputToolsParser

from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper

from langchain_fireworks.chat_models import ChatFireworks

In [2]:

os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ['LANGCHAIN_PROJECT'] = 'Reflexion'

In [3]:
search = TavilySearchAPIWrapper()
tavily_tool = TavilySearchResults(api_wrapper=search, max_results=5)

The tools are invoked _in context_. Create a function that invokes all the requested tools.

In [4]:
# this is a helper class we have that is useful for running tools
# it takes in an agent action and calls that tool and returns the result
tool_executor = ToolExecutor([tavily_tool])
# parse the tool messages for the execution / invocation
parser = JsonOutputToolsParser(return_id=True)

def execute_tools(state: List[BaseMessage]) -> List[BaseMessage]:
    tool_invocation: AIMessage = state[-1]
    parsed_tool_calls = parser.invoke(tool_invocation)
    ids = []
    tool_invocations = []
    for parsed_call in parsed_tool_calls:
        for query in parsed_call['args']['search_queries']:
            tool_invocations.append(
                ToolInvocation(
                    # we only have this one for now. Would want to map it if we change
                    tool = 'tavily_search_results_json',
                    tool_input='query'
                )
            )
            ids.append(parsed_call['id'])

    outputs = tool_executor.batch(tool_invocations)
    outputs_map = defaultdict(dict)
    for id_, output, invocation in zip(ids, outputs, tool_invocations):
        outputs_map[id_][invocation.tool_input] = output
    
    return [ToolMessage(content=json.dumps(query_outputs), tool_call_id=id_) for id_, query_outputs in outputs_map.items()]

In [5]:
actor_prompt = '''You are an expert researcher. 
Current time: {time}

1. {first_instruction}
2. Reflect and critique your answer. Be severe to maximize improvement.
3. Recommend search queries to research information and improve your answer.'''

actor_prompt_template = ChatPromptTemplate.from_messages(
    [
        (
            'system',
            actor_prompt
        ),
        MessagesPlaceholder(variable_name='messages'),
        ('system', 'Answer the user\'s question above using the required format.')
    ]
).partial(time=lambda: datetime.datetime.now().isoformat())


class Reflection(BaseModel):
    missing: str = Field(description='Critique of what is missing')
    superfluous: str = Field(description='Critique of what is superfluous')


class AnswerQuestion(BaseModel):
    '''Answer the question'''
    answer: str = Field(description='~250 word detailed answer to the question.')
    reflection: Reflection = Field(description='Your reflection on the initial answer.')
    search_queries: List[str] = Field(description='1 - 3 search queries for researching improvements to address the critique of your current answer.')


llm = ChatOpenAI(model='gpt-4-turbo')
initial_answer_chain = actor_prompt_template.partial(first_instruction='Provide a detailed ~250 word answer.') | llm.bind_tools(tools=[AnswerQuestion], tool_choice='AnswerQuestion')
validator = PydanticToolsParser(tools=[AnswerQuestion])


class ResponderWithRetries:
    def __init__(self, runnable, validator):
        self.runnable = runnable
        self.validator = validator

    @traceable
    def respond(self, state: List[BaseMessage]):
        response = []
        for attempt in range(3):
            try:
                response = self.runnable.invoke({'messages': state})
                self.validator.invoke(response)
                return response
            except ValidationError as e:
                print('RETRYING', attempt)
                state = state + [HumanMessage(content=repr(e))]

        return response

In [6]:
first_responder = ResponderWithRetries(runnable=initial_answer_chain, validator=validator)

In [None]:
example_question = 'Why is reflection useful in AI?'
initial = first_responder.respond([HumanMessage(content=example_question)])

In [None]:
parsed = parser.invoke(initial)
parsed

### Revision
The second part of the actor is a revision step

In [7]:
revise_instructions = '''Revise your previous answer using the new information. 
    - You should use the previous critique to add important information to your answer. 
        - You MUST include numerical citations in your revised answer to ensure it can be verified. 
        - Add a "References" section to the bottom of your answer (which does not count towards the word limit). In form of: 
            - [1] https://example.com
            - [2] https://example.com
    - You should use the previous critique to remove superfluous information from your answer and make SURE it is not more than 250 words.'''

# extend the initial answer schema to include references
# forcing citation in the model encourages grounded responses
class ReviseAnswer(AnswerQuestion):
    '''Revise your original answer to the question'''
    references: List[str] = Field(description='Citations motivating your updated answer')

revision_chain = actor_prompt_template.partial(first_instruction=revise_instructions) | llm.bind_tools(tools=[ReviseAnswer], tool_choice='ReviseAnswer')
revision_validator = PydanticToolsParser(tools=[ReviseAnswer])
revisor = ResponderWithRetries(runnable=revision_chain, validator=revision_validator)

In [None]:
revised = revisor.respond(
    [
        HumanMessage(content=''),
        initial,
        ToolMessage(
            tool_call_id=initial.additional_kwargs['tool_calls'][0]['id'],
            content=json.dumps(
                tavily_tool.invoke(str(parsed[0]['args']['search_queries']))
            )
        )
    ]
)

In [None]:
parsed = parser.invoke(revised)
parsed

## Construct Graph

In [8]:
MAX_ITERATIONS = 5
builder = MessageGraph()
builder.add_node('draft', first_responder.respond)
builder.add_node('execute_tools', execute_tools)
builder.add_node('revise', revisor.respond)

builder.add_edge('draft', 'execute_tools') # draft -> execute_tools
builder.add_edge('execute_tools', 'revise') # execute_tools -> revise

In [9]:
# explicit loop
def _get_num_iterations(state: List[BaseMessage]):
    i = 0
    for m in state[::-1]:
        if not isinstance(m, (ToolMessage, AIMessage)):
            break
        i += 1
    return i

def event_loop(state: List[BaseMessage]) -> str:
    # in our case we'll just stop after N plans
    num_iterations = _get_num_iterations(state)
    if num_iterations > MAX_ITERATIONS:
        return END
    return 'execute_tools'

# revise -> execute_tools OR end
builder.add_conditional_edges('revise', event_loop)
builder.set_entry_point('draft')
graph = builder.compile()

In [10]:
events = graph.stream(
    [HumanMessage(content='How should we handle the climate crisis?')]
)

for i, step in enumerate(events):
    node, output = next(iter(step.items()))
    print(f'## {i + 1}. {node}')
    print(str(output)[:100] + ' ...')
    print('---')

RETRYING 0
## 1. draft
content='' additional_kwargs={'tool_calls': [{'id': 'call_5sIG2uSL6qwsAHqeplrr24uW', 'function': {'a ...
---
## 2. execute_tools
[ToolMessage(content='{"query": [{"url": "https://www.dictionary.com/browse/query", "content": "to m ...
---
## 3. revise
content='' additional_kwargs={'tool_calls': [{'id': 'call_dmJ8BKh8JvMz4cjvaCJPI2zL', 'function': {'a ...
---
## 4. execute_tools
[ToolMessage(content='{"query": [{"url": "https://www.dictionary.com/browse/query", "content": "to m ...
---
## 5. revise
content='' additional_kwargs={'tool_calls': [{'id': 'call_QsRKwD36t9KrLH9TzVOmW8H9', 'function': {'a ...
---
## 6. execute_tools
[ToolMessage(content='{"query": [{"url": "https://www.vocabulary.com/dictionary/query", "content": " ...
---
## 7. revise
content='' additional_kwargs={'tool_calls': [{'id': 'call_q5x9PYUdhr3VBttyMzh9Ylms', 'function': {'a ...
---
