
1. Building a basic NL2SQL model
2. Adding few-shot examples
3. Dynamic few-shot example selection
4. Dynamic relevant table selection (Large DB Isssue)
5. Customizing prompts
6. Adding memory to the chatbot so that it answers follow-up questions related to the database.

In [1]:
import os
#Import pandas as pd
#import streamlit as st
from dotenv import load_dotenv
from langchain_openai import AzureChatOpenAI
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
from langchain_core.pydantic_v1 import BaseModel, Field
import openai
from langchain.chains.openai_tools import create_extraction_chain_pydantic
from typing import List
from langchain.agents.agent_types import AgentType
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain.agents import create_sql_agent, ZeroShotAgent
from langchain.prompts.chat import ChatPromptTemplate
import langchain.globals
langchain.globals.set_verbose(True)  # Or False, depending on your needs
verbose = langchain.globals.get_verbose()

In [7]:
#!pip install langchain_openai langchain_community langchain pymysql chromadb -q

In [2]:
username = "dbWSS"
password = "LW)ppknteP'"
host = "used.database.windows.net"
database = "USED"
driver = "ODBC Driver 17 for SQL Server"

In [3]:
# Create connection string
connection_string = (
        r'mssql+pyodbc://dbWSS:*JlaIl'
        r'@useq.database.windows.net/USEQ'
        r'?driver=ODBC+Driver+17+for+SQL+Server'
s
    )


In [4]:
#db = SQLDatabase.from_uri(connection_string, schema="dbo",sample_rows_in_table_info=1,include_tables=['account','group','membership'],custom_table_info={'account':"account",'group':"group"})
db = SQLDatabase.from_uri(connection_string, schema="dbo",sample_rows_in_table_info=1,include_tables=['account','group','membership'])

###Building a basic NLSQL model

In [5]:
print(db.dialect)
print(db.get_usable_table_names())
print(db.table_info)

mssql
['account', 'group', 'membership']

CREATE TABLE dbo.[group] (
	id INTEGER NOT NULL IDENTITY(1,1), 
	job_id NVARCHAR(255) COLLATE SQL_Latin1_General_CP1_CI_AS NOT NULL, 
	group_name NVARCHAR(255) COLLATE SQL_Latin1_General_CP1_CI_AS NULL, 
	system_id INTEGER NULL, 
	system_name NVARCHAR(150) COLLATE SQL_Latin1_General_CP1_CI_AS NULL, 
	application_name NVARCHAR(150) COLLATE SQL_Latin1_General_CP1_CI_AS NULL DEFAULT (''), 
	platform_name NVARCHAR(50) COLLATE SQL_Latin1_General_CP1_CI_AS NULL, 
	group_type NVARCHAR(255) COLLATE SQL_Latin1_General_CP1_CI_AS NULL, 
	create_time DATETIME NULL, 
	createby NVARCHAR(255) COLLATE SQL_Latin1_General_CP1_CI_AS NULL, 
	update_time DATETIME NULL, 
	updateby NVARCHAR(255) COLLATE SQL_Latin1_General_CP1_CI_AS NULL, 
	is_privileged INTEGER NULL, 
	privileged_tier NVARCHAR(10) COLLATE SQL_Latin1_General_CP1_CI_AS NULL, 
	description1 NVARCHAR(4000) COLLATE SQL_Latin1_General_CP1_CI_AS NULL, 
	description2 NVARCHAR(4000) COLLATE SQL_Latin1_General

In [6]:
import os
load_dotenv()

openai.api_type = os.getenv('OPENAI_API_TYPE')
openai.api_version = os.getenv('OPENAI_API_VERSION')
openai.azure_endpoint = os.getenv('AZURE_OPENAI_ENDPOINT')
openai.api_key = os.getenv("OPENAI_API_KEY")

In [8]:
from langchain.chains import create_sql_query_chain
from langchain_openai import AzureChatOpenAI

llm = AzureChatOpenAI(
        deployment_name="nlp-gpt4",
        temperature=0.2,
        max_tokens=900,
        azure_endpoint=openai.azure_endpoint
    )

In [9]:
generate_query = create_sql_query_chain(llm, db)
query = generate_query.invoke({"question": "what is the group id of group name parisdev`"})

print(query)

SELECT TOP 1 [id] FROM dbo.[group] WHERE [group_name] = 'parisdev'


In [10]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
execute_query = QuerySQLDataBaseTool(db=db)
execute_query.invoke(query)

'[(4845,)]'

In [11]:
chain = generate_query | execute_query
chain.invoke({"question": "How many accounts are there in account table"})

'[(223,)]'

In [12]:
chain.get_prompts()[0].pretty_print()

You are an MS SQL expert. Given an input question, first create a syntactically correct MS SQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the TOP clause as per MS SQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in square brackets ([]) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CAST(GETDATE() as date) function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQL

In [13]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

rephrase_answer = answer_prompt | llm | StrOutputParser()

chain = (
    RunnablePassthrough.assign(query=generate_query).assign(
        result=itemgetter("query") | execute_query
    )
    | rephrase_answer
)

chain.invoke({"question": "How many accounts are there in account table"})

'There are 223 accounts in the account table.'

###Adding few-shot examples

In [14]:
examples = [
    {"input": "List all accounts.", "query": "SELECT * FROM [account];"},
    {
        "input": "Count account ID of account name adm?",
        "query": "SELECT COUNT(ID) FROM [account] WHERE acct_name ='bin';",
    },
    {
        "input": "List all group having platform name 'Linux Server'",
        "query": "SELECT g.* FROM [group] g JOIN [membership] m ON g.id = m.m_group_id JOIN [account] a ON a.id = m.account_id WHERE a.platform_name='Linux Server';",
    },
    {
        "input": "Consider display name as user name and provide all the user name where group name is Parisdev'",
         "query": "SELECT a.display_name AS user_name FROM [account] a INNER JOIN [membership] m ON a.id = m.account_id INNER JOIN [group] g ON m.m_group_id = g.id WHERE g.group_name = 'Parisdev';",
    },
    {
        "input": "Provide group name for account name root",
         "query": "SELECT g.group_name FROM [group] g INNER JOIN [membership] m ON g.id = m.m_group_id INNER JOIN [account] a ON m.account_id = a.id WHERE a.acct_name = 'root';",
    },
]

In [15]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate

example_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{input}\nSQLQuery:"),
        ("ai", "{query}"),
    ]
)
few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    examples=examples,
    # input_variables=["input","top_k"],
    input_variables=["input"],
)
print(few_shot_prompt.format(input1="How many groups are there?"))

