## Chat With a MySQL Database Using Python and LangChain
- [source](https://alejandro-ao.com/chat-with-mysql-using-python-and-langchain/)

In [1]:
from langchain_community.utilities import SQLDatabase
from dotenv import load_dotenv
from IPython.display import display, Markdown
import os

In [2]:
load_dotenv()

mysql_password = os.getenv('MYSQL_PASSWORD')
mysql_uri = f'mysql+mysqlconnector://root:{mysql_password}@localhost:3306/Chinook'
# posgres_uri = "postgresql://username:password@localhost:5432/mydatabase"

db = SQLDatabase.from_uri(mysql_uri)


In [3]:
db.run('SELECT * FROM Album LIMIT 5;')

"[(1, 'For Those About To Rock We Salute You', 1), (2, 'Balls to the Wall', 2), (3, 'Restless and Wild', 2), (4, 'Let There Be Rock', 1), (5, 'Big Ones', 3)]"

In [5]:
from langchain_core.prompts import ChatPromptTemplate

template = """Based on the table schema below, write a only SQL query that would answer the user's query:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)


In [6]:
def get_schema(_):
    schema = db.get_table_info()
    return schema


In [26]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.llms import Ollama

llm = Ollama(temperature = 0.8, model="qwen2:1.5b")
# llm = Ollama(temperature = 0.0, model="llama3.1:8b-instruct-q2_K")


sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)


In [8]:
user_question = 'how many albums are there in the database?'
sql_chain.invoke({"question": user_question})

'```sql\nSELECT COUNT(*) FROM Album;\n```\n\nThis query will return a single row with a count of all albums in the `Album` table, which gives us an idea of how many albums are stored in our database.'

## Create the full Chain

In [9]:
template = """Based on the table schema below, write a natural language response based on the question, SQL query, and SQL response:
{schema}

Question: {question}
SQL Query: {query} 
SQL Response: {response}"""

prompt_response = ChatPromptTemplate.from_template(template)


In [10]:
import re
def clean_sql_query(query):
    """
    Removes any non-SQL formatting from the given SQL query string and returns a runnable SQL query.
    
    :param query: A string containing the SQL query, possibly with non-SQL formatting (e.g., Markdown).
    :return: A string containing the formatted, runnable SQL query.
    """

    match = re.search(r"```(?:sql)?\s*(.*?)\s*```", query, re.DOTALL)
    if match:
        query = match.group(1).strip()
    return query

In [40]:
def run_query(query):
    try:
        # Execute the cleaned SQL query using SQLDatabase
        query = clean_sql_query(query)
        print(query)
        result = db.run(query)
        print(result)
        return result
        
    except Exception as e:
        print(f"An error occurred: {e}")
        return None


In [41]:
full_chain = (
    RunnablePassthrough.assign(query=sql_chain).assign(
        schema=get_schema,
        response=lambda vars: run_query(vars["query"]),
    )
    | prompt_response
    | llm
)


In [42]:
full_chain.invoke({"question": 'how many albums are there in the database?'})

SELECT COUNT(AlbumId) FROM Album;
[(347,)]


'There are 347 albums in the database.'

In [44]:
res = full_chain.invoke({"question": 'get all track with Rabin'})
print(res)


SELECT * FROM `Track` WHERE Composer = 'Rabin';
[(1, 'For Those About To Rock (We Salute You)', 1, 1, 1, 'Rabin', 343719, 11170334, Decimal('0.99'))]
The query successfully retrieves all the columns and rows from Track table where Composer is equal to 'Rabin'. Here are the details of the track:

- Name: For Those About To Rock (We Salute You)
- AlbumId: 1
- MediaTypeId: 1
- GenreId: 1
- Composer: Rabin
- Milliseconds: 343719
- Bytes: 11170334
- UnitPrice: '0.99'
