In [1]:
#Importing Libraries
from langchain.utilities import SQLDatabase
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import create_sql_query_chain
from langchain.schema.runnable import RunnablePassthrough
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.output_parsers import StrOutputParser

In [3]:
pip install langchain langchain-google-genai openai pymysql

Collecting pymysql
  Downloading pymysql-1.1.2-py3-none-any.whl.metadata (4.3 kB)
Downloading pymysql-1.1.2-py3-none-any.whl (45 kB)
Installing collected packages: pymysql
Successfully installed pymysql-1.1.2
Note: you may need to restart the kernel to use updated packages.


In [4]:
#Connect MySQL database
host = 'localhost'
port = '3306'
username = 'root'
password = 'root'
database_schema = 'text_to_sql'

mysql_uri = f"mysql+pymysql://{username}:{password}@{host}:{port}/{database_schema}"

db = SQLDatabase.from_uri(mysql_uri, sample_rows_in_table_info=2)

In [5]:
db = SQLDatabase.from_uri(mysql_uri, sample_rows_in_table_info=1)

db.get_table_info()

'\nCREATE TABLE budget (\n\t`Product_Names` TEXT, \n\t`Budget` INTEGER\n)COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB DEFAULT CHARSET=utf8mb4\n\n/*\n1 rows from budget table:\nProduct_Names\tBudget\nProduct 1\t1613979\n*/\n\n\nCREATE TABLE customers (\n\t`Customer_Name_Index` INTEGER, \n\t`Customer_Names` TEXT\n)COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB DEFAULT CHARSET=utf8mb4\n\n/*\n1 rows from customers table:\nCustomer_Name_Index\tCustomer_Names\n1\tGeiss Company\n*/\n\n\nCREATE TABLE orders (\n\t`OrderNumber` TEXT, \n\t`OrderDate` TEXT, \n\t`Customer_Name_Index` INTEGER, \n\t`Channel` TEXT, \n\t`Currency_Code` TEXT, \n\t`Warehouse_Code` TEXT, \n\t`Delivery_Regoin_index` INTEGER, \n\t`Product_Index` INTEGER, \n\t`Product_Quantity` INTEGER, \n\t`Unit_Price` DOUBLE, \n\t`Line_Total` DOUBLE, \n\t`Total_Unit_Cost` DOUBLE\n)COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB DEFAULT CHARSET=utf8mb4\n\n/*\n1 rows from orders table:\nOrderNumber\tOrderDate\tCustomer_Name_Index\tChannel\tCurrency_Code\tWar

In [6]:
#Create the LLM Prompt Template 
from langchain_core.prompts import ChatPromptTemplate

template = """Based on the table schema below, Write a SQL Query that would answer the user's question:
Remember : Only provide me the sql query don't include anything else. Provide me sql query in a single line don't add line breaks 
Table Schema : {schema} 
Question : {question} 
SQL Query:
"""

prompt = ChatPromptTemplate.from_template(template)

In [7]:
# get the schema of database
def get_schema(db):
    schema = db.get_table_info()
    return schema

In [None]:
llm = ChatGoogleGenerativeAI(
    model = 'gemini-2.5-flash',
    api_key = 'XYZ')    # Replace 'XYZ' with your actual Google Generative AI API key

In [9]:
# Define the chain
sql_chain = (
    RunnablePassthrough.assign(schema=lambda _: get_schema(db))
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

In [10]:
# test the SQL Query chain with a sample question

resp = sql_chain.invoke({"question": "What is the total 'Line Total' for Geiss Company"})
print(resp)

SELECT SUM(T1.Line_Total) FROM orders AS T1 INNER JOIN customers AS T2 ON T1.Customer_Name_Index = T2.Customer_Name_Index WHERE T2.Customer_Names = 'Geiss Company'


In [11]:
resp = sql_chain.invoke({"question": "How many total customers are we have???"})
print(resp)

SELECT COUNT(*) FROM customers


In [12]:
db.run(resp)

'[(25,)]'