In [3]:
import os
import re
import time
import json
import jellyfish
import openai
from operator import itemgetter
from typing import Optional, Union, Dict, Any, List

from dotenv import load_dotenv
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough, RunnableLambda

import clickhouse_connect

load_dotenv()

True

Functions for Extracting Valid Values from the Database

In [4]:
def query_as_list(db, query):
    """ Execute a SQL query and return a list of cleaned unique elements. """
    res = db.run(query)
    res = [el[0] for el in res if el and el[0]]  # Extract first element from tuple
    res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]  # Remove numeric values
    return list(set(res))

In [5]:
def get_valid_values(db):
    """Extract valid values for brand, l1_category, and l3_category from the database."""
    valid_brand = []
    valid_brand += query_as_list(db, "SELECT DISTINCT brand FROM product_view")
    valid_brand += query_as_list(db, "SELECT DISTINCT brand FROM product_wishlist_add")
    valid_brand += query_as_list(db, "SELECT DISTINCT brand FROM product_wishlist_remove")
    valid_brand += query_as_list(db, "SELECT DISTINCT brand FROM add_to_cart")
    valid_brand += query_as_list(db, "SELECT DISTINCT brand FROM remove_from_cart")
    valid_brand = list(set(valid_brand))
    
    valid_l1 = query_as_list(db, "SELECT DISTINCT l1_category FROM product_view")
    valid_l1 = list(set(valid_l1))
    
    valid_l3 = []
    valid_l3 += query_as_list(db, "SELECT DISTINCT l3_category FROM product_view")
    valid_l3 += query_as_list(db, "SELECT DISTINCT l3_category FROM product_wishlist_add")
    valid_l3 += query_as_list(db, "SELECT DISTINCT l3_category FROM product_wishlist_remove")
    valid_l3 += query_as_list(db, "SELECT DISTINCT l3_category FROM add_to_cart")
    valid_l3 += query_as_list(db, "SELECT DISTINCT l3_category FROM remove_from_cart")
    valid_l3 = list(set(valid_l3))
    
    return {
        "brand": valid_brand,
        "l1_category": valid_l1,
        "l3_category": valid_l3,
    }

Fuzzy Category Mapping and Correction

In [6]:
def get_best_correction(candidate, valid_values, threshold=0.85):
    """
    Returns a tuple (best_match, best_score) for the candidate compared to a list of valid values.
    """
    best_match = candidate
    best_score = 0
    for val in valid_values:
        score = jellyfish.jaro_winkler_similarity(candidate.lower(), val.lower())
        if score > best_score:
            best_score = score
            best_match = val
    return best_match, best_score

def process_l1_l3_candidates(extracted_l1, extracted_l3, valid_l1, valid_l3, threshold=0.85):
    """
    Processes ambiguous candidates for l1 and l3.
    If one list is empty, duplicate its candidate(s) from the other.
    For each candidate in the union, compute its best correction against both valid lists
    and choose the correction with the higher similarity score.
    Returns a mapping: candidate -> {"value": final_correction, "column": "L1_category" or "L3_category", "score": <score>}.
    """
    if not extracted_l1 and extracted_l3:
        extracted_l1 = extracted_l3.copy()
    if not extracted_l3 and extracted_l1:
        extracted_l3 = extracted_l1.copy()

    final_corrections = {}
    all_candidates = set(extracted_l1).union(set(extracted_l3))
    for candidate in all_candidates:
        corr_l1, score_l1 = get_best_correction(candidate, valid_l1, threshold)
        corr_l3, score_l3 = get_best_correction(candidate, valid_l3, threshold)
        if score_l1 >= score_l3:
            final_corrections[candidate] = {"value": corr_l1, "column": "L1_category", "score": score_l1}
        else:
            final_corrections[candidate] = {"value": corr_l3, "column": "L3_category", "score": score_l3}
    return final_corrections

