In [None]:
import os
os.environ['OPENAI_API_KEY'] = ''
# Enter your OPENAI KEY!

In [2]:
from langchain_core.prompts import ChatPromptTemplate

template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)

In [3]:
prompt.format(schema="my schema", question="how many users are there?")

"Human: Based on the table schema below, write a SQL query that would answer the user's question:\nmy schema\n\nQuestion: how many users are there?\nSQL Query:"

In [4]:
from langchain_community.utilities import SQLDatabase

db_uri = "mysql+mysqlconnector://root:Bagera773$@localhost:3306/Chinook"
db = SQLDatabase.from_uri(db_uri)

In [5]:
db.run("SELECT * from album LIMIT 5")

"[(1, 'For Those About To Rock We Salute You', 1), (2, 'Balls to the Wall', 2), (3, 'Restless and Wild', 2), (4, 'Let There Be Rock', 1), (5, 'Big Ones', 3)]"

In [6]:
def get_schema(_):
    """
    To get the schema of the database
    """
    return db.get_table_info()

In [None]:
# Creating the SQL chain
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

In [8]:
llm = ChatOpenAI()

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop="\nSQL Result:")   #bind acts as a stopper to avoid result hallucination
    | StrOutputParser()
)

In [None]:
# trial
# LLM takes user question and returns a sql query
sql_chain.invoke({"question": "how many artists are there?"})

# added $5 to credit plan at https://platform.openai.com/settings/organization/billing/overview

'SELECT COUNT(*) AS total_artists\nFROM artist;'

In [11]:
# Create the full project chain which uses the SQL chain
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""

prompt_response = ChatPromptTemplate.from_template(template)

In [12]:
def run_query(query):
    """
    takes the query generated and runs it 
    """
    return db.run(query)

In [13]:
run_query("SELECT COUNT(*) AS total_artists\nFROM artist;")

'[(275,)]'

In [29]:
full_chain = (
    RunnablePassthrough.assign(query=sql_chain).assign(
        schema = get_schema,
        response = lambda variables: run_query(variables["query"])
    )
    | prompt_response
    | llm
    | StrOutputParser()
)

In [30]:
full_chain.invoke({"question": "how many artists are there?"})

'There are 275 artists in the database.'

In [31]:
full_chain.invoke({"question": "top three most popular artists"})

'The top three most popular artists in the database based on the number of albums they have are Iron Maiden with 21 albums, Led Zeppelin with 14 albums, and Deep Purple with 11 albums.'