<a href="https://colab.research.google.com/github/VisvaV/Natural-Language-to-SQL-Query-Generation-with-Gemini/blob/main/NL2SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U langchain langchain-openai
!pip install langchain_community langchain pymysql chromadb -q
!pip install langchain-google-genai
!pip install "cloud-sql-python-connector[pymysql]"
!pip install langsmith
!pip install --quiet google-cloud-sql-connector pymysql sqlalchemy google-auth


In [2]:
import os
from google.colab import userdata

gemini_key = userdata.get('GOOGLE_API_KEY')
langsmith_key = userdata.get('LANGSMITH_API_KEY')

os.environ["GOOGLE_API_KEY"] = gemini_key
os.environ["LANGSMITH_API_KEY"] = langsmith_key
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_PROJECT"] = "gemini-demo"
os.environ["LANGSMITH_TRACING_V2"] = "true"


In [3]:
LANGSMITH_TRACING=True
LANGSMITH_ENDPOINT="https://api.smith.langchain.com"
LANGSMITH_API_KEY="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"
LANGSMITH_PROJECT="google-genai-trace-demo"
OPENAI_API_KEY="AIzaSyAzbxmAPyzcR595QvjWsTHoIlprJmaKlJQ"

In [4]:
from langchain_google_genai import GoogleGenerativeAI
model = GoogleGenerativeAI(model="gemini-2.0-flash")
model.invoke("Write a poem on topic love.")

"The whisper first, a feather's light,\nA glance that held a star-filled night.\nA hesitant smile, a touch so brief,\nA nascent hope, a sweet relief.\n\nThe seeds were sown in fertile ground,\nWhere tender feelings could be found.\nAnd slowly, surely, they took root,\nBearing blossoms, a vibrant loot.\n\nLove is a sunrise, gold and bright,\nChasing away the lonely night.\nA gentle rain, a cleansing tear,\nWashing away all doubt and fear.\n\nIt's holding hands in silent grace,\nFinding solace in a warm embrace.\nIt's laughter shared, a joyful sound,\nOn solid, understanding ground.\n\nBut love is also stormy seas,\nAnd branches swaying in the breeze.\nIt's patience tested, words unsaid,\nAnd battles fought inside the head.\n\nYet through the trials, love remains,\nA flickering candle in the rains.\nA constant force, a guiding star,\nLeading us closer, near and far.\n\nSo cherish love in all its forms,\nWeather its sunshine, brave its storms.\nFor in its depths, a truth resides,\nThat lo

In [None]:
from google.colab import files
uploaded = files.upload()
import os
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "service-account-api-key"


In [8]:
from google.cloud.sql.connector import Connector
import sqlalchemy
from google.oauth2 import service_account
import pymysql

key_path = "json-key-path"

credentials = service_account.Credentials.from_service_account_file(
    key_path
)
connector = Connector(credentials=credentials)

def getconn():
    conn = connector.connect(
        os.getenv("INSTANCE_CONNECTION_NAME", "your-project:your-region:your-instance"),
        "pymysql",
        user=os.getenv("DB_USER", "visva"),
        password=os.getenv("DB_PASSWORD", "your-password"),
        db=os.getenv("DB_NAME", "classicmodels"),
    )
    return conn

engine = sqlalchemy.create_engine(
    "mysql+pymysql://",
    creator=getconn,
)

In [9]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase(engine)

In [10]:
def clean_sql_query(raw_query: str) -> str:
    lines = raw_query.strip().splitlines()
    cleaned_lines = [line for line in lines if not line.strip().startswith("```")]
    return "\n".join(cleaned_lines).strip()

In [11]:
from langchain_google_genai import GoogleGenerativeAI
from langchain.chains import create_sql_query_chain
from sqlalchemy import text

llm = GoogleGenerativeAI(model="gemini-2.0-flash")

generate_query = create_sql_query_chain(llm, db)

