### Dynamic relevant table selection (Large database use case)

1. MySql database
2. Azure OpenAI LLM (GPT-4-32k)
3. Langchain
4. MsSql database


In [1]:
#!pip install langchain_openai langchain_community langchain langchain-experimental pymysql chromadb -q
#pip install -U langsmith
#%pip install python-dotenv


**Step 1: Import required libraries**


In [12]:
import os
import pandas as pd
from typing import List
from langchain_openai import AzureChatOpenAI
from langchain.chains import create_sql_query_chain
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnablePassthrough
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain.memory import ChatMessageHistory

**Step 2: Load config values**


In [2]:
db_user = os.getenv("MySql_db_user")
db_password = os.getenv("MySql_db_password")
db_host = os.getenv("MySql_db_host")
db_name = os.getenv("MySql_db_name")
db_port = os.getenv("MySql_db_port")

LANGCHAIN_TRACING_V2 = "true"
LANGCHAIN_API_KEY = os.environ["LANGCHAIN_API_KEY"]

openai_api_base=os.getenv("AZURE_OPENAI_ENDPOINT") 
openai_api_version=os.getenv("AZURE_OPENAI_API_VERSION") 
azure_deployment="gpt-4-32k"
openai_api_key = os.getenv("AZURE_OPENAI_KEY") 
openai_api_type="azure"

search_endpoint = os.getenv("AZURE_AI_SEARCH_ENDPOINT")
search_key = os.getenv("AZURE_AI_SEARCH_KEY")
index = "sqlsearchindex"

**Step 3: Connecting to MySql database server**


In [3]:
# db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}",sample_rows_in_table_info=1,include_tables=['customers','orders'],custom_table_info={'customers':"customer"})
db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")

In [4]:
print(db.dialect)
print(db.get_usable_table_names())
print(db.table_info)

mysql
['customers', 'employees', 'offices', 'orderdetails', 'orders', 'payments', 'productlines', 'products']

CREATE TABLE customers (
	`customerNumber` INTEGER NOT NULL, 
	`customerName` VARCHAR(50) NOT NULL, 
	`contactLastName` VARCHAR(50) NOT NULL, 
	`contactFirstName` VARCHAR(50) NOT NULL, 
	phone VARCHAR(50) NOT NULL, 
	`addressLine1` VARCHAR(50) NOT NULL, 
	`addressLine2` VARCHAR(50), 
	city VARCHAR(50) NOT NULL, 
	state VARCHAR(50), 
	`postalCode` VARCHAR(15), 
	country VARCHAR(50) NOT NULL, 
	`salesRepEmployeeNumber` INTEGER, 
	`creditLimit` DECIMAL(10, 2), 
	PRIMARY KEY (`customerNumber`), 
	CONSTRAINT customers_ibfk_1 FOREIGN KEY(`salesRepEmployeeNumber`) REFERENCES employees (`employeeNumber`)
)ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4