Human: List all accounts.
SQLQuery:
AI: SELECT * FROM [account];
Human: Count account ID of account name adm?
SQLQuery:
AI: SELECT COUNT(ID) FROM [account] WHERE acct_name ='bin';
Human: List all group having platform name 'Linux Server'
SQLQuery:
AI: SELECT g.* FROM [group] g JOIN [membership] m ON g.id = m.m_group_id JOIN [account] a ON a.id = m.account_id WHERE a.platform_name='Linux Server';
Human: Consider display name as user name and provide all the user name where group name is Parisdev'
SQLQuery:
AI: SELECT a.display_name AS user_name FROM [account] a INNER JOIN [membership] m ON a.id = m.account_id INNER JOIN [group] g ON m.m_group_id = g.id WHERE g.group_name = 'Parisdev';
Human: Provide group name for account name root
SQLQuery:
AI: SELECT g.group_name FROM [group] g INNER JOIN [membership] m ON g.id = m.m_group_id INNER JOIN [account] a ON m.account_id = a.id WHERE a.acct_name = 'root';


In [16]:
print(few_shot_prompt.format(input="tell me group name of account root?"))

Human: List all accounts.
SQLQuery:
AI: SELECT * FROM [account];
Human: Count account ID of account name adm?
SQLQuery:
AI: SELECT COUNT(ID) FROM [account] WHERE acct_name ='bin';
Human: List all group having platform name 'Linux Server'
SQLQuery:
AI: SELECT g.* FROM [group] g JOIN [membership] m ON g.id = m.m_group_id JOIN [account] a ON a.id = m.account_id WHERE a.platform_name='Linux Server';
Human: Consider display name as user name and provide all the user name where group name is Parisdev'
SQLQuery:
AI: SELECT a.display_name AS user_name FROM [account] a INNER JOIN [membership] m ON a.id = m.account_id INNER JOIN [group] g ON m.m_group_id = g.id WHERE g.group_name = 'Parisdev';
Human: Provide group name for account name root
SQLQuery:
AI: SELECT g.group_name FROM [group] g INNER JOIN [membership] m ON g.id = m.m_group_id INNER JOIN [account] a ON m.account_id = a.id WHERE a.acct_name = 'root';


###Dynamic few-shot example selection

