In [1]:
import sqlite3
import pandas as pd
from operator import itemgetter

from langchain_ollama import ChatOllama
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
from langchain.chains import create_sql_query_chain
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

In [None]:
### local db create

data = pd.read_csv('employee_attrition_data.csv')
conn = sqlite3.connect("employee_records.db")

# Write the DataFrames to the database as tables
data.to_sql("employee_master", conn, if_exists="replace", index=False)

conn.close()

In [3]:

llm = ChatOllama(model='llama3:instruct', temperature=0.0)
db=SQLDatabase.from_uri("sqlite:///employee_records.db")


table_info = db.get_table_info()

In [4]:


template_str = f"""
You are an expert SQL assistant.

Your task is to take a user's natural language question and return only a syntactically correct SQL query using ONLY the table and column names exactly as shown.

Use the schema below:
{table_info}

## Example:
Question: Find employees with above-average performance rating but below-average income.
SQL Query: SELECT * FROM employee_master WHERE "PerformanceRating" > (SELECT AVG("PerformanceRating") FROM employee_master) AND "MonthlyIncome" < (SELECT AVG("MonthlyIncome") FROM employee_master);

## Rules:
- Use only the tables/columns from the schema.
- Return ONLY the SQL query (no markdown, no extra text, no formatting).
- Do not need Here is the SQL query to answer this question:
- Do not need to say I am an AI language model.
- Only return the SQL query without any '''


Question: {{question}}
SQL Query: {{query}}
"""

answer_prompt = PromptTemplate.from_template(template_str)

execute_query = QuerySQLDatabaseTool(db=db)
write_query = create_sql_query_chain(llm, db)


chain = (
    RunnablePassthrough.assign(query=write_query)
    .assign(result=itemgetter("query") | execute_query)
    .assign(question=itemgetter("question")) 
    | answer_prompt
    | llm
    | StrOutputParser()
)


sql_query = chain.invoke({"question": "Find the attrition rate by department."})
print(sql_query)


SELECT "Department", COUNT(CASE WHEN "Attrition" = 'Yes' THEN 1 ELSE NULL END) AS AttritionCount, 
       COUNT(*) AS TotalEmployees
FROM employee_master
GROUP BY "Department"
ORDER BY AttritionCount DESC;


In [5]:
conn = sqlite3.connect("employee_records.db")

query = f"""{sql_query}"""

df = pd.read_sql_query(query, conn)
conn.close()

df.head()

Unnamed: 0,Department,AttritionCount,TotalEmployees
0,Research & Development,133,961
1,Sales,92,446
2,Human Resources,12,63