def process_extracted_categories(extracted_dict, combined_valid_values, threshold=0.85):
    """
    Processes extracted category values.
    For 'brand', uses normal fuzzy matching.
    For ambiguous l1_category and l3_category values, returns a mapping (candidate -> correction mapping).
    The result is a dictionary with keys "brand" and "l1_l3".
    """
    corrected = {}
    # Process brand normally.
    if "brand" in extracted_dict:
        corrected_brand = []
        for val in extracted_dict["brand"]:
            corr, _ = get_best_correction(val, combined_valid_values.get("brand", []), threshold)
            corrected_brand.append(corr)
        corrected["brand"] = corrected_brand

    # Process ambiguous l1 and l3 candidates.
    extracted_l1 = extracted_dict.get("l1_category", [])
    extracted_l3 = extracted_dict.get("l3_category", [])
    if not extracted_l1 and extracted_l3:
        extracted_l1 = extracted_l3.copy()
    if not extracted_l3 and extracted_l1:
        extracted_l3 = extracted_l1.copy()

    valid_l1 = combined_valid_values.get("l1_category", [])
    valid_l3 = combined_valid_values.get("l3_category", [])
    
    l1l3_mapping = process_l1_l3_candidates(extracted_l1, extracted_l3, valid_l1, valid_l3, threshold)
    corrected["l1_l3"] = l1l3_mapping
    return corrected

def fuzzy_prompt_replace(prompt, extracted_value, corrected_value, threshold=0.8):
    """
    Searches the prompt for an n-gram (n = number of words in extracted_value)
    that is similar to extracted_value using Jaro–Winkler similarity.
    If a match above threshold is found, replaces that substring with corrected_value.
    """
    tokens = prompt.split()
    n = len(extracted_value.split())
    best_index = None
    best_score = 0
    best_substring = None
    for i in range(len(tokens) - n + 1):
        candidate = " ".join(tokens[i:i+n])
        score = jellyfish.jaro_winkler_similarity(candidate.lower(), extracted_value.lower())
        if score > best_score:
            best_score = score
            best_index = i
            best_substring = candidate
    if best_score >= threshold and best_substring:
        new_tokens = tokens[:best_index] + [corrected_value] + tokens[best_index+n:]
        return " ".join(new_tokens)
    else:
        return prompt

def extract_categories_from_prompt(prompt):
    """
    Uses LangChain's ChatOpenAI to extract categorical values (brand, l1_category, l3_category)
    from the user prompt. The output is expected to be a JSON dictionary.
    """
    system_message = (
        "You are a helpful assistant that extracts categorical values from user prompts. "
        "Given a prompt, output a JSON object with keys 'brand', 'l1_category', and 'l3_category'. "
        "Each key should map to a list of values found in the prompt. If a category is not mentioned, "
        "output an empty list for that key. For example, if the prompt is: "
        "'get all the users who wishlied shirts from catawalk brand products and added to cart but did not checkout', "
        "then you should output: {\"brand\": [\"Catwalk\"], \"l1_category\": [\"Shirts\"], \"l3_category\": []}. "
        "Output only a valid JSON object with no extra text. "
        "make sure ```json  ``` does not come in the response "
    )
    user_message = f"Extract categories from the following prompt in JSON format:\n\n{prompt}"
    llm = ChatOpenAI(model="gpt-4o", temperature=0)
    response = llm.invoke(f"{system_message}\n\n{user_message}")
    try:
        extracted = response.content.strip()
        extracted_dict = json.loads(extracted)
    except Exception as e:
        print("Error extracting categories from prompt:", e)
        print("Raw response:", response.content if hasattr(response, "content") else response)
        extracted_dict = {}
    return extracted_dict

def correct_prompt(original_prompt, extracted_dict, corrected_mapping, threshold=0.8):
    """
    Updates the original prompt using fuzzy replacement.
    For 'brand' and for each candidate in the ambiguous l1/l3 mapping,
    finds the best matching n-gram in the prompt and replaces it with the final corrected value,
    appending the column name in parentheses.
    """
    corrected_prompt = original_prompt
    # Process brand replacement.
    for orig, corr in zip(extracted_dict.get("brand", []), corrected_mapping.get("brand", [])):
        replacement = f"{corr} (brand)"
        corrected_prompt = fuzzy_prompt_replace(corrected_prompt, orig, replacement, threshold)
    # Process ambiguous l1/l3 candidates.
    for candidate, mapping in corrected_mapping.get("l1_l3", {}).items():
        replacement = f"{mapping['value']} ({mapping['column']})"
        corrected_prompt = fuzzy_prompt_replace(corrected_prompt, candidate, replacement, threshold)
    return corrected_prompt

Clickhouse COnnection

