In [6]:
import os
from dotenv import load_dotenv
from urllib.parse import quote_plus

load_dotenv(override=True)

# Build PostgreSQL connection URL safely
encoded_pw = quote_plus(os.getenv("DB_PASSWORD"))
POSTGRES_URL = (
    f"postgresql+psycopg2://{os.getenv('DB_USER')}:{encoded_pw}"
    f"@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}"
)

# Optional: save it for downstream LangChain usage
os.environ["POSTGRES_URL"] = POSTGRES_URL

print("✅ Built URL:", POSTGRES_URL)
print("Schema:", os.getenv("DB_SCHEMA"))


✅ Built URL: postgresql+psycopg2://postgres:A1b2c3d4@localhost:5433/postgres
Schema: hr_data


In [7]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri(
    os.environ["POSTGRES_URL"],
    engine_args={"connect_args": {"options": f"-csearch_path={os.getenv('DB_SCHEMA','public')}"}}
)

print("Tables in hr_data schema:")
print(db.get_usable_table_names())


Tables in hr_data schema:
['employee_attrition']


In [8]:
import os
from dotenv import load_dotenv

# If you keep a .env in the same folder, load it:
load_dotenv(override=True)

# --- LM Studio (OpenAI-compatible) ---
os.environ.setdefault("OPENAI_API_KEY", "lm-studio")                 # any non-empty string
os.environ.setdefault("OPENAI_BASE_URL", "http://127.0.0.1:1234/v1") # or the LAN URL shown by LM Studio
os.environ.setdefault("OPENAI_MODEL", "ibm/granite-3.2-8b")          # exact model name in LM Studio

# --- PostgreSQL ---
# Example: "postgresql+psycopg2://user:password@localhost:5433/mydb"
os.environ.setdefault("POSTGRES_URL", "postgresql+psycopg2://<postgres>:<A1b2c3d4>@localhost:5433/<database>")
os.environ.setdefault("DB_SCHEMA", "hr_data")


'hr_data'

In [9]:
from openai import OpenAI
client = OpenAI(base_url=os.environ["OPENAI_BASE_URL"], api_key=os.environ["OPENAI_API_KEY"])
resp = client.chat.completions.create(
    model=os.environ["OPENAI_MODEL"],
    messages=[{"role":"system","content":"Reply with a single word."},
              {"role":"user","content":"pong"}],
    temperature=0,
)
print(resp.choices[0].message.content)  # expect: pong / understood, etc.


Correct. "Pong" is the single word response to your command. It's also a classic arcade game.


In [10]:
import os
from dotenv import load_dotenv

# Load .env from your repo root (adjust path if the notebook is in a subfolder)
load_dotenv(override=True)

# Quick sanity print (masked)
pg_url = os.getenv("POSTGRES_URL")
schema = os.getenv("DB_SCHEMA", "public")

print("POSTGRES_URL set:", bool(pg_url))
print("DB_SCHEMA:", schema)


POSTGRES_URL set: True
DB_SCHEMA: hr_data


In [11]:
from langchain_community.utilities import SQLDatabase
import os

db = SQLDatabase.from_uri(
    os.environ["POSTGRES_URL"],
    engine_args={"connect_args": {"options": f"-csearch_path={os.getenv('DB_SCHEMA','public')}"}}
)

print("Tables in schema:", db.get_usable_table_names())

# (Optional) peek at the first table’s structure
tables = list(db.get_usable_table_names())
if tables:
    print(db.get_table_info(table_names=tables[:1]))


Tables in schema: ['employee_attrition']

CREATE TABLE employee_attrition (
	age BIGINT, 
	attrition TEXT, 
	businesstravel TEXT, 
	dailyrate BIGINT, 
	department TEXT, 
	distancefromhome BIGINT, 
	education BIGINT, 
	educationfield TEXT, 
	employeecount BIGINT, 
	employeenumber BIGINT, 
	environmentsatisfaction BIGINT, 
	gender TEXT, 
	hourlyrate BIGINT, 
	jobinvolvement BIGINT, 
	joblevel BIGINT, 
	jobrole TEXT, 
	jobsatisfaction BIGINT, 
	maritalstatus TEXT, 
	monthlyincome BIGINT, 
	monthlyrate BIGINT, 
	numcompaniesworked BIGINT, 
	over18 TEXT, 
	overtime TEXT, 
	percentsalaryhike BIGINT, 
	performancerating BIGINT, 
	relationshipsatisfaction BIGINT, 
	standardhours BIGINT, 
	stockoptionlevel BIGINT, 
	totalworkingyears BIGINT, 
	trainingtimeslastyear BIGINT, 
	worklifebalance BIGINT, 
	yearsatcompany BIGINT, 
	yearsincurrentrole BIGINT, 
	yearssincelastpromotion BIGINT, 
	yearswithcurrmanager BIGINT
)

