# Topics
### Building a basic NL2SQL model
### Adding few-shot examples
### Dynamic few-shot example selection
### Dynamic relevant table selection (Large DB Isssue)
### Customizing prompts
### Adding memory to the chatbot so that it answers follow-up questions related to the database.

In [1]:
#DB conection details

import os
from dotenv import load_dotenv
load_dotenv()

db_user = os.getenv("db_user")
db_password = os.getenv("db_password")
db_host = os.getenv("db_host")
db_name = os.getenv("db_name")

# # Building a basic NL2SQL model

In [None]:
# DB Connection 

from langchain_community.utilities.sql_database import SQLDatabase

db = SQLDatabase.from_uri(
    database_uri="mssql+pyodbc://**?driver=ODBC+Driver+17+for+SQL+Server&trusted_connection=yes"  ## update your conn str here
)

#SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")

In [5]:
#Check DB info

print(db.dialect)
print("-"*100)
print(db.get_usable_table_names())
print("-"*100)
print(db.table_info)

mssql
----------------------------------------------------------------------------------------------------
['M_ACCESS', 'M_ACCESS_OBJ_ROLE_MAP', 'M_ACCESS_PANEL', 'M_ACCESS_ROLES', 'M_ACCESS_USR_ROLE_MAP', 'M_AUDITTRAIL', 'M_AUTHORITY_MATRIX1', 'M_BUSINESS_PLACE', 'M_COMPANY', 'M_COST_CENTER', 'M_DEPARTMENT', 'M_DESIGNATION', 'M_DOCUMENTS', 'M_EMPLOYEE', 'M_GL_CODES', 'M_GRADE', 'M_INVDATE_CAL', 'M_INVOICE_APPROVAL_MATRIX', 'M_INVOICE_DOCUMENTS', 'M_INVOICE_FI_MATRIX', 'M_INVOICE_MM_MATRIX', 'M_LOCATION', 'M_NATURE_OF_EXPENSE', 'M_PRIODICITY_OF_EXP', 'M_PROJECT_PURCHORG', 'M_REJECTION_REASON', 'M_SAP_FIN_DTL', 'M_TAX_CODE', 'M_TDS_TAXCODE', 'M_UNIQUENUMBER', 'M_USER', 'M_VENDOR', 'TEMP_INVOICE_HDR', 'TEMP_INVOICE_HDR1', 'T_ADVANCE_HDR', 'T_INVOICE_PAYMET_DTL', 'T_INVOICE_SUBMISSION_DTL', 'T_INVOICE_SUBMISSION_HDR', 'T_INVOICE_SUBMISSION_PODTL', 'T_LOGIN', 'T_MAILLOG', 'T_UPLOADINVOICE_MODIFICATIONDATA_HDR_ORG']
---------------------------------------------------------------------------

In [None]:
#Import Keys

os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
os.environ["LANGCHAIN_TRACING_V2"] = os.getenv("LANGCHAIN_TRACING_V2")
os.environ["LANGCHAIN_PROJECT"] = os.getenv("LANGCHAIN_PROJECT")
os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_API_KEY")

In [7]:
# Function to clean SQL Query generated by OpenAI LLM by removing extr chars

import re

