In [None]:
from typing import List
import sqlalchemy
from sqlalchemy.engine.base import Engine
from sqlalchemy import text, create_engine
import pandas as pd
from langchain_core.tools import tool
from langchain_core.runnables.config import RunnableConfig
from langchain_openai import ChatOpenAI

In [None]:
llm = ChatOpenAI(
  model="gpt-4o-mini",
  temperature=0.0,
  base_url="https://openai.vocareum.com/v1"
)

In [None]:
db_engine = create_engine(f"sqlite:///sales.db")

In [None]:
inspector = sqlalchemy.inspect(db_engine)

In [None]:
inspector.get_table_names()

In [None]:
table_name = "sales"
inspector.get_columns(table_name)

In [None]:
schema = inspector.get_columns(table_name)
column_names = [column["name"] for column in schema]
column_names

In [None]:
sql = f"SELECT * FROM {table_name} LIMIT 10"

In [None]:
with db_engine.begin() as connection:
  answer = connection.execute(text(sql)).fetchall()
  
answer

In [None]:
pd.DataFrame(answer, columns=column_names)

Tools

In [10]:
@tool
def list_table_tool(config: RunnableConfig) -> List[str]:
  """
  List all tables in database
  """
  
  db_engine:Engine = config.get("configurable", {}).get("db_engine")
  inspector = sqlalchemy.inspect(db_engine)
  
  return inspector.get_table_names()

In [11]:
@tool
def get_table_schema_tool(table_name:str, config: RunnableConfig) -> List[str]:
  """
  Get schema information about a table. Returns a list of dictionaries.
  - name is the column name
  - type is the column type
  - nullable is whether the column is nullable or not
  - default is the default value of the column
  - primary_key is wheather the column is a primary key or not
  
  Args:
    table_name (str): Table name
  """
  
  db_engine:Engine = config.get("configurable", {}).get("db_engine")
  inspector = sqlalchemy.inspect(db_engine)
  
  return inspector.get_columns(table_name)


In [12]:
@tool
def execute_sql_tool(query:str, config: RunnableConfig) -> int:
  """
  Execute SQL query and return result.
  This will automatically connect to the database and execute the query.
  However, if the query is not valid, an error will be raised
  
  Args:
    query (str): SQL query
  """
  
  db_engine:Engine = config.get("configurable", {}).get("db_engine")
  with db_engine.begin() as connection:
    answer = connection.execute(text(query)).fetchall()
  
  return answer


In [None]:
db_engine = create_engine(f"sqlite:///sales.db")

In [None]:
config = {"Configurable": {"db_engine": db_engine}}

In [None]:
tables = list_table_tool.invoke({}, config)
tables

In [None]:
schemas = {
  table: get_table_schema_tool({
    'table_name': table
  }, config)
  for table in tables
}

schemas

In [None]:
sql = f"SELECT * FROM {tables[0]} LIMIT 10"
result = execute_sql_tool.invoke({"query": sql}, config)
result