In [3]:
from langchain_community.utilities.sql_database import SQLDatabase
from dotenv import load_dotenv
import os

load_dotenv()
import os

db_user = os.getenv("db_user")
db_password = os.getenv("db_password")
db_host = os.getenv("db_host")
db_name = os.getenv("db_name")

db=SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_database}")
# db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}",sample_rows_in_table_info=1,include_tables=['customers','orders'],custom_table_info={'customers':"customer"})

In [4]:
print(db.dialect)
print(db.get_usable_table_names())
print(db.table_info)

mysql
['countries', 'departments', 'employees', 'job_history', 'jobs', 'locations', 'regions']

CREATE TABLE countries (
	country_id CHAR(2) NOT NULL, 
	country_name VARCHAR(40), 
	region_id INTEGER UNSIGNED NOT NULL, 
	PRIMARY KEY (country_id), 
	CONSTRAINT countries_ibfk_1 FOREIGN KEY(region_id) REFERENCES regions (region_id)
)DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB

/*
3 rows from countries table:
country_id	country_name	region_id
AR	Argentina	2
AU	Australia	3
BE	Belgium	1
*/


CREATE TABLE departments (
	department_id INTEGER UNSIGNED NOT NULL, 
	department_name VARCHAR(30) NOT NULL, 
	manager_id INTEGER UNSIGNED, 
	location_id INTEGER UNSIGNED, 
	PRIMARY KEY (department_id), 
	CONSTRAINT departments_ibfk_1 FOREIGN KEY(location_id) REFERENCES locations (location_id), 
	CONSTRAINT departments_ibfk_2 FOREIGN KEY(manager_id) REFERENCES employees (employee_id)
)DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB

