# Text-to-SQL Agent for HR Analytics

This notebook implements a clean text-to-SQL agent that:
1. Takes natural language queries
2. Generates safe SQL queries using HR context and data dictionary
3. Executes queries against PostgreSQL
4. Returns results as pandas DataFrames

**Flow:** `User Query → SQL Generation → Execution → DataFrame`

## 1. Setup Environment & Database Connection

In [1]:
import os
from dotenv import load_dotenv
from urllib.parse import quote_plus

# Load environment variables
load_dotenv(override=True)

# Build PostgreSQL connection URL safely
encoded_pw = quote_plus(os.getenv("DB_PASSWORD"))
POSTGRES_URL = (
    f"postgresql+psycopg2://{os.getenv('DB_USER')}:{encoded_pw}"
    f"@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}"
)

# Save for downstream usage
os.environ["POSTGRES_URL"] = POSTGRES_URL

print("✅ Environment variables loaded")
print("Schema:", os.getenv("DB_SCHEMA"))

✅ Environment variables loaded
Schema: hr_data


In [2]:
from langchain_community.utilities import SQLDatabase

# Connect to PostgreSQL with hr_data schema
db = SQLDatabase.from_uri(
    os.environ["POSTGRES_URL"],
    engine_args={"connect_args": {"options": f"-csearch_path={os.getenv('DB_SCHEMA','public')}"}}
)

print("✅ Connected to database")
print("Tables available:", db.get_usable_table_names())

✅ Connected to database
Tables available: ['employee_attrition']


## 2. Configure LLM (LM Studio)

In [3]:
from langchain_openai import ChatOpenAI

# Initialize LLM pointing to LM Studio
llm = ChatOpenAI(
    model=os.environ.get("OPENAI_MODEL", "ibm/granite-3.2-8b"),
    base_url=os.environ.get("OPENAI_BASE_URL", "http://127.0.0.1:1234/v1"),
    api_key=os.environ.get("OPENAI_API_KEY", "lm-studio"),
    temperature=0.0,
)

print("✅ LLM initialized:", os.environ.get("OPENAI_MODEL", "ibm/granite-3.2-8b"))

✅ LLM initialized: ibm/granite-3.2-8b


## 3. Load HR Data Dictionary & KPI Context

In [4]:
import pandas as pd
import re
from typing import Dict, Any
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

# HR Data Dictionary Context
DATA_DICTIONARY = """
TABLE: employee_attrition
Description: HR analytics data collection for employee attrition

COLUMNS:
- age (int): Employee's age in years
- attrition (text): Whether employee left the company (Yes/No) - CATEGORICAL
- businesstravel (text): Frequency of business travel - CATEGORICAL
  VALUES: 'Travel_Rarely', 'Travel_Frequently', 'Non-Travel'
  NOTE: 'Non-Travel' means employees who don't travel for business (closest to work-from-home/office-based)
- dailyrate (int): Daily salary rate
- department (text): Employee's department - CATEGORICAL
- distancefromhome (int): Distance from home to workplace in miles
- education (int): Education level 1-5 - CATEGORICAL
- educationfield (text): Field of study - CATEGORICAL
- employeenumber (int): Unique employee identifier
- environmentsatisfaction (int): Work environment satisfaction 1-4 - CATEGORICAL
- gender (text): Male/Female - CATEGORICAL
  VALUES: 'Male', 'Female'
- hourlyrate (int): Hourly wage rate
- jobinvolvement (int): Job involvement level 1-4 - CATEGORICAL
- joblevel (int): Position level 1-5 - CATEGORICAL
- jobrole (text): Specific job title - CATEGORICAL
- jobsatisfaction (int): Job satisfaction level 1-4 - CATEGORICAL
- maritalstatus (text): Single/Married/Divorced - CATEGORICAL
- monthlyincome (int): Monthly salary
- monthlyrate (int): Monthly billing rate
- numcompaniesworked (int): Number of previous employers
- overtime (text): Works overtime Yes/No - CATEGORICAL
  VALUES: 'Yes', 'No'
- percentsalaryhike (int): Percentage salary increase
- performancerating (int): Performance rating 1-4 - CATEGORICAL
- relationshipsatisfaction (int): Workplace relationship satisfaction 1-4 - CATEGORICAL
- stockoptionlevel (int): Stock option level 0-3 - CATEGORICAL
- totalworkingyears (int): Total years of work experience
- trainingtimeslastyear (int): Number of training sessions last year
- worklifebalance (int): Work-life balance rating 1-4 - CATEGORICAL
- yearsatcompany (int): Years at the company
- yearsincurrentrole (int): Years in current role
- yearssincelastpromotion (int): Years since last promotion
- yearswithcurrmanager (int): Years with current manager
"""

KPI_FORMULAS = """
KEY HR METRICS:
1. Attrition Rate = (COUNT(attrition='Yes') / COUNT(*)) * 100
2. Average Tenure = AVG(yearsatcompany)
3. Gender Pay Gap = ((AVG(monthlyincome WHERE gender='Male') - AVG(monthlyincome WHERE gender='Female')) / AVG(monthlyincome WHERE gender='Male')) * 100
4. Overtime Rate = (COUNT(overtime='Yes') / COUNT(*)) * 100
5. Promotion Rate = Based on yearssincelastpromotion
"""

print("✅ Context loaded: Data Dictionary and KPI Formulas")

✅ Context loaded: Data Dictionary and KPI Formulas


## 4. Text-to-SQL Agent Implementation

