In [1]:
import sqlite3
import json
import os

# Path to your SQLite DB file
db_path = "data/grocery_sales.db"
db_name = os.path.splitext(os.path.basename(db_path))[0]

# Connect to the database
conn = sqlite3.connect(db_path)
cursor = conn.cursor()

# Get all user-defined tables
cursor.execute("""
    SELECT name FROM sqlite_master
    WHERE type='table' AND name NOT LIKE 'sqlite_%';
""")
tables = cursor.fetchall()

# Build schema dictionary with column names and data types
schema_info = {}

for table in tables:
    table_name = table[0]
    cursor.execute(f"PRAGMA table_info('{table_name}')")
    columns = cursor.fetchall()
    column_info = {col[1]: col[2] for col in columns}  # {column_name: data_type}
    schema_info[table_name] = column_info

# Convert to JSON
schema_json = json.dumps(schema_info, indent=2)

json_path = db_name + ".json"

# Optional: write to file

with open(json_path, "w") as f:
    f.write(schema_json)

# Cleanup
cursor.close()
conn.close()

print("Schema exported to JSON with column names and data types.")


Schema exported to JSON with column names and data types.


In [None]:
import openai
from langchain.chat_models import ChatOpenAI


groq_api_key = "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"
groq_api_base = "https://api.groq.com/openai/v1"
model_name = "llama3-70b-8192"
llm = ChatOpenAI(
        model_name=model_name,
    temperature=0.0,
    openai_api_key=groq_api_key,
    openai_api_base=groq_api_base
)

  llm = ChatOpenAI(


In [None]:
# Importing necessary libraries for file handling, serialization, vector search, and embeddings
import os
import pickle
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.docstore.document import Document

# Setting constants for the schema index directory and embedding model to use
SCHEMA_INDEX_DIR = "schema_index"
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"

# -----------------------------------------------
# Build or Load Schema Embeddings
# -----------------------------------------------

def build_schema_index(schema_dict, save_path=SCHEMA_INDEX_DIR):
    # Initializing the Hugging Face sentence embedding model
    embedding_model = HuggingFaceEmbeddings(model_name=MODEL_NAME)
    
    # Flattening the schema into "table.column" format for embedding
    schema_lines = [f"{table}.{col}" for table, cols in schema_dict.items() for col in cols]
    # Wrapping each schema line into a Document object
    documents = [Document(page_content=line) for line in schema_lines]

    # Creating a FAISS vector index from the embedded documents
    db = FAISS.from_documents(documents, embedding_model)

    # Creating the directory if it doesn't exist and saving the FAISS index
    os.makedirs(save_path, exist_ok=True)
    db.save_local(folder_path=save_path)
    print(f"Schema index saved to: {save_path}")

def load_schema_index(load_path=SCHEMA_INDEX_DIR):
    # Initializing the same embedding model to match the one used in indexing
    embedding_model = HuggingFaceEmbeddings(model_name=MODEL_NAME)
    # Loading the FAISS index from the local directory
    return FAISS.load_local(
        folder_path=load_path,
        embeddings=embedding_model,
        allow_dangerous_deserialization=True  # Needed if using pickle under the hood
    )

# -----------------------------------------------
# Query Similar Schema Entries
# -----------------------------------------------

def search_schema(index, query, top_k=5):
    # Performing a similarity search on the index for the given query
    results = index.similarity_search(query, k=top_k)
    # Returning the matched schema entries as plain text
    return [doc.page_content for doc in results]


In [None]:
# Opening the 'grocery_sales.json' file in read mode
with open('grocery_sales.json', 'r') as file:
    # Loading the JSON content into a Python dictionary
    schema_grocery_sales = json.load(file)

In [None]:
# Creating the schema of the DataBase and First-time build (comment this after first run)
build_schema_index(schema_grocery_sales)


  embedding_model = HuggingFaceEmbeddings(model_name=MODEL_NAME)





  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'
Schema index saved to: schema_index


In [None]:
from langchain.prompts import PromptTemplate

def llm_correct_prompt(query: str, alias_dict: dict, llm) -> str:
    """
    Uses LLM to correct typos/shortforms in the query using alias_dict/schema.
    Returns the corrected query string only.
    """

    # Converting the alias dictionary to a pretty-printed JSON string for context
    db_schema_str = json.dumps(alias_dict, indent=2)

    # Defining the prompt template for the LLM to correct the user query
    prompt_temp = PromptTemplate.from_template("""
You are a Data Assistant and SQL expert. Your task is to correct user prompts for spelling mistakes or shortforms based on the database schema.

The schema is in the format:
"Table name": {{
  "Column name": "datatype",
  "Column name": "datatype"    
}}

User prompt:
{query}

Database schema:
{db_schema}

Return your response in **only** this JSON format:
{{ "corrected_query": "..." }}
""")

    # Filling in the query and schema into the prompt
    prompt = prompt_temp.format(query=query, db_schema=db_schema_str)
    
    # Sending the prompt to the LLM and capturing its response
    response = llm.invoke(prompt)

    # Stripping extra whitespace from the LLM's raw output
    raw_output = response.content.strip()
    print("🔍 Raw LLM Output:\n", raw_output)

    try:
        # Removing markdown code block markers if present
        cleaned_output = re.sub(r"```(json)?", "", raw_output).strip()

        # Extracting the JSON content from the response
        json_match = re.search(r'\{.*\}', cleaned_output, re.DOTALL)
        if not json_match:
            raise ValueError("No valid JSON found in response")

        json_str = json_match.group(0)
        parsed = json.loads(json_str)

        # Returning only the corrected query from the parsed JSON
        return parsed["corrected_query"]

    except Exception as e:
        # Handling cases where the LLM's output can't be parsed
        print("❌ Could not parse LLM output properly.")
        raise e


In [12]:
from collections import defaultdict
def format_tableschema(raw_input):

    # Initialize a dictionary with list values
    table_columns = defaultdict(set)

    # Process each line
    for line in raw_input.strip().splitlines():
        table, column = line.strip().split('.')
        table_columns[table].add(column)

    # Convert sets to sorted lists
    formatted_output = {table: sorted(list(columns)) for table, columns in table_columns.items()}

    # Print result
    # import json
    # print(json.dumps(formatted_output, indent=4))


    return formatted_output


In [None]:
from collections import defaultdict

def formating_columns_for_input(column_list):
    """
    Converts a list of 'table.column' strings into a dictionary grouped by table name.

    Args:
        column_list (list): List of strings in the format 'table.column'

    Returns:
        dict: Dictionary with table names as keys and lists of column names as values
    """

    # Initializing a dictionary where each key maps to a list of columns
    table_columns = defaultdict(list)
    
    # Iterating through each table.column string in the list
    for item in column_list:
        # Checking that the string is properly formatted with a dot separator
        if '.' in item:
            # Splitting into table name and column name
            table, column = item.split('.', 1)
            # Grouping the column under the appropriate table
            table_columns[table].append(column)
    
    # Converting the defaultdict to a regular dictionary and returning it
    return dict(table_columns)


In [None]:
from langchain.sql_database import SQLDatabase
import sqlite3
import re
from langchain.prompts import PromptTemplate
from langchain.chains.sql_database.query import create_sql_query_chain

def query_generation(given_prompt, db, llm, schema_dict):

    db_path = "data/grocery_sales.db"

    # Step 1: Correcting spelling mistakes or shortforms in the prompt using the schema
    corrected_prompt = llm_correct_prompt(given_prompt, schema_dict, llm)

    # Step 2: Loading the FAISS index and retrieving the most relevant schema entries
    index = load_schema_index()
    matches = search_schema(index, corrected_prompt)

    # Formatting matched schema entries into table-wise structure
    table_schema_formatted = formating_columns_for_input(matches)

    # Step 3: Creating a custom prompt to instruct the LLM on how to write SQL queries
    custom_prompt = PromptTemplate.from_template("""
    *** You are an expert at writing SQL queries based on natural language questions.
    There can be multiple tables in the database, and you can use any of them to answer the question. You can also join tables if needed. You can also use aggregate functions like COUNT, SUM, AVG, etc. to answer the question. You can also use GROUP BY and ORDER BY clauses if needed.
    You have full liberty to use any SQL functions or clauses to answer the question. You can also use subqueries if needed. You can also use DISTINCT keyword if needed. You can also use WHERE clause to filter the results. You can also use HAVING clause to filter the results after aggregation. You can also use LIMIT clause to limit the number of results returned.
    In short you can do anything to answer the question. ***

    Use the following table schema:()
    {table_info}

    The user wants to retrieve up to {top_k} results, but **do not add a LIMIT clause unless explicitly instructed**.

    Question: {input}

    Return the SQL query in the following format:

    SQLQuery: <your SQL query here>;
    """)

    # Step 4: Building the LangChain SQL query generation chain using the LLM and custom prompt
    sql_chain = create_sql_query_chain(llm=llm, db=db, prompt=custom_prompt)

    # Step 5: Sending the corrected prompt and schema to the chain to generate a query
    write_query = sql_chain.invoke({
        "input": corrected_prompt,
        "top_k": 250,
        "table_info": table_schema_formatted,
        "question" : corrected_prompt
    })

    # Step 6: Extracting the SQL statement from the LLM's response
    def extract_sql(query_response: str) -> str:
        match = re.search(r"SELECT .*?;", query_response, re.DOTALL | re.IGNORECASE)
        return match.group(0).strip() if match else ""

    cleaned_query = extract_sql(write_query)
    print("Formatted query:", cleaned_query)

    # Step 7: Running the generated SQL query on the database
    result = db.run(cleaned_query)

    # Step 8: Preparing a prompt to ask the LLM for a human-readable answer based on SQL output
    answer_prompt = PromptTemplate.from_template(
        """Given the following some instructions in *** ***,  user question, corresponding SQL query, and SQL result, answer the user question. 
        *** Don't sumarize the SQL result, until asked to do explicitly. ***
        *** If there is any error in the SQL query, intepret it and answer the user accordingly without telling that there is any error in the system. ***
        *** If there is no result or vague result then simple say no details found.***
    Question: {question}    
    SQL Query: {query}
    SQL Result: {result}
    Answer:"""
    )

    # Filling in the question, query, and result into the answer prompt
    final_answer_prompt = answer_prompt.format(question=corrected_prompt, query=cleaned_query, result=result)

    # Invoking the LLM to get the final natural language answer
    response = llm.invoke(final_answer_prompt)

    # Extracting and returning just the content part of the LLM response
    content = response.content
    return content


In [None]:
groc_db_path = rf"D:\i2e Internship projects\New ChatBot\data\grocery_sales.db"     # Grocery Database Path
groc_db = SQLDatabase.from_uri(f"sqlite:///{groc_db_path}")        # Loading the database

In [None]:
given_prompt  = "Which prdct has highest sal"    # Giving the prompt  
out = query_generation(given_prompt, groc_db, llm, schema_grocery_sales)  # Calling the function 

🔍 Raw LLM Output:
 { "corrected_query": "Which product has the highest sale" }
Formatted query: SELECT p.ProductName 
FROM sales s 
JOIN products p ON s.ProductID = p.ProductID 
GROUP BY p.ProductName 
ORDER BY SUM(TotalPrice) DESC;


In [34]:
out

'The product with the highest sale is Zucchini - Yellow.'