In [1]:
pip install openai



In [2]:
import re
from datetime import datetime
from typing import Dict, List
from pprint import pprint
from openai import OpenAI
import os
import json

In [3]:
from google.colab import userdata
import os
from openai import OpenAI

# Access the API key from Colab's Secrets Manager
client = OpenAI(
  api_key=userdata.get('OPENAI_API_KEY')
)

In [4]:
SYSTEM_PROMPT = """
You are a query planner assistant. Your job is to convert natural language questions into structured query plans.

Use the following JSON format:
{
  "intent": "ranking" | "trend" | "filter" | "comparison",
  "table": "sales_data",
  "filters": [ { "field": ..., "operator": ..., "value": ... } ],
  "metrics": [ { "name": ..., "aggregation": ..., "alias": ... } ],
  "group_by": [...],
  "sort": [ { "field": ..., "order": ... } ],
  "limit": ...,
  "original_user_query": ...
}
Only return valid JSON.
"""

In [5]:
def get_query_plan_with_llm(user_query: str) -> dict:
    response = client.chat.completions.create(
        model="gpt-3.5-turbo",  # or gpt-3.5-turbo
        messages=[
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": user_query}
        ],
        temperature=0.2
    )
    content = response.choices[0].message.content.strip()

    try:
        plan = json.loads(content)
        return plan
    except json.JSONDecodeError:
        print("❌ Failed to parse LLM output as JSON.")
        print(content)
        return {}

In [6]:
user_query = "Show me the last 15 products with highest sales in March 2024"

In [7]:
plan = get_query_plan_with_llm(user_query)

from pprint import pprint
pprint(plan)

{'filters': [{'field': 'date',
              'operator': 'between',
              'value': ['2024-03-01', '2024-03-31']}],
 'group_by': ['product'],
 'intent': 'ranking',
 'limit': 15,
 'metrics': [{'aggregation': 'sum', 'alias': 'total_sales', 'name': 'product'}],
 'original_user_query': 'Show me the last 15 products with highest sales in '
                        'March 2024',
 'sort': [{'field': 'total_sales', 'order': 'desc'}],
 'table': 'sales_data'}


In [8]:
 def extract_query_info(prompt: str) -> Dict:
    # Simple intent rules
    intent = "ranking" if "top" in prompt.lower() else "unknown"

    # Extract limit
    match_limit = re.search(r"top (\d+)", prompt.lower())
    limit = int(match_limit.group(1)) if match_limit else None

    # Extract time range (very simple version)
    if "q2 2024" in prompt.lower():
        start_date = "2024-04-01"
        end_date = "2024-06-30"
    else:
        start_date, end_date = None, None

    return {
        "intent": intent,
        "limit": limit,
        "date_range": [start_date, end_date],
        "metrics": ["sales_volume"],
        "group_by": ["product_id"],
        "table": "sales_data"
    }

parsed_info = extract_query_info(user_query)
pprint(parsed_info)

{'date_range': [None, None],
 'group_by': ['product_id'],
 'intent': 'unknown',
 'limit': None,
 'metrics': ['sales_volume'],
 'table': 'sales_data'}


In [9]:
query_plan = {
    "original_user_query": user_query,
    "intent": parsed_info["intent"],
    "filters": [
        {
            "field": "sale_date",
            "operator": "between",
            "value": parsed_info["date_range"]
        }
    ],
    "metrics": [
        {
            "name": "sales_volume",
            "aggregation": "sum",
            "alias": "total_sales"
        }
    ],
    "group_by": parsed_info["group_by"],
    "sort": [
        {
            "field": "total_sales",
            "order": "desc"
        }
    ],
    "limit": parsed_info["limit"],
    "table": parsed_info["table"]
}

pprint(query_plan)


{'filters': [{'field': 'sale_date',
              'operator': 'between',
              'value': [None, None]}],
 'group_by': ['product_id'],
 'intent': 'unknown',
 'limit': None,
 'metrics': [{'aggregation': 'sum',
              'alias': 'total_sales',
              'name': 'sales_volume'}],
 'original_user_query': 'Show me the last 15 products with highest sales in '
                        'March 2024',
 'sort': [{'field': 'total_sales', 'order': 'desc'}],
 'table': 'sales_data'}


In [10]:
def build_sql(plan: Dict) -> str:
    metric = plan["metrics"][0]
    date_start, date_end = plan["filters"][0]["value"]
    sql = f"""
    SELECT {plan['group_by'][0]}, SUM({metric['name']}) AS {metric['alias']}
    FROM {plan['table']}
    WHERE sale_date BETWEEN '{date_start}' AND '{date_end}'
    GROUP BY {plan['group_by'][0]}
    ORDER BY {metric['alias']} DESC
    LIMIT {plan['limit']};
    """.strip()
    return sql

print(build_sql(query_plan))


SELECT product_id, SUM(sales_volume) AS total_sales
    FROM sales_data
    WHERE sale_date BETWEEN 'None' AND 'None'
    GROUP BY product_id
    ORDER BY total_sales DESC
    LIMIT None;