In [5]:
class TextToSQLAgent:
    """
    Clean Text-to-SQL Agent for HR Analytics
    
    Flow: User Query → SQL Generation → Execution → DataFrame
    """
    
    def __init__(self, db, llm):
        """
        Initialize the agent with database connection and LLM.
        
        Args:
            db: LangChain SQLDatabase instance
            llm: LangChain ChatOpenAI instance
        """
        self.db = db
        self.llm = llm
        self.schema_info = db.get_table_info()
        self._setup_chain()
        
    def _setup_chain(self):
        """Setup the LangChain prompt and chain."""
        self.prompt = ChatPromptTemplate.from_messages([
            ("system", 
             "You are an EXPERT PostgreSQL query writer for HR analytics. Your ONLY job is to write correct, efficient SQL queries.\n\n"
             
             "═══════════════════════════════════════════════════════════\n"
             "CRITICAL RULES - FOLLOW EXACTLY:\n"
             "═══════════════════════════════════════════════════════════\n"
             "1. Return ONLY a valid PostgreSQL SELECT query - NO explanations, NO markdown, NO comments\n"
             "2. Use ONLY tables and columns from the schema provided below\n"
             "3. NEVER use INSERT, UPDATE, DELETE, DROP, ALTER, TRUNCATE, CREATE, or EXEC\n"
             "4. ALL column names are LOWERCASE in the database\n"
             "5. Table name is 'employee_attrition' (lowercase)\n\n"
             
             "═══════════════════════════════════════════════════════════\n"
             "PERCENTAGE/RATE CALCULATIONS - EXTREMELY IMPORTANT:\n"
             "═══════════════════════════════════════════════════════════\n"
             "PostgreSQL does INTEGER DIVISION by default. You MUST cast to numeric!\n\n"
             
             "✅ CORRECT PATTERNS:\n"
             "Pattern 1 - Using ::numeric casting:\n"
             "  ROUND((COUNT(CASE WHEN condition THEN 1 END)::numeric / COUNT(*)::numeric) * 100, 2)\n\n"
             
             "Pattern 2 - Using CAST:\n"
             "  ROUND((CAST(COUNT(CASE WHEN condition THEN 1 END) AS numeric) / CAST(COUNT(*) AS numeric)) * 100, 2)\n\n"
             
             "Pattern 3 - Using 100.0 to force decimal:\n"
             "  ROUND((COUNT(CASE WHEN condition THEN 1 END) * 100.0 / COUNT(*)), 2)\n\n"
             
             "❌ WRONG (Will return 0 due to integer division):\n"
             "  (COUNT(CASE WHEN condition THEN 1 END) / COUNT(*)) * 100\n\n"
             
             "═══════════════════════════════════════════════════════════\n"
             "FILTERING FOR SPECIFIC GROUPS - CRITICAL:\n"
             "═══════════════════════════════════════════════════════════\n"
             
             "When asked for a rate/percentage for a SPECIFIC GROUP:\n"
             "✅ CORRECT - Use WHERE clause to filter the group FIRST:\n"
             "  Example: \"What is the male attrition rate?\"\n"
             "  SELECT \n"
             "    ROUND((COUNT(CASE WHEN attrition='Yes' THEN 1 END)::numeric / COUNT(*)::numeric) * 100, 2) as male_attrition_rate\n"
             "  FROM employee_attrition\n"
             "  WHERE gender = 'Male'\n\n"
             
             "  Example: \"What is the attrition rate for employees who work overtime?\"\n"
             "  SELECT \n"
             "    ROUND((COUNT(CASE WHEN attrition='Yes' THEN 1 END)::numeric / COUNT(*)::numeric) * 100, 2) as overtime_attrition_rate\n"
             "  FROM employee_attrition\n"
             "  WHERE overtime = 'Yes'\n\n"
             
             "❌ WRONG - Do NOT put the filter condition in CASE WHEN:\n"
             "  SELECT (COUNT(CASE WHEN attrition='Yes' AND gender='Male' THEN 1 END) / COUNT(*)) * 100\n"
             "  FROM employee_attrition\n"
             "  -- This divides by ALL employees, not just males!\n\n"
             
             "═══════════════════════════════════════════════════════════\n"
             "COMPARISON QUERIES:\n"
             "═══════════════════════════════════════════════════════════\n"
             
             "When comparing groups (e.g., \"overtime vs non-overtime\", \"male vs female\"):\n"
             "✅ CORRECT - Use GROUP BY:\n"
             "  SELECT \n"
             "    overtime,\n"
             "    COUNT(*) as total,\n"
             "    COUNT(CASE WHEN attrition='Yes' THEN 1 END) as left,\n"
             "    ROUND((COUNT(CASE WHEN attrition='Yes' THEN 1 END)::numeric / COUNT(*)::numeric) * 100, 2) as attrition_rate\n"
             "  FROM employee_attrition\n"
             "  GROUP BY overtime\n\n"
             
             "═══════════════════════════════════════════════════════════\n"
             "IMPACT/CORRELATION QUERIES:\n"
             "═══════════════════════════════════════════════════════════\n"
             
             "When asked \"How does X impact Y?\" or \"How does X affect Y?\":\n"
             "✅ CORRECT - Group by X, calculate Y for each group:\n"
             "  Example: \"How does job satisfaction impact attrition?\"\n"
             "  SELECT \n"
             "    jobsatisfaction,\n"
             "    COUNT(*) as total_employees,\n"
             "    COUNT(CASE WHEN attrition='Yes' THEN 1 END) as employees_left,\n"
             "    ROUND((COUNT(CASE WHEN attrition='Yes' THEN 1 END)::numeric / COUNT(*)::numeric) * 100, 2) as attrition_rate\n"
             "  FROM employee_attrition\n"
             "  GROUP BY jobsatisfaction\n"
             "  ORDER BY jobsatisfaction\n\n"
             
             "  Example: \"How does education level affect monthly income?\"\n"
             "  SELECT \n"
             "    education,\n"
             "    COUNT(*) as employee_count,\n"
             "    ROUND(AVG(monthlyincome)::numeric, 2) as avg_income,\n"
             "    MIN(monthlyincome) as min_income,\n"
             "    MAX(monthlyincome) as max_income\n"
             "  FROM employee_attrition\n"
             "  GROUP BY education\n"
             "  ORDER BY education\n\n"
             
             "═══════════════════════════════════════════════════════════\n"
             "AGGREGATION BEST PRACTICES:\n"
             "═══════════════════════════════════════════════════════════\n"
             "- Always include COUNT(*) to show sample size\n"
             "- Use ROUND() for decimals: ROUND(value, 2)\n"
             "- Order results logically (ORDER BY)\n"
             "- Use descriptive column aliases (AS meaningful_name)\n"
             "- For categorical columns, always GROUP BY them\n\n"
             
             "═══════════════════════════════════════════════════════════\n"
             "CATEGORICAL COLUMN VALUES (Case-sensitive!):\n"
             "═══════════════════════════════════════════════════════════\n"
             "- attrition: 'Yes', 'No'\n"
             "- gender: 'Male', 'Female'\n"
             "- overtime: 'Yes', 'No'\n"
             "- businesstravel: 'Travel_Rarely', 'Travel_Frequently', 'Non-Travel'\n"
             "- maritalstatus: 'Single', 'Married', 'Divorced'\n"
             "- department: Use actual values from data\n"
             "- jobrole: Use actual values from data\n"
             "- educationfield: Use actual values from data\n\n"
             
             "═══════════════════════════════════════════════════════════\n"
             "RATING COLUMNS (1-4 or 1-5 scales):\n"
             "═══════════════════════════════════════════════════════════\n"
             "These are integers representing ratings:\n"
             "- jobsatisfaction (1-4): 1=Low, 4=High\n"
             "- environmentsatisfaction (1-4)\n"
             "- relationshipsatisfaction (1-4)\n"
             "- worklifebalance (1-4)\n"
             "- performancerating (1-4)\n"
             "- jobinvolvement (1-4)\n"
             "- education (1-5)\n"
             "- joblevel (1-5)\n"
             "- stockoptionlevel (0-3)\n\n"
             
             "For questions like \"low satisfaction\" or \"high satisfaction\":\n"
             "- Low: values 1-2\n"
             "- High: values 3-4\n\n"
             
             "═══════════════════════════════════════════════════════════\n"
             "COMMON QUESTION PATTERNS & SOLUTIONS:\n"
             "═══════════════════════════════════════════════════════════\n\n"
             
             "Pattern: \"What is the [metric] for [specific group]?\"\n"
             "Solution: Use WHERE to filter group, then calculate metric\n\n"
             
             "Pattern: \"How does [factor] impact [outcome]?\"\n"
             "Solution: GROUP BY factor, calculate outcome for each group\n\n"
             
             "Pattern: \"Compare [metric] between [group1] and [group2]\"\n"
             "Solution: GROUP BY the grouping column, calculate metric\n\n"
             
             "Pattern: \"What is the average [column]?\"\n"
             "Solution: SELECT AVG(column)::numeric\n\n"
             
             "Pattern: \"Show me distribution of [column]\"\n"
             "Solution: SELECT column, COUNT(*) GROUP BY column\n\n"
             
             "Pattern: \"Top/Bottom N by [metric]\"\n"
             "Solution: Use ORDER BY metric DESC/ASC LIMIT N\n\n"
             
             "═══════════════════════════════════════════════════════════\n"
             "DATABASE SCHEMA:\n"
             "═══════════════════════════════════════════════════════════\n"
             "{schema}\n\n"
             
             "═══════════════════════════════════════════════════════════\n"
             "DATA DICTIONARY:\n"
             "═══════════════════════════════════════════════════════════\n"
             "{data_dict}\n\n"
             
             "═══════════════════════════════════════════════════════════\n"
             "HR KPI FORMULAS:\n"
             "═══════════════════════════════════════════════════════════\n"
             "{kpi_formulas}\n\n"
             
             "═══════════════════════════════════════════════════════════\n"
             "REMEMBER:\n"
             "═══════════════════════════════════════════════════════════\n"
             "1. ALWAYS cast to ::numeric for divisions\n"
             "2. Use WHERE to filter specific groups BEFORE calculating rates\n"
             "3. Use GROUP BY to compare groups or show impact\n"
             "4. Return ONLY the SQL query - no markdown, no explanations\n"
             "5. Test your logic: denominators should match the filtered group\n"
             "═══════════════════════════════════════════════════════════\n"),
            ("user", 
             "Generate a PostgreSQL query for this question:\n{question}\n\n"
             "Return ONLY the SQL query. No markdown formatting. No explanations.")
        ])
        
        self.chain = self.prompt | self.llm | StrOutputParser()
    
    def _extract_sql(self, text: str) -> str:
        """Extract SQL from LLM response, handling markdown code blocks."""
        # Try to extract from ```sql ... ``` blocks
        sql_block = re.search(r"```sql\s*(.*?)```", text, re.IGNORECASE | re.DOTALL)
        if sql_block:
            sql = sql_block.group(1)
        else:
            # Try generic ``` ... ``` blocks
            code_block = re.search(r"```\s*(.*?)```", text, re.DOTALL)
            sql = code_block.group(1) if code_block else text
        
        return sql.strip().strip(';')
    
    def _validate_sql(self, sql: str) -> bool:
        """Validate that SQL is safe to execute."""
        sql_upper = sql.upper()
        
        # Must be a SELECT query
        if not sql_upper.strip().startswith("SELECT"):
            raise ValueError("Only SELECT queries are allowed")
        
        # Check for dangerous keywords
        dangerous_keywords = ["INSERT", "UPDATE", "DELETE", "DROP", 
                             "ALTER", "TRUNCATE", "CREATE", "EXEC"]
        for keyword in dangerous_keywords:
            if keyword in sql_upper:
                raise ValueError(f"Unsafe SQL detected: {keyword} operation not allowed")
        
        return True
    
    def generate_sql(self, question: str) -> str:
        """
        Generate SQL query from natural language question.
        
        Args:
            question: Natural language query
            
        Returns:
            SQL query string
        """
        response = self.chain.invoke({
            "schema": self.schema_info,
            "data_dict": DATA_DICTIONARY,
            "kpi_formulas": KPI_FORMULAS,
            "question": question
        })
        
        sql = self._extract_sql(response)
        self._validate_sql(sql)
        
        return sql
    
    def execute_sql(self, sql: str) -> pd.DataFrame:
        """
        Execute SQL query and return results as DataFrame.
        
        Args:
            sql: SQL query string
            
        Returns:
            pandas DataFrame with query results
        """
        # Get SQLAlchemy engine from LangChain db
        with self.db._engine.connect() as conn:
            df = pd.read_sql(sql, conn)
        
        return df
    
    def query(self, question: str, verbose: bool = True) -> pd.DataFrame:
        """
        Complete pipeline: Question → SQL → DataFrame
        
        Args:
            question: Natural language query
            verbose: If True, print generated SQL
            
        Returns:
            pandas DataFrame with results
        """
        try:
            # Generate SQL
            sql = self.generate_sql(question)
            
            if verbose:
                print("=" * 60)
                print("GENERATED SQL:")
                print("=" * 60)
                print(sql)
                print("=" * 60)
            
            # Execute and return DataFrame
            df = self.execute_sql(sql)
            
            if verbose:
                print(f"\n✅ Query executed successfully. Returned {len(df)} rows.\n")
            
            return df
            
        except Exception as e:
            print(f"❌ Error: {str(e)}")
            raise