#question = "insert a record into products table with productCode = S10_6969420, productName = Akshay thar, productLine= Classic cars, productScale = 1:10,productVendor = axtrLabs, productDescription = this car has fine over 40k rs, quantityInStock = 1, buyPrice = 500000, MSRP=95"
question = "give me the top 5 most expensive car"
query = generate_query.invoke({"question": question})
cleaned_query = clean_sql_query(query)

print(cleaned_query)
#result = db.run(text(cleaned_query))



SELECT
  `productName`,
  `buyPrice`
FROM products
ORDER BY
  `buyPrice` DESC
LIMIT 5;


In [12]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
execute_query = QuerySQLDataBaseTool(db=db)
execute_query.invoke({"query": cleaned_query})

  execute_query = QuerySQLDataBaseTool(db=db)


"[('Akshay thar', Decimal('500000.00')), ('Akshay thar', Decimal('500000.00')), ('1962 LanciaA Delta 16V', Decimal('103.42')), ('1998 Chrysler Plymouth Prowler', Decimal('101.51')), ('1952 Alpine Renault 1300', Decimal('98.58'))]"

In [13]:
import re
from langchain_core.runnables import RunnableLambda
from langchain_core.output_parsers import StrOutputParser

def strip_sql_fences(sql_string: str) -> str:
    return re.sub(r"```[\w]*\n([\s\S]*?)```", r"\1", sql_string).strip()

clean_sql_chain = RunnableLambda(strip_sql_fences)

chain = generate_query | clean_sql_chain | execute_query

result = chain.invoke({"question": "How many stocks are left in total?"})
print(result)


[(Decimal('555133'),)]


In [14]:
chain.get_prompts()[0].pretty_print()

You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the S

In [29]:
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
import re

def strip_sql_fences(sql_string: str) -> str:
    return re.sub(r"```[\w]*\n([\s\S]*?)```", r"\1", sql_string).strip()

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

rephrase_answer = answer_prompt | llm | StrOutputParser()

chain = (
    RunnablePassthrough
    .assign(query=generate_query)
    .assign(query=itemgetter("query") | RunnableLambda(strip_sql_fences))
    .assign(result=itemgetter("query") | execute_query)
    | rephrase_answer
)

result = chain.invoke({"question": "How many customers are there in total?"})
print(result)


There are 122 customers in total.


In [16]:
examples = [
    {
        "input": "List all customers located in France.",
        "query": "SELECT customerName FROM customers WHERE country = 'France';"
    },
    {
        "input": "Retrieve the names of employees who report to employee number 1143.",
        "query": "SELECT firstName, lastName FROM employees WHERE reportsTo = 1143;"
    },
    {
        "input": "Find the total number of orders placed by customer number 103.",
        "query": "SELECT COUNT(*) FROM orders WHERE customerNumber = 103;"
    },
    {
        "input": "Get the names of products that are in the 'Classic Cars' product line.",
        "query": "SELECT productName FROM products WHERE productLine = 'Classic Cars';"
    },
    {
        "input": "Show the order dates for orders with status 'Shipped'.",
        "query": "SELECT orderDate FROM orders WHERE status = 'Shipped';"
    },
    {
        "input": "List all payments made by customer number 141.",
        "query": "SELECT paymentDate, amount FROM payments WHERE customerNumber = 141;"
    },
    {
        "input": "Find the office codes and cities for offices located in the USA.",
        "query": "SELECT officeCode, city FROM offices WHERE country = 'USA';"
    },
    {
        "input": "Retrieve the names of customers who have not placed any orders.",
        "query": "SELECT customerName FROM customers WHERE customerNumber NOT IN (SELECT customerNumber FROM orders);"
    },
    {
        "input": "Get the product names and quantities ordered for order number 10123.",
        "query": "SELECT productName, quantityOrdered FROM orderdetails JOIN products USING (productCode) WHERE orderNumber = 10123;"
    },
    {
        "input": "List the names of employees working in the San Francisco office.",
        "query": "SELECT firstName, lastName FROM employees WHERE officeCode = (SELECT officeCode FROM offices WHERE city = 'San Francisco');"
    }
]


