Skip to content

Commit

Permalink
docs(langchain): add langchain sql agent example (#314)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomas2D committed Feb 7, 2024
1 parent 0812280 commit a7333f4
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 1 deletion.
156 changes: 156 additions & 0 deletions examples/extensions/langchain/langchain_sql_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""LangChain SQL Agent
In this example, we first create an SQL database with a ‘countries’ table, and subsequently, we will use LangChain
Agent to make queries against it.
"""

import contextlib
from tempfile import TemporaryFile

from dotenv import load_dotenv
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers import JSONAgentOutputParser
from langchain.memory import ConversationBufferMemory
from langchain.tools.render import render_text_description_and_args
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnablePassthrough
from langchain_core.tools import BaseTool

from genai import Client, Credentials
from genai.extensions.langchain import LangChainChatInterface
from genai.schema import TextGenerationParameters

load_dotenv()

try:
import pandas as pd
import sqlalchemy
from langchain.sql_database import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit
except ImportError:
print("Please install 'pandas' / 'sqlalchemy' to run this example.")
raise


@contextlib.contextmanager
def get_countries_db():
with TemporaryFile(suffix=".db") as f:
df = pd.DataFrame(
{
"country": [
"United States",
"United Kingdom",
"France",
"Germany",
"Italy",
"Spain",
"Canada",
"Australia",
"Japan",
"China",
],
"gdp": [
19294482071552,
2891615567872,
2411255037952,
3435817336832,
1745433788416,
1181205135360,
1607402389504,
1490967855104,
4380756541440,
14631844184064,
],
}
)

engine = sqlalchemy.create_engine(f"sqlite:///{f.name}")
df.to_sql("countries", con=engine, index=False, if_exists="replace")
yield SQLDatabase.from_uri(f"sqlite:///{f.name}")
engine.dispose(close=True)


def create_llm():
client = Client(credentials=Credentials.from_env())
return LangChainChatInterface(
client=client,
model_id="meta-llama/llama-2-70b-chat",
parameters=TextGenerationParameters(
max_new_tokens=250, min_new_tokens=20, temperature=0, stop_sequences=["\nObservation"]
),
)


def create_agent(tools: list[BaseTool], llm: LangChainChatInterface):
system_prompt = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:
{tools}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{
"action": "Final Answer",
"action_input": "Final response to human"
}}
Begin! Reminder to ALWAYS respond with a valid json blob of a single action.
Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation""" # noqa

human_prompt = """{input}
{agent_scratchpad}
(reminder to respond in a JSON blob no matter what)"""

memory = ConversationBufferMemory()

prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder("chat_history", optional=True),
("human", human_prompt),
]
).partial(
tools=render_text_description_and_args(list(tools)),
tool_names=", ".join([t.name for t in tools]),
)

agent = (
RunnablePassthrough.assign(
# format the agent's scratchpad to a string
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
# pass the memory as the chat history
chat_history=lambda x: memory.chat_memory.messages,
)
| prompt
| llm
| JSONAgentOutputParser()
)

return AgentExecutor(agent=agent, tools=tools, handle_parsing_errors=True, verbose=True, memory=memory)


with get_countries_db() as db:
llm = create_llm()

sql_toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = sql_toolkit.get_tools()

agent_executor = create_agent(tools, llm)
agent_executor.invoke({"input": "How many rows are in the countries table?"})
agent_executor.invoke({"input": "Which are the countries with GDP greater than 3000000000000?"})
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ sphinx-multiversion = { "git" = "https://github.com/samtygier-stfc/sphinx-multiv
sphinx-copybutton = "^0.5.2"
sphinx-notfound-page = "^1.0.0"
langchain = "^0.1.4"
sqlalchemy = "^2.0.25"


[tool.poetry.extras]
Expand Down

0 comments on commit a7333f4

Please sign in to comment.