In [17]:
load_dotenv

<function dotenv.main.load_dotenv(dotenv_path: Union[str, ForwardRef('os.PathLike[str]'), NoneType] = None, stream: Optional[IO[str]] = None, verbose: bool = False, override: bool = False, interpolate: bool = True, encoding: Optional[str] = 'utf-8') -> bool>

In [18]:
from langchain_community.vectorstores import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain.embeddings.azure_openai import AzureOpenAIEmbeddings


In [17]:
pip install langchain_elasticsearch

Note: you may need to restart the kernel to use updated packages.


In [22]:
#OPENAI_EMBEDDING_DEPLOYMENT_NAME =  "text-embedding-ada-002-cybercti"
#embeddings = AzureOpenAIEmbeddings(deployment=OPENAI_EMBEDDING_DEPLOYMENT_NAME, )

In [34]:
! pip install sentence-transformers

Collecting sentence-transformers
  Using cached sentence_transformers-2.7.0-py3-none-any.whl.metadata (11 kB)
Collecting transformers<5.0.0,>=4.34.0 (from sentence-transformers)
  Using cached transformers-4.40.2-py3-none-any.whl.metadata (137 kB)
Collecting scikit-learn (from sentence-transformers)
  Using cached scikit_learn-1.4.2-cp312-cp312-win_amd64.whl.metadata (11 kB)
Collecting tokenizers<0.20,>=0.19 (from transformers<5.0.0,>=4.34.0->sentence-transformers)
  Using cached tokenizers-0.19.1-cp312-none-win_amd64.whl.metadata (6.9 kB)
Collecting safetensors>=0.4.1 (from transformers<5.0.0,>=4.34.0->sentence-transformers)
  Using cached safetensors-0.4.3-cp312-none-win_amd64.whl.metadata (3.9 kB)
Collecting threadpoolctl>=2.0.0 (from scikit-learn->sentence-transformers)
  Using cached threadpoolctl-3.5.0-py3-none-any.whl.metadata (13 kB)
Using cached sentence_transformers-2.7.0-py3-none-any.whl (171 kB)
Using cached transformers-4.40.2-py3-none-any.whl (9.0 MB)
Using cached scikit_

  You can safely remove it manually.


In [19]:
from langchain.prompts import SemanticSimilarityExampleSelector
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma


embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')

to_vectorize = [" ".join(example.values()) for example in examples]

  from .autonotebook import tqdm as notebook_tqdm


In [20]:
to_vectorize

['List all accounts. SELECT * FROM [account];',
 "Count account ID of account name adm? SELECT COUNT(ID) FROM [account] WHERE acct_name ='bin';",
 "List all group having platform name 'Linux Server' SELECT g.* FROM [group] g JOIN [membership] m ON g.id = m.m_group_id JOIN [account] a ON a.id = m.account_id WHERE a.platform_name='Linux Server';",
 "Consider display name as user name and provide all the user name where group name is Parisdev' SELECT a.display_name AS user_name FROM [account] a INNER JOIN [membership] m ON a.id = m.account_id INNER JOIN [group] g ON m.m_group_id = g.id WHERE g.group_name = 'Parisdev';",
 "Provide group name for account name root SELECT g.group_name FROM [group] g INNER JOIN [membership] m ON g.id = m.m_group_id INNER JOIN [account] a ON m.account_id = a.id WHERE a.acct_name = 'root';"]

In [21]:
vectorstore = Chroma.from_texts(to_vectorize, embeddings, metadatas=examples)

In [22]:
example_selector = SemanticSimilarityExampleSelector(
    vectorstore=vectorstore,
    k=2,
)

example_selector.select_examples({"Question": "How many accounts are there?"})

[{'input': 'List all accounts.', 'query': 'SELECT * FROM [account];'},
 {'input': 'Count account ID of account name adm?',
  'query': "SELECT COUNT(ID) FROM [account] WHERE acct_name ='bin';"}]

In [23]:
### my sql based instruction prompt
mssql_prompt = """You are a MsSQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: Query to run with no pre-amble
SQLResult: Result of the SQLQuery
Answer: Final answer here

No pre-amble.
"""

In [24]:
from langchain.prompts import FewShotPromptTemplate
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX, _mssql_prompt

print(PROMPT_SUFFIX)

Only use the following tables:
{table_info}

Question: {input}


In [25]:
from langchain.prompts.prompt import PromptTemplate