print("✅ TextToSQLAgent class defined with ENHANCED PROMPT (v2.0)")

✅ TextToSQLAgent class defined with ENHANCED PROMPT (v2.0)


## 5. Initialize the Agent

In [6]:
# Initialize the Text-to-SQL Agent
agent = TextToSQLAgent(db=db, llm=llm)

print("✅ Text-to-SQL Agent initialized and ready!")

✅ Text-to-SQL Agent initialized and ready!


## 6. Test Examples

Now let's test the agent with various HR analytics queries:

In [7]:
# Example 1: Department-wise employee count
df1 = agent.query("How many employees are in each department?")
df1

GENERATED SQL:
SELECT
    department,
    COUNT(*) as employee_count
FROM
    employee_attrition
GROUP BY
    department
ORDER BY
    employee_count DESC

✅ Query executed successfully. Returned 3 rows.



Unnamed: 0,department,employee_count
0,Research & Development,961
1,Sales,446
2,Human Resources,63


In [8]:
# Example 2: Attrition rate by department
df2 = agent.query("What is the attrition rate for each department?")
df2

GENERATED SQL:
SELECT
    department,
    COUNT(*) as total_employees,
    COUNT(CASE WHEN attrition = 'Yes' THEN 1 END) as employees_left,
    ROUND((COUNT(CASE WHEN attrition = 'Yes' THEN 1 END)::numeric / COUNT(*)::numeric) * 100, 2) as attrition_rate
