In [68]:
import sqlite3

import requests
from langchain_community.utilities.sql_database import SQLDatabase
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
from langchain.schema import AIMessage, HumanMessage
from langchain_core.messages import ToolMessage


def get_engine_for_chinook_db():
    """Pull sql file, populate in-memory database, and create engine."""
    url = "https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql"
    response = requests.get(url)
    sql_script = response.text

    connection = sqlite3.connect(":memory:", check_same_thread=False)
    connection.executescript(sql_script)
    return create_engine(
        "sqlite://",
        creator=lambda: connection,
        poolclass=StaticPool,
        connect_args={"check_same_thread": False},
    )


engine = get_engine_for_chinook_db()

db = SQLDatabase(engine)

In [3]:
import getpass
import os

os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter API key for OpenAI: ")

from langchain.chat_models import init_chat_model

llm = init_chat_model("gpt-4o", model_provider="openai")

In [4]:
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

In [5]:
from langchain_community.tools.sql_database.tool import (
    InfoSQLDatabaseTool,
    ListSQLDatabaseTool,
    QuerySQLCheckerTool,
    QuerySQLDatabaseTool,
)

In [6]:
from langchain import hub

prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")

assert len(prompt_template.messages) == 1
print(prompt_template.input_variables)
system_message = prompt_template.format(dialect="SQLite", top_k=5)



['dialect', 'top_k']


In [7]:
from langgraph.prebuilt import create_react_agent

agent_executor = create_react_agent(llm, toolkit.get_tools(), prompt=system_message)

In [62]:
example_query = "Who are the top 3 best selling artists?"

events = agent_executor.stream(
    {"messages": [("user", example_query)]},
    stream_mode="values",
)
test = []
for event in events:
    print(event["messages"][-1])
    test.append(event)
    event["messages"][-1].pretty_print()

content='Who are the top 3 best selling artists?' additional_kwargs={} response_metadata={} id='1d2c13e9-6452-42a8-86f5-afa370e910a4'

Who are the top 3 best selling artists?
content='' additional_kwargs={'tool_calls': [{'id': 'call_U15qMaSnECRJPbOcV1S3lV26', 'function': {'arguments': '{}', 'name': 'sql_db_list_tables'}, 'type': 'function'}], 'refusal': None} response_metadata={'token_usage': {'completion_tokens': 12, 'prompt_tokens': 553, 'total_tokens': 565, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_07871e2ad8', 'id': 'chatcmpl-ByfKCueCnhTdgSx7CEZxcHU4RjAsN', 'service_tier': 'default', 'finish_reason': 'tool_calls', 'logprobs': None} id='run--50922cec-c6f8-46c8-8ba3-be1ac3ca4ab8-0' tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': 'call_U15qMaSnECRJPbO

In [66]:
for event in test:
    # Assuming the last message is the one we want to print
    if event["messages"]:
        print(event["messages"])
    else:
        print("No messages found in event.")

[HumanMessage(content='Who are the top 3 best selling artists?', additional_kwargs={}, response_metadata={}, id='1d2c13e9-6452-42a8-86f5-afa370e910a4')]
[HumanMessage(content='Who are the top 3 best selling artists?', additional_kwargs={}, response_metadata={}, id='1d2c13e9-6452-42a8-86f5-afa370e910a4'), AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_U15qMaSnECRJPbOcV1S3lV26', 'function': {'arguments': '{}', 'name': 'sql_db_list_tables'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 12, 'prompt_tokens': 553, 'total_tokens': 565, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_07871e2ad8', 'id': 'chatcmpl-ByfKCueCnhTdgSx7CEZxcHU4RjAsN', 'service_tier': 'default', 'finish_reason': 'tool_calls', 'logpr

In [42]:
example_query = "Who are the top 3 best selling artists?"

ai_msg = agent_executor.invoke({"messages": [("user", example_query)]})


In [44]:
ai_msg

{'messages': [HumanMessage(content='Who are the top 3 best selling artists?', additional_kwargs={}, response_metadata={}, id='ed307842-c356-48a0-9913-e12bbf18fd80'),
  AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_gnoy8cdPSqzQaemcrMoiVrBS', 'function': {'arguments': '{}', 'name': 'sql_db_list_tables'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 12, 'prompt_tokens': 553, 'total_tokens': 565, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_07871e2ad8', 'id': 'chatcmpl-ByexBBfsMo8tOXMcmRvWEB31BC4Vw', 'service_tier': 'default', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run--10f584d0-37af-49a5-8d64-8b2847272073-0', tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': 'call_gnoy8cdP

In [71]:
import gradio as gr
from gradio import ChatMessage

def interact_with_agent(user_message, history):
    events = agent_executor.stream(
        {"messages": [("user", user_message)]},
        stream_mode="values",
    )

    for event in events:
        repsonse = event["messages"][-1].text()
        if isinstance(event["messages"][-1], AIMessage):
            history.append(ChatMessage(role="assistant", content=repsonse))
        elif isinstance(event["messages"][-1], ToolMessage):
            history.append(ChatMessage(role="assistant", content= repsonse,metadata={"title": "🛠️ Used tool SQL"}))
        yield history


demo = gr.ChatInterface(
    interact_with_agent,
    type="messages",
    chatbot= gr.Chatbot(
        label="Agent",
        type="messages",
        avatar_images=(
            None,
            "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
        ),
    ),
    examples=[
        ["Who are the top 3 best selling artists?"],
        ["Which country's customers spent the most?"]],
    save_history=True,
)

In [72]:
demo.launch()

* Running on local URL:  http://127.0.0.1:7880
* To create a public link, set `share=True` in `launch()`.