example_prompt = PromptTemplate(
    input_variables=["Question"],
    template="\nQuestion: {input}",
)

In [33]:
from langchain.prompts.prompt import PromptTemplate

#example_prompt = PromptTemplate(
    #input_variables=["Question", "SQLQuery", "SQLResult","Answer",],
    #template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}\nAnswer: {Answer}",
#)

In [26]:
print(_mssql_prompt)

You are an MS SQL expert. Given an input question, first create a syntactically correct MS SQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the TOP clause as per MS SQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in square brackets ([]) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CAST(GETDATE() as date) function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to r

In [28]:
#vectorstore = Chroma()
#vectorstore.delete_collection()
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    embeddings,
    vectorstore,
    k=1,
    input_keys=["input"],
)
example_selector.select_examples({"input": "provide all the user name where group name is root'?"})

[{'input': 'Provide group name for account name root',
  'query': "SELECT g.group_name FROM [group] g INNER JOIN [membership] m ON g.id = m.m_group_id INNER JOIN [account] a ON m.account_id = a.id WHERE a.acct_name = 'root';"}]

In [29]:
from langchain.prompts.prompt import PromptTemplate

example_prompt = PromptTemplate(
    input_variables=["Question"],
    template="\nQuestion: {input}",
)

In [31]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate

example_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{input}\nSQLQuery:"),
        ("ai", "{query}"),
    ]
)
few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    example_selector=example_selector,    
    input_variables=["input","top_k"],
)
print(few_shot_prompt.format(input="tell me group name of account root?"))

Human: Provide group name for account name root
SQLQuery:
AI: SELECT g.group_name FROM [group] g INNER JOIN [membership] m ON g.id = m.m_group_id INNER JOIN [account] a ON m.account_id = a.id WHERE a.acct_name = 'root';


In [32]:
few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    example_selector=example_selector,
    input_variables=["input","top_k"],
)
print(few_shot_prompt.format(input="tell me group name of account root?"))

Human: Provide group name for account name root
SQLQuery:
AI: SELECT g.group_name FROM [group] g INNER JOIN [membership] m ON g.id = m.m_group_id INNER JOIN [account] a ON m.account_id = a.id WHERE a.acct_name = 'root';


Customizing prompts

In [33]:
final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a SQL server expert, who can execute and query SQL database to find answers based on user's question about tables available in the database and take only 3 tables for answering i.e account, membership and group table and try to join them together in order to get the result."\
         "Sometimes answers you can get from one table in that case you don't need to combine all 3 tables . First understand if requirements can be satisfying by using one table only if needed join all 3 tables [dbo].[account], [dbo].[membership] and [dbo].[group] if not getting answer from one table but do not go beyond 3 tables in the database." \
          "You should consider m_group_id column from [dbo].[membership] when joining with the [dbo].[group] table on id column and consider account_id when joining with [dbo].[account]"\
          "Given an input question, create a syntactically correct  query to run, then look at the results of the query and return the answer."\
          "Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 10 results."\
          "You can order the results by a relevant column to return the most interesting examples in the database."\
          "Never query for all the columns from a specific table, only ask for the relevant columns given the question."\
          "You have access to tools for interacting with the database."\
          "Only use the given tools. Only use the information returned by the tools to construct your final answer."\
          "You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again."\
          "DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database."\
          "If the question does not seem related to the database, just return 'I don't know' as the answer."\
          "Here are some examples of user inputs and their corresponding SQL queries" \
          "After joining all the tables if not getting any relevant answer pls return 'No data available' as the answer. Don't show any list or list of user names in this case"\
          "You don't want to see the entire data present in the table in the action input but wants to see final result.Also limit the number of rows to top 10 only when generating SQL queries .consider display name to user name in the account table."),
        ("system", "Example 1: If a user asks 'What are the account details for user XYZ?', you should generate a SQL query like 'SELECT Top 10 acct_name FROM account WHERE display_name = 'XYZ'"),
        ("system", "Example 2: If a user asks 'What groups is user XYZ a member of?', you should generate a SQL query like 'SELECT Top 10 [group].name FROM [group] JOIN membership ON [group].id = membership.group_id JOIN account ON membership.account_id = account.id WHERE account.display_name  = 'XYZ' LIMIT 10'.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries."),
        few_shot_prompt,
        ("human", "{input}"),
    ]
)
print(final_prompt.format(input="What is the user name for account name root ?",table_info="account, group,membership"))