def clean_sql_query(text: str) -> str:
    """
    Clean SQL query by removing code block syntax, various SQL tags, backticks,
    prefixes, and unnecessary whitespace while preserving the core SQL query.

    Args:
        text (str): Raw SQL query text that may contain code blocks, tags, and backticks

    Returns:
        str: Cleaned SQL query
    """
    # Step 1: Remove code block syntax and any SQL-related tags
    # This handles variations like ```sql, ```SQL, ```SQLQuery, etc.
    block_pattern = r"```(?:sql|SQL|SQLQuery|mysql|postgresql)?\s*(.*?)\s*```"
    text = re.sub(block_pattern, r"\1", text, flags=re.DOTALL)

    # Step 2: Handle "SQLQuery:" prefix and similar variations
    # This will match patterns like "SQLQuery:", "SQL Query:", "MySQL:", etc.
    prefix_pattern = r"^(?:SQL\s*Query|SQLQuery|MySQL|PostgreSQL|SQL)\s*:\s*"
    text = re.sub(prefix_pattern, "", text, flags=re.IGNORECASE)

    # Step 3: Extract the first SQL statement if there's random text after it
    # Look for a complete SQL statement ending with semicolon
    sql_statement_pattern = r"(SELECT.*?;)"
    sql_match = re.search(sql_statement_pattern, text, flags=re.IGNORECASE | re.DOTALL)
    if sql_match:
        text = sql_match.group(1)

    # Step 4: Remove backticks around identifiers
    text = re.sub(r'`([^`]*)`', r'\1', text)

    # Step 5: Normalize whitespace
    # Replace multiple spaces with single space
    text = re.sub(r'\s+', ' ', text)

    # Step 6: Preserve newlines for main SQL keywords to maintain readability
    keywords = ['SELECT', 'FROM', 'WHERE', 'GROUP BY', 'HAVING', 'ORDER BY',
               'LIMIT', 'JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'INNER JOIN',
               'OUTER JOIN', 'UNION', 'VALUES', 'INSERT', 'UPDATE', 'DELETE']

    # Case-insensitive replacement for keywords
    pattern = '|'.join(r'\b{}\b'.format(k) for k in keywords)
    text = re.sub(f'({pattern})', r'\n\1', text, flags=re.IGNORECASE)

    # Step 7: Final cleanup
    # Remove leading/trailing whitespace and extra newlines
    text = text.strip()
    text = re.sub(r'\n\s*\n', '\n', text)

    # Step 8: 
    if text.startswith("`sql SQLQuery:"):
        text = text[len("`sql SQLQuery:"):].lstrip()

    return text

In [8]:
# Generate Query using DB info 

from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model=os.getenv("OPENAI_MODEL"), temperature=0)
generate_query = create_sql_query_chain(llm, db)

query = generate_query.invoke({"question": "list all unique vendors present"})
print(query)


