<h1>DB Connection</h1>

In [77]:
import sqlite3
from langchain.sql_database import SQLDatabase


# Connect to the database
db_name = "retail_sales.db"

db = SQLDatabase.from_uri(
    f"sqlite:///{db_name}",
    #sample_rows_in_table_info=5,
)

table_info = db.get_table_info()
print("Table Info:", table_info)

Table Info: 
CREATE TABLE sales_tb (
	"TransactionID" INTEGER, 
	"Date" DATE, 
	"CustomerID" VARCHAR(10), 
	"Gender" VARCHAR(10), 
	"Age" INTEGER, 
	"ProductCategory" VARCHAR(50), 
	"Quantity" INTEGER, 
	"PriceperUnit" DECIMAL(10, 2), 
	"TotalAmount" DECIMAL(10, 2)
)

/*
3 rows from sales_tb table:
TransactionID	Date	CustomerID	Gender	Age	ProductCategory	Quantity	PriceperUnit	TotalAmount
1	2023-11-24	CUST001	Male	34	Beauty	3	50.00	150.00
2	2023-02-27	CUST002	Female	26	Clothing	2	500.00	1000.00
3	2023-01-13	CUST003	Male	50	Electronics	1	30.00	30.00
*/


<h1>LLM Connection</h1>

In [54]:
from langchain_community.chat_models import ChatOllama

llm = ChatOllama(
    model="phi3",
    base_url="http://localhost:11434"
)

<h1>Question to SQL query </h1>

In [58]:
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many customers are there"})
response

'Question: How many customers are there?\nSQLQuery: SELECT COUNT(DISTINCT "CustomerID") FROM sales_tb;'

In [65]:
#clean the query
import re

match = re.search(r'SQLQuery: (.*)', response)
if match:
    sql_query = match.group(1)
    print(sql_query)

SELECT COUNT(DISTINCT "CustomerID") FROM sales_tb;


In [69]:
# Execute the cleaned query
result = db.run(sql_query)
print(result)

[(29,)]


In [56]:
#Configure LangChain to Return Only SQL

chain = create_sql_query_chain(llm, db)
chain.get_prompts()[0].template = """You are a helpfull assistant, given an input question, only create a syntactically correct SQLite query to run. Do not write unnesesary things

Question: {input}
SQLQuery:"""

response = chain.invoke({"question": "How many unique customers are there for each product category"})
print(response)

SELECT ProductCategory, COUNT(DISTINCT CustomerID) AS UniqueCustomerCount
FROM Sales
GROUP BY ProductCategory;


In [71]:
# Execute the cleaned query
result = db.run(response)
print(result)

<h1>Automate the Process</h1>

In [103]:
from langchain.chains import create_sql_query_chain
from langchain.prompts import PromptTemplate

# Step 1: Create enhanced prompt
enhanced_prompt = PromptTemplate(
    input_variables=["input", "table_info", "top_k"],
    template="""You are a SQLite expert. Given an input question, create a syntactically correct SQLite query.

IMPORTANT RULES:
- Return ONLY the SQL query
- No explanations, no answers, no additional text
- Use proper SQLite syntax
- Use double quotes for column names with special characters
- Always end with semicolon
- If applicable, limit results to top {top_k} rows

Table Schema:
{table_info}

Question: {input}

"""
)

# Step 2: Create the chain
chain = create_sql_query_chain(llm, db, prompt=enhanced_prompt)

# Step 3: Enhanced execution function
def execute_query(question):
    try:
        response = chain.invoke({"question": question})
        print("###################################################")
        print(f"LLM Response: {response}")
        print("###################################################")
        
        # Simple cleaning - just remove outer quotes
        cleaned_query = response.strip().strip("'\"")
        print("###################################################")
        print(f"Cleaned Query: {cleaned_query}")
        print("###################################################")        
        
        result = db.run(cleaned_query)
        print("###################################################")        
        print(f"Result: {result}")
        print("###################################################")
        return result
        print("###################################################")
        
    except Exception as e:
        print(f"An error occurred: {e}")
        return None

def clean_and_validate_query(response):
    """Clean and validate SQL query from LLM response"""
    
    # Remove any wrapper text and extract SQL
    response = response.strip()
    
    # Try multiple extraction methods
    patterns = [
        r'SQLQuery:\s*(.*?)(?:\n|$)',
        r'Query:\s*(.*?)(?:\n|$)',
        r'(SELECT.*?;)',
        r'(INSERT.*?;)',
        r'(UPDATE.*?;)',
        r'(DELETE.*?;)'
    ]
    
    extracted_query = None
    
    for pattern in patterns:
        match = re.search(pattern, response, re.IGNORECASE | re.DOTALL)
        if match:
            extracted_query = match.group(1).strip()
            break
    
    # If no pattern matched, assume the entire response is the query
    if not extracted_query:
        extracted_query = response
    
    # Clean the query
    extracted_query = extracted_query.strip()
    extracted_query = re.sub(r'\s+', ' ', extracted_query)
    
    # Ensure proper semicolon
    if not extracted_query.endswith(';'):
        extracted_query += ';'
    
    # Basic validation
    if validate_sql_query(extracted_query):
        return extracted_query
    else:
        print(f"Invalid query detected: {extracted_query}")
        return None

# Test the improved function
execute_query("How many unique customers are there for each product category?")