System: You are a SQL server expert, who can execute and query SQL database to find answers based on user's question about tables available in the database and take only 3 tables for answering i.e account, membership and group table and try to join them together in order to get the result.Sometimes answers you can get from one table in that case you don't need to combine all 3 tables . First understand if requirements can be satisfying by using one table only if needed join all 3 tables [dbo].[account], [dbo].[membership] and [dbo].[group] if not getting answer from one table but do not go beyond 3 tables in the database.You should consider m_group_id column from [dbo].[membership] when joining with the [dbo].[group] table on id column and consider account_id when joining with [dbo].[account]Given an input question, create a syntactically correct  query to run, then look at the results of the query and return the answer.Unless the user specifies a specific number of examples they wish 

In [34]:
generate_query = create_sql_query_chain(llm, db, final_prompt)

def debug_print_query(context):
    print("SQL Query:", context["query"])
    return context

chain = (
    RunnablePassthrough.assign(query=generate_query)
    | debug_print_query
)

chain = chain.assign(result=itemgetter("query") | execute_query) | rephrase_answer

chain.invoke({"question": "Provide group name for account name root"})

SQL Query: SELECT TOP 10 g.group_name 
FROM dbo.[group] g 
JOIN dbo.[membership] m ON g.id = m.m_group_id 
JOIN dbo.[account] a ON m.account_id = a.id 
WHERE a.acct_name = 'root' 
ORDER BY g.group_name;


"The group names for the account name root are 'abrt', 'adm', 'adm', 'avdefs', 'bin', 'bin', 'bj692vf', 'cgred', 'chrony', and 'CN=Account Operators,CN=Builtin,DC=ADVCRD,DC=sbp,DC=local'."

In [26]:
#generate_query = create_sql_query_chain(llm, db,final_prompt)
#def debug_print_query(context):
    #print("SQL Query:", context["query"])
    #return context


#chain = (
#RunnablePassthrough.assign(query=generate_query).assign(
    #result=itemgetter("query") | execute_query
#)
#| rephrase_answer
#)
#chain.invoke({"question": "provide name of user where account name is bin"})

###Dynamic relevant table selection

In [48]:
from operator import itemgetter
from langchain.chains.openai_tools import create_extraction_chain_pydantic
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List
import pandas as pd

def get_table_details():
    # Read the CSV file into a DataFrame
    table_description = pd.read_csv("database_table_descriptions.csv")
    table_docs = []

    # Iterate over the DataFrame rows to create Document objects
    table_details = ""
    for index, row in table_description.iterrows():
        table_details = table_details + "Table Name:" + row['Table'] + "\n" + "Table Description:" + row['Description'] + "\n\n"

    return table_details


class Table(BaseModel):
    """Table in SQL database."""

    name: str = Field(description="Name of table in SQL database.")

# table_names = "\n".join(db.get_usable_table_names())
table_details = get_table_details()
print(table_details)

FileNotFoundError: [Errno 2] No such file or directory: 'database_table_descriptions.csv'

In [28]:
table_details_prompt = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
The tables are:

{table_details}

Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""

table_chain = create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt)
tables = table_chain.invoke({"input": "Provide platform name for account name root"})
tables

[]

In [29]:
def get_tables(tables: List[Table]) -> List[str]:
    tables  = [table.name for table in tables]
    return tables

select_table = {"input": itemgetter("question")} | create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt) | get_tables
select_table.invoke({"question": "Tell me the count of account name root"})

['account']

In [30]:
def debug_print_query(context):
    print("SQL Query:", context["query"])
    return context

chain = (
    RunnablePassthrough.assign(table_names_to_use=select_table)
    | RunnablePassthrough.assign(query=generate_query)
    | debug_print_query
)

chain = chain.assign(result=itemgetter("query") | execute_query) | rephrase_answer

chain.invoke({"question": "provide user name for account name root"})

SQL Query: SELECT Top 10 display_name FROM [dbo].[account] WHERE acct_name = 'root'


'The user names for the account name root are Vincent, Steve Gonzales, Jennifer Lee, Jesse Boyd, and Kimberly Glass.'

In [31]:
#chain = (
#RunnablePassthrough.assign(table_names_to_use=select_table) |
#RunnablePassthrough.assign(query=generate_query).assign(
    #result=itemgetter("query") | execute_query
#)
#| rephrase_answer
#)
#chain.invoke({"question": "provide user name for account root"})

In [32]:
chain.invoke({"question": "Provide entitlement name for Linux server platform from membership table"})