In [17]:
 from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate

 example_prompt = ChatPromptTemplate.from_messages(
     [
         ("human", "{input}\nSQLQuery:"),
         ("ai", "{query}"),
     ]
 )
 few_shot_prompt = FewShotChatMessagePromptTemplate(
     example_prompt=example_prompt,
     examples=examples,
     input_variables=["input"],
 )
 print(few_shot_prompt.format(input1="How many products are there?"))


Human: List all customers located in France.
SQLQuery:
AI: SELECT customerName FROM customers WHERE country = 'France';
Human: Retrieve the names of employees who report to employee number 1143.
SQLQuery:
AI: SELECT firstName, lastName FROM employees WHERE reportsTo = 1143;
Human: Find the total number of orders placed by customer number 103.
SQLQuery:
AI: SELECT COUNT(*) FROM orders WHERE customerNumber = 103;
Human: Get the names of products that are in the 'Classic Cars' product line.
SQLQuery:
AI: SELECT productName FROM products WHERE productLine = 'Classic Cars';
Human: Show the order dates for orders with status 'Shipped'.
SQLQuery:
AI: SELECT orderDate FROM orders WHERE status = 'Shipped';
Human: List all payments made by customer number 141.
SQLQuery:
AI: SELECT paymentDate, amount FROM payments WHERE customerNumber = 141;
Human: Find the office codes and cities for offices located in the USA.
SQLQuery:
AI: SELECT officeCode, city FROM offices WHERE country = 'USA';
Human: Ret

In [18]:
from langchain_community.vectorstores import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_google_genai import GoogleGenerativeAIEmbeddings

embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")

vectorstore = Chroma(embedding_function=embeddings)
vectorstore.delete_collection()

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples=examples,
    embeddings=embeddings,
    vectorstore_cls = Chroma(embedding_function=embeddings),
    k=2,
    input_keys=["input"],
)

print(example_selector.select_examples({"input": "how many employees we have?"}))

few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    example_selector=example_selector,
    input_variables=["input", "top_k"]
)

print(few_shot_prompt.format(input="How many products are there?", top_k=2))


  vectorstore = Chroma(embedding_function=embeddings)


[{'query': "SELECT firstName, lastName FROM employees WHERE officeCode = (SELECT officeCode FROM offices WHERE city = 'San Francisco');", 'input': 'List the names of employees working in the San Francisco office.'}, {'input': 'Retrieve the names of employees who report to employee number 1143.', 'query': 'SELECT firstName, lastName FROM employees WHERE reportsTo = 1143;'}]
Human: Get the names of products that are in the 'Classic Cars' product line.
SQLQuery:
AI: SELECT productName FROM products WHERE productLine = 'Classic Cars';
Human: Get the product names and quantities ordered for order number 10123.
SQLQuery:
AI: SELECT productName, quantityOrdered FROM orderdetails JOIN products USING (productCode) WHERE orderNumber = 10123;


In [19]:
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.prompts import ChatPromptTemplate

final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system",
"You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. "
"Only use the table and column names exactly as given in the table info. Do not assume any schema names like 'classicmodels'.\n\n"
"Here is the relevant table info: {table_info}\n\n"
"Below are a number of examples of questions and their corresponding SQL queries."
),
        few_shot_prompt,
        ("human", "{input}"),
    ]
)

print(final_prompt.format(input="How many products are there?", table_info="some table info"))

generate_query = create_sql_query_chain(llm, db, final_prompt)

def strip_sql_fences(output: dict) -> dict:
    if "query" in output:
        output["query"] = output["query"].strip().strip("```sql").strip("```").strip()
    return output

chain = (
    RunnablePassthrough
    .assign(query=generate_query)
    .assign(query=RunnableLambda(strip_sql_fences))
    .assign(result=itemgetter("query") | execute_query)
    | rephrase_answer
)