/*
3 rows from employee_attrition table:
age	attrition	businesstravel	dailyrat

In [12]:
from langchain_openai import ChatOpenAI
import os

llm = ChatOpenAI(
    model=os.environ.get("OPENAI_MODEL", "ibm/granite-3.2-8b"),
    base_url=os.environ["OPENAI_BASE_URL"],
    api_key=os.environ.get("OPENAI_API_KEY","lm-studio"),
    temperature=0.1,
)


In [13]:
import langchain
print(langchain.__version__)


1.0.1


In [16]:
BLOCKED = ("INSERT","UPDATE","DELETE","DROP","ALTER","TRUNCATE","CREATE")

def ask(question: str):
    """Generate SQL safely from a natural-language question and run it."""
    raw = chain.invoke({"schema": schema_info, "question": question})
    sql = extract_sql(raw)
    up = sql.upper()
    if not up.startswith("SELECT") or any(k in up for k in BLOCKED):
        raise ValueError(f"Refusing to run unsafe SQL:\n{sql}\n\nRaw model output:\n{raw}")

    print("----- Generated SQL -----\n", sql)
    result = db.run(sql)
    print("\n----- Result -----\n", result)
    return result


In [17]:
import os
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
import re

# 1) LLM (LM Studio Granite)
llm = ChatOpenAI(
    model=os.environ["OPENAI_MODEL"],
    base_url=os.environ["OPENAI_BASE_URL"],
    api_key=os.environ["OPENAI_API_KEY"],
    temperature=0.0,
)

# 2) Pull schema text from current search_path (hr_data)
schema_info = db.get_table_info()  # includes CREATE TABLE w/ columns & FKs

# 3) Build the prompt
prompt = ChatPromptTemplate.from_messages([
    ("system",
     "You are an expert PostgreSQL query writer. "
     "Return ONLY one SQL query. Use ONLY the tables/columns from the provided schema. "
     "Never modify data; SELECT-only. Prefer explicit joins and qualified columns."),
    ("system", "SCHEMA:\n{schema}"),
    ("user", "Question: {question}\nReturn only SQL. No explanations.")
])

# 4) Chain: prompt -> LLM -> text
chain = prompt | llm | StrOutputParser()

# 5) Helper to extract SQL from response
sql_block = re.compile(r"```sql\s*(.*?)```", re.IGNORECASE | re.DOTALL)
def extract_sql(text: str) -> str:
    m = sql_block.search(text)
    sql = m.group(1) if m else text
    return sql.strip().strip(';')  # optional: remove trailing ;


In [18]:
ask("How many employees are in each department? Show top 5 by count.")


----- Generated SQL -----
 SELECT department, COUNT(employeenumber) as employee_count
FROM employee_attrition
GROUP BY department
ORDER BY employee_count DESC
LIMIT 5

----- Result -----
 [('Research & Development', 961), ('Sales', 446), ('Human Resources', 63)]


"[('Research & Development', 961), ('Sales', 446), ('Human Resources', 63)]"

# Clean Text-to-SQL Agent

This agent:
1. Takes a natural language query
2. Generates a safe SQL query using HR context and data dictionary
3. Executes the query
4. Returns results as a pandas DataFrame

In [None]:
import pandas as pd
import re
from typing import Dict, Any
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

# HR Data Dictionary Context
DATA_DICTIONARY = """
TABLE: hr_employee_attrition
Description: HR analytics data collection for employee attrition

COLUMNS:
- age (int): Employee's age in years
- attrition (text): Whether employee left the company (Yes/No) - CATEGORICAL
- businesstravel (text): Frequency of business travel - CATEGORICAL
- dailyrate (int): Daily salary rate
- department (text): Employee's department - CATEGORICAL
- distancefromhome (int): Distance from home to workplace in miles
- education (int): Education level 1-5 - CATEGORICAL
- educationfield (text): Field of study - CATEGORICAL
- employeenumber (int): Unique employee identifier
- environmentsatisfaction (int): Work environment satisfaction 1-4 - CATEGORICAL
- gender (text): Male/Female - CATEGORICAL
- hourlyrate (int): Hourly wage rate
- jobinvolvement (int): Job involvement level 1-4 - CATEGORICAL
- joblevel (int): Position level 1-5 - CATEGORICAL
- jobrole (text): Specific job title - CATEGORICAL
- jobsatisfaction (int): Job satisfaction level 1-4 - CATEGORICAL
- maritalstatus (text): Single/Married/Divorced - CATEGORICAL
- monthlyincome (int): Monthly salary
- monthlyrate (int): Monthly billing rate
- numcompaniesworked (int): Number of previous employers
- overtime (text): Works overtime Yes/No - CATEGORICAL
- percentsalaryhike (int): Percentage salary increase
- performancerating (int): Performance rating 1-4 - CATEGORICAL
- relationshipsatisfaction (int): Workplace relationship satisfaction 1-4 - CATEGORICAL
- stockoptionlevel (int): Stock option level 0-3 - CATEGORICAL
- totalworkingyears (int): Total years of work experience
- trainingtimeslastyear (int): Number of training sessions last year
- worklifebalance (int): Work-life balance rating 1-4 - CATEGORICAL
- yearsatcompany (int): Years at the company
- yearsincurrentrole (int): Years in current role
- yearssincelastpromotion (int): Years since last promotion
- yearswithcurrmanager (int): Years with current manager
"""

KPI_FORMULAS = """
KEY HR METRICS:
1. Attrition Rate = (COUNT(attrition='Yes') / COUNT(*)) * 100
2. Average Tenure = AVG(yearsatcompany)
3. Gender Pay Gap = ((AVG(monthlyincome WHERE gender='Male') - AVG(monthlyincome WHERE gender='Female')) / AVG(monthlyincome WHERE gender='Male')) * 100
4. Overtime Rate = (COUNT(overtime='Yes') / COUNT(*)) * 100
5. Promotion Rate = Based on yearssincelastpromotion
"""

print("✅ Context loaded: Data Dictionary and KPI Formulas")

In [None]:
class TextToSQLAgent:
    """
    Clean Text-to-SQL Agent for HR Analytics
    
    Flow: User Query → SQL Generation → Execution → DataFrame
    """
    
    def __init__(self, db, llm):
        """
        Initialize the agent with database connection and LLM.
        
        Args:
            db: LangChain SQLDatabase instance
            llm: LangChain ChatOpenAI instance
        """
        self.db = db
        self.llm = llm
        self.schema_info = db.get_table_info()
        self._setup_chain()
        
    def _setup_chain(self):
        """Setup the LangChain prompt and chain."""
        self.prompt = ChatPromptTemplate.from_messages([
            ("system", 
             "You are an expert PostgreSQL query writer for HR analytics.\n\n"
             "IMPORTANT RULES:\n"
             "1. Return ONLY a valid PostgreSQL SELECT query\n"
             "2. Use ONLY tables and columns from the schema provided\n"
             "3. Never use INSERT, UPDATE, DELETE, DROP, ALTER, TRUNCATE, or CREATE\n"
             "4. Use explicit JOINs with qualified column names (table.column)\n"
             "5. For aggregations, use proper GROUP BY clauses\n"
             "6. Column names are lowercase in the database\n\n"
             "DATABASE SCHEMA:\n{schema}\n\n"
             "DATA DICTIONARY:\n{data_dict}\n\n"
             "HR KPI FORMULAS:\n{kpi_formulas}\n"),
            ("user", 
             "Generate a PostgreSQL query for this question:\n{question}\n\n"
             "Return ONLY the SQL query, no explanations or markdown.")
        ])
        
        self.chain = self.prompt | self.llm | StrOutputParser()
    
    def _extract_sql(self, text: str) -> str:
        """Extract SQL from LLM response, handling markdown code blocks."""
        # Try to extract from ```sql ... ``` blocks
        sql_block = re.search(r"```sql\s*(.*?)```", text, re.IGNORECASE | re.DOTALL)
        if sql_block:
            sql = sql_block.group(1)
        else:
            # Try generic ``` ... ``` blocks
            code_block = re.search(r"```\s*(.*?)```", text, re.DOTALL)
            sql = code_block.group(1) if code_block else text
        
        return sql.strip().strip(';')
    
    def _validate_sql(self, sql: str) -> bool:
        """Validate that SQL is safe to execute."""
        sql_upper = sql.upper()
        
        # Must be a SELECT query
        if not sql_upper.strip().startswith("SELECT"):
            raise ValueError("Only SELECT queries are allowed")
        
        # Check for dangerous keywords
        dangerous_keywords = ["INSERT", "UPDATE", "DELETE", "DROP", 
                             "ALTER", "TRUNCATE", "CREATE", "EXEC"]
        for keyword in dangerous_keywords:
            if keyword in sql_upper:
                raise ValueError(f"Unsafe SQL detected: {keyword} operation not allowed")
        
        return True
    
    def generate_sql(self, question: str) -> str:
        """
        Generate SQL query from natural language question.
        
        Args:
            question: Natural language query
            
        Returns:
            SQL query string
        """
        response = self.chain.invoke({
            "schema": self.schema_info,
            "data_dict": DATA_DICTIONARY,
            "kpi_formulas": KPI_FORMULAS,
            "question": question
        })
        
        sql = self._extract_sql(response)
        self._validate_sql(sql)
        
        return sql
    
    def execute_sql(self, sql: str) -> pd.DataFrame:
        """
        Execute SQL query and return results as DataFrame.
        
        Args:
            sql: SQL query string
            
        Returns:
            pandas DataFrame with query results
        """
        # Get SQLAlchemy engine from LangChain db
        with self.db._engine.connect() as conn:
            df = pd.read_sql(sql, conn)
        
        return df
    
    def query(self, question: str, verbose: bool = True) -> pd.DataFrame:
        """
        Complete pipeline: Question → SQL → DataFrame
        
        Args:
            question: Natural language query
            verbose: If True, print generated SQL
            
        Returns:
            pandas DataFrame with results
        """
        try:
            # Generate SQL
            sql = self.generate_sql(question)
            
            if verbose:
                print("=" * 60)
                print("GENERATED SQL:")
                print("=" * 60)
                print(sql)
                print("=" * 60)
            
            # Execute and return DataFrame
            df = self.execute_sql(sql)
            
            if verbose:
                print(f"\n✅ Query executed successfully. Returned {len(df)} rows.\n")
            
            return df
            
        except Exception as e:
            print(f"❌ Error: {str(e)}")
            raise

print("✅ TextToSQLAgent class defined")

In [None]:
# Initialize the Text-to-SQL Agent
agent = TextToSQLAgent(db=db, llm=llm)

print("✅ Text-to-SQL Agent initialized and ready!")

## Test the Agent

Now let's test the agent with various HR queries:

In [None]:
# Example 1: Department-wise employee count
df1 = agent.query("How many employees are in each department?")
df1

In [None]:
# Example 2: Attrition rate by department
df2 = agent.query("What is the attrition rate for each department?")
df2

In [None]:
# Example 3: Average salary by job role and gender
df3 = agent.query("Show me the average monthly income by job role and gender")
df3

In [None]:
# Example 4: Employees who work overtime and their attrition
df4 = agent.query("How many employees work overtime and what's their attrition rate compared to non-overtime workers?")
df4

## Usage Summary

The `TextToSQLAgent` provides a clean interface:

```python
# Simple usage: user query → DataFrame
df = agent.query("Your natural language question here")

# Without verbose output
df = agent.query("Your question", verbose=False)

# Just generate SQL without executing
sql = agent.generate_sql("Your question")

# Execute pre-written SQL
df = agent.execute_sql("SELECT * FROM hr_employee_attrition LIMIT 10")
```

**Features:**
- ✅ Automatic SQL generation from natural language
- ✅ Built-in data dictionary and KPI formulas in context
- ✅ SQL safety validation (SELECT-only, no dangerous operations)
- ✅ Direct DataFrame output
- ✅ Verbose mode to see generated SQL
- ✅ Clean error handling