In [3]:
from langchain.utilities import SQLDatabase
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import create_sql_query_chain
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_google_genai import ChatGoogleGenerativeAI

In [6]:
# Connect MySQL database
host = 'localhost'
port = '3306'
username = 'root'
password = 'password'
database_schema = 'text_to_sql'

mysql_uri = f"mysql+pymysql://{username}:{password}@{host}:{port}/{database_schema}"

db = SQLDatabase.from_uri(mysql_uri, sample_rows_in_table_info=2)

In [7]:
db = SQLDatabase.from_uri(mysql_uri, sample_rows_in_table_info=1)

db.get_table_info()

'\nCREATE TABLE `2017_budgets` (\n\t`Product_Name` TEXT, \n\t`2017_Budgets` DOUBLE\n)COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB DEFAULT CHARSET=utf8mb4\n\n/*\n1 rows from 2017_budgets table:\nProduct_Name\t2017_Budgets\nProduct 1\t3016489.2089999998\n*/\n\n\nCREATE TABLE customers (\n\t`Customer_Index` INTEGER, \n\t`Customer_Names` TEXT\n)COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB DEFAULT CHARSET=utf8mb4\n\n/*\n1 rows from customers table:\nCustomer_Index\tCustomer_Names\n1\tGeiss Company\n*/\n\n\nCREATE TABLE products (\n\t`Index` INTEGER, \n\t`Product_Name` TEXT\n)COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB DEFAULT CHARSET=utf8mb4\n\n/*\n1 rows from products table:\nIndex\tProduct_Name\n1\tProduct 1\n*/\n\n\nCREATE TABLE regions (\n\tid INTEGER, \n\tname TEXT, \n\tcounty TEXT, \n\tstate_code TEXT, \n\tstate TEXT, \n\ttype TEXT, \n\tlatitude DOUBLE, \n\tlongitude DOUBLE, \n\tarea_code INTEGER, \n\tpopulation INTEGER, \n\thouseholds INTEGER, \n\tmedian_income INTEGER, \n\tland_area INTEGER, \

In [8]:
# Create the LLM Prompt Template
from langchain_core.prompts import ChatPromptTemplate

template = """
Based on the table schema below, write a SQL query that would answer the user's question:
Remember: Only provide me the sql query dont include anything else. Provide me sql query in a single line dont add line breaks
Table Schema: {schema}
Question: {question}
SQL Query:
"""

prompt = ChatPromptTemplate.from_template(template)

In [9]:
# get the schema of the database
def get_schema(db):
    schema = db.get_table_info()
    return schema

In [None]:
llm = ChatGoogleGenerativeAI(
    model = 'gemini-2.0-flash',
    api_key = 'XYZ'
)

E0000 00:00:1759686826.521699 10724683 alts_credentials.cc:93] ALTS creds ignored. Not running on GCP and untrusted ALTS is not enabled.


In [15]:
# Create the SQL query chain using the LLM and the prompt template
sql_chain = (
    RunnablePassthrough.assign(schema=lambda _:get_schema(db))
    | prompt
    | llm.bind(stop = ["\nSQLResult:"])
    | StrOutputParser()
)

In [20]:
# test the SQL query chain with a sample question
resp = sql_chain.invoke({"question": "What is the total 'Line Total' for Geiss Company"})
print(resp)

```sql
SELECT sum(T1.Line_Total) FROM sales_order AS T1 INNER JOIN customers AS T2 ON T1.Customer_Name_Index = T2.Customer_Index WHERE T2.Customer_Names = "Geiss Company"
```


In [24]:
import re
match = re.search(r"```sql\s*(.*?)\s*```", resp, re.DOTALL | re.IGNORECASE)
if match:
    query = match.group(1).strip()
    print(query)
else:
    print("No SQL query found in response.")

SELECT sum(T1.Line_Total) FROM sales_order AS T1 INNER JOIN customers AS T2 ON T1.Customer_Name_Index = T2.Customer_Index WHERE T2.Customer_Names = "Geiss Company"


In [25]:
result = db.run(query)
print(result)

[(5516846.999999994,)]