chain.invoke({
    "question": "How many customers with credit limit more than 50000 ?",
    "table_info": db.get_table_info()
})


System: You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Only use the table and column names exactly as given in the table info. Do not assume any schema names like 'classicmodels'.

Here is the relevant table info: some table info

Below are a number of examples of questions and their corresponding SQL queries.
Human: Get the names of products that are in the 'Classic Cars' product line.
SQLQuery:
AI: SELECT productName FROM products WHERE productLine = 'Classic Cars';
Human: Get the product names and quantities ordered for order number 10123.
SQLQuery:
AI: SELECT productName, quantityOrdered FROM orderdetails JOIN products USING (productCode) WHERE orderNumber = 10123;
Human: How many products are there?


'There are 85 customers with credit limit more than 50000.'

In [20]:
from google.colab import files
uploaded = files.upload()

Saving database_table_descriptions.csv to database_table_descriptions (1).csv


In [22]:
from operator import itemgetter
from langchain.chains.openai_tools import create_extraction_chain_pydantic
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List
import pandas as pd

def get_table_details():
    table_description = pd.read_csv("database_table_descriptions.csv")
    table_docs = []

    table_details = ""
    for index, row in table_description.iterrows():
        table_details = table_details + "Table Name:" + row['Table'] + "\n" + "Table Description:" + row['Description'] + "\n\n"

    return table_details


class Table(BaseModel):
    """Table in SQL database."""

    name: str = Field(description="Name of table in SQL database.")

table_details = get_table_details()
print(table_details)


Table Name:productlines
Table Description:Stores information about the different product lines offered by the company, including a unique name, textual description, HTML description, and image. Categorizes products into different lines.

Table Name:products
Table Description:Contains details of each product sold by the company, including code, name, product line, scale, vendor, description, stock quantity, buy price, and MSRP. Linked to the productlines table.

Table Name:offices
Table Description:Holds data on the company's sales offices, including office code, city, phone number, address, state, country, postal code, and territory. Each office is uniquely identified by its office code.

Table Name:employees
Table Description:Stores information about employees, including number, last name, first name, job title, contact info, and office code. Links to offices and maps organizational structure through the reportsTo attribute.

Table Name:customers
Table Description:Captures data on cus

In [23]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain.chains import LLMChain

table_details_prompt = PromptTemplate.from_template("""
Return the names of ALL the SQL tables that MIGHT be relevant to the user question.
The tables are:

{table_details}

Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed.

Question: {input}
Only list the relevant table names, comma separated.
""")

table_chain = LLMChain(llm=llm, prompt=table_details_prompt, output_parser=StrOutputParser())

tables_str = table_chain.invoke({
    "input": "give me details of customer and their order count",
    "table_details": table_details
})

tables = [t.strip() for t in tables_str['text'].split(",") if t.strip()]

print(tables)


  table_chain = LLMChain(llm=llm, prompt=table_details_prompt, output_parser=StrOutputParser())


['customers', 'orders']


In [24]:
def get_tables(tables_response: str) -> List[str]:
    return [t.strip() for t in tables_response.split(",") if t.strip()]

select_table = (
    {"input": itemgetter("question"), "table_details": itemgetter("table_details")}
    | table_details_prompt
    | llm
    | StrOutputParser()
    | get_tables
)

result = select_table.invoke({
    "question": "give me details of customer and their order count",
    "table_details": table_details
})

print(result)

['customers', 'orders']


In [34]:
import re
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from operator import itemgetter

def clean_generated_sql(sql_string: str) -> str:
    sql = re.sub(r"```[\w]*\n([\s\S]*?)```", r"\1", sql_string).strip()
    sql = re.sub(r"^(AI:|Answer:)\s*", "", sql, flags=re.IGNORECASE).strip()
    return sql

strip_fences = RunnableLambda(clean_generated_sql)