/*
3 rows from customers table:
customerNumber	customerName	contactLastName	contactFirstName	phone	addressLine1	addressLine2	city	state	postalCode	country	salesRepEmployeeNumber	creditLimit
103	Atelier graphique	Schmit

**Step 4: Creating LLM instance**


In [5]:
llm = AzureChatOpenAI(
    azure_endpoint=openai_api_base,
    openai_api_version=openai_api_version,
    azure_deployment=azure_deployment,
    openai_api_key=openai_api_key,
    openai_api_type=openai_api_type,
    temperature=0,
)

**Step 5: Initializing generate and execute query**


In [6]:
generate_query = create_sql_query_chain(llm, db)
execute_query = QuerySQLDataBaseTool(db=db)

**Step 6: Prompt template to generate a prompt to rephrase the sql result to natural language**


In [7]:
answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

rephrase_answer = answer_prompt | llm | StrOutputParser()

**Step 7: Dynamic relevant table selection**


In [16]:
def get_table_details():
    # Read the CSV file into a DataFrame
    table_description = pd.read_csv("database_table_descriptions.csv")
    # Iterate over the DataFrame rows to create Document objects
    table_details = ""
    for index, row in table_description.iterrows():
        table_details = table_details + "Table Name:" + row['Table'] + "\n" + "Table Description:" + row['Description'] + "\n\n"
    return table_details

class Table(BaseModel):
    """Table in SQL database."""
    name: str = Field(description="Name of table in SQL database.")

class Tables(BaseModel):
    """Tables in SQL database."""
    # Creates a model so that we can extract multiple entities.
    name: List[Table]
    
table_details = get_table_details()
print(table_details)

Table Name:productlines
Table Description:Stores information about the different product lines offered by the company, including a unique name, textual description, HTML description, and image. Categorizes products into different lines.

Table Name:products
Table Description:Contains details of each product sold by the company, including code, name, product line, scale, vendor, description, stock quantity, buy price, and MSRP. Linked to the productlines table.

Table Name:offices
Table Description:Holds data on the company's sales offices, including office code, city, phone number, address, state, country, postal code, and territory. Each office is uniquely identified by its office code.

Table Name:employees
Table Description:Stores information about employees, including number, last name, first name, job title, contact info, and office code. Links to offices and maps organizational structure through the reportsTo attribute.

Table Name:customers
Table Description:Captures data on cus

In [17]:
table_details_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Return the names of ALL the table names that MIGHT be relevant to the user question. Remember to include ALL POTENTIALLY RELEVANT table names, even if you're not sure that they're needed.The tables are:  {table_details}",
        ),
        MessagesPlaceholder(variable_name="messages"),
        ("human", "{input}"),
    ]
)

In [18]:
def extract_table_names(tables: Tables) -> List[str]:
    return [table.name for table in tables.name]
#extraction approach based on OpenAI Functions
select_table = table_details_prompt | llm.with_structured_output(schema=Tables) | extract_table_names
question = "give me details of customer and their order count."
select_table.invoke({"input": question, "table_details": table_details,"messages":[]})
#e.g.,['customers', 'orders']

['customers', 'orders']

In [19]:
chain = (
RunnablePassthrough.assign(table_names_to_use=select_table) |
RunnablePassthrough.assign(query=generate_query).assign(
    result=itemgetter("query") | execute_query
)
| rephrase_answer
)
question = "How many customers with order count more than 5"
chain.invoke({"question": question,"input": question,"table_details": table_details,"messages":[]})

'2'

**End of dynamic relevant table selection demo**


In [13]:
#chain.invoke({"question": "Can you list their names?"})
question = "Can you list their names?"
chain.invoke({"question": question, "input": question,"table_details": table_details})

'The names are Diane Murphy, Mary Patterson, Jeff Firrelli, William Patterson, and Gerard Bondur.'

**Adding few-shot prompt**


In [15]:
examples = [
    {
        "input": "List all customers in France with a credit limit over 20,000.",
        "query": "SELECT * FROM customers WHERE country = 'France' AND creditLimit > 20000;"
    },
    {
        "input": "Get the highest payment amount made by any customer.",
        "query": "SELECT MAX(amount) FROM payments;"
    },
    {
        "input": "Show product details for products in the 'Motorcycles' product line.",
        "query": "SELECT * FROM products WHERE productLine = 'Motorcycles';"
    },
    {
        "input": "Retrieve the names of employees who report to employee number 1002.",
        "query": "SELECT firstName, lastName FROM employees WHERE reportsTo = 1002;"
    },
    {
        "input": "List all products with a stock quantity less than 7000.",
        "query": "SELECT productName, quantityInStock FROM products WHERE quantityInStock < 7000;"
    },
    {
     'input':"what is price of `1968 Ford Mustang`",
     "query": "SELECT `buyPrice`, `MSRP` FROM products  WHERE `productName` = '1968 Ford Mustang' LIMIT 1;"
    }
]

In [16]:
example_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{input}\nSQLQuery:"),
        ("ai", "{query}"),
    ]
)
few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    examples=examples,
    # input_variables=["input","top_k"],
    input_variables=["input","top_k"],
)
print(few_shot_prompt.format(input1="How many products are there?"))

Human: List all customers in France with a credit limit over 20,000.
SQLQuery:
AI: SELECT * FROM customers WHERE country = 'France' AND creditLimit > 20000;
Human: Get the highest payment amount made by any customer.
SQLQuery:
AI: SELECT MAX(amount) FROM payments;
Human: Show product details for products in the 'Motorcycles' product line.
SQLQuery:
AI: SELECT * FROM products WHERE productLine = 'Motorcycles';
Human: Retrieve the names of employees who report to employee number 1002.
SQLQuery:
AI: SELECT firstName, lastName FROM employees WHERE reportsTo = 1002;
Human: List all products with a stock quantity less than 7000.
SQLQuery:
AI: SELECT productName, quantityInStock FROM products WHERE quantityInStock < 7000;
Human: what is price of `1968 Ford Mustang`
SQLQuery:
AI: SELECT `buyPrice`, `MSRP` FROM products  WHERE `productName` = '1968 Ford Mustang' LIMIT 1;


In [17]:
# This is the embedding class used to produce embeddings which are used to measure semantic similarity.
embeddings = AzureOpenAIEmbeddings(
    azure_deployment="text-embedding-ada-002",
    openai_api_version=openai_api_version,
    azure_endpoint=openai_api_base,
    api_key=openai_api_key,
)

vectorstore = Chroma()

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    embeddings,
    vectorstore,
    k=2,
    input_keys=["input"],
)
example_selector.select_examples({"input": "how many employees we have?"})
# example_selector.select_examples({"input": "How many employees?"})

[{'input': 'Retrieve the names of employees who report to employee number 1002.',
  'query': 'SELECT firstName, lastName FROM employees WHERE reportsTo = 1002;'},
 {'input': 'List all customers in France with a credit limit over 20,000.',
  'query': "SELECT * FROM customers WHERE country = 'France' AND creditLimit > 20000;"}]

**Adding memory to the chatbot so that it answers follow-up questions related to the database.**


In [40]:
final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries. Those examples are just for referecne and hsould be considered while answering follow up questions"),
        few_shot_prompt,
        MessagesPlaceholder(variable_name="messages"),
        ("human", "{input}"),
    ]
)
print(final_prompt.format(input="How many products are there?",table_info="some table info",messages=[]))

System: You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.

Here is the relevant table info: some table info

Below are a number of examples of questions and their corresponding SQL queries. Those examples are just for referecne and hsould be considered while answering follow up questions
Human: List all customers in France with a credit limit over 20,000.
SQLQuery:
AI: SELECT * FROM customers WHERE country = 'France' AND creditLimit > 20000;
Human: Get the highest payment amount made by any customer.
SQLQuery:
AI: SELECT MAX(amount) FROM payments;
Human: Show product details for products in the 'Motorcycles' product line.
SQLQuery:
AI: SELECT * FROM products WHERE productLine = 'Motorcycles';
Human: Retrieve the names of employees who report to employee number 1002.
SQLQuery:
AI: SELECT firstName, lastName FROM employees WHERE reportsTo = 1002;
Human: List all products with a stock quantity less than 7000.
S

In [22]:

history = ChatMessageHistory()

generate_query = create_sql_query_chain(llm, db,final_prompt)

chain = (
RunnablePassthrough.assign(table_names_to_use=select_table) |
RunnablePassthrough.assign(query=generate_query).assign(
    result=itemgetter("query") | execute_query
)
| rephrase_answer
)

In [23]:
question = "How many customers with order count more than 5"
response = chain.invoke({"question": question,"input": question,"table_details":table_details,"messages":history.messages})
response

'2'

In [24]:
history.add_user_message(question)
history.add_ai_message(response)
history.messages

[HumanMessage(content='How many customers with order count more than 5'),
 AIMessage(content='2')]

In [25]:
question = "Can you list there names?"
response = chain.invoke({"question": question,"input": question,"table_details":table_details,"messages":history.messages})
response

'The names are Mini Gifts Distributors Ltd. and Euro+ Shopping Channel.'