In [1]:
from langchain.prompts import ChatPromptTemplate
import constants.constants as constants
from langchain.utilities import SQLDatabase
from langchain.chat_models import ChatOpenAI
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough

template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)

In [2]:
db = SQLDatabase.from_uri("sqlite:///./raika.db")

In [3]:
def get_schema(_):
    return db.get_table_info()

In [4]:
def run_query(query):
    return db.run(query)

In [5]:
model = ChatOpenAI(openai_api_key=constants.OPENAI_API_KEY)

sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | model.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

In [6]:
sql_response.invoke({"question": "Wie viele Filalien gibt es"})

'SELECT COUNT(*) FROM raika'

In [7]:
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_template(template)

In [8]:
full_chain = (
    RunnablePassthrough.assign(query=sql_response)
    | RunnablePassthrough.assign(
        schema=get_schema,
        response=lambda x: db.run(x["query"]),
    )
    | prompt_response
    | model
)

In [9]:
full_chain.invoke({"question": "Wie viele Filalien gibt es"})

AIMessage(content='Es gibt eine Filiale in der Tabelle "raika".')

In [10]:
db.run("""select count(*) from raika""")

'[(1,)]'

In [14]:
print(get_schema())

TypeError: get_schema() missing 1 required positional argument: '_'

In [12]:
db._all_tables

{'raika'}