chain = (
    RunnablePassthrough.assign(table_names_to_use=select_table) |
    RunnablePassthrough.assign(query=generate_query) |
    RunnablePassthrough.assign(result=itemgetter("query") | strip_fences | execute_query) |
    rephrase_answer
)

response = chain.invoke({
    "question": "How many cutomers with order count more than 5",
    "table_details": table_details
})

print(response)


There are 2 customers with more than 5 orders.


In [35]:
from langchain.memory import ChatMessageHistory
history = ChatMessageHistory()


In [36]:
 final_prompt = ChatPromptTemplate.from_messages(
     [
         ("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries. Those examples are just for referecne and hsould be considered while answering follow up questions"),
         few_shot_prompt,
         MessagesPlaceholder(variable_name="messages"),
         ("human", "{input}"),
     ]
 )
 print(final_prompt.format(input="How many products are there?",table_info="some table info",messages=[]))


System: You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.

Here is the relevant table info: some table info

Below are a number of examples of questions and their corresponding SQL queries. Those examples are just for referecne and hsould be considered while answering follow up questions
Human: Get the names of products that are in the 'Classic Cars' product line.
SQLQuery:
AI: SELECT productName FROM products WHERE productLine = 'Classic Cars';
Human: Get the product names and quantities ordered for order number 10123.
SQLQuery:
AI: SELECT productName, quantityOrdered FROM orderdetails JOIN products USING (productCode) WHERE orderNumber = 10123;
Human: How many products are there?


In [39]:
import re
from langchain_core.runnables import RunnableLambda

def clean_generated_sql(sql_string: str) -> str:
    sql = re.sub(r"```[\w]*\n([\s\S]*?)```", r"\1", sql_string).strip()
    sql = re.sub(r"^(AI:|Answer:)\s*", "", sql, flags=re.IGNORECASE).strip()
    return sql

strip_fences = RunnableLambda(clean_generated_sql)


In [40]:
generate_query = create_sql_query_chain(llm, db, final_prompt)

chain = (
    RunnablePassthrough.assign(table_names_to_use=select_table) |
    RunnablePassthrough.assign(query=generate_query) |
    RunnablePassthrough.assign(result=itemgetter("query") | strip_fences | execute_query) |
    rephrase_answer
)

In [41]:
question = "How many cutomers with order count more than 5"
response = chain.invoke({"question": question,"messages":history.messages, "table_details": table_details})
print(response)


There are 2 customers with more than 5 orders.


In [42]:
from langchain_core.messages import HumanMessage, AIMessage

history.add_user_message(question)
history.add_ai_message(response)

for msg in history.messages:
    print(f"{msg.type.upper()}: {msg.content}")


HUMAN: How many cutomers with order count more than 5
AI: There are 2 customers with more than 5 orders.


In [43]:
response = chain.invoke({"question": "Can you list there names?","messages":history.messages, "table_details": table_details})
print(response)


Mini Gifts Distributors Ltd. and Euro+ Shopping Channel.


In [44]:
response = chain.invoke({
    "question": "Insert a new product into the products table with productCode 'S72_9999', productName 'Akshay XUV', productLine 'Classic Cars', productScale '1:18', productVendor 'Hot Wheels', productDescription 'On-going case since 2020', quantityInStock 500, buyPrice 750000.00, MSRP 120.00",
    "messages": history.messages,
    "table_details": table_details
})
print(response)

Okay. The SQL query you provided attempts to insert a new product into the `products` table with the specified values. Since there's no error message or output from the SQL result, we can assume the query was executed successfully.

Therefore, the answer is:
```
The product 'Akshay XUV' with product code 'S72_9999' has been successfully added to the products table.
```


In [45]:
response = chain.invoke({
    "question": "update the product name from Akshay XUV to Akshay SUV",
    "messages": history.messages,
    "table_details": table_details
})
print(response)

The SQL query successfully updated the product name from "Akshay XUV" to "Akshay SUV".