SQL Query: SELECT TOP 10 entitlement_name FROM [dbo].[membership] WHERE platform_name = 'Linux Server';


"The entitlement name for Linux server platform from the membership table is 'XYZ'."

###Adding memory to the chatbot so that it answers follow-up questions related to the database.






In [33]:
final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a SQL expert. Given an input question, create a syntactically correct SQL query to run. Unless otherwise specificed.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries. Those examples are just for referecne and hsould be considered while answering follow up questions"),
        few_shot_prompt,
        MessagesPlaceholder(variable_name="messages"),
        ("human", "{input}"),
    ]
)
print(final_prompt.format(input="How many accounts are there?",table_info="account",messages=[]))

System: You are a SQL expert. Given an input question, create a syntactically correct SQL query to run. Unless otherwise specificed.

Here is the relevant table info: account

Below are a number of examples of questions and their corresponding SQL queries. Those examples are just for referecne and hsould be considered while answering follow up questions
Human: List all accounts.
SQLQuery:
AI: SELECT * FROM [account];
Human: How many accounts are there?


In [34]:
from langchain.memory import ChatMessageHistory
history = ChatMessageHistory()

generate_query = create_sql_query_chain(llm, db, final_prompt)

def debug_print_query(context):
    print("SQL Query:", context["query"])
    return context

chain = (
    RunnablePassthrough.assign(table_names_to_use=select_table)
    | RunnablePassthrough.assign(query=generate_query)
    | debug_print_query
)

chain = chain.assign(result=itemgetter("query") | execute_query) | rephrase_answer


In [35]:
#from langchain.memory import ChatMessageHistory
#history = ChatMessageHistory()

#generate_query = create_sql_query_chain(llm, db,final_prompt)

#chain = (
#RunnablePassthrough.assign(table_names_to_use=select_table) |
#RunnablePassthrough.assign(query=generate_query).assign(
    #result=itemgetter("query") | execute_query
#)
#| rephrase_answer
#)


In [36]:
question = "How many accounts are there?"
response = chain.invoke({"question": question,"messages":history.messages})
response

SQL Query: SELECT COUNT(*) FROM [account];


'There are 223 accounts.'

In [37]:
history.add_user_message(question)
history.add_ai_message(response)


In [38]:
history.messages

[HumanMessage(content='How many accounts are there?'),
 AIMessage(content='There are 223 accounts.')]

In [39]:
response = chain.invoke({"question": "Provide group name for account name root?","messages":history.messages})
response

SQL Query: SELECT g.group_name FROM [group] g INNER JOIN [membership] m ON g.id = m.m_group_id INNER JOIN [account] a ON m.account_id = a.id WHERE a.acct_name = 'root';


"The account name 'root' is associated with the following group names: 'CN=Account Operators,CN=Builtin,DC=ADVCRD,DC=sbp,DC=local', 'CN=Administrators,CN=Builtin,DC=ADVCRD,DC=sbp,DC=local', 'CN=Cert Publishers,CN=Users,DC=ADVCRD,DC=sbp,DC=local', 'CN=Certificate Service DCOM Access,CN=Builtin,DC=ADVCRD,DC=sbp,DC=local', 'CN=Cyberark-Test-Group,OU=Test Objects,DC=ADVCRD,DC=sbp,DC=local', 'CN=Denied RODC Password Replication Group,CN=Users,DC=ADVCRD,DC=sbp,DC=local', 'CN=DnsAdmins,CN=Users,DC=ADVCRD,DC=sbp,DC=local', 'CN=Domain Admins,CN=Users,DC=ADVCRD,DC=sbp,DC=local', 'CN=Enterprise Admins,CN=Users,DC=ADVCRD,DC=sbp,DC=local', 'CN=Group Policy Creator Owners,CN=Users,DC=ADVCRD,DC=sbp,DC=local', 'CN=Guests,CN=Builtin,DC=ADVCRD,DC=sbp,DC=local', 'CN=IT-CTPSQLDBA-Team,CN=Users,DC=ADVCRD,DC=sbp,DC=local', 'CN=PAM-COE-BeyondTrust-Server-Admin-GPO,OU=BeyondTrust,OU=Groups,DC=ADVCRD,DC=sbp,DC=local', 'CN=PAM-COE-CyberArk-Admin,OU=CyberArk,OU=Groups,DC=ADVCRD,DC=sbp,DC=local', 'CN=PAM-COE-Cybe