In [None]:
class ClickHouseSQLDatabase:
    def __init__(self):
        """Initialize ClickHouse connection."""
        self.host = os.getenv("host")
        self.port = int(os.getenv("port", 8123))  # Default ClickHouse HTTP port
        self.user = os.getenv("user", "default")
        self.password = os.getenv("password", "")
        self.database = os.getenv("database", "default")
        os.environ['OPENAI_API_KEY'] = os.getenv("OPEN_AI_API_KEY")

        # Create ClickHouse client
        self.client = clickhouse_connect.get_client(
            host=self.host, port=self.port, username=self.user, password=self.password, database=self.database
        )

        # Load LLM for query correction
        self.llm = ChatOpenAI(model="gpt-4o", temperature=0, top_p=0.1)

    def get_usable_table_names(self):
        """Fetches the list of tables in the ClickHouse database."""
        result = self.client.query("SHOW TABLES")
        return [row[0] for row in result.result_set]

    def get_table_info(self, table_names=None):
        """
        Fetches schema details for the specified tables.
        If no table_names are provided, fetches schema for all tables.
        """
        if table_names is None:
            table_names = self.get_usable_table_names()

        table_info = ""
        for table in table_names:
            query = f"""
            SELECT name, type 
            FROM system.columns 
            WHERE database = '{self.database}' AND table = '{table}'
            """
            result = self.client.query(query)
            columns = [f"{row[0]} ({row[1]})" for row in result.result_set]
            table_info += f"Table: {table}\nColumns: {', '.join(columns)}\n\n"

        return table_info.strip()

    def run(self, query: str, retries: int = 3, fetch="all", include_columns=False):
        """
        Executes a SQL query with up to 3 retries using LLM for query correction.

        Args:
            query (str): The SQL query to execute.
            retries (int): Number of retries allowed for query correction.
            fetch (str): "all" for all rows, "one" for a single row, "many" for first 10 rows.
            include_columns (bool): If True, return result as list of dicts with column names.

        Returns:
            list | str: Query results or error message.
        """
        for attempt in range(retries):
            try:
                # Execute SQL query
                result = self.client.query(query)

                # Format the output
                if fetch == "one":
                    output = result.result_set[0] if result.result_set else None
                elif fetch == "many":
                    output = result.result_set[:10]
                else:
                    output = result.result_set  # Default: all rows

                # Include column names if required
                if include_columns:
                    columns = [col[0] for col in result.columns]
                    output = [dict(zip(columns, row)) for row in output]

                return output  # ✅ Successfully executed

            except Exception as e:
                error_message = str(e)
                print(f"⚠️ Query failed on attempt {attempt + 1}: {error_message}")

                if attempt < retries - 1:
                    # Regenerate query with LLM using the same system prompt
                    query = self.regenerate_query_with_llm(query, error_message)
                    print(f"🔄 Retrying with corrected query:\n{query}\n")
                    time.sleep(1)  # Small delay before retrying
                else:
                    print("❌ Max retries reached. Returning error message.")
                    return f"Query failed after {retries} attempts: {error_message}"

    def regenerate_query_with_llm(self, query, error_message):
        """
        Uses the same system prompt to correct SQL queries based on errors.
        Now instructs the LLM to respect column data types when generating comparisons.

        Note: Ensure that string columns are enclosed in quotes while numeric columns are not.
        """
        system_prompt = f"""
        You are a SQL expert. Given an input question or an incorrect SQL query, generate a syntactically 
        correct SQL query that runs successfully on ClickHouse. 

        **Instructions:**
        - If an error message is provided, analyze it and correct the query accordingly.
        - You will just return the **Clickhouse executable** SQL query and make sure not to add any markdown formatting.
        - Do NOT remove or alter the filter: application_id = '{os.getenv("application_id")}'.
        - Ensure all table and column names exist in the provided schema.
        - When filtering or comparing values, pay close attention to the column data types:
            - Enclose values in single quotes if the column type is a string.
            - Do not use quotes if the column type is numeric.
        - If filtering or comparing values from an array column, use `arraySum()` or `arrayJoin()`.
        - Do not return explanations; only return the corrected SQL query.

        **Database Schema:**
        {self.get_table_info()}

        **Error Message:**
        {error_message}

        **Incorrect Query:**
        {query}

        Now, generate a correct SQL query.
        """

        corrected_query = self.llm.invoke(system_prompt)
        return corrected_query.content.strip() # ✅ Returns corrected SQL query

    @property
    def dialect(self):
        """Returns a string representing the database dialect."""
        return "clickhouse"

Initiate CLickhouse instance

In [8]:
db = ClickHouseSQLDatabase()

Get the names of the brands, L1 and L3 categories

In [9]:
combined_valid_values = get_valid_values(db)

