# Reflection
https://blog.langchain.dev/reflection-agents/

In the context of LLM agent building, reflection refers to the process of prompting an LLM to observe its past steps (along with potential observations from tools/the environment) to assess the quality of the chosen actions.
This is then used downstream for things like re-planning, search, or evaluation.

![Reflection](./img/reflection.png)

This notebook demonstrates a very simple form of reflection in LangGraph.

#### Prerequisites

We will be using a basic agent with a search tool here.

In [1]:
%pip install -U --quiet  langchain langgraph langchain_openai langchain-experimental


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m817.0/817.0 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m171.6/171.6 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m246.4/246.4 kB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.1/62.1 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m226.7/226.7 kB[0m [31m13.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m32.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━

In [2]:
import getpass
import os


def _set_if_undefined(var: str) -> None:
    if os.environ.get(var):
        return
    os.environ[var] = getpass.getpass(var)


# Optional: Configure tracing to visualize and debug the agent
#_set_if_undefined("LANGCHAIN_API_KEY")
#os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "Reflection"
#OPENAI_API_KEY = getpass()
_set_if_undefined("OPENAI_API_KEY")
#_set_if_undefined("FIREWORKS_API_KEY")

OPENAI_API_KEY··········


## Generate

For our example, we will create a "5 paragraph essay" generator. First, create the generator:


In [20]:
from langchain_openai import ChatOpenAI
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            " You are a professional SQL expert. Given the following SQL tables, your job is to write queries given a user’s request.\n  Table orders \n  id int primary key \n  user_id int not null, unique \n  status varchar \n  created_at varchar \n  \n Table order_items   \n  order_id int \n  product_id int \n  quantity int \n \n Table products   \n  id int primary key \n  name varchar \n  merchant_id int not null \n  price int \n  status varchar \n  created_at varchar \n  category_id int \n \n Table users  \n  id int primary key \n  full_name varchar \n  email varchar unique \n  gender varchar \n  date_of_birth varchar \n  created_at varchar \n  country_code int \n "
            " Generate the best essay possible for the user's request."
            " If the user provides critique, respond with a revised version of your previous attempts.",
        ),
        MessagesPlaceholder(variable_name="messages"),
    ]
)
llm = ChatOpenAI(model="gpt-3.5-turbo",temperature=0.7)
generate = prompt | llm

In [21]:
sqlquery = ""
request = HumanMessage(
    content="Write a SQL query which gives Top 5 Customers based on Sales in current year"
)
for chunk in generate.stream({"messages": [request]}):
    print(chunk.content, end="")
    sqlquery += chunk.content

To find the top 5 customers based on sales in the current year, you can use the following SQL query:

```sql
WITH customer_sales AS (
    SELECT u.id AS user_id, u.full_name AS customer_name, SUM(p.price * oi.quantity) AS total_sales
    FROM users u
    JOIN orders o ON u.id = o.user_id
    JOIN order_items oi ON o.id = oi.order_id
    JOIN products p ON oi.product_id = p.id
    WHERE YEAR(o.created_at) = YEAR(CURRENT_DATE())
    GROUP BY u.id, u.full_name
)

SELECT customer_name, total_sales
FROM customer_sales
ORDER BY total_sales DESC
LIMIT 5;
```

In this query:
1. We first calculate the total sales for each customer by joining the `users`, `orders`, `order_items`, and `products` tables.
2. We filter the orders based on the current year using the `YEAR()` function and `CURRENT_DATE()` function.
3. We group the results by user_id and full_name and calculate the total sales.
4. Finally, we select the customer_name and total_sales from the subquery and order them by total_sales in de

### Reflect

In [22]:
reflection_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a SQL SME who is grading an SQL submission. Generate critique and recommendations for the user's submission."
            " Provide detailed recommendations on sql query optimization, including number of columns in select query, joins, group by etc.",
        ),
        MessagesPlaceholder(variable_name="messages"),
    ]
)
reflect = reflection_prompt | llm

In [23]:
reflection = ""
for chunk in reflect.stream({"messages": [request, HumanMessage(content=sqlquery)]}):
    print(chunk.content, end="")
    reflection += chunk.content

Overall, the SQL query you provided is on the right track for finding the top 5 customers based on sales in the current year. Here are some critique and recommendations for optimization:

1. **Column Selection**: Since we are only interested in the top 5 customers based on sales, you could optimize the query by selecting only the necessary columns in the final output. In this case, selecting just the `customer_name` would be sufficient as we are ranking customers based on sales.

2. **Indexing**: To improve query performance, ensure that the columns used in join conditions and filtering (such as `user_id`, `order_id`, `product_id`, `created_at`) are properly indexed in the respective tables.

3. **Subquery vs. Common Table Expression (CTE)**: While using a CTE is a good practice for readability and maintainability, in some cases, using a subquery might perform better depending on the database engine's query optimizer. You can test both approaches to see which one works more efficiently

### Repeat

And... that's all there is too it! You can repeat in a loop for a fixed number of steps, or use an LLM (or other check) to decide when the finished product is good enough.

