In [23]:
from dotenv import load_dotenv
load_dotenv()

True

In [24]:
from langchain_community.utilities.sql_database import SQLDatabase


In [25]:
sqlite_uri = 'sqlite:///./HRDataset.db' 
db = SQLDatabase.from_uri(sqlite_uri)

In [26]:
print(db.dialect)
print(db.get_usable_table_names())
print(db.table_info)
     

sqlite
['HRData']

CREATE TABLE "HRData" (
	"Employee_Name" TEXT, 
	"EmpID" INTEGER, 
	"MarriedID" INTEGER, 
	"MaritalStatusID" INTEGER, 
	"GenderID" INTEGER, 
	"EmpStatusID" INTEGER, 
	"DeptID" INTEGER, 
	"PerfScoreID" INTEGER, 
	"FromDiversityJobFairID" INTEGER, 
	"Salary" INTEGER, 
	"Termd" INTEGER, 
	"PositionID" INTEGER, 
	"Position" TEXT, 
	"State" TEXT, 
	"Zip" INTEGER, 
	"DOB" TEXT, 
	"Sex" TEXT, 
	"MaritalDesc" TEXT, 
	"CitizenDesc" TEXT, 
	"HispanicLatino" TEXT, 
	"RaceDesc" TEXT, 
	"DateofHire" TEXT, 
	"DateofTermination" TEXT, 
	"TermReason" TEXT, 
	"EmploymentStatus" TEXT, 
	"Department" TEXT, 
	"ManagerName" TEXT, 
	"ManagerID" REAL, 
	"RecruitmentSource" TEXT, 
	"PerformanceScore" TEXT, 
	"EngagementSurvey" REAL, 
	"EmpSatisfaction" INTEGER, 
	"SpecialProjectsCount" INTEGER, 
	"LastPerformanceReview_Date" TEXT, 
	"DaysLateLast30" INTEGER, 
	"Absences" INTEGER
)