###################################################
LLM Response: SELECT "ProductCategory", COUNT(DISTINCT "CustomerID") AS UniqueCustomers
FROM sales_tb
GROUP BY "ProductCategory";
###################################################
###################################################
Cleaned Query: SELECT "ProductCategory", COUNT(DISTINCT "CustomerID") AS UniqueCustomers
FROM sales_tb
GROUP BY "ProductCategory";
###################################################
###################################################
Result: [('Beauty', 8), ('Clothing', 13), ('Electronics', 8)]
###################################################


"[('Beauty', 8), ('Clothing', 13), ('Electronics', 8)]"

In [89]:
q1 = "How many customers are there"
execute_query(q1)

LLM Response: SELECT COUNT(DISTINCT "CustomerID") FROM sales_tb;
###################################################
Cleaned Query: SELECT COUNT(DISTINCT "CustomerID") FROM sales_tb;
###################################################
###################################################
Result: [(29,)]


'[(29,)]'

In [91]:
q2 = "Calculate total sales amount per product category:"
execute_query(q2)

LLM Response: SELECT "ProductCategory", SUM("TotalAmount") AS TotalSalesPerCategory FROM sales_tb GROUP BY "ProductCategory";
###################################################
Cleaned Query: SELECT "ProductCategory", SUM("TotalAmount") AS TotalSalesPerCategory FROM sales_tb GROUP BY "ProductCategory";
###################################################
###################################################
Result: [('Beauty', 1455), ('Clothing', 5040), ('Electronics', 5310)]


"[('Beauty', 1455), ('Clothing', 5040), ('Electronics', 5310)]"

In [93]:
q3 = "calculates the average age of customers grouped by gender."
execute_query(q3)

LLM Response: ```sql
SELECT "Gender", AVG("Age") as AverageAge FROM sales_tb GROUP BY "Gender";
```
###################################################
Cleaned Query: ```sql
SELECT "Gender", AVG("Age") as AverageAge FROM sales_tb GROUP BY "Gender";
```
###################################################
An error occurred: (sqlite3.OperationalError) near "```sql
SELECT "Gender", AVG("Age") as AverageAge FROM sales_tb GROUP BY "Gender";
```": syntax error
[SQL: ```sql
SELECT "Gender", AVG("Age") as AverageAge FROM sales_tb GROUP BY "Gender";
```]
(Background on this error at: https://sqlalche.me/e/20/e3q8)


In [94]:
q4 = "identify the top spending customers based on their total amount spent."
execute_query(q4)

LLM Response: SELECT "CustomerID", SUM("TotalAmount") AS TotalSpent
FROM sales_tb
GROUP BY "CustomerID"
ORDER BY TotalSpent DESC
LIMIT 5;
###################################################
Cleaned Query: SELECT "CustomerID", SUM("TotalAmount") AS TotalSpent
FROM sales_tb
GROUP BY "CustomerID"
ORDER BY TotalSpent DESC
LIMIT 5;
###################################################
###################################################
Result: [('CUST015', 2000), ('CUST016', 1500), ('CUST013', 1500), ('CUST026', 1000), ('CUST002', 1000)]


"[('CUST015', 2000), ('CUST016', 1500), ('CUST013', 1500), ('CUST026', 1000), ('CUST002', 1000)]"

In [95]:
q5 = "counts the number of transactions made each month."
execute_query(q5)

LLM Response: SELECT strftime('%Y-%m', "Date") AS Month, COUNT(*) AS TransactionCount FROM sales_tb GROUP BY Month ORDER BY Month;
###################################################
Cleaned Query: SELECT strftime('%Y-%m', "Date") AS Month, COUNT(*) AS TransactionCount FROM sales_tb GROUP BY Month ORDER BY Month;
###################################################
###################################################
Result: [('2023-01', 4), ('2023-02', 4), ('2023-03', 1), ('2023-04', 5), ('2023-05', 2), ('2023-08', 3), ('2023-09', 1), ('2023-10', 4), ('2023-11', 3), ('2023-12', 2)]


"[('2023-01', 4), ('2023-02', 4), ('2023-03', 1), ('2023-04', 5), ('2023-05', 2), ('2023-08', 3), ('2023-09', 1), ('2023-10', 4), ('2023-11', 3), ('2023-12', 2)]"

In [96]:
q6 = "calculates the total sales amount and average price per unit for each product category."
execute_query(q6)

LLM Response: ```sql
SELECT "ProductCategory", SUM("TotalAmount") AS TotalSales, AVG("PriceperUnit") AS AveragePricePerUnit 
FROM sales_tb 
GROUP BY "ProductCategory";
```
###################################################
Cleaned Query: ```sql
SELECT "ProductCategory", SUM("TotalAmount") AS TotalSales, AVG("PriceperUnit") AS AveragePricePerUnit 
FROM sales_tb 
GROUP BY "ProductCategory";
```
###################################################
An error occurred: (sqlite3.OperationalError) near "```sql
SELECT "ProductCategory", SUM("TotalAmount") AS TotalSales, AVG("PriceperUnit") AS AveragePricePerUnit 
FROM sales_tb 
GROUP BY "ProductCategory";
```": syntax error
[SQL: ```sql
SELECT "ProductCategory", SUM("TotalAmount") AS TotalSales, AVG("PriceperUnit") AS AveragePricePerUnit 
FROM sales_tb 
GROUP BY "ProductCategory";
```]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