In [24]:
for chunk in generate.stream(
    {"messages": [request, AIMessage(content=sqlquery), HumanMessage(content=reflection)]}
):
    print(chunk.content, end="")

Thank you for the valuable feedback and suggestions. Here is the revised SQL query incorporating the optimizations:

```sql
WITH customer_sales AS (
    SELECT u.full_name AS customer_name, SUM(p.price * oi.quantity) AS total_sales
    FROM users u
    JOIN orders o ON u.id = o.user_id
    JOIN order_items oi ON o.id = oi.order_id
    JOIN products p ON oi.product_id = p.id
    WHERE YEAR(o.created_at) = YEAR(CURRENT_DATE())
    GROUP BY u.full_name
)

SELECT customer_name
FROM customer_sales
ORDER BY total_sales DESC
LIMIT 5;
```

This query now selects only the necessary column `customer_name` for the final output, ensuring optimal performance. The join conditions, grouping, and ordering remain intact to accurately identify the top 5 customers based on sales in the current year. Additionally, implementing proper indexing on relevant columns can further enhance the query's efficiency.

## Define graph

Now that we've shown each step in isolation, we can wire it up in a graph.

In [25]:
from typing import List, Sequence

from langgraph.graph import END, MessageGraph


async def generation_node(state: Sequence[BaseMessage]):
    return await generate.ainvoke({"messages": state})


async def reflection_node(messages: Sequence[BaseMessage]) -> List[BaseMessage]:
    # Other messages we need to adjust
    cls_map = {"ai": HumanMessage, "human": AIMessage}
    # First message is the original user request. We hold it the same for all nodes
    translated = [messages[0]] + [
        cls_map[msg.type](content=msg.content) for msg in messages[1:]
    ]
    res = await reflect.ainvoke({"messages": translated})
    # We treat the output of this as human feedback for the generator
    return HumanMessage(content=res.content)


builder = MessageGraph()
builder.add_node("generate", generation_node)
builder.add_node("reflect", reflection_node)
builder.set_entry_point("generate")


def should_continue(state: List[BaseMessage]):
    if len(state) > 6:
        # End after 3 iterations
        return END
    return "reflect"


builder.add_conditional_edges("generate", should_continue)
builder.add_edge("reflect", "generate")
graph = builder.compile()

In [26]:
async for event in graph.astream(
    [
        HumanMessage(
            content="Write a SQL query which gives Top 5 Customers based on Sales in current year"
        )
    ],
):
    print(event)
    print("---")

{'generate': AIMessage(content='To find the top 5 customers based on sales in the current year, we need to join the `users`, `orders`, `order_items`, and `products` tables. We will need to calculate the total sales amount for each user in the current year and then rank them accordingly. Here is the SQL query to achieve this:\n\n```sql\nSELECT u.id AS user_id, u.full_name AS customer_name, SUM(oi.quantity * p.price) AS total_sales\nFROM users u\nJOIN orders o ON u.id = o.user_id\nJOIN order_items oi ON o.id = oi.order_id\nJOIN products p ON oi.product_id = p.id\nWHERE YEAR(o.created_at) = YEAR(CURDATE())\nGROUP BY u.id, u.full_name\nORDER BY total_sales DESC\nLIMIT 5;\n```\n\nThis query will give you the top 5 customers based on sales in the current year, displaying their user ID, full name, and total sales amount.')}
---
{'reflect': HumanMessage(content="Your SQL query is on the right track to find the top 5 customers based on sales in the current year. Here are some recommendations fo

In [27]:
ChatPromptTemplate.from_messages(event[END]).pretty_print()


Write a SQL query which gives Top 5 Customers based on Sales in current year


To find the top 5 customers based on sales in the current year, we need to join the `users`, `orders`, `order_items`, and `products` tables. We will need to calculate the total sales amount for each user in the current year and then rank them accordingly. Here is the SQL query to achieve this:

```sql
SELECT u.id AS user_id, u.full_name AS customer_name, SUM(oi.quantity * p.price) AS total_sales
FROM users u
JOIN orders o ON u.id = o.user_id
JOIN order_items oi ON o.id = oi.order_id
JOIN products p ON oi.product_id = p.id
WHERE YEAR(o.created_at) = YEAR(CURDATE())
GROUP BY u.id, u.full_name
ORDER BY total_sales DESC
LIMIT 5;
```

This query will give you the top 5 customers based on sales in the current year, displaying their user ID, full name, and total sales amount.


Your SQL query is on the right track to find the top 5 customers based on sales in the current year. Here are some recommendations for opt

## Conclusion

Now that you've applied reflection to an LLM agent, I'll note one thing: self-reflection is inherantly cyclic: it is much more effective if the reflection step has additional context or feedback (from tool observations, checks, etc.). If, like in the scenario above, the reflection step simply prompts the LLM to reflect on its output, it can still benefit the output quality (since the LLM then has multiple "shots" at getting a good output), but it's less guaranteed.