FROM employee_attrition
GROUP BY department
ORDER BY attrition_rate DESC

✅ Query executed successfully. Returned 3 rows.



Unnamed: 0,department,total_employees,employees_left,attrition_rate
0,Sales,446,92,20.63
1,Human Resources,63,12,19.05
2,Research & Development,961,133,13.84


In [9]:
# Example 3: Average salary by job role and gender
df3 = agent.query("Show me the average monthly income by job role and gender")
df3

GENERATED SQL:
SELECT
    jobrole,
    gender,
    ROUND(AVG(monthlyincome)::numeric, 2) as avg_monthly_income
FROM
    employee_attrition
GROUP BY
    jobrole,
    gender
ORDER BY
    jobrole,
    gender

✅ Query executed successfully. Returned 18 rows.



Unnamed: 0,jobrole,gender,avg_monthly_income
0,Healthcare Representative,Female,7433.8
1,Healthcare Representative,Male,7589.3
2,Human Resources,Female,4540.69
3,Human Resources,Male,4100.22
4,Laboratory Technician,Female,3246.91
5,Laboratory Technician,Male,3232.41
6,Manager,Female,16915.28
7,Manager,Male,17409.33
8,Manufacturing Director,Female,7409.17
9,Manufacturing Director,Male,7182.67


In [10]:
# Example 4: Overtime and attrition analysis
df4 = agent.query("How many employees work overtime and what's their attrition rate compared to non-overtime workers?")
df4

GENERATED SQL:
SELECT
    overtime,
    COUNT(*) as total_employees,
    COUNT(CASE WHEN attrition = 'Yes' THEN 1 END) as left,
    ROUND((COUNT(CASE WHEN attrition = 'Yes' THEN 1 END)::numeric / COUNT(*)) * 100, 2) as attrition_rate
FROM employee_attrition
GROUP BY overtime
ORDER BY overtime

✅ Query executed successfully. Returned 2 rows.



Unnamed: 0,overtime,total_employees,left,attrition_rate
0,No,1054,110,10.44
1,Yes,416,127,30.53


In [11]:
# Example 5: Work-life balance analysis
df5 = agent.query("Show me the average work-life balance rating by department and its correlation with attrition")
df5

GENERATED SQL:
SELECT
    department,
    ROUND(AVG(worklifebalance)::numeric, 2) as avg_worklifebalance,
    ROUND((COUNT(CASE WHEN attrition='Yes' THEN 1 END)::numeric / COUNT(*)::numeric) * 100, 2) as attrition_rate
FROM employee_attrition
GROUP BY department
ORDER BY avg_worklifebalance DESC

✅ Query executed successfully. Returned 3 rows.



Unnamed: 0,department,avg_worklifebalance,attrition_rate
0,Human Resources,2.92,19.05
1,Sales,2.82,20.63
2,Research & Development,2.73,13.84


In [12]:
# Example 6: Gender pay gap analysis
df6 = agent.query("Calculate the gender pay gap for each department")
df6

GENERATED SQL:
SELECT
    department,
    ROUND((AVG(monthlyincome::numeric) * 100 - AVG(CASE WHEN gender = 'Female' THEN monthlyincome::numeric ELSE NULL END)::numeric) / AVG(monthlyincome::numeric), 2) as gender_pay_gap
FROM
    employee_attrition
GROUP BY
    department
ORDER BY
    gender_pay_gap DESC

✅ Query executed successfully. Returned 3 rows.



Unnamed: 0,department,gender_pay_gap
0,Sales,99.0
1,Research & Development,98.96
2,Human Resources,98.91


## 7. Usage Guide

### Simple Usage

```python
# Complete pipeline: user query → DataFrame
df = agent.query("Your natural language question here")
```

### Advanced Usage

```python
# Without verbose output
df = agent.query("Your question", verbose=False)

# Just generate SQL without executing
sql = agent.generate_sql("Your question")
print(sql)

# Execute pre-written SQL
df = agent.execute_sql("SELECT * FROM hr_employee_attrition LIMIT 10")
```

### Features

- ✅ **Automatic SQL generation** from natural language
- ✅ **Built-in data dictionary** and KPI formulas in context
- ✅ **SQL safety validation** (SELECT-only, blocks dangerous operations)
- ✅ **Direct DataFrame output** for easy analysis
- ✅ **Verbose mode** to see generated SQL
- ✅ **Clean error handling** with informative messages

### Security Features

The agent automatically blocks:
- INSERT, UPDATE, DELETE operations
- DROP, ALTER, TRUNCATE, CREATE operations
- Non-SELECT queries
- EXEC commands

## 8. Try Your Own Queries

Use the cell below to ask your own questions:

In [13]:
# Your custom query here
df = agent.query("Your question here")
df

GENERATED SQL:
SELECT
    department,
    COUNT(*) as total_employees,
    COUNT(CASE WHEN attrition = 'Yes' THEN 1 END) as employees_left,
    ROUND((COUNT(CASE WHEN attrition = 'Yes' THEN 1 END)::numeric / COUNT(*)::numeric) * 100, 2) as attrition_rate
FROM employee_attrition
GROUP BY department
ORDER BY attrition_rate DESC

✅ Query executed successfully. Returned 3 rows.



Unnamed: 0,department,total_employees,employees_left,attrition_rate
0,Sales,446,92,20.63
1,Human Resources,63,12,19.05
2,Research & Development,961,133,13.84


In [14]:
# Your custom query here
df = agent.query("what is the male attrition rate?")
df

GENERATED SQL:
SELECT
    ROUND((COUNT(CASE WHEN gender = 'Male' AND attrition = 'Yes' THEN 1 END)::numeric / COUNT(CASE WHEN gender = 'Male' THEN 1 END)::numeric) * 100, 2) as male_attrition_rate
FROM employee_attrition

✅ Query executed successfully. Returned 1 rows.



Unnamed: 0,male_attrition_rate
0,17.01


In [15]:
# Let's verify the data - check male employees and their attrition
test_sql = """
SELECT 
    gender,
    COUNT(*) as total_employees,
    COUNT(CASE WHEN attrition = 'Yes' THEN 1 END) as left_company,
    ROUND((COUNT(CASE WHEN attrition = 'Yes' THEN 1 END)::numeric / COUNT(*)::numeric) * 100, 2) as attrition_rate
FROM employee_attrition
GROUP BY gender
"""

