In [27]:
import psycopg2

import requests
from langchain_community.utilities.sql_database import SQLDatabase
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
from langchain.schema import AIMessage, HumanMessage
from langchain_core.messages import ToolMessage


def get_engine_for_postgres_db():
    # Connection URL for SQLAlchemy using the psycopg2 PostgreSQL driver
    url = "postgresql+psycopg2://langchain_user:supersecurepassword@localhost:5432/langchain_db"

    # Function that creates a new connection to the PostgreSQL database
    def connect():
        return psycopg2.connect(
            dbname="langchain_db",   
            user="langchain_user",            
            password="supersecurepassword",        
            host="localhost",        
            port=5432               
        )

    # Creates a SQLAlchemy engine with:
    # - the connection URL
    # - a creator function that opens new psycopg2 connections as needed (instead of using connection pooling)
    # - StaticPool, which keeps the same connection open and does not manage a real pool
    engine = create_engine(
        url,
        creator=connect,
        poolclass=StaticPool
    )
    return engine


engine = get_engine_for_postgres_db()

db = SQLDatabase(engine)

In [28]:
import getpass
import os

openai_key  = os.getenv("OPENAI_API_KEY")

from langchain.chat_models import init_chat_model

llm = init_chat_model("gpt-4o", model_provider="openai")

In [29]:
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

In [30]:
from langchain_community.tools.sql_database.tool import (
    InfoSQLDatabaseTool,
    ListSQLDatabaseTool,
    QuerySQLCheckerTool,
    QuerySQLDatabaseTool,
)

In [31]:
toolkit.get_tools()

[QuerySQLDatabaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x000001BC91064C50>),
 InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x000001BC91064C50>),
 ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x000001BC91064C50>),
 QuerySQLCheckerTool(description='Use this tool to 

In [32]:
from langchain import hub

prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")

assert len(prompt_template.messages) == 1
print(prompt_template.input_variables)
system_message = prompt_template.format(dialect="SQLite", top_k=5)



['dialect', 'top_k']


In [33]:
from langgraph.prebuilt import create_react_agent

agent_executor = create_react_agent(llm, toolkit.get_tools(), prompt=system_message)

In [34]:
import gradio as gr
from gradio import ChatMessage

def interact_with_agent(user_message, history):
    events = agent_executor.stream(
        {"messages": [("user", user_message)]},
        stream_mode="values",
    )

    for event in events:
        last_msg = event["messages"][-1]
        # Prüfe, ob last_msg ein Objekt mit .text Attribut oder Methode ist
        if hasattr(last_msg, 'text'):
            response = last_msg.text if not callable(last_msg.text) else last_msg.text()
        else:
            response = str(last_msg)  # Fallback

        if isinstance(last_msg, AIMessage):
            chat_message = ChatMessage(role="assistant", content=response)
            history.append(chat_message)
        elif isinstance(last_msg, ToolMessage):
            chat_message = ChatMessage(role="assistant", content=response, metadata={"title": "🛠️ Used tool SQL"})
            history.append(chat_message)
        yield history  

demo = gr.ChatInterface(
    interact_with_agent,
    type="messages",
    chatbot=gr.Chatbot(
        label="Agent",
        type="messages",
        avatar_images=(
            None,
            "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
        ),
    ),
    examples=[
        ["Who are the top 3 best selling artists?"],
        ["Which country's customers spent the most?"]
    ],
    save_history=True,
)


In [35]:
demo.launch()

* Running on local URL:  http://127.0.0.1:7863
* To create a public link, set `share=True` in `launch()`.