```sql
SQLQuery: SELECT DISTINCT [VENDOR_NAME] FROM [T_INVOICE_SUBMISSION_HDR] WHERE [VENDOR_NAME] IS NOT NULL


In [9]:
# Call Function to clean SQL Query generated by OpenAI LLM to removing extr chars so that we can execute without syntax error

print("-"*150)
query = clean_sql_query(query)
print(query)

------------------------------------------------------------------------------------------------------------------------------------------------------
SELECT DISTINCT [VENDOR_NAME] 
FROM [T_INVOICE_SUBMISSION_HDR] 
WHERE [VENDOR_NAME] IS NOT NULL


In [10]:
# Execute Query 

from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)
execute_query.invoke(query)



  execute_query = QuerySQLDataBaseTool(db=db)


'[(\'Sadguru Broadband Service\',), (\'BSB ELECTRONICS\',), (\'SWITCHING_CHANNELS\',), (\'Dhruva Advisors LLP\',), (\'Star Link Cable Network\',), (\'RAHUL AGARWAL \',), (\'MARK CERTIFICATION CONSULTANTS\',), (\'Aarushi Gupta /sanjayg_kn            \',), (\'Rajasthan Cable Network Service Sunel /5012968\',), (\'5013307_Fypoint_Coco_Eaeg_Shreeenterprises_New\',), (\'Sushil Sundaram/ss_sushils\',), (\'GUPTA PRINTERS\',), (\'S. K. CONSULTANTS/ss_santoshikamble\',), (\'sumitb_ws / Rutuja Pradip Ambekar\',), (\'Fibra Netway Private Limited\',), (\'Umesh Cable Network\',), (\'5001768_Ftth_Coco_W1_Wc12_Navsonarbala\',), (\'5015805 - Coco_Honesty Telecom - Bhandup Maharashtra Nagar\',), (\'RADHEY CABLE NETWORK\',), (\'Sway Computer\',), (\'Jayram J Palav\',), (\'Ajinkya Modak/ss_ajinkya09\',), (\'Ram_Cable NEtwork_Delhi\',), (\'Virendra Gunvantrai Bhatt\',), (\'SSV Broadband Pvt.Ltd.\',), (\'ABC/TEST\',), (\'SHIVA SAI CABLE NETWORK\',), (\'5000061_Delhi-Sargam_Cable\',), ("Prashant James D\'so

In [11]:
# Create chain -> to generate query -> to clean quert -> to execute query & get result

from langchain_core.runnables import RunnablePassthrough, RunnableLambda

chain = generate_query | RunnableLambda(clean_sql_query) | execute_query

chain.invoke({"question": "how many invoices are completed"})
     

'[(29930,)]'

In [12]:
#Prompt

chain.get_prompts()[0].pretty_print()

You are an MS SQL expert. Given an input question, first create a syntactically correct MS SQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the TOP clause as per MS SQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in square brackets ([]) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CAST(GETDATE() as date) function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQL

In [13]:
# result Chain -> to generate query -> to clean quert -> to execute query & get result -> Rephrase answer using Question asked, query & Sql execution reult

from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda

answer_prompt = PromptTemplate.from_template(
       """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

rephrase_answer = answer_prompt | llm | StrOutputParser()

chain = (
    RunnablePassthrough.assign(query=generate_query | RunnableLambda(clean_sql_query)).assign(
        result=itemgetter("query") | execute_query
    )
    | rephrase_answer
)

chain.invoke({"question": "for how many companies invoices are completed"})

'There are 3,214 companies for which invoices are completed.'

# Adding few-shot examples

In [14]:
# Establish a set of questions and their associated queries to serve as reference inputs for the language model.

examples = [
    {
        "input": "For how many companies have invoices been completed",
        "query": "SELECT COUNT(DISTINCT [VENDOR_NAME]) AS CompletedInvoicesCount FROM [T_INVOICE_SUBMISSION_HDR] WHERE [STATUS] = 'Completed';"
    },
    {
        "input": "What is the total number of completed invoices",
        "query": "SELECT COUNT(*) AS CompletedInvoices FROM T_INVOICE_SUBMISSION_HDR WHERE STATUS = 'Completed'"
    },
    {
        "input": "How many orders are there",
        "query": "SELECT COUNT(*) AS [TotalOrders] FROM [T_INVOICE_SUBMISSION_HDR]"
    },
    {
        "input": "list all unique vendors present",
        "query": "SQLQuery: SELECT DISTINCT [VENDOR_NAME] FROM [T_INVOICE_SUBMISSION_HDR] WHERE [VENDOR_NAME] IS NOT NULL"
    }
    
]
     

In [15]:
# Convert above example into Human statements (questions) & AI(LLM) statements (response) 

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate

example_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{input}\nSQLQuery:"),
        ("ai", "{query}"),
    ]
)
few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    examples=examples,
    # input_variables=["input","top_k"],
    input_variables=["input"],
)
print(few_shot_prompt.format(input1="What is the count of invoices"))

Human: For how many companies have invoices been completed
SQLQuery:
AI: SELECT COUNT(DISTINCT [VENDOR_NAME]) AS CompletedInvoicesCount FROM [T_INVOICE_SUBMISSION_HDR] WHERE [STATUS] = 'Completed';
Human: What is the total number of completed invoices
SQLQuery:
AI: SELECT COUNT(*) AS CompletedInvoices FROM T_INVOICE_SUBMISSION_HDR WHERE STATUS = 'Completed'
Human: How many orders are there
SQLQuery:
AI: SELECT COUNT(*) AS [TotalOrders] FROM [T_INVOICE_SUBMISSION_HDR]
Human: list all unique vendors present
SQLQuery:
AI: SQLQuery: SELECT DISTINCT [VENDOR_NAME] FROM [T_INVOICE_SUBMISSION_HDR] WHERE [VENDOR_NAME] IS NOT NULL


# Dynamic few-shot example selection

In [16]:
#Using Semantic search, choose only the relavent set of examples to pass as reference  

from langchain_chroma import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings

vectorstore = Chroma()
vectorstore.delete_collection()
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    vectorstore,
    k=2,
    input_keys=["input"],
)
example_selector.select_examples({"input": "What is the count of invoices?"})

[{'input': 'What is the total number of completed invoices',
  'query': "SELECT COUNT(*) AS CompletedInvoices FROM T_INVOICE_SUBMISSION_HDR WHERE STATUS = 'Completed'"},
 {'input': 'For how many companies have invoices been completed',
  'query': "SELECT COUNT(DISTINCT [VENDOR_NAME]) AS CompletedInvoicesCount FROM [T_INVOICE_SUBMISSION_HDR] WHERE [STATUS] = 'Completed';"}]

In [17]:
# Convert relavent set of examples into Human statements (questions) & AI(LLM) statements (response) 

few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    example_selector=example_selector,
    input_variables=["input","top_k"],
)
print(few_shot_prompt.format(input="What is the count of invoices?"))

Human: What is the total number of completed invoices
SQLQuery:
AI: SELECT COUNT(*) AS CompletedInvoices FROM T_INVOICE_SUBMISSION_HDR WHERE STATUS = 'Completed'
Human: For how many companies have invoices been completed
SQLQuery:
AI: SELECT COUNT(DISTINCT [VENDOR_NAME]) AS CompletedInvoicesCount FROM [T_INVOICE_SUBMISSION_HDR] WHERE [STATUS] = 'Completed';


In [18]:
# Combine above steps with user Question

final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries."),
        few_shot_prompt,
        ("human", "{input}"),
    ]
)
print(final_prompt.format(input="What is the count of invoices?",table_info="some table info"))

System: You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.

Here is the relevant table info: some table info

Below are a number of examples of questions and their corresponding SQL queries.
Human: What is the total number of completed invoices
SQLQuery:
AI: SELECT COUNT(*) AS CompletedInvoices FROM T_INVOICE_SUBMISSION_HDR WHERE STATUS = 'Completed'
Human: For how many companies have invoices been completed
SQLQuery:
AI: SELECT COUNT(DISTINCT [VENDOR_NAME]) AS CompletedInvoicesCount FROM [T_INVOICE_SUBMISSION_HDR] WHERE [STATUS] = 'Completed';
Human: What is the count of invoices?


In [19]:
# result Chain -> to generate query with relavent set of Human question & AI response of query-> to clean query -> to execute query & get result -> Rephrase answer using Question asked, query & Sql execution reult


generate_query = create_sql_query_chain(llm, db,final_prompt)

chain = (
    RunnablePassthrough.assign(query=generate_query | RunnableLambda(clean_sql_query)).assign(
        result=itemgetter("query") | execute_query
    )
    | rephrase_answer
)

chain.invoke({"question": "What is the count of invoices?"})

'The count of invoices is 36,090.'

# Dynamic relevant table selection

In [20]:
# Select the most relevant database tables and include their details to improve the accuracy of query generation by the LLM

from operator import itemgetter
#from langchain.chains.openai_tools import create_extraction_chain_pydantic
from pydantic import BaseModel, Field
from typing import List
import pandas as pd

def get_table_details():
    # Read the CSV file into a DataFrame
    table_description = pd.read_excel("TableInfo.xlsx")
    table_docs = []

    # Iterate over the DataFrame rows to create Document objects
    table_details = ""
    for index, row in table_description.iterrows():
        table_details = table_details + "Table Name:" + row['table_name'] + "\n" + "Table Description:" + row['description'] + "\n\n"

    return table_details


class Table(BaseModel):
    """Table in SQL database."""

    name: List[str] = Field(description="List of Name of tables in SQL database.")

# table_names = "\n".join(db.get_usable_table_names())
table_details = get_table_details()
print(table_details)

Table Name:M_DEPARTMENT
Table Description:Holds the information of Depart ment name department description department head and creation date

Table Name:T_INVOICE_SUBMISSION_HDR
Table Description:Holds the information of Invoice type, vendor name, PO number, Invoice number, invoice date, invoice value, Foreign key of department, project name, project location




In [21]:
# The input consists of a table name and a short description. The output will be a list of relevant database tables based on this context

table_details_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", """Return the names of ALL the SQL tables that MIGHT be relevant to the user question.
                          The tables are:

                          {table_details}

                          Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""),
            ("human", "{question}")
        ]
    )

structured_llm = llm.with_structured_output(Table)

table_chain = table_details_prompt | structured_llm
tables = table_chain.invoke({"question": "which department has submitted the highest number of invoices?", "table_details":table_details})
tables

Table(name=['M_DEPARTMENT', 'T_INVOICE_SUBMISSION_HDR'])

In [22]:
# function to defines and a processing pipeline to extract relevant database table names based on a user's question and associated table metadata.

def get_tables(table_response: Table) -> List[str]:
    """
    Extracts the list of table names from a Table object.

    Args:
        table_response (Table): A Pydantic Table object containing table names.

    Returns:
        List[str]: A list of table names.
    """
    return table_response.name

select_table = {"question": itemgetter("question"), "table_details": itemgetter("table_details")} | table_chain | get_tables

select_table.invoke({"question": "give me details of customer and their order count", "table_details":table_details})
     

['M_DEPARTMENT', 'T_INVOICE_SUBMISSION_HDR']

In [23]:
# result Chain -> to generate query with (relavent set of Human question & AI response of query) & (relavent set of DB tables & its details) -> to clean query -> to execute query & get result -> Rephrase answer using Question asked, query & Sql execution reult

chain = (
RunnablePassthrough.assign(table_names_to_use=select_table) |
RunnablePassthrough.assign(query=generate_query | RunnableLambda(clean_sql_query)).assign(
    result=itemgetter("query") | execute_query
)
| rephrase_answer
)
chain.invoke({"question": "what is the total invoice count in last 3 months ?", "table_details":table_details})
     
     

'The total invoice count in the last 3 months is 1,653.'

In [25]:
chain.invoke({"question": "what is the toatl count in last 2 months ?", "table_details":table_details})

'The total count in the last 2 months is 579.'

# Adding memory to the chatbot so that it answers follow-up questions

In [26]:
# Adding Conversation history 

final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries. Those examples are just for referecne and hsould be considered while answering follow up questions"),
        few_shot_prompt,
        MessagesPlaceholder(variable_name="messages"),
        ("human", "{input}"),
    ]
)
print(final_prompt.format(input="what is the total invoice count in last 3 months ?",table_info="some table info",messages=[]))

System: You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.

Here is the relevant table info: some table info

Below are a number of examples of questions and their corresponding SQL queries. Those examples are just for referecne and hsould be considered while answering follow up questions
Human: What is the total number of completed invoices
SQLQuery:
AI: SELECT COUNT(*) AS CompletedInvoices FROM T_INVOICE_SUBMISSION_HDR WHERE STATUS = 'Completed'
Human: For how many companies have invoices been completed
SQLQuery:
AI: SELECT COUNT(DISTINCT [VENDOR_NAME]) AS CompletedInvoicesCount FROM [T_INVOICE_SUBMISSION_HDR] WHERE [STATUS] = 'Completed';
Human: what is the total invoice count in last 3 months ?


In [27]:
# Chain -> above steps + conversation history 

from langchain.memory import ChatMessageHistory
history = ChatMessageHistory()

generate_query = create_sql_query_chain(llm, db,final_prompt)

chain = (
RunnablePassthrough.assign(table_names_to_use=select_table) |
RunnablePassthrough.assign(query=generate_query | RunnableLambda(clean_sql_query)).assign(
    result=itemgetter("query") | execute_query
)
| rephrase_answer
)

In [38]:
# Testing above code

question = "what is the total invoice count in last 3 months ?"

response = chain.invoke({"question": question,"messages":history.messages, "table_details":table_details})
response

'The total invoice count in the last 3 months is 1,653.'

In [39]:
# add conversation history 

history.add_user_message(question)
history.add_ai_message(response)

In [40]:
# Print conversation history 

history.messages

[HumanMessage(content='what is the total invoice count in last 3 months ?', additional_kwargs={}, response_metadata={}),
 AIMessage(content='The total invoice count in the last 3 months is 1,653.', additional_kwargs={}, response_metadata={})]

In [41]:
#test 2

question = "what is the total invoice count in last 2 months ?"

response = chain.invoke({"question": question,"messages":history.messages, "table_details":table_details})
response
     

'The total invoice count in the last 2 months is 1018.'

In [42]:
# Print conversation history 

history.messages

[HumanMessage(content='what is the total invoice count in last 3 months ?', additional_kwargs={}, response_metadata={}),
 AIMessage(content='The total invoice count in the last 3 months is 1,653.', additional_kwargs={}, response_metadata={})]

In [43]:
# add conversation history 

history.add_user_message(question)
history.add_ai_message(response)


In [44]:
# Print conversation history 

history.messages

[HumanMessage(content='what is the total invoice count in last 3 months ?', additional_kwargs={}, response_metadata={}),
 AIMessage(content='The total invoice count in the last 3 months is 1,653.', additional_kwargs={}, response_metadata={}),
 HumanMessage(content='what is the total invoice count in last 2 months ?', additional_kwargs={}, response_metadata={}),
 AIMessage(content='The total invoice count in the last 2 months is 1018.', additional_kwargs={}, response_metadata={})]

In [46]:
#test 3

question = "1 month?"

response = chain.invoke({"question": question,"messages":history.messages, "table_details":table_details})
response

'The answer to your question "1 month?" is that there were a total of 298 invoices submitted in the last month.'

In [47]:
#test 3 verification

question = "what is the total invoice count in last 1 month ?"

response = chain.invoke({"question": question,"messages":history.messages, "table_details":table_details})
response
     

'The total invoice count in the last 1 month is 298.'