df_test = agent.execute_sql(test_sql)
print("Attrition by Gender:")
df_test

Attrition by Gender:


Unnamed: 0,gender,total_employees,left_company,attrition_rate
0,Female,588,87,14.8
1,Male,882,150,17.01


### 🔍 Analysis of the Issue

The LLM generated this SQL:
```sql
SELECT (COUNT(CASE WHEN attrition = 'Yes' AND gender = 'Male' THEN 1 END) / COUNT(*)) * 100
FROM employee_attrition
```

**Problems:**
1. **Wrong denominator**: Divides by ALL employees, not just males
2. **Integer division**: PostgreSQL does integer division by default (150/1470 = 0)

**Correct approach:**
- Filter for males first with WHERE clause
- Cast to numeric/decimal for proper division
- Or use `COUNT(*) FILTER (WHERE ...)` syntax

**Actual Result:** Male attrition rate is **17.01%** (150 out of 882 male employees left)

In [16]:
# Test again with the improved agent
df_male_attrition = agent.query("What is the male attrition rate?")
df_male_attrition

GENERATED SQL:
SELECT
    ROUND((COUNT(CASE WHEN gender = 'Male' AND attrition = 'Yes' THEN 1 END)::numeric / COUNT(CASE WHEN gender = 'Male' THEN 1 END)::numeric) * 100, 2) as male_attrition_rate
FROM employee_attrition

✅ Query executed successfully. Returned 1 rows.



Unnamed: 0,male_attrition_rate
0,17.01


In [17]:
# Test again with the improved agent
df_male_attrition = agent.query("What is the male attrition rate?")
df_male_attrition

GENERATED SQL:
SELECT
    ROUND((COUNT(CASE WHEN gender = 'Male' AND attrition = 'Yes' THEN 1 END)::numeric / COUNT(CASE WHEN gender = 'Male' THEN 1 END)::numeric) * 100, 2) as male_attrition_rate
FROM employee_attrition

✅ Query executed successfully. Returned 1 rows.



Unnamed: 0,male_attrition_rate
0,17.01


In [18]:
# Test again with the improved agent
df_male_attrition = agent.query("What is the male attrition rate?")
df_male_attrition

GENERATED SQL:
SELECT
    ROUND((COUNT(CASE WHEN gender = 'Male' AND attrition = 'Yes' THEN 1 END)::numeric / COUNT(CASE WHEN gender = 'Male' THEN 1 END)::numeric) * 100, 2) as male_attrition_rate
FROM employee_attrition

✅ Query executed successfully. Returned 1 rows.



Unnamed: 0,male_attrition_rate
0,17.01


In [19]:
# Test again with the improved agent
df_male_attrition = agent.query("What is the male attrition rate?")
df_male_attrition

GENERATED SQL:
SELECT
    ROUND((COUNT(CASE WHEN gender = 'Male' AND attrition = 'Yes' THEN 1 END)::numeric / COUNT(CASE WHEN gender = 'Male' THEN 1 END)::numeric) * 100, 2) as male_attrition_rate
FROM employee_attrition

✅ Query executed successfully. Returned 1 rows.



Unnamed: 0,male_attrition_rate
0,17.01


In [20]:
# Test again with the improved agent
df_male_attrition = agent.query("What is the male attrition rate?")
df_male_attrition

GENERATED SQL:
SELECT
    ROUND((COUNT(CASE WHEN gender = 'Male' AND attrition = 'Yes' THEN 1 END)::numeric / COUNT(CASE WHEN gender = 'Male' THEN 1 END)::numeric) * 100, 2) as male_attrition_rate
FROM employee_attrition

✅ Query executed successfully. Returned 1 rows.



Unnamed: 0,male_attrition_rate
0,17.01


In [21]:
# Test again with the improved agent
df_male_attrition = agent.query("What is the male attrition rate?")
df_male_attrition

GENERATED SQL:
SELECT
    ROUND((COUNT(CASE WHEN gender = 'Male' AND attrition = 'Yes' THEN 1 END)::numeric / COUNT(CASE WHEN gender = 'Male' THEN 1 END)::numeric) * 100, 2) as male_attrition_rate
FROM employee_attrition

✅ Query executed successfully. Returned 1 rows.



Unnamed: 0,male_attrition_rate
0,17.01


In [22]:
# Your custom query here
df = agent.query("What is the work from home attrition rate?")
df

GENERATED SQL:
SELECT
    ROUND((COUNT(CASE WHEN attrition = 'Yes' AND businesstravel = 'Non-Travel' THEN 1 END)::numeric / COUNT(*)::numeric) * 100, 2) as work_from_home_attrition_rate
FROM employee_attrition

✅ Query executed successfully. Returned 1 rows.



Unnamed: 0,work_from_home_attrition_rate
0,0.82


In [23]:
# Let's check what values exist in the businesstravel column
check_sql = """
SELECT 
    businesstravel,
    COUNT(*) as count,
    COUNT(CASE WHEN attrition = 'Yes' THEN 1 END) as attrition_count,
    ROUND((COUNT(CASE WHEN attrition = 'Yes' THEN 1 END)::numeric / COUNT(*)::numeric) * 100, 2) as attrition_rate
FROM employee_attrition
GROUP BY businesstravel
ORDER BY count DESC
"""

df_business_travel = agent.execute_sql(check_sql)
print("Business Travel Categories and Attrition:")
df_business_travel

Business Travel Categories and Attrition:


Unnamed: 0,businesstravel,count,attrition_count,attrition_rate
0,Travel_Rarely,1043,156,14.96
1,Travel_Frequently,277,69,24.91
2,Non-Travel,150,12,8.0


### 🔍 Important Note about "Work From Home"

The database **does not have a "work from home" category**. The `businesstravel` column has:
- `Travel_Rarely` - Employees who travel occasionally
- `Travel_Frequently` - Employees who travel often
- `Non-Travel` - Employees who don't travel (closest to office-based/remote work)

If you're asking about "work from home attrition," you likely mean **`Non-Travel`** employees.

In [24]:
# Correct query: attrition rate for Non-Travel employees
df_non_travel = agent.query("What is the attrition rate for employees with businesstravel = 'Non-Travel'?")
df_non_travel

GENERATED SQL:
SELECT
    ROUND((COUNT(CASE WHEN attrition = 'Yes' THEN 1 END)::numeric / COUNT(*)::numeric) * 100, 2) as non_travel_attrition_rate
FROM employee_attrition
WHERE businesstravel = 'Non-Travel'

✅ Query executed successfully. Returned 1 rows.



