Database connection

In [None]:
from pyprojroot import here

from sqlalchemy import create_engine
from langchain_community.utilities import SQLDatabase

database_path = here("data/sqldb.db")

connection_string = f"sqlite:///{database_path}"
engine = create_engine(connection_string, echo=True)
db = SQLDatabase(engine)

Connection to the desired LLM model

In [None]:
import getpass
import os

if not os.environ.get("OPENAI_API_KEY"):
  os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter API key for OpenAI: ")

from langchain.chat_models import init_chat_model

llm = init_chat_model("gpt-4o-mini", model_provider="openai")

Preparing a promt engineering by extracting a desitred prompt

In [7]:
from langchain import hub

query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")



The message that should be passed to the model in case of teh error while running the query against the database

In [None]:
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
from langchain.prompts.prompt import PromptTemplate

error_template = PromptTemplate(
    input_variables=["dialect", "input", "table_info", "top_k", "error", "prev_result"],
    template=(
        "Given an input question, create a syntactically correct {dialect} query to run to help find the answer. Unless the user specifies in his question a specific number of examples they wish to obtain, always limit your query to at most {top_k} results."
        "You can order the results by a relevant column to return the most interesting examples in the database." 
        "Never query for all the columns from a specific table, only ask for a the few relevant columns given the question."
        "Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table."
        "Only use the following tables:\n{table_info}"
        "Question: {input}"
        "Also, you have already been asked this question and you produced the next result: {error_result}."
        "However, it generated the next error while trying to execute: {error}"
        "Based on the question and the previously generated error, please try to fix the query or gengerate a new one"
    )
)

sys_msg = SystemMessagePromptTemplate(prompt=error_template)
error_prompt_template = ChatPromptTemplate([sys_msg])

input_variables=['dialect', 'input', 'table_info', 'top_k'] input_types={} partial_variables={} metadata={'lc_hub_owner': 'langchain-ai', 'lc_hub_repo': 'sql-query-system-prompt', 'lc_hub_commit_hash': '5d6c20e97a0a3dc6f955719a185eb8987d9fce8a04ec1df70344ff92497ebcfb'} messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['dialect', 'input', 'table_info', 'top_k'], input_types={}, partial_variables={}, template='Given an input question, create a syntactically correct {dialect} query to run to help find the answer. Unless the user specifies in his question a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.\n\nNever query for all the columns from a specific table, only ask for a the few relevant columns given the question.\n\nPay attention to use only the column names that you can see in the schema description. 

Creating a state for messaging tracking among the tools

In [None]:
from typing_extensions import TypedDict, List

class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str
    error: str
    counter: int
    tables: List[str]
    prev_result: str

Creating a set of tools that are responsible for the selection of the most relevant tables for a particular request and the eventual generation of the query

In [None]:
from typing_extensions import Annotated
from langchain.tools import tool
import yaml


class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]


@tool
def handle_request(user_input: str) -> State:
    return {"question": user_input, "query": "", "result": "", "answer": "", "tables": []}

@tool
def select_tables(state: State):
    question = state["question"]
    with open('../config/config.yml', 'r') as file:
        config = yaml.safe_load(file)
    llm_config = config.get("llm_model")
    prompt = llm_config.get("table_prompt")
    tables = db.get_table_names()
    filled_prompt = prompt.format(question=question, table_names=tables)
    result = llm.predict(filled_prompt)
    return {"tables": result.split(",")}

@tool
def confirm_table_selection(state: State):
    tables = state["tables"]
    confirm = input(f"Are the next table correct? {tables} (yes/no)")
    if confirm.lower() == "yes":
        return {"tables": tables}
    # resolve the no branch

@tool
def schema_retriever_query_generator(state: State):
    tables = state["tables"]
    schemas = db.get_table_info(tables)
    """Generate SQL query to fetch information."""

    if "error" in state:
        prompt = error_prompt_template.invoke(
            {
                "dialect": db.dialect,
                "input": state["question"],
                "table_info": schemas,
                "top_k": 10,
                "error": state["error"],
                "prev_result": state["prev_result"]
            }
        )
    else:
        prompt = query_prompt_template.invoke(
            {
                "dialect": db.dialect,
                "top_k": 10,
                "table_info": schemas,
                "input": state["question"],
            }
        )

    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return {"query": result["query"]}

Block of code responsible for running the generated query against the table

In [None]:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool

execute_query_tool = QuerySQLDatabaseTool(db=db)

@tool
def execute_query(state: State):
    """Execute SQL query."""
    try:
        result = execute_query_tool.invoke(state["query"])
        return {"result": result}
    except Exception as e:
        err_msg = str(e)
        state["error"] = err_msg
        state["prev_result"] = state["query"]
        state["counter"] = state["counter"] + 1
        
    # return {"result": execute_query_tool.invoke(state["query"])}