In [None]:
import os
os.environ['OPENAI_API_KEY'] = ""

In [None]:
from langchain_core.prompts import ChatPromptTemplate

template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)

In [None]:
prompt.format(schema="my schema", question="How many users there?")

In [None]:
from langchain_community.utilities import SQLDatabase

# if you are using SQLite
sqlite_uri = 'sqlite:///./Chinook.db' 

# if you are using MySQL
mysql_uri = 'mysql+mysqlconnector://root:admin@localhost:3306/test_db'

db = SQLDatabase.from_uri(sqlite_uri)

In [None]:
db.run("SELECT * FROM Album LIMIT 5")

In [None]:
def get_schema(db):
    return db.get_table_info()
   

In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

llm = ChatOpenAI()

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

In [None]:
user_question = 'how many albums are there in the database?'
sql_chain.invoke({"question": user_question})

In [None]:
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_template(template)

In [None]:
def run_query(query):
    return db.run(query)

In [None]:
full_chain = (
    RunnablePassthrough.assign(query=sql_chain).assign(
        schema=get_schema,
        response=lambda variables: run_query(variables["query"]),
    )
    | prompt
    | llm
    | StrOutputParser()
)

In [None]:
user_question = 'how many albums are there in the database?'
full_chain.invoke({"question": user_question})