Unnamed: 0,non_travel_attrition_rate
0,8.0


### ✅ Results Summary

**Business Travel Categories & Attrition Rates:**
- **Non-Travel**: 8.00% attrition (12 out of 150 employees)
- **Travel_Rarely**: 14.96% attrition (156 out of 1,043 employees)
- **Travel_Frequently**: 24.91% attrition (69 out of 277 employees)

**Key Insight:** Employees who travel frequently have the **highest attrition rate (24.91%)**, while non-traveling employees have the **lowest (8.00%)**. This suggests that frequent business travel may be a factor in employee turnover.

In [25]:
# Your custom query here
df = agent.query("what is the male attrition rate?")
df

GENERATED SQL:
SELECT
    ROUND((COUNT(CASE WHEN gender = 'Male' AND attrition = 'Yes' THEN 1 END)::numeric / COUNT(CASE WHEN gender = 'Male' THEN 1 END)::numeric) * 100, 2) as male_attrition_rate
FROM employee_attrition

✅ Query executed successfully. Returned 1 rows.



Unnamed: 0,male_attrition_rate
0,17.01


In [26]:
# Your custom query here
df = agent.query("how does job satisfaction impact attrition rates by department?")
df

GENERATED SQL:
SELECT
    jobsatisfaction,
    department,
    COUNT(*) as total_employees,
    COUNT(CASE WHEN attrition = 'Yes' THEN 1 END) as employees_left,
    ROUND((COUNT(CASE WHEN attrition = 'Yes' THEN 1 END)::numeric / COUNT(*)::numeric) * 100, 2) as attrition_rate
FROM employee_attrition
GROUP BY jobsatisfaction, department
ORDER BY department

✅ Query executed successfully. Returned 12 rows.



Unnamed: 0,jobsatisfaction,department,total_employees,employees_left,attrition_rate
0,2,Human Resources,20,2,10.0
1,3,Human Resources,15,3,20.0
2,1,Human Resources,11,5,45.45
3,4,Human Resources,17,2,11.76
4,3,Research & Development,300,43,14.33
5,4,Research & Development,295,28,9.49
6,2,Research & Development,174,24,13.79
7,1,Research & Development,192,38,19.79
8,1,Sales,86,23,26.74
9,2,Sales,86,20,23.26


## 9. Dynamic Plotly Visualization Generator

This section adds automatic visualization generation:
- Takes the DataFrame from SQL query results
- Uses LLM to generate appropriate Plotly visualization code
- Executes the generated code to display interactive charts

**Flow:** `DataFrame → Data Summary → LLM → Python Code → Visualization`

In [46]:
import plotly.express as px
import plotly.graph_objects as go
from io import StringIO

