<a href="https://colab.research.google.com/github/VisvaV/aXtr-Zerame-NL2SQL-demo/blob/main/nl2sqldemo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U langchain langchain-community langchain-openai langchain-google-genai
!pip install sqlalchemy chromadb -q


In [1]:
import os
import re
import sqlite3
import pandas as pd
from operator import itemgetter
from langchain_google_genai import GoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain, LLMChain
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder, FewShotChatMessagePromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain.memory import ChatMessageHistory
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_community.vectorstores import Chroma
from langchain_core.messages import HumanMessage, AIMessage


In [6]:
from google.colab import userdata
import os

os.environ["GOOGLE_API_KEY"] = userdata.get("GOOGLE_API_KEY")

from langchain_google_genai import GoogleGenerativeAI
llm = GoogleGenerativeAI(model="gemini-2.0-flash")


In [2]:
db_path = "/content/nl2sql_demo.sqlite"
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")


In [3]:

examples = [
    {
        "input": "List all customers located in France.",
        "query": "SELECT customerName FROM customers WHERE country = 'France';"
    },
    {
        "input": "Find the total number of orders placed by customer number 101.",
        "query": "SELECT COUNT(*) FROM orders WHERE customerNumber = 101;"
    },
    {
        "input": "Get the names of products that are in the 'Classic Cars' product line.",
        "query": "SELECT productName FROM products WHERE productLine = 'Classic Cars';"
    },
    {
        "input": "What are the order dates for orders with status 'Shipped'?",
        "query": "SELECT orderDate FROM orders WHERE status = 'Shipped';"
    }
]


In [7]:

example_prompt = ChatPromptTemplate.from_messages([
    ("human", "{input}\nSQLQuery:"),
    ("ai", "{query}"),
])

embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
vectorstore = Chroma(embedding_function=embeddings)
vectorstore.delete_collection()

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples=examples,
    embeddings=embeddings,
    vectorstore_cls=Chroma,
    k=2,
    input_keys=["input"]
)

few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    example_selector=example_selector,
    input_variables=["input"]
)


In [9]:

table_details = db.get_table_info()
table_details

'\nCREATE TABLE customers (\n\t"customerNumber" INTEGER, \n\t"customerName" TEXT, \n\tcountry TEXT, \n\tPRIMARY KEY ("customerNumber")\n)\n\n/*\n3 rows from customers table:\ncustomerNumber\tcustomerName\tcountry\n101\tAlice\tUSA\n102\tBob\tFrance\n103\tCharlie\tIndia\n*/\n\n\nCREATE TABLE orders (\n\t"orderNumber" INTEGER, \n\t"customerNumber" INTEGER, \n\t"orderDate" TEXT, \n\tstatus TEXT, \n\tPRIMARY KEY ("orderNumber"), \n\tFOREIGN KEY("customerNumber") REFERENCES customers ("customerNumber")\n)\n\n/*\n3 rows from orders table:\norderNumber\tcustomerNumber\torderDate\tstatus\n1001\t101\t2023-01-01\tShipped\n1002\t102\t2023-01-02\tPending\n1003\t101\t2023-01-03\tShipped\n*/\n\n\nCREATE TABLE products (\n\t"productCode" TEXT, \n\t"productName" TEXT, \n\t"productLine" TEXT, \n\t"quantityInStock" INTEGER, \n\t"buyPrice" REAL, \n\t"MSRP" REAL, \n\tPRIMARY KEY ("productCode")\n)\n\n/*\n3 rows from products table:\nproductCode\tproductName\tproductLine\tquantityInStock\tbuyPrice\tMSRP\nS1

In [15]:
from langchain.prompts import PromptTemplate

final_prompt = PromptTemplate(
    input_variables=["input", "table_info", "top_k"],
    template="""
You are a SQLite expert. Given an input question, first create a syntactically correct SQL query to run, then provide the answer.
Use only the tables listed below:

{table_info}

Use at most {top_k} relevant columns.

Question: {input}
SQLQuery:
""".strip()
)


In [16]:

generate_query = create_sql_query_chain(llm, db, final_prompt)
execute_query = QuerySQLDataBaseTool(db=db)


  execute_query = QuerySQLDataBaseTool(db=db)


In [24]:

def clean_sql_query(sql_string: str) -> str:
    sql = re.sub(r"```[\w]*\n([\s\S]*?)```", r"\1", sql_string).strip()
    sql = re.sub(r"^(AI:|Answer:)\s*", "", sql, flags=re.IGNORECASE).strip()
    return sql

strip_fences = RunnableLambda(clean_generated_sql)


In [25]:

rephrase_prompt = PromptTemplate.from_template("""
Given the following user question, SQL query, and SQL result, answer the question in natural language.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer:
""")
rephrase_answer = rephrase_prompt | llm | StrOutputParser()


In [26]:

history = ChatMessageHistory()


In [27]:
from operator import itemgetter
from langchain_core.runnables import RunnableLambda, RunnablePassthrough

chain = (
    RunnablePassthrough
    .assign(question=itemgetter("question"))
    .assign(query=lambda x: generate_query.invoke({
        "question": x["question"],
        "table_info": x["table_info"],
        "top_k": x.get("top_k", 5)
    }))
    .assign(query=itemgetter("query") | RunnableLambda(clean_sql_query))
    .assign(result=itemgetter("query") | execute_query)
    | rephrase_answer
)


In [28]:
response = chain.invoke({
    "question": "List the product names in Classic Cars",
    "table_info": table_details,
    "top_k": 5
})


In [30]:
followup = "What is their stock count?"
response2 = chain.invoke({
    "question": followup,
    "table_info": table_details,
    "top_k": 5
})
history.add_user_message(followup)
history.add_ai_message(response2)
print(response2)


The stock count for Akshay Thar is 10, for Ferrari F8 it is 5, and for Jeep Wrangler it is 8.