/*
3 rows from departments table:
d

  metadata_table_names = [tbl.name for tbl in self._metadata.sorted_tables]
  for tbl in self._metadata.sorted_tables


In [5]:
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI

model=ChatOpenAI(model='gpt-3.5-turbo',temperature=0)
generate_query=create_sql_query_chain(model,db)
query1=generate_query.invoke({'question':'which employee has employee id 100?'})
print(query1)

In [6]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query=QuerySQLDataBaseTool(db=db)
execute_query.invoke(query)

In [7]:
chain=generate_query | execute_query
chain.invoke({'question':'Get the average tenure of employees in each department.'})

In [37]:
chain.get_prompts()[0].pretty_print()

You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the S

In [8]:
from operator import itemgetter
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

template=PromptTemplate.from_template(
    ''''
    "You are a MySQL expert. Given an input question, corresponding SQL query and sql result answer the user question.
    
    Question:{question}
    SQL Query: {query}
    SQL Answer: {result}
    Answer:    
    '''
)

rephrase_answer=template| model | StrOutputParser()

chain=(
    RunnablePassthrough.assign(query=generate_query)
       .assign(result=itemgetter('query') | execute_query )
      | rephrase_answer
)
chain.invoke({'question':'how many employees are there in each department.'})

## Dynamic Prompt and Examples

In [9]:
examples = [
    {
        "input": "List all customers in France with a credit limit over 20,000.",
        "query": "SELECT * FROM customers WHERE country = 'France' AND creditLimit > 20000;"
    },
    {
        "input": "Get the highest payment amount made by any customer.",
        "query": "SELECT MAX(amount) FROM payments;"
    },
    {
        "input": "Show product details for products in the 'Motorcycles' product line.",
        "query": "SELECT * FROM products WHERE productLine = 'Motorcycles';"
    },
    {
        "input": "Retrieve the names of employees who report to employee number 1002.",
        "query": "SELECT firstName, lastName FROM employees WHERE reportsTo = 1002;"
    },
    {
        "input": "List all products with a stock quantity less than 7000.",
        "query": "SELECT productName, quantityInStock FROM products WHERE quantityInStock < 7000;"
    },
    {
     'input':"what is price of `1968 Ford Mustang`",
     "query": "SELECT `buyPrice`, `MSRP` FROM products  WHERE `productName` = '1968 Ford Mustang' LIMIT 1;"
    },
    {
        "input": "Retrieve the names and salaries of employees who joined after January 1st, 2023.",
        "query": "SELECT firstName, lastName, salary FROM employees WHERE startDate > '2023-01-01';"
    },
    {
        "input": "List all departments with more than 10 employees.",
        "query": "SELECT departmentName FROM departments WHERE employeeCount > 10;"
    },
    {
        "input": "Show the average age of employees in each department.",
        "query": "SELECT departmentName, AVG(age) AS averageAge FROM employees GROUP BY departmentName;"
    },
    {
        "input": "Get the total number of employees hired in each year.",
        "query": "SELECT YEAR(startDate) AS hiringYear, COUNT(*) AS hires FROM employees GROUP BY hiringYear;"
    },
    {
        "input": "Retrieve the names of employees who have been promoted in the last 6 months.",
        "query": "SELECT firstName, lastName FROM employees WHERE promotionDate > DATE_SUB(CURRENT_DATE(), INTERVAL 6 MONTH);"
    },
    {
        "input": "List all employees who have exceeded their annual leave quota.",
        "query": "SELECT firstName, lastName FROM employees WHERE remainingLeave < 0;"
    },
    {
        "input": "Show the highest and lowest salaries in the company.",
        "query": "SELECT MAX(salary) AS highestSalary, MIN(salary) AS lowestSalary FROM employees;"
    },
    {
        "input": "Get the average tenure of employees in each department.",
        "query": "SELECT departmentName, AVG(DATEDIFF(CURRENT_DATE(), startDate)) AS averageTenure FROM employees GROUP BY departmentName;"
    },
    {
        "input": "Retrieve the names of employees whose birthdays are in the next month.",
        "query": "SELECT firstName, lastName FROM employees WHERE MONTH(birthDate) = MONTH(DATE_ADD(CURRENT_DATE(), INTERVAL 1 MONTH));"
    },
    {
        "input": "List all employees who have attended training sessions in the past quarter.",
        "query": "SELECT DISTINCT firstName, lastName FROM employees INNER JOIN trainingSessions ON employees.employeeID = trainingSessions.employeeID WHERE trainingDate BETWEEN DATE_SUB(CURRENT_DATE(), INTERVAL 3 MONTH) AND CURRENT_DATE();"
    },
    {
        "input": "Show the total number of employees by job title.",
        "query": "SELECT jobTitle, COUNT(*) AS employeeCount FROM employees GROUP BY jobTitle;"
    },
    {
        "input": "Retrieve the names and ages of employees who have children.",
        "query": "SELECT firstName, lastName, age FROM employees WHERE numberOfChildren > 0;"
    },
    {
        "input": "List all employees who are eligible for retirement (age > 65).",
        "query": "SELECT firstName, lastName FROM employees WHERE age > 65;"
    },
    {
        "input": "Show the average number of sick days taken by employees in each department.",
        "query": "SELECT departmentName, AVG(sickDaysTaken) AS avgSickDays FROM employees GROUP BY departmentName;"
    },
    {
        "input": "Get the total number of male and female employees.",
        "query": "SELECT gender, COUNT(*) AS employeeCount FROM employees GROUP BY gender;"
    },
    {
        "input": "Retrieve the names of employees who have received awards in the past year.",
        "query": "SELECT firstName, lastName FROM employees WHERE awardDate BETWEEN DATE_SUB(CURRENT_DATE(), INTERVAL 1 YEAR) AND CURRENT_DATE();"
    },
    {
        "input": "List all employees who have completed their probation period.",
        "query": "SELECT firstName, lastName FROM employees WHERE startDate < DATE_SUB(CURRENT_DATE(), INTERVAL 3 MONTH);"
    },
    {
        "input": "Show the department with the highest average employee satisfaction rating.",
        "query": "SELECT departmentName FROM employees GROUP BY departmentName ORDER BY AVG(satisfactionRating) DESC LIMIT 1;"
    },
    {
        "input": "Retrieve the names of employees who have taken unpaid leave in the last month.",
        "query": "SELECT firstName, lastName FROM employees WHERE unpaidLeaveStartDate BETWEEN DATE_SUB(CURRENT_DATE(), INTERVAL 1 MONTH) AND CURRENT_DATE();"
    },
    {
        "input": "List all employees who have completed mandatory compliance training.",
        "query": "SELECT firstName, lastName FROM employees WHERE complianceTrainingStatus = 'Completed';"
    }
] 

In [10]:
from langchain_core.prompts import FewShotChatMessagePromptTemplate,MessagesPlaceholder, ChatPromptTemplate

example_prompt=ChatPromptTemplate.from_messages(
    [
        ("human","{input}\n SQL:"),
        ("ai","{query}")

    ]
)


fewshot_template=FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    examples=examples,
    input_variables=['input']
)

print(fewshot_template.format(input='how many employees are there'))

Human: List all customers in France with a credit limit over 20,000.
 SQL:
AI: SELECT * FROM customers WHERE country = 'France' AND creditLimit > 20000;
Human: Get the highest payment amount made by any customer.
 SQL:
AI: SELECT MAX(amount) FROM payments;
Human: Show product details for products in the 'Motorcycles' product line.
 SQL:
AI: SELECT * FROM products WHERE productLine = 'Motorcycles';
Human: Retrieve the names of employees who report to employee number 1002.
 SQL:
AI: SELECT firstName, lastName FROM employees WHERE reportsTo = 1002;
Human: List all products with a stock quantity less than 7000.
 SQL:
AI: SELECT productName, quantityInStock FROM products WHERE quantityInStock < 7000;
Human: what is price of `1968 Ford Mustang`
 SQL:
AI: SELECT `buyPrice`, `MSRP` FROM products  WHERE `productName` = '1968 Ford Mustang' LIMIT 1;
Human: Retrieve the names and salaries of employees who joined after January 1st, 2023.
 SQL:
AI: SELECT firstName, lastName, salary FROM employees W

In [11]:
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_core.example_selectors import SemanticSimilarityExampleSelector


vector_store=Chroma()
embeddings=OpenAIEmbeddings()
example_selector=SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    vector_store,
    k=2,
    input_keys=['input']
)

example_selector.select_examples({'input':'who has the highest salary'})

[{'input': 'Show the highest and lowest salaries in the company.',
  'query': 'SELECT MAX(salary) AS highestSalary, MIN(salary) AS lowestSalary FROM employees;'},
 {'input': 'Get the highest payment amount made by any customer.',
  'query': 'SELECT MAX(amount) FROM payments;'}]

In [12]:
fewshot_template=FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    example_selector=example_selector,
    input_variables=['input','top_k']
)

print(fewshot_template.format(input='how many employees are there',top_k=2))

Human: Get the total number of employees hired in each year.
 SQL:
AI: SELECT YEAR(startDate) AS hiringYear, COUNT(*) AS hires FROM employees GROUP BY hiringYear;
Human: Get the total number of male and female employees.
 SQL:
AI: SELECT gender, COUNT(*) AS employeeCount FROM employees GROUP BY gender;


In [13]:
final_prompt=ChatPromptTemplate.from_messages([
    ("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries."),
        fewshot_template,
        ("human", "{input}"),
]
)

print(final_prompt.format(input='whast the max salary',top_k=2,table_info='Some Table Info'))

System: You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.

Here is the relevant table info: Some Table Info

Below are a number of examples of questions and their corresponding SQL queries.
Human: Show the highest and lowest salaries in the company.
 SQL:
AI: SELECT MAX(salary) AS highestSalary, MIN(salary) AS lowestSalary FROM employees;
Human: Get the highest payment amount made by any customer.
 SQL:
AI: SELECT MAX(amount) FROM payments;
Human: whast the max salary


In [14]:
generate_query=create_sql_query_chain(model,db,final_prompt)

chain=(
    RunnablePassthrough.assign(query=generate_query)
       .assign(result=itemgetter('query') | execute_query )
      | rephrase_answer
)
chain.invoke({'question':'what is the max salary of employees in each department.'})

# Dynamic Table selector

In [15]:
from langchain.chains.openai_tools import create_extraction_chain_pydantic
from langchain_core.pydantic_v1 import BaseModel,Field
from typing import List
from operator import itemgetter
import pandas as pd

def get_table_details():

    table_description=pd.read_csv('table_desc.csv')
    table_docs=[]
    table_details=''
    for index,row in table_description.iterrows():
        table_details=table_details+'Table_name:'+ row['Tables']+ '\n' + 'Table_description:' + row['Description']+ '\n\n'

    return table_details



class Table(BaseModel):
    '''Table in SQL database'''
    name:str= Field(description='Name of the table in sql database')

table_details = get_table_details()
print(table_details)

Table_name:employees
Table_description:Contains employee information such as employee ID, name, job, manager ID, hire date, and salary.

Table_name:departments
Table_description:Stores department information including department ID, name, and manager ID.

Table_name:locations
Table_description:Holds location details such as location ID, street address, postal code, city, state, and country.

Table_name:countries
Table_description:Contains country information such as country ID and country name.

Table_name:regions
Table_description:Stores region information including region ID and region name.

Table_name:jobs
Table_description:Holds job information such as job ID, job title, minimum salary, and maximum salary.

Table_name:job_history
Table_description:Stores historical job data including employee ID, start date, end date, job ID, and department ID.




In [17]:

table_prompt=f'''Return the names of all the relevant SQL tables that MIGHT be relevant based on the question asked by the user,\
    The Tables are:

    {table_details}

    Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed.
'''
table_chain=create_extraction_chain_pydantic(Table,model,system_message=table_prompt)
tables = table_chain.invoke({"input": "give me details of employee and managers"})
tables



[Table(name='employees'), Table(name='departments')]

In [29]:
def get_tables(tables: List[Table]) -> List[str]:
    tables  = [table.name for table in tables]
    return tables

get_tables(tables)

['employees', 'departments']

In [30]:



select_tables={'input':itemgetter('question')} | create_extraction_chain_pydantic(Table,model,system_message=table_prompt) | get_tables

select_tables.invoke({'question':'details of managers and locations'})

['employees', 'departments', 'locations']

In [31]:
chain=(
    RunnablePassthrough.assign(tables_to_use=select_tables)|
    RunnablePassthrough.assign(query=generate_query).assign(result=itemgetter('query')|execute_query)
    |rephrase_answer
)

chain.invoke({"question": "How many department with employee count more than 5"})

  metadata_table_names = [tbl.name for tbl in self._metadata.sorted_tables]
  for tbl in self._metadata.sorted_tables


'There are 4 departments with more than 5 employees.'

# Adding Memory to Chatbot

In [36]:
final_prompt=ChatPromptTemplate.from_messages([
    ("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries."),
        fewshot_template,
        MessagesPlaceholder(variable_name="messages"),
        ("human", "{input}"),
]
)

print(final_prompt.format(input='How many locations are there?',table_info='some table info',messages=[]))

System: You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.

Here is the relevant table info: some table info

Below are a number of examples of questions and their corresponding SQL queries.
Human: List all departments with more than 10 employees.
 SQL:
AI: SELECT departmentName FROM departments WHERE employeeCount > 10;
Human: Get the total number of male and female employees.
 SQL:
AI: SELECT gender, COUNT(*) AS employeeCount FROM employees GROUP BY gender;
Human: How many locations are there?


In [40]:
from langchain.memory import ChatMessageHistory
history=ChatMessageHistory()

generate_query = create_sql_query_chain(model, db,final_prompt)
chain = (
    RunnablePassthrough.assign(tables_to_use=select_tables) |
    RunnablePassthrough.assign(query=generate_query).assign(
        result=itemgetter("query") | execute_query
    )
    | rephrase_answer
)



In [42]:
response=chain.invoke({"question": "How many employees with salary more than 12000","messages":history.messages})
response

  metadata_table_names = [tbl.name for tbl in self._metadata.sorted_tables]
  for tbl in self._metadata.sorted_tables


'There are 6 employees with a salary greater than 12000.'

In [43]:
question= "How many employees with salary more than 12000"
history.add_user_message(question)
history.add_ai_message(response)
history.messages
response = chain.invoke({"question": "Can you list there names?","messages":history.messages})
response

  metadata_table_names = [tbl.name for tbl in self._metadata.sorted_tables]
  for tbl in self._metadata.sorted_tables


'The names of employees with a salary greater than $12,000 are:\n1. Steven King\n2. Neena Kochhar\n3. Lex De Haan\n4. John Russell\n5. Karen Partners\n6. Michael Hartstein'

In [45]:
chain.schema()

{'title': 'RunnableSequence',
 'description': "Sequence of Runnables, where the output of each is the input of the next.\n\nRunnableSequence is the most important composition operator in LangChain as it is\nused in virtually every chain.\n\nA RunnableSequence can be instantiated directly or more commonly by using the `|`\noperator where either the left or right operands (or both) must be a Runnable.\n\nAny RunnableSequence automatically supports sync, async, batch.\n\nThe default implementations of `batch` and `abatch` utilize threadpools and\nasyncio gather and will be faster than naive invocation of invoke or ainvoke\nfor IO bound Runnables.\n\nBatching is implemented by invoking the batch method on each component of the\nRunnableSequence in order.\n\nA RunnableSequence preserves the streaming properties of its components, so if all\ncomponents of the sequence implement a `transform` method -- which\nis the method that implements the logic to map a streaming input to a streaming\nout