class PlotlyVisualizationGenerator:
    """
    Generates Plotly visualization code dynamically using LLM based on DataFrame content.
    
    Flow: DataFrame → Data Summary → LLM → Python Code → Visualization
    """
    
    def __init__(self, llm):
        """
        Initialize with LLM instance.
        
        Args:
            llm: LangChain ChatOpenAI instance
        """
        self.llm = llm
        self._setup_prompt()
    
    def _setup_prompt(self):
        """Setup the visualization generation prompt."""
        self.prompt = ChatPromptTemplate.from_messages([
            ("system",
             "You are an EXPERT Python Plotly visualization developer. Your job is to generate COMPLETE, "
             "EXECUTABLE Python code that creates appropriate Plotly visualizations.\n\n"
             
             "═══════════════════════════════════════════════════════════\n"
             "CRITICAL RULES:\n"
             "═══════════════════════════════════════════════════════════\n"
             "1. Return ONLY executable Python code - NO markdown, NO explanations, NO comments before/after\n"
             "2. Code must be complete and ready to execute\n"
             "3. Assume 'df' variable already exists with the data\n"
             "4. Import statements should be included if needed (plotly.express as px, plotly.graph_objects as go)\n"
             "5. The code MUST create a variable called 'fig' containing the Plotly figure\n"
             "6. End with 'fig.show()' to display the visualization\n"
             "7. Choose the MOST APPROPRIATE chart type based on the data\n"
             "8. Use clear titles, labels, and formatting\n"
             "9. Add hover data for interactivity when relevant\n"
             "10. Use color schemes that are professional and accessible\n"
             "11. IMPORTANT: Use ONLY columns that exist in df - check the data summary!\n\n"
             
             "═══════════════════════════════════════════════════════════\n"
             "CHART TYPE SELECTION GUIDE:\n"
             "═══════════════════════════════════════════════════════════\n"
             "- **Bar Chart**: Comparing categories (departments, job roles, etc.)\n"
             "- **Grouped/Stacked Bar**: Comparing categories with subcategories (gender by department)\n"
             "- **Line Chart**: Trends over time or ordered categories\n"
             "- **Scatter Plot**: Correlation between two numerical variables\n"
             "- **Box Plot**: Distribution and outliers in numerical data\n"
             "- **Heatmap**: Correlation matrix or 2D categorical data\n\n"
             
             "═══════════════════════════════════════════════════════════\n"
             "COMMON PATTERNS:\n"
             "═══════════════════════════════════════════════════════════\n\n"
             
             "**Pattern 1: Simple Bar Chart**\n"
             "```python\n"
             "import plotly.express as px\n"
             "fig = px.bar(df, x='category_column', y='value_column', \n"
             "             title='Title Here',\n"
             "             labels={{'category_column': 'Label', 'value_column': 'Label'}})\n"
             "fig.update_layout(xaxis_tickangle=-45)\n"
             "fig.show()\n"
             "```\n\n"
             
             "**Pattern 2: Grouped Bar Chart**\n"
             "```python\n"
             "import plotly.express as px\n"
             "fig = px.bar(df, x='category', y='value', color='group',\n"
             "             title='Title Here', barmode='group')\n"
             "fig.show()\n"
             "```\n\n"
             
             "═══════════════════════════════════════════════════════════\n"
             "BEST PRACTICES:\n"
             "═══════════════════════════════════════════════════════════\n"
             "- Always add meaningful titles and axis labels\n"
             "- Use texttemplate to show values on bars/points when appropriate\n"
             "- Sort data logically (by value, alphabetically, or naturally)\n"
             "- Use appropriate color scales (RdYlGn for good/bad, Blues for intensity)\n"
             "- Add hover_data for additional context\n"
             "- Format percentages with .2f or .1f\n"
             "- Rotate x-axis labels if they're long (xaxis_tickangle=-45)\n\n"
             
             "═══════════════════════════════════════════════════════════\n"
             "DATA SUMMARY:\n"
             "═══════════════════════════════════════════════════════════\n"
             "{data_summary}\n\n"
             
             "═══════════════════════════════════════════════════════════\n"
             "DATA DICTIONARY CONTEXT:\n"
             "═══════════════════════════════════════════════════════════\n"
             "{data_dictionary}\n\n"
             
             "Remember: Return ONLY Python code. No markdown code blocks, no explanations.\n"),
            ("user",
             "Create a Plotly visualization for this data:\n\n"
             "Original Question: {original_question}\n\n"
             "Generate complete, executable Python code that creates the most appropriate visualization.\n"
             "The code must create a 'fig' variable and end with 'fig.show()'.")
        ])
        
        self.chain = self.prompt | self.llm | StrOutputParser()
    
    def _get_data_summary(self, df: pd.DataFrame) -> str:
        """
        Generate a summary of the DataFrame for the LLM.
        
        Args:
            df: pandas DataFrame
            
        Returns:
            String summary with head, info, and basic stats
        """
        # Capture df.info() output
        buffer = StringIO()
        df.info(buf=buffer)
        info_str = buffer.getvalue()
        
        # Build summary
        summary = f"""
DATAFRAME SHAPE: {df.shape[0]} rows × {df.shape[1]} columns

COLUMN NAMES AND TYPES:
{info_str}

FIRST 5 ROWS:
{df.head().to_string()}

BASIC STATISTICS:
{df.describe().to_string()}
"""
        return summary
    
    def _extract_code(self, text: str) -> str:
        """Extract Python code from LLM response."""
        # Try to extract from ```python ... ``` blocks
        python_block = re.search(r"```python\s*(.*?)```", text, re.IGNORECASE | re.DOTALL)
        if python_block:
            return python_block.group(1).strip()
        
        # Try generic ``` ... ``` blocks
        code_block = re.search(r"```\s*(.*?)```", text, re.DOTALL)
        if code_block:
            return code_block.group(1).strip()
        
        # Return as-is if no code blocks found
        return text.strip()
    
    def generate_code(self, df: pd.DataFrame, original_question: str = "") -> str:
        """
        Generate Plotly visualization code based on DataFrame.
        
        Args:
            df: pandas DataFrame with data to visualize
            original_question: The original user question (for context)
            
        Returns:
            Python code string
        """
        data_summary = self._get_data_summary(df)
        
        response = self.chain.invoke({
            "data_summary": data_summary,
            "data_dictionary": DATA_DICTIONARY,
            "original_question": original_question or "Visualize this data"
        })
        
        code = self._extract_code(response)
        return code
    
    def visualize(self, df: pd.DataFrame, original_question: str = "", verbose: bool = True):
        """
        Complete pipeline: DataFrame → Generate Code → Execute → Show Plot
        
        Args:
            df: pandas DataFrame with data
            original_question: Original user question for context
            verbose: If True, print generated code
            
        Returns:
            The Plotly figure object
        """
        try:
            # Special handling for single-value DataFrames (1 row, 1 column)
            if df.shape[0] == 1 and df.shape[1] == 1:
                value = df.iloc[0, 0]
                column_name = df.columns[0]
                title = column_name.replace('_', ' ').title()
                
                # Determine if it's a percentage/rate
                is_percentage = 'rate' in column_name.lower() or 'percent' in column_name.lower()
                
                # Create a gauge/indicator chart for single values
                code = f"""import plotly.graph_objects as go

value = {value}
title_text = "{title}"

fig = go.Figure(go.Indicator(
    mode='number+gauge',
    value=value,
    title={{'text': title_text, 'font': {{'size': 24}}}},
    number={{'suffix': '%' if {is_percentage} else '', 'font': {{'size': 48}}}},
    gauge={{
        'axis': {{'range': [0, 100]}},
        'bar': {{'color': 'darkblue'}},
        'steps': [
            {{'range': [0, 10], 'color': 'lightgreen'}},
            {{'range': [10, 20], 'color': 'lightyellow'}},
            {{'range': [20, 100], 'color': 'lightcoral'}}
        ],
        'threshold': {{
            'line': {{'color': 'red', 'width': 4}},
            'thickness': 0.75,
            'value': value
        }}
    }}
))
fig.update_layout(height=400)
fig.show()
"""
            else:
                # Generate code using LLM for multi-row/multi-column data
                code = self.generate_code(df, original_question)
            
            if verbose:
                print("=" * 60)
                print("GENERATED PLOTLY CODE:")
                print("=" * 60)
                print(code)
                print("=" * 60)
                print()
            
            # Execute code
            # Create a namespace with df and necessary imports
            namespace = {
                'df': df,
                'px': px,
                'go': go,
                'pd': pd
            }
            
            exec(code, namespace)
            
            # Return the figure (should be stored in 'fig' variable)
            if 'fig' in namespace:
                if verbose:
                    print("✅ Visualization generated successfully!\n")
                return namespace['fig']
            else:
                raise ValueError("Generated code did not create a 'fig' variable")
                
        except Exception as e:
            print(f"❌ Error generating visualization: {str(e)}")
            raise

print("✅ PlotlyVisualizationGenerator class defined")

✅ PlotlyVisualizationGenerator class defined


In [48]:
# Initialize the Plotly Visualization Generator
viz_generator = PlotlyVisualizationGenerator(llm=llm)

print("✅ Plotly Visualization Generator initialized and ready!")

✅ Plotly Visualization Generator initialized and ready!


## 10. Complete Text-to-SQL + Visualization Pipeline

Now let's create a complete function that:
1. Takes a natural language question
2. Generates and executes SQL query
3. Automatically creates an appropriate visualization

In [41]:
def query_and_visualize(question: str, verbose: bool = True):
    """
    Complete pipeline: Natural Language → SQL → DataFrame → Visualization
    
    Args:
        question: Natural language question about HR data
        verbose: If True, print SQL and generated code
        
    Returns:
        tuple: (DataFrame, Plotly Figure)
    """
    print("🔍 Processing question:", question)
    print()
    
    # Step 1: Generate SQL and get DataFrame
    df = agent.query(question, verbose=verbose)
    
    print(f"📊 Retrieved {len(df)} rows with {len(df.columns)} columns")
    print()
    
    # Step 2: Generate and display visualization
    fig = viz_generator.visualize(df, original_question=question, verbose=verbose)
    
    return df, fig

print("✅ Complete pipeline function 'query_and_visualize()' ready!")

✅ Complete pipeline function 'query_and_visualize()' ready!


## 11. Test the Complete Pipeline

Let's test with various questions to see automatic visualizations:

In [52]:
# Example 1: Department-wise attrition rate
df1, fig1 = query_and_visualize("What is the attrition rate of male and female employees?")

🔍 Processing question: What is the attrition rate of male and female employees?