In [10]:
# Extract proper nouns from relevant tables
proper_nouns = []
proper_nouns += query_as_list(db, "SELECT DISTINCT brand FROM product_view")
proper_nouns += query_as_list(db, "SELECT DISTINCT l1_category FROM product_view")
proper_nouns += query_as_list(db, "SELECT DISTINCT l3_category FROM product_view")
proper_nouns += query_as_list(db, "SELECT DISTINCT brand FROM product_wishlist_add")
proper_nouns += query_as_list(db, "SELECT DISTINCT l3_category FROM product_wishlist_add")
proper_nouns += query_as_list(db, "SELECT DISTINCT brand FROM product_wishlist_remove")
proper_nouns += query_as_list(db, "SELECT DISTINCT l3_category FROM product_wishlist_remove")
proper_nouns += query_as_list(db, "SELECT DISTINCT brand FROM add_to_cart")
proper_nouns += query_as_list(db, "SELECT DISTINCT l3_category FROM add_to_cart")
proper_nouns += query_as_list(db, "SELECT DISTINCT brand FROM remove_from_cart")
proper_nouns += query_as_list(db, "SELECT DISTINCT l3_category FROM remove_from_cart")

proper_nouns = list(set(proper_nouns))
print(f"Extracted {len(proper_nouns)} unique proper nouns")

Extracted 1731 unique proper nouns


Query Correction

In [11]:
sql_agent_llm = ChatOpenAI(model="gpt-4o", temperature=0)

def correct_query_using_llm(query, proper_nouns):
    """
    Uses an LLM to correct misspelled words in a query based on known database terms.
    """
    system_prompt = f"""
    You are an expert in natural language correction for SQL queries. Your job is to correct any misspelled words 
    in the given question while preserving its original meaning. 

    - Only correct words if they appear to be misspellings of a known brand, product, or category.
    - Do NOT change common words like "how," "many," "added," "to," etc.
    - Ensure that the corrected query is as close as possible to the original.
    - Use the provided list of valid words to make corrections.

    List of valid brand names and categories: {proper_nouns}
    """

    corrected_query = sql_agent_llm.invoke(f"{system_prompt}\n\nQuery: {query}\nCorrected Query:")
    return corrected_query.content.strip()

Enforcing Application ID

In [12]:
def enforce_application_id_filter(sql_query: str, application_id: str) -> str:
    """
    Ensures that the SQL query includes a filter for `application_id`.
    """
    application_filter = f"application_id = '{application_id}'"

    # Check if WHERE clause exists
    if re.search(r"\bWHERE\b", sql_query, re.IGNORECASE):
        # Append to existing WHERE clause
        sql_query = re.sub(r"(\bWHERE\b)", r"\1 " + application_filter + " AND", sql_query, flags=re.IGNORECASE)
    else:
        # Add new WHERE clause
        sql_query = re.sub(r"(\bFROM\b\s+\w+)", r"\1 WHERE " + application_filter, sql_query, flags=re.IGNORECASE)

    return sql_query

Define the SQL Chain

In [13]:
def create_sql_query_chain(
    llm: ChatOpenAI,
    db: ClickHouseSQLDatabase,
    prompt: Optional[BasePromptTemplate] = None,
    k: int = 5,
) -> Runnable:
    """
    Creates a LangChain runnable that generates SQL queries, enforces application_id filtering,
    and applies retries for errors.
    """
    if prompt is None:
        raise ValueError("A valid SQL generation prompt must be provided.")

    application_id = os.getenv("application_id")
    if not application_id:
        raise ValueError("Missing application_id. Ensure it is set in the environment variables.")

    inputs = {
        "input": lambda x: x["question"] + "\nSQLQuery: ",
        "table_info": lambda x: db.get_table_info(table_names=x.get("table_names_to_use")),
        "application_id": lambda _: application_id,
    }

    return (
        RunnablePassthrough.assign(**inputs)
        | prompt.partial(top_k=str(k))
        | llm.bind(stop=["\nSQLResult:"])
        | StrOutputParser()
        | (lambda query: enforce_application_id_filter(query, application_id))
    )

In [None]:
system_prompt = """You are a SQL expert. Given an input question, create a syntactically \
correct CLICKHOUSE executable SQL query to run. Unless otherwise specified, do not return more than \
{top_k} rows.

While making the query take into consideration the proper query given by = {corrected_query}

You can join relevant tables to get the best query possible.
Here is the relevant table info: {table_info}

Ensure that every query includes the filter `application_id = '{application_id}'`.

Follow the below instructions:

1) You will just return the SQL query and make sure not to add ```sql before or after the query.
2) Make sure to generate a **Clickhouse executable** SQL query
3) If filtering or comparing values from an array column, use `arraySum()` or `arrayJoin()`.
4) When comparing or filtering values, use the column type information from the schema:
   - Enclose string values in single quotes.
   - Do not enclose numeric values in quotes.
5) In the user prompt you are given which column to consider in the brackets of each values, for example Loungewear (L3_category) from Raymond (brand).
6) For checking if a value exists in an array, use `has(array, value)` instead of `ANY()`
7) For applying filters or checking values convert everything into lowercase, for example lower(brand) = lower(Gucci)
8) please dont use any tables anme staring with ".inner_id."in the query. 
"""

prompt = ChatPromptTemplate.from_messages([("system", system_prompt), ("human", "{input}")])

query_chain = create_sql_query_chain(sql_agent_llm, db, prompt=prompt, k=15)


New Full Correction Function (Integrating Category Mapping)


In [15]:
def full_correction(query):
    # Step 1: Correct misspellings using the original LLM-based method.
    corrected_query = correct_query_using_llm(query, proper_nouns)
    
    # Step 2: Extract categories from the corrected query.
    extracted = extract_categories_from_prompt(corrected_query)
    if not extracted.get("l1_category") and extracted.get("l3_category"):
        extracted["l1_category"] = extracted["l3_category"].copy()
    if not extracted.get("l3_category") and extracted.get("l1_category"):
        extracted["l3_category"] = extracted["l1_category"].copy()
    
    # Step 3: Process the extracted categories using the valid values from the DB.
    mapping = process_extracted_categories(extracted, combined_valid_values, threshold=0.85)
    
    # Step 4: Use fuzzy replacement to correct the prompt with proper category mapping.
    final_query = correct_prompt(corrected_query, extracted, mapping, threshold=0.8)
    return final_query


In [16]:
correction_chain = (
    RunnableLambda(lambda input: input["question"])  # Extract the original question.
    | RunnableLambda(lambda query: full_correction(query))  # Full correction: spelling + category mapping.
    | RunnableLambda(lambda corrected: {"question": corrected, "corrected_query": corrected})
)

In [None]:
chain = RunnablePassthrough.assign(
    corrected_query=correction_chain,
    table_info=lambda input: db.get_table_info()
) | query_chain

In [26]:
question = "i want to see the top 100 users who spent more than 2000 on tshirts, chinas, loungewaer, casaual sherts , footwer , clothing together"

In [27]:
corrected_query = correction_chain.invoke({"question": question})

In [28]:
print(corrected_query)

{'question': 'I want to see the top 100 users who spent more than 2000 on T-Shirts (L3_category) Chinos (L3_category) Loungewear (L3_category) Casual Shirts (L3_category) Footwear (L1_category) Clothing (L1_category) together.', 'corrected_query': 'I want to see the top 100 users who spent more than 2000 on T-Shirts (L3_category) Chinos (L3_category) Loungewear (L3_category) Casual Shirts (L3_category) Footwear (L1_category) Clothing (L1_category) together.'}


In [29]:
result = chain.invoke({"question": question})
print(result)

SELECT 
    user_id, 
    SUM(price * quantity) AS total_spent
FROM 
    add_to_cart
WHERE application_id = '000000000000000000000001' AND 
    lower(l3_category) IN ('t-shirts', 'chinos', 'loungewear', 'casual shirts') 
    OR lower(l1_category) IN ('footwear', 'clothing')
    AND application_id = '000000000000000000000001'
GROUP BY 
    user_id
HAVING 
    total_spent > 2000
ORDER BY 
    total_spent DESC
LIMIT 100;


In [25]:
db.run(result)

[('', 111634.0),
 ('679f96e219b98690a755e180', 5975.0),
 ('67cc12f427c572dbdc6a33c7', 2898.0),
 ('67ab38d1c9329259c2438419', 2249.0)]

In [30]:
question = "i want to see the top 100 users who spent more than 2000 in the last 7 days."

In [31]:
result = chain.invoke({"question": question})
print(result)

SELECT user_id, SUM(order_value) AS total_spent
FROM sales_summary
WHERE application_id = '000000000000000000000001' AND order_created_at >= now() - INTERVAL 7 DAY
AND total_order_value > 2000
AND application_id = '000000000000000000000001'
GROUP BY user_id
ORDER BY total_spent DESC
LIMIT 100;