/*
3 rows from HRData table:
Employee_Name	EmpID	MarriedID	MaritalStatusID	GenderID	EmpStatusID	DeptID	PerfSco

In [27]:
table_info=db.table_info

In [28]:
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
generate_query = create_sql_query_chain(llm, db)
query = generate_query.invoke({"question": "tell me the emp id for norman bates"})
# "what is price of `1968 Ford Mustang`"
print(query)

SELECT "EmpID" 
FROM "HRData" 
WHERE "Employee_Name" = 'Bates, Norman'
LIMIT 1;


In [29]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
execute_query = QuerySQLDataBaseTool(db=db)
execute_query.invoke(query)
     

'[(10061,)]'

In [30]:
chain = generate_query | execute_query
chain.invoke({"question": "How many absences does norman bates has?"})

'[(20,)]'

In [31]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

rephrase_answer = answer_prompt | llm | StrOutputParser()

chain = (
    RunnablePassthrough.assign(query=generate_query).assign(
        result=itemgetter("query") | execute_query
    )
    | rephrase_answer
)

chain.invoke({"question": "How many absences does norman bates has?"})
     

'Norman Bates has 20 absences.'

In [32]:
examples = [
    {
        "input": "Retrieve details of all active employees.",
        "query": "SELECT * FROM UpdatedMergedEmployeeInsurance WHERE EmploymentStatus = 'Active';"
    },
    {
        "input": "How many employees have a performance score of 'Exceeds'?",
        "query": "SELECT COUNT(*) FROM UpdatedMergedEmployeeInsurance WHERE PerformanceScore = 'Exceeds';"
    },
    {
        "input": "Provide names and salaries of employees hired after 2022.",
        "query": "SELECT Employee_Name, Salary FROM UpdatedMergedEmployeeInsurance WHERE DateofHire > '2022-01-01';"
    },
    {
        "input": "What is the average engagement survey score of employees in the Engineering department?",
        "query": "SELECT AVG(EngagementSurvey) FROM UpdatedMergedEmployeeInsurance WHERE Department = 'Engineering';"
    },
    {
        "input": "How many male employees earn more than $50,000?",
        "query": "SELECT COUNT(*) FROM UpdatedMergedEmployeeInsurance WHERE GenderID = 'M' AND Salary > 50000;"
    },
]


In [33]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate

example_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{input}\nSQLQuery:"),
        ("ai", "{query}"),
    ]
)
few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    examples=examples,
    # input_variables=["input","top_k"],
    input_variables=["input"],
)
print(few_shot_prompt.format(input1="is hassan an employee?"))

Human: Retrieve details of all active employees.
SQLQuery:
AI: SELECT * FROM UpdatedMergedEmployeeInsurance WHERE EmploymentStatus = 'Active';
Human: How many employees have a performance score of 'Exceeds'?
SQLQuery:
AI: SELECT COUNT(*) FROM UpdatedMergedEmployeeInsurance WHERE PerformanceScore = 'Exceeds';
Human: Provide names and salaries of employees hired after 2022.
SQLQuery:
AI: SELECT Employee_Name, Salary FROM UpdatedMergedEmployeeInsurance WHERE DateofHire > '2022-01-01';
Human: What is the average engagement survey score of employees in the Engineering department?
SQLQuery:
AI: SELECT AVG(EngagementSurvey) FROM UpdatedMergedEmployeeInsurance WHERE Department = 'Engineering';
Human: How many male employees earn more than $50,000?
SQLQuery:
AI: SELECT COUNT(*) FROM UpdatedMergedEmployeeInsurance WHERE GenderID = 'M' AND Salary > 50000;


In [34]:
from langchain_community.vectorstores import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings

vectorstore = Chroma()
vectorstore.delete_collection()
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    vectorstore,
    k=2,
    input_keys=["input"],
)
example_selector.select_examples({"input": "how many employees we have?"})
# example_selector.select_examples({"input": "How many employees?"})

[{'input': "How many employees have a performance score of 'Exceeds'?",
  'query': "SELECT COUNT(*) FROM UpdatedMergedEmployeeInsurance WHERE PerformanceScore = 'Exceeds';"},
 {'input': 'How many male employees earn more than $50,000?',
  'query': "SELECT COUNT(*) FROM UpdatedMergedEmployeeInsurance WHERE GenderID = 'M' AND Salary > 50000;"}]

In [35]:
few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    example_selector=example_selector,
    input_variables=["input","top_k"],
)
print(few_shot_prompt.format(input="who is norman bates"))

Human: How many employees have a performance score of 'Exceeds'?
SQLQuery:
AI: SELECT COUNT(*) FROM UpdatedMergedEmployeeInsurance WHERE PerformanceScore = 'Exceeds';
Human: Provide names and salaries of employees hired after 2022.
SQLQuery:
AI: SELECT Employee_Name, Salary FROM UpdatedMergedEmployeeInsurance WHERE DateofHire > '2022-01-01';


In [36]:
final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries."),
        few_shot_prompt,
        ("human", "{input}"),
    ]
)
print(final_prompt.format(input="How many products are there?",table_info="some table info"))

System: You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.

Here is the relevant table info: some table info

Below are a number of examples of questions and their corresponding SQL queries.
Human: How many employees have a performance score of 'Exceeds'?
SQLQuery:
AI: SELECT COUNT(*) FROM UpdatedMergedEmployeeInsurance WHERE PerformanceScore = 'Exceeds';
Human: How many male employees earn more than $50,000?
SQLQuery:
AI: SELECT COUNT(*) FROM UpdatedMergedEmployeeInsurance WHERE GenderID = 'M' AND Salary > 50000;
Human: How many products are there?


In [37]:
generate_query = create_sql_query_chain(llm, db,final_prompt)
chain = (
RunnablePassthrough.assign(query=generate_query).assign(
    result=itemgetter("query") | execute_query
)
| rephrase_answer
)
chain.invoke({"question": "tell me attendence data of hassan"})

'The attendance data for Hassan is not provided in the SQL result.'

In [44]:
final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries. Those examples are just for referecne and hsould be considered while answering follow up questions"),
        few_shot_prompt,
        MessagesPlaceholder(variable_name="messages"),
        ("human", "{input}")
    ]
)
print(final_prompt.format(input="How many columns are there?",table_info="some table info" ,messages=[]))

System: You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.

Here is the relevant table info: some table info

Below are a number of examples of questions and their corresponding SQL queries. Those examples are just for referecne and hsould be considered while answering follow up questions
Human: How many male employees earn more than $50,000?
SQLQuery:
AI: SELECT COUNT(*) FROM UpdatedMergedEmployeeInsurance WHERE GenderID = 'M' AND Salary > 50000;
Human: How many employees have a performance score of 'Exceeds'?
SQLQuery:
AI: SELECT COUNT(*) FROM UpdatedMergedEmployeeInsurance WHERE PerformanceScore = 'Exceeds';
Human: How many columns are there?


In [39]:
from langchain.memory import ChatMessageHistory
history = ChatMessageHistory()

generate_query = create_sql_query_chain(llm, db, final_prompt)

chain = (
RunnablePassthrough.assign(query=generate_query).assign(
    result=itemgetter("query") | execute_query
)
| rephrase_answer
)


In [40]:
table_info=db.table_info

In [41]:
print(table_info)


CREATE TABLE "HRData" (
	"Employee_Name" TEXT, 
	"EmpID" INTEGER, 
	"MarriedID" INTEGER, 
	"MaritalStatusID" INTEGER, 
	"GenderID" INTEGER, 
	"EmpStatusID" INTEGER, 
	"DeptID" INTEGER, 
	"PerfScoreID" INTEGER, 
	"FromDiversityJobFairID" INTEGER, 
	"Salary" INTEGER, 
	"Termd" INTEGER, 
	"PositionID" INTEGER, 
	"Position" TEXT, 
	"State" TEXT, 
	"Zip" INTEGER, 
	"DOB" TEXT, 
	"Sex" TEXT, 
	"MaritalDesc" TEXT, 
	"CitizenDesc" TEXT, 
	"HispanicLatino" TEXT, 
	"RaceDesc" TEXT, 
	"DateofHire" TEXT, 
	"DateofTermination" TEXT, 
	"TermReason" TEXT, 
	"EmploymentStatus" TEXT, 
	"Department" TEXT, 
	"ManagerName" TEXT, 
	"ManagerID" REAL, 
	"RecruitmentSource" TEXT, 
	"PerformanceScore" TEXT, 
	"EngagementSurvey" REAL, 
	"EmpSatisfaction" INTEGER, 
	"SpecialProjectsCount" INTEGER, 
	"LastPerformanceReview_Date" TEXT, 
	"DaysLateLast30" INTEGER, 
	"Absences" INTEGER
)

/*
3 rows from HRData table:
Employee_Name	EmpID	MarriedID	MaritalStatusID	GenderID	EmpStatusID	DeptID	PerfScoreID	FromDiversity

In [43]:

question = "give me the attendence record for hassan?"
response = chain.invoke({"question": question,"messages": history.messages,"table_info": table_info })
response

KeyError: "Input to ChatPromptTemplate is missing variables {'messages'}.  Expected: ['input', 'messages', 'table_info', 'top_k'] Received: ['input', 'top_k', 'table_info']"

In [22]:
print(response)

NameError: name 'response' is not defined