GENERATED SQL:
SELECT
    gender,
    COUNT(*) as total_employees,
    COUNT(CASE WHEN attrition = 'Yes' THEN 1 END) as left,
    ROUND((COUNT(CASE WHEN attrition = 'Yes' THEN 1 END)::numeric / COUNT(*)::numeric) * 100, 2) as attrition_rate
FROM employee_attrition
GROUP BY gender
ORDER BY gender

✅ Query executed successfully. Returned 2 rows.

📊 Retrieved 2 rows with 4 columns

GENERATED SQL:
SELECT
    gender,
    COUNT(*) as total_employees,
    COUNT(CASE WHEN attrition = 'Yes' THEN 1 END) as left,
    ROUND((COUNT(CASE WHEN attrition = 'Yes' THEN 1 END)::numeric / COUNT(*)::numeric) * 100, 2) as attrition_rate
FROM employee_attrition
GROUP BY gender
ORDER BY gender

✅ Query executed successfully. Returned 2 rows.

📊 Retrieved 2 rows with 4 columns

GENERATED PLOTLY CODE:
import plotly.express as px

# Create a grouped bar chart for attrition rate by gender
fig = px.bar(df, x='gender', y='attrition_rat

✅ Visualization generated successfully!



In [44]:
# Example 2: Job satisfaction impact on attrition
df2, fig2 = query_and_visualize("How does job satisfaction impact attrition?")

🔍 Processing question: How does job satisfaction impact attrition?

GENERATED SQL:
SELECT
    jobsatisfaction,
    COUNT(*) as total_employees,
    COUNT(CASE WHEN attrition='Yes' THEN 1 END) as employees_left,
    ROUND((COUNT(CASE WHEN attrition='Yes' THEN 1 END)::numeric / COUNT(*)::numeric) * 100, 2) as attrition_rate
FROM employee_attrition
GROUP BY jobsatisfaction
ORDER BY jobsatisfaction

✅ Query executed successfully. Returned 4 rows.

📊 Retrieved 4 rows with 4 columns

GENERATED SQL:
SELECT
    jobsatisfaction,
    COUNT(*) as total_employees,
    COUNT(CASE WHEN attrition='Yes' THEN 1 END) as employees_left,
    ROUND((COUNT(CASE WHEN attrition='Yes' THEN 1 END)::numeric / COUNT(*)::numeric) * 100, 2) as attrition_rate
FROM employee_attrition
GROUP BY jobsatisfaction
ORDER BY jobsatisfaction

✅ Query executed successfully. Returned 4 rows.

📊 Retrieved 4 rows with 4 columns

GENERATED PLOTLY CODE:
import plotly.express as px

# Create scatter plot with jobsatisfaction on x-ax

✅ Visualization generated successfully!



In [None]:
# Example 3: Overtime impact on attrition
df3, fig3 = query_and_visualize("Compare attrition rates between overtime and non-overtime employees")

In [None]:
# Example 4: Gender pay analysis by department
df4, fig4 = query_and_visualize("Show average monthly income by department and gender")

In [None]:
# Example 5: Business travel and attrition
df5, fig5 = query_and_visualize("How does business travel frequency affect attrition rates?")

## 12. Try Your Own Question + Visualization

Use the cell below to ask your own question and get automatic visualization:

In [53]:
# Your custom question here
my_question = "what is the attrition rate for male and female employees?"
df_custom, fig_custom = query_and_visualize(my_question)

🔍 Processing question: what is the attrition rate for male and female employees?

GENERATED SQL:
SELECT
    gender,
    COUNT(*) as total_employees,
    COUNT(CASE WHEN attrition = 'Yes' THEN 1 END) as left,
    ROUND((COUNT(CASE WHEN attrition = 'Yes' THEN 1 END)::numeric / COUNT(*)::numeric) * 100, 2) as attrition_rate
FROM employee_attrition
GROUP BY gender
ORDER BY gender

✅ Query executed successfully. Returned 2 rows.

📊 Retrieved 2 rows with 4 columns

GENERATED SQL:
SELECT
    gender,
    COUNT(*) as total_employees,
    COUNT(CASE WHEN attrition = 'Yes' THEN 1 END) as left,
    ROUND((COUNT(CASE WHEN attrition = 'Yes' THEN 1 END)::numeric / COUNT(*)::numeric) * 100, 2) as attrition_rate
FROM employee_attrition
GROUP BY gender
ORDER BY gender

✅ Query executed successfully. Returned 2 rows.

📊 Retrieved 2 rows with 4 columns

GENERATED PLOTLY CODE:
import plotly.express as px

# Create a grouped bar chart for attrition rate by gender
fig = px.bar(df, x='gender', y='attrition_ra

✅ Visualization generated successfully!



## 13. Advanced Usage - Manual Control

You can also use the components separately for more control:

In [None]:
# Step 1: Get data from SQL query
question = "Show me attrition rate by job role"
df_manual = agent.query(question)

print("Data retrieved:")
display(df_manual)

# Step 2: Generate visualization code (without executing)
viz_code = viz_generator.generate_code(df_manual, original_question=question)

print("\nGenerated visualization code:")
print("=" * 60)
print(viz_code)
print("=" * 60)

# Step 3: Execute the visualization manually
exec(viz_code)

## 14. Summary & Features

### 🎯 What We Built

**Complete Text-to-SQL + Visualization Pipeline:**
1. **TextToSQLAgent**: Converts natural language → SQL → DataFrame
2. **PlotlyVisualizationGenerator**: Converts DataFrame → Python Code → Interactive Chart
3. **Integrated Pipeline**: One function call gets you data + visualization

### ✨ Key Features

**Text-to-SQL Agent:**
- Natural language to PostgreSQL queries
- Built-in HR data dictionary and KPI context
- SQL safety validation (SELECT-only)
- Smart handling of percentages and rates
- Proper GROUP BY for comparisons

**Visualization Generator:**
- Automatic chart type selection (bar, grouped bar, line, scatter, etc.)
- Context-aware based on data structure
- Professional styling and formatting
- Interactive Plotly charts
- Executable Python code generation

**Security:**
- No INSERT/UPDATE/DELETE allowed
- SQL injection protection
- Validated query execution

### 📊 Usage Patterns

```python
# Simple: Question → Data + Chart
df, fig = query_and_visualize("Your question here")

# SQL Only: Question → Data
df = agent.query("Your question here")

# Visualization Only: Data → Chart
fig = viz_generator.visualize(df, "context about the data")

# Just generate code (don't execute)
code = viz_generator.generate_code(df, "context")
```

### 🚀 Next Steps

- Test with your own questions
- Modify visualization styles in generated code
- Extend to other databases or datasets
- Add more chart types to the prompt
- Export visualizations as images or HTML