## Imports

In [None]:
import json
import spacy
from spacy.tokens import DocBin

import re
import nltk
from nltk.tokenize import word_tokenize
# from nltk.corpus import stopwords
# from nltk.stem import PorterStemmer
from nltk.stem import WordNetLemmatizer
import pandas as pd
from difflib import get_close_matches
from thefuzz import fuzz
import networkx as nx
import heapq
from collections import defaultdict
from ordered_set import OrderedSet
from collections import Counter
from word2number import w2n

In [2]:
nltk.download('punkt')
nltk.download('punkt_tab')
# nltk.download('stopwords')
nltk.download('wordnet')

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/vedanttibrewal/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /Users/vedanttibrewal/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/vedanttibrewal/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [3]:
lemmatizer = WordNetLemmatizer()

## Sample questions

In [4]:
SQL = {
    "products": {
        "pk": ["product_id"],
        "fk": {"categories": "category_id"},
        "product_id": "INT",
        "product_name": "VARCHAR(255)",
        "description": "TEXT",
        "price": "DECIMAL(10, 2)",
        "stock_quantity": "INT",
        "category_id": "INT"
    },
    "categories": {
        "pk": ["category_id"],
        "fk": {},
        "category_id": "INT",
        "category_name": "VARCHAR(100)",
        "description": "TEXT"
    },
    "customers": {
        "pk": ["customer_id"],
        "fk": {},
        "customer_id": "INT",
        "first_name": "VARCHAR(50)",
        "last_name": "VARCHAR(50)",
        "email": "VARCHAR(100)",
        "phone_number": "VARCHAR(20)",
        "address": "TEXT"
    },
    "orders": {
        "pk": ["order_id"],
        "fk": {
            "customers": "customer_id"
        },
        "order_id": "INT",
        "customer_id": "INT",
        "order_date": "DATE",
        "total_amount": "DECIMAL(10, 2)",
        "status": "VARCHAR(20)"
    },
    "order_items": {
        "pk": ["order_item_id"],
        "fk": {
            "orders": "order_id",
            "products": "product_id"
        },
        "order_item_id":"INT",
        "order_id": "INT",
        "product_id": "INT",
        "quantity": "INT",
        "unit_price": "DECIMAL(10, 2)"
    }
}

req = {
    "products": [
        "product_id",
        "product_name",
        "price"
    ],
    "customers": [
        "customer_id",
        "first_name",
        "last_name",
        "email"
    ],
    "orders": [
        "order_id",
        "customer_id",
        "order_date",
        "total_amount"
    ]
}

In [992]:
#Healthcare Schema

SQL = {
    "admissions": {
        "pk": ["admissionid"],
        "fk": {
            "patients": "patientid",
            "insurance": "insuranceid"
        },
        "admissionid": "INT",
        "patientid": "INT",
        "insuranceid": "INT",  
        "doctor": "VARCHAR(255)",
        "hospital": "VARCHAR(255)",
        "intakedate": "DATE",
        "dischargedate": "DATE",
        "roomnumber": "INT",
        "carelevel": "VARCHAR(50)"
    },

    "patients": {
        "pk": ["patientid"],
        "fk": {},
        "patientid": "INT",
        "patientname": "VARCHAR(255)",
        "age": "INT",
        "gender": "VARCHAR(10)",
        "bloodtype": "VARCHAR(5)",
        "disease": "VARCHAR(255)"
    },

    "insurance": {
        "pk": ["insuranceid"],
        "fk": {
            "patients": "patientid"
        },
        "insuranceid": "INT",
        "patientid": "INT",
        "insuranceprovider": "VARCHAR(255)",
        "billingcost": "DECIMAL(10, 2)",
        "medication": "VARCHAR(255)",
        "testresults": "VARCHAR(50)"
    }
}

In [638]:
# Basic COUNT query
q1 = "What is the total number of customers?"
sql1 = "SELECT COUNT(*) FROM customers;"

# Simple SELECT with WHERE clause
q2 = "List all products names with a price greater than $50."
sql2 = "SELECT product_name, price FROM products WHERE price > 50;"

# JOIN operation with ORDER BY
q3 = "Show the first name of customers who have placed orders, sorted by their last name."
sql3 = """
SELECT DISTINCT c.first_name, c.last_name 
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id
ORDER BY c.last_name;
"""

# Aggregation with GROUP BY
q4 = "What is the total amount of orders for each customer?"
sql4 = """
SELECT c.customer_id, c.first_name, c.last_name, SUM(o.total_amount) as total_spent
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id
GROUP BY c.customer_id, c.first_name, c.last_name;
"""

# HAVING clause
q5 = "Which categories have count more than 5 products?"
sql5 = """
SELECT c.category_name, COUNT(p.product_id) as product_count
FROM categories c
JOIN products p ON c.category_id = p.category_id
GROUP BY c.category_name
HAVING COUNT(p.product_id) > 5;
"""

# LIMIT clause
q6 = "What are the top 5 most priciest products?" # "What are the top 5 most expensive products?"
sql6 = """
SELECT product_name, price
FROM products
ORDER BY price DESC
LIMIT 5;
"""

# Aggregation in WHERE clause
q7 = "Find orders with more than 3 product items."
sql7 = """
SELECT o.order_id, COUNT(oi.product_id) as item_count
FROM orders o
JOIN order_items oi ON o.order_id = oi.order_id
GROUP BY o.order_id
HAVING COUNT(oi.product_id) > 3;
"""

# Multiple JOINs
q8 = "List the top 3 customers who have spent the most on 'Electronics' category products."
sql8 = """
SELECT c.customer_id, c.first_name, c.last_name, SUM(oi.quantity * oi.unit_price) as total_spent
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id
JOIN order_items oi ON o.order_id = oi.order_id
JOIN products p ON oi.product_id = p.product_id
JOIN categories cat ON p.category_id = cat.category_id
WHERE cat.category_name = 'Electronics'
GROUP BY c.customer_id, c.first_name, c.last_name
ORDER BY total_spent DESC
LIMIT 3;
"""

## Constants

In [348]:
stop_words = [
    'a', 'an', 'the',
    'i', 'me', 'my', 'mine', 'myself', 'you', 'your', 'yours', 'yourself', 'he', 'him', 'his', 'himself', 'she', 'her', 'hers', 'herself', 'it', 'its', 'itself', 'we', 'our', 'ours', 'ourselves', 'they', 'them', 'their', 'theirs', 'themselves',
    'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
    'do', 'does', 'did', 'have', 'has', 'had', 'can', 'could', 'shall', 'should', 'will', 'would', 'may', 'might', 'must',
    'about', 'across', 'after', 'against', 'along', 'around', 'at', 'before', 'behind', 'below', 'beneath', 'beside', 'during', 'into', 'near', 'off', 'out', 'through', 'toward', 'under', 'up', 'with',
    'if', 'then', 'because', 'as', 'until', 'while',
    'this', 'that', 'these', 'those', 'such', 'what', 'which', 'whose', 'whoever', 'whatever', 'whichever', 'whomever', 'either', 'neither', 'both',
    'very', 'really', 'always', 'never', 'too', 'already', 'often', 'sometimes', 'rarely', 'seldom', 'again', 'further', 'then', 'once', 'here', 'there', 'where', 'why', 'how',
    # 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'second', 'third', 'fourth', 'fifth', # find a way to implement it in limit
    'few', 'little', 'much', 'enough',
    'yes', 'no', 'not', 'okay', 'ok', 'right', 'sure', 'well', 'uh', 'um', 'oh', 'eh', 'hmm', 'just', 'ever', 'yet', 'etc', 'perhaps', 'maybe', 'list',
    'who', 'spent'
]

In [None]:


aggregate_functions = {
    'COUNT': ['count', 'number of', 'quantity of', 'total number of', 'tally', 'enumerate', 'how many'],
    'SUM': ['sum', 'total', 'aggregate', 'combined', 'add up', 'overall', 'cumulative'],
    'AVG': ['average', 'mean', 'typical', 'median', 'expected value', 'norm', 'central tendency'],
    'MAX': ['maximum', 'highest', 'most', 'top', 'peak', 'greatest', 'largest', 'biggest', 'uppermost'],
    'MIN': ['minimum', 'lowest', 'least', 'bottom', 'smallest', 'tiniest', 'least', 'floor'],
    'DISTINCT': ['unique', 'different', 'distinct', 'individual', 'separate', 'non-duplicate', 'exclusive'],
    # 'GROUP_CONCAT': ['concatenate', 'combine strings', 'join', 'merge text', 'string aggregation', 'text combination'],
    # 'FIRST': ['first', 'initial', 'earliest', 'primary', 'leading', 'opening', 'foremost'], # limit implementation along with number
    # 'LAST': ['last', 'final', 'latest', 'ultimate', 'concluding', 'terminal', 'closing']
}

comparison_operators = {
    '>': ['with greater than', 'greater than', 'with more than', 'more than', 'is above' ,'above', 'over', 'exceeding', 'surpassing', 'beyond', 'higher than', 'in excess of'],
    '<': ['with less than', 'less than', 'with fewer than', 'fewer than', 'is below' ,'below', 'under', 'beneath', 'lower than', 'not as much as', 'smaller than'],
    '=': ['equal to', 'is same as', 'same as', 'identical to', 'matching', 'equivalent to', 'corresponds to', 'is', 'for', 'with'],
    '!=': ['not equal to', 'different from', 'excluding', 'not the same as', 'dissimilar to', 'unlike', 'other than'],
    '>=': ['greater than or equal to', 'at least', 'no less than', 'minimum of', 'not below', ' starting from'],
    '<=': ['less than or equal to', 'at most', 'no more than', 'maximum of', 'not above', 'up to'],
    'BETWEEN': ['between', 'in the range of', 'within the bounds of', 'inside the limits of'],
    'IN': ['in', 'within', 'among', 'included in', 'part of', 'contained in', 'one of'],
    'NOT IN': ['not in', 'outside of', 'excluded from', 'not among', 'not part of', 'not contained in'],
    'LIKE': ['like', 'similar to', 'resembling', 'matching pattern', 'corresponding to'],
    'NOT LIKE': ['not like', 'dissimilar to', 'unlike', 'not matching pattern', 'different from pattern']
}

logical_operators = {
    'AND': ['and', 'also', 'as well as', 'in addition to', 'plus', 'together with', 'along with', 'including'],
    'OR': ['or', 'alternatively', 'either', 'otherwise', 'else', 'and/or'],
    'NOT': ['not', 'except', 'excluding', 'other than', 'but not', 'save for', 'apart from']
}

In [8]:
data_types = {
    'INTEGER': ['integer', 'int', 'whole number', 'numeric'],
    'FLOAT': ['float', 'decimal', 'real number', 'fractional number'],
    'VARCHAR': ['string', 'text', 'characters', 'alphanumeric'],
    'DATE': ['date', 'calendar day', 'day'],
    'TIMESTAMP': ['timestamp', 'date and time', 'moment', 'point in time'],
    'BOOLEAN': ['boolean', 'true/false', 'yes/no', 'binary']
}

In [840]:
# order matters
sql_clauses = {
    'FROM': ['from', 'in', 'out of', 'sourced from', 'derived from', 'based on'],
    'WHERE': ['where', 'for which', 'that have', 'meeting the condition', 'satisfying', 'fulfilling'],
    'ORDER BY': ['order by', 'ordered by', 'sort by', 'sorted by', 'arrange by', 'rank by', 'sequence by'],
    'GROUP BY': ['group by', 'categorize by', 'classify by', 'organize by', 'arrange by', 'cluster by', 'for each', 'broken down by', 'per', 'by'], # check about 'by'
    'HAVING': ['having', 'with the condition', 'subject to', 'meeting the criteria', 'have'],
    'LIMIT': ['limit', 'top', 'first', 'restrict to', 'cap at', 'only show'],
    'JOIN': ['join', 'combine', 'merge', 'connect', 'link', 'associate'],
    'UNION': ['union', 'combine', 'merge', 'incorporate', 'consolidate', 'unite'],
    'INTERSECT': ['intersect', 'in common', 'shared by', 'mutual', 'overlapping'],
    'EXCEPT': ['except', 'subtract', 'exclude', 'remove', 'omit', 'leave out'],
    'SELECT': ['show', 'list all', 'list', 'give', 'return', 'fetch', 'retrieve', 'get', 'find', 'which', 'what is the', 'what is', 'what are', 'what'], # 'display
}

In [481]:
ALL_KEYWORDS = set(list(aggregate_functions.keys())) | (set(list(logical_operators.keys()))) | (set(sql_clauses.keys())) | (set(comparison_operators.keys()))
ALL_KEYWORDS = list(map(str.lower, ALL_KEYWORDS))

In [860]:
constants = {"ALL_KEYWORDS": ALL_KEYWORDS,
             "sql_clauses": sql_clauses,
             "data_types": data_types,
             "logical_operators": logical_operators,
             "comparison_operators": comparison_operators,
             "aggregate_functions": aggregate_functions,
             "stop_words": stop_words}

## Preprocess

In [712]:
def replace_string(text, replace_dict):
    # Create a reverse mapping from values to keys
    value_to_key = {}
    for key, values in replace_dict.items():
        for value in values:
            value_to_key[value.lower()] = key.lower()
    
    # Create a regex pattern for matching words
    pattern = r'\b(' + '|'.join(re.escape(word) for word in value_to_key.keys()) + r')\b'
    
    # Function to replace matched words
    def replace_word(match):
        return value_to_key.get(match.group(0).lower(), match.group(0))
    
    # Perform the replacement
    replaced_text = re.sub(pattern, replace_word, text, flags=re.IGNORECASE)
    
    return replaced_text

def replace_numbers(token):
    try:
        return w2n.word_to_num(token)
    except:
        return token

def preprocess_text(text, constants):
    # Lowercase
    # text = text.lower() # change it to lower while comparisson
    
    # Remove punctuation
    # pattern = r'[^\w\s()*\d-]|(?<!\d)-(?!\d)'
    # text = re.sub(pattern, '', text)
    text = re.sub(r'\?', '', text)
    text = re.sub(r'\$', '', text)
    text = re.sub(r'\.$', '', text)
    print(text)
    text = re.sub(r',', '', text)
    pattern = r'\b(\d+)-(\d+)\b'
    text = re.sub(pattern, r'\1 to \2', text)



    # "name" condition
    tokens = text.split()
    result = []
    i = 0
    while i < len(tokens) - 1:  # Changed to len(tokens) - 1
        if tokens[i].lower() not in constants['ALL_KEYWORDS'] and (tokens[i+1].lower() == "name" or tokens[i+1].lower() == "names"):
            # print(f"{tokens[i]}_{tokens[i+1]}")
            result.append(f"{tokens[i]}_{tokens[i+1]}")
            i += 2
        else:
            result.append(tokens[i])
            i += 1
    if i < len(tokens):  # Add any remaining token
        result.append(tokens[i])

    i = 0
    result2 = []
    # group_by_keys = constants['sql_clauses']['GROUP BY']
    # group_by_keys = [key.split('by')[0].strip() for key in group_by_keys if key.split('by')[0] != '']
    # print("group keys", group_by_keys)

    order_by_keys = constants['sql_clauses']['ORDER BY']
    order_by_keys = [key.split('by')[0].strip() for key in order_by_keys if key.split('by')[0] != '']
    # print("order keys", order_by_keys)

    while i < len(result) - 1:
        if result[i+1].lower() =='by' and result[i].lower() in order_by_keys: #(result[i].lower() in group_by_keys):
            result2.append(f"{result[i]} {result[i+1]}")
            i += 2
        else:
            result2.append(result[i])
            i += 1
    if i < len(result):  # Add any remaining token
        result2.append(result[i])

    # print(result)
    text = ' '.join(result2)

    # substitute words
    text = replace_string(text, sql_clauses)
    text = replace_string(text, aggregate_functions)
    # from \d+ -> >=
    text = replace_string(text, comparison_operators)
    between_pattern = r'from \b(\w+|\d+)\b to \b(\w+|\d+)\b | \b(\w+|\d+)\b to \b(\w+|\d+)\b'
    text = re.sub(between_pattern, r" between(\1, \2)", text)
    text = replace_string(text, logical_operators)

    # Tokenization
    # tokens = word_tokenize(text)
    tokens = text.split()

    tokens = [str(replace_numbers(token)) for token in tokens]
    # print(tokens)

    # Remove stopwords
    #  load custom stop words
    tokens = [token for token in tokens if token.lower() not in stop_words]


    # Lemmatization (optional) // important for similarity
    tokens = [lemmatizer.lemmatize(token) for token in tokens]
    
    return tokens# ' '.join(tokens)

In [407]:
q1_tok = preprocess_text(q1, constants)
q2_tok = preprocess_text(q2, constants)
q3_tok = preprocess_text(q3, constants)
q4_tok = preprocess_text(q4, constants)
print(q1)
print(q2)
print(q4)
q1_tok, q2_tok, q3_tok, q4_tok

triggered sorted by
['Show', 'the', 'first_name', 'of', 'customers', 'who', 'have', 'placed', 'orders', 'sorted by']
What is the total number of customers?
List all products names with a price greater than $50.
What is the total amount of orders for each customer?


(['select', 'count', 'customer'],
 ['select', 'products_names', 'price', '>', '50'],
 ['select',
  'first_name',
  'of',
  'customer',
  'placed',
  'order',
  'order',
  'by',
  'last_name'],
 ['select', 'sum', 'amount', 'of', 'order', 'group', 'by', 'custome'])

## Identify the keywords

In [86]:
def remove_keywords(tokens, keywords):
    token_without_key = [token for token in tokens if token not in keywords]

    return token_without_key


In [261]:
q1_nk_tok = remove_keywords(q1_tok, ALL_KEYWORDS)
q2_nk_tok = remove_keywords(q2_tok, ALL_KEYWORDS)
q4_nk_tok = remove_keywords(q4_tok, ALL_KEYWORDS)
q1_nk_tok, q2_nk_tok, q4_nk_tok

(['customer'],
 ['products_names', 'price', '50'],
 ['amount', 'of', 'order', 'group', 'by', 'custome'])

## Identify Group By

In [492]:
indentify_col_tables(identify_group_by(remove_keywords(preprocess_text("total orders by status", constants), ALL_KEYWORDS), SQL), SQL)

['orders']


{'orders': ['status']}

In [517]:
def identify_group_by(tokens, schema):
    result = []
    i = 0
    while i < len(tokens) - 2:  # Changed to len(tokens) - 1

        if tokens[i].lower() == "group" and tokens[i+1].lower() == "by":
            # print(f"{tokens[i]}_{tokens[i+1]}")
            for table, columns in schema.items():
                # for column in schema[table].keys():
                # for column in columns.keys():
                match = get_close_matches(tokens[i+2], list(columns.keys()))
                if match:
                    # print(tokens[i+2], ":", match)                        
                    result.append(f"{tokens[i]} {tokens[i+1]} {match[0]}")
                    break
            i += 3
        else:
            result.append(tokens[i])
            i += 1
    if i < len(tokens):  # Add any remaining token
        result.extend(tokens[i:])
    
    return result


In [415]:
print(q1)
q1_nk_tok = remove_keywords(q1_tok, ALL_KEYWORDS)
print(q1_nk_tok)
group_tok1 = identify_group_by(q1_nk_tok, SQL)
group_tok1

What is the total number of customers?
['customer']


['customer']

In [416]:
print(q2)
q2_nk_tok = remove_keywords(q2_tok, ALL_KEYWORDS)
print(q2_nk_tok)
group_tok2 = identify_group_by(q2_nk_tok, SQL)
group_tok2

List all products names with a price greater than $50.
['products_names', 'price', '50']


['products_names', 'price', '50']

In [514]:
print(q3)
q3_nk_tok = remove_keywords(q3_tok, ALL_KEYWORDS)
print(q3_nk_tok)
group_tok3 = identify_group_by(q3_nk_tok, SQL)
group_tok3

Show the first name of customers who have placed orders, sorted by their last name.
['first_name', 'of', 'customer', 'placed', 'order', 'order', 'by', 'last_name']


['first_name', 'of', 'customer', 'placed', 'order', 'order', 'by', 'last_name']

In [518]:
print(q4)
q4_nk_tok = remove_keywords(q4_tok, ALL_KEYWORDS)
print(q4_nk_tok)
group_tok4 = identify_group_by(q4_nk_tok, SQL)
group_tok4

What is the total amount of orders for each customer?
['amount', 'of', 'order', 'group', 'by', 'custome']


['amount', 'of', 'order', 'group by customer_id']

## Identify Columns and Tables

In [None]:
# group by before this
def indentify_table(tokens, schema):
    
    detected_table = [[], [], []] # [[table names], [similarities], [tokens]]
    
    for i, token in enumerate(tokens):
        if token.lower() == "from":
            table_name = tokens[i + 1]
            if table_name in schema:
                print(f"Table identified: {table_name}")
                return table_name
        
        for table in schema.keys():
            similarity = fuzz.ratio(token, table.lower())
            if token in detected_table[2]:
                # print(f"Table identified: {token} {table} {similarity}")
                # print(detected_table)
                i = detected_table[2].index(token)
                if similarity >= detected_table[1][i]:
                    # detected_table[0] = table
                    detected_table[1][i] = similarity
                    detected_table[0][i] = table
                # print(f"Table identified: {table} {similarity}")
            else:
                if similarity >= 70: # initial threshold
                    detected_table[0].append(table)
                    detected_table[1].append(similarity)
                    detected_table[2].append(token)


    for i in range(len(detected_table[0])):
        if detected_table[1][i] > 85:
            tokens.remove(detected_table[2][i])    
        
    # print("dsfsdgd",detected_table[0])

    return detected_table[0]


def indentify_col_tables(tokens, schema):
    res = dict()

    identified_table = indentify_table(tokens, schema)
    print(identified_table)
    if not tokens:
        res[identified_table[0]] = ['*']
        return res

    if identified_table:
        for table in identified_table:
            res[table] = dict()
            for token in tokens:
                for column in schema[table].keys():
                    # for column in columns.keys():
                    if column=='pk' or column=='fk':
                        continue
                    else:
                        similarity = fuzz.ratio(token, column.lower())
                        if similarity >= 50:
                            if res[table].get(token):
                                # print("case 1")
                                # print(token, ":", column, similarity)
                                old_similarity = list(res[table][token].values())[0]
                                # print("sim", old_similarity)
                                if similarity > old_similarity:
                                    res[table][token] = {column: similarity}
                            else:
                                # print("case 2")
                                # print(token, ":", column, similarity)
                                res[table][token] = {column: similarity}
                            # print(res)
                            # if res.get(identified_table):
                            #     res[identified_table].append(column)
                            # else:
                            #     res[identified_table] = [column]
    else:
        for token in tokens:
            for table, columns in schema.items():
                for column in schema[table].keys():
                # for column in columns.keys():
                    if column=='pk' or column=='fk':
                        continue
                    else:
                        match = get_close_matches(token, columns)
                        if match:
                            print("*"*4, token, ":", match[0])
                            # if token in columns: # replace by fuzzy logic
                            # print(f"Column identified: {column} in table {table}")
                            if res.get(table):
                                res[table].append(match[0])
                            else:
                                res[table] = [match[0]]

    if identified_table:
        result = {}
        for table, token in res.items():
            cols = []
            for col in token.values():
                cols.extend(col.keys())
            result[table] = cols

        result = {key: value for key, value in result.items() if value != []}
        return result
    

    res = {key: value for key, value in res.items() if value != []}    

    return res



In [421]:
print(q2)
print(q2_nk_tok)
q2_nk_tok = remove_keywords(q2_tok, ALL_KEYWORDS)
req = indentify_col_tables(group_tok2, SQL)
req

List all products names with a price greater than $50.
['products_names', 'price', '50']
['products']


{'products': ['product_name', 'price']}

In [422]:
print(q1)
print(q1_nk_tok)
group_tok1 = identify_group_by(q1_nk_tok, SQL)
req = indentify_col_tables(group_tok1, SQL)
req

What is the total number of customers?
['customer']
['customers']


{'customers': ['*']}

In [322]:
print(q4)
q4_nk_tok = remove_keywords(q4_tok, ALL_KEYWORDS)
print(q4_nk_tok)

group_tok4 = identify_group_by(q4_nk_tok, SQL)
req = indentify_col_tables(group_tok4, SQL)
req

What is the total amount of orders for each customer?
['amount', 'of', 'order', 'group', 'by', 'custome']
['orders']


{'orders': ['total_amount', 'customer_id']}

## JOIN

In [22]:
def create_graph(db_schema, directional=False):
    if directional:
        G = nx.DiGraph()
    else:
        G = nx.Graph()

    # Add nodes and edges based on the SQL dictionary
    for table, details in db_schema.items():
        G.add_node(table)  # Add table as a node
        for fk_table in details['fk']:  # Iterate through foreign keys
            G.add_edge(table, fk_table)  # Create an edge from current table to foreign key table

    # Get nodes and edges for verification
    nodes = list(G.nodes)
    edges = list(G.edges)

    # Output nodes and edges
    # print(nodes)
    # print(edges)

    return G

In [23]:
def required_tables_graph(G, start, end, required_tables):
    def dijkstra(graph, start, end):
        distances = {node: float('infinity') for node in graph}
        distances[start] = 0
        pq = [(0, start)]
        previous = {node: None for node in graph}

        while pq:
            current_distance, current_node = heapq.heappop(pq)

            if current_node == end:
                path = []
                while current_node:
                    path.append(current_node)
                    current_node = previous[current_node]
                return path[::-1], current_distance

            for neighbor in graph[current_node]:
                distance = current_distance + 1  # All edges have weight of 1
                if distance < distances[neighbor]:
                    distances[neighbor] = distance
                    previous[neighbor] = current_node
                    heapq.heappush(pq, (distance, neighbor))

        return None, float('infinity')

    required_tables = set(required_tables) - {start, end}
    best_path = None
    best_distance = float('infinity')

    def dfs(current_path, current_distance, remaining_required):
        nonlocal best_path, best_distance

        if not remaining_required:
            path, distance = dijkstra(G, current_path[-1], end)
            if path:
                total_path = current_path + path[1:]
                total_distance = current_distance + distance
                if total_distance < best_distance:
                    best_path = total_path
                    best_distance = total_distance
            return

        for node in remaining_required:
            path, distance = dijkstra(G, current_path[-1], node)
            if path:
                new_path = current_path + path[1:]
                new_distance = current_distance + distance
                new_remaining = remaining_required - {node}
                dfs(new_path, new_distance, new_remaining)

    dfs([start], 0, required_tables)

    return best_path, best_distance

# G = create_graph(SQL, directional=False)
# # Set start, end, and required nodes
# required_tables = ['categories', 'products', 'orders']
# start = required_tables[0]
# end = required_tables[2]

# # Find the shortest path
# required_tables, distance = required_tables_graph(G, start, end, required_tables)

# print(f"Shortest path: {' -> '.join(required_tables) if required_tables else 'No path found'}")
# # Shortest path: categories -> products -> order_items -> orders

# print(f"Total distance: {distance}")
# # Total distance: 3

In [24]:
G = create_graph(SQL, directional=True)
list(G.edges)

[('products', 'categories'),
 ('orders', 'customers'),
 ('order_items', 'orders'),
 ('order_items', 'products')]

In [25]:
def graph_sort(edges):
    counts = Counter(t[0] for t in edges)
    
    # Sort tuples based on the counts in descending order
    sorted_tuples = sorted(edges, key=lambda x: counts[x[0]], reverse=True)
    
    return sorted_tuples


In [155]:
graph_sort(G.edges)

[('order_items', 'orders'),
 ('order_items', 'products'),
 ('products', 'categories'),
 ('orders', 'customers')]

In [26]:
def join_clause(req_schema: dict, db_schema: dict):
    clause = ""

    # not required
    # max_cols = 0
    # prim_table = None
    # # change logic based on foreign key numbers
    # for table, cols in req_schema.items():
    #     if len(cols) > max_cols:
    #         max_cols = len(cols)
    #         prim_table = table
    
    # # pk_col = db_schema[prim_table]['pk']
    # print(prim_table)

    # create graph from Database schema
    db_graph = create_graph(db_schema)
    db_dir_graph = create_graph(db_schema, directional=True)

    # print(db_dir_graph)

    required_tables = list(req_schema.keys())

    print(required_tables)

    min_dist = float('inf')
    for st_table in required_tables:
        for end_table in required_tables:
            sub_graph, distance = required_tables_graph(db_graph, st_table, end_table, required_tables)
            if distance < min_dist:
                join_graph = OrderedSet(sub_graph)
                min_dist = distance

    print(join_graph)

    req_graph = []

    for edge in db_dir_graph.edges:
        # assumption binary relations between tables
        if edge[0] in join_graph and edge[1] in join_graph:
            req_graph.append(edge)

    req_graph = graph_sort(req_graph)
    print("required graph: ", req_graph)

    clause += f"FROM {req_graph[0][0]}\n"

    for i, edge in enumerate(req_graph):
        table1 = edge[0]
        table2 = edge[1]
        print(edge)
        # if table1 in db_schema[table2]['fk']:
        #     # print("1")
        #     fk_col = db_schema[table2]['fk'][table1] # foreign key corresponding to 2nd table  primary key
        #     pk_col = db_schema[table1]['pk'][0] # primary key of 2nd table
        #     clause += f"JOIN {table1} ON {table2}.{fk_col}={table1}.{pk_col} \n"
        # else:
        fk_col = db_schema[table1]['fk'][table2] # foreign key corresponding to 2nd table  primary key
        pk_col = db_schema[table2]['pk'][0] # primary key of 2nd table
        clause += f"JOIN {table2} ON {table1}.{fk_col}={table2}.{pk_col} \n"
        
    return clause

In [233]:
req

{'orders': ['total_amount'], 'customers': []}

In [234]:
if len(req.keys()) > 1:
    print(join_clause(req, SQL))
else:
    print("No Join required")

['orders', 'customers']
OrderedSet(['orders', 'customers'])
required graph:  [('orders', 'customers')]
('orders', 'customers')
FROM orders
JOIN customers ON orders.customer_id=customers.customer_id 



In [174]:
SQL2 = {
    "books": {
        "pk": ["book_id"],
        "fk": {"authors": "author_id"},
        "book_id": "INT",
        "title": "VARCHAR(255)",
        "author_id": "INT",
        "isbn": "VARCHAR(13)",
        "publication_year": "INT",
        "price": "DECIMAL(10, 2)",
        "stock_quantity": "INT"
    },
    "authors": {
        "pk": ["author_id"],
        "fk": {},
        "author_id": "INT",
        "first_name": "VARCHAR(50)",
        "last_name": "VARCHAR(50)",
        "birth_date": "DATE",
        "nationality": "VARCHAR(50)",
        "biography": "TEXT"
    }
}

req2 = {
    "books": [
        "book_id",
        "title",
        "stock_quantity"
    ],
    "authors": [
        "author_id",
        "first_name",
        "last_name",
    ]
}

print(join_clause(req2, SQL2))

['books', 'authors']
OrderedSet(['books', 'authors'])
required graph:  [('books', 'authors')]
('books', 'authors')
FROM books
JOIN authors ON books.author_id=authors.author_id 



## Identify Limit

In [342]:
def indentify_limit(tokens):
    for i, token in enumerate(tokens):
        if token.lower() == "limit" or token.lower() == "limits":
            if tokens[i+1].isdigit():
                return f"LIMIT {tokens[i+1]}"
            else:
                return f"LIMIT 1"
    
    return -1

In [347]:
print(q6)
q6_nk_tok = preprocess_text(q6, constants)
print(q6_nk_tok)
limit_tok6 = indentify_limit(q6_nk_tok)
limit_tok6

What are the top 5 most expensive products?
['select', 'limit', '5', 'max', 'expensive', 'product']


'LIMIT 5'

## Identify Condition

In [462]:
print(q1)
print(preprocess_text(q1, constants))
print((q2))
print(preprocess_text(q2, constants))
print((q3))
print(preprocess_text(q3, constants))
print((q4))
print(preprocess_text(q4, constants))
print((q5))
print(preprocess_text(q5, constants))
print((q6))
print(preprocess_text(q6, constants))
print((q7))
print(preprocess_text(q7, constants))
print((q8))
print(preprocess_text(q8, constants))

What is the total number of customers?
['select', 'count', 'customer']
List all products names with a price greater than $50.
['select', 'products_names', 'price', '>', '50']
Show the first name of customers who have placed orders, sorted by their last name.
['select', 'first_name', 'of', 'customer', 'placed', 'order', 'order', 'by', 'last_name']
What is the total amount of orders for each customer?
['select', 'sum', 'amount', 'of', 'order', 'group', 'by', 'custome']
Which categories have more than 5 products?
['select', 'category', '>', '5', 'product']
What are the top 5 most priciest products?
['select', 'limit', '5', 'max', 'priciest', 'product']
Find all orders with more than 3 items.
['select', 'all', 'order', '>', '3', 'item']
List the top 3 customers who have spent the most on 'Electronics' category products.
['select', 'limit', '3', 'customer', 'max', 'on', "'Electronics'", 'category', 'product']


In [None]:
# import re

def identify_condition(tokens, req):
    conditions = {"where": [], "having": []}
    # "where" or "having" is in the prompt
    for i, token in enumerate(tokens):
        if token.lower() == "where":
            conditions['where'].append(" ".join(tokens[i+1:])) # refine with 
            break

        if token.lower() == "having":
            conditions['having'].append(" ".join(tokens[i+1:])) # refine with 
            break

        
    print(f"Conditions: {conditions}")
    # Example Output: Conditions: ["age > 30"]


In [None]:
def extract_conditions_with_clauses(tokens, schema):
    """
    Extracts WHERE and HAVING conditions from a natural language query based on schema.

    Parameters:
    - input_query (str): The natural language query.
    - schema (dict): A dictionary containing table names as keys and list of column names as values.

    Returns:
    - dict: A dictionary with separate lists of WHERE and HAVING conditions.
    """
    # Define regex patterns for conditions and aggregate functions
    comparison_pattern = r'(\w+|\w+\s*\(.*?\))\s*(=|>|<|>=|<=|!=)\s*("[^"]*"|\'[^\']*\'|\d+(\.\d+)?)'
    aggregate_functions = ["SUM", "AVG", "COUNT", "MAX", "MIN"]

    # Find all matches in the input query
    
    input_query = (' ').join(tokens)
    matches = re.findall(comparison_pattern, input_query)

    where_conditions = []
    having_conditions = []

    for match in matches:
        column_or_expression = match[0]
        operator = match[1]
        value = match[2]

        # Remove quotes from value if they exist
        value = value.strip('"').strip("'")

        # Check if the column_or_expression contains an aggregate function
        is_aggregate = any(func in column_or_expression.upper() for func in aggregate_functions)

        # If it is an aggregate function, add it to HAVING
        if is_aggregate or "COUNT" in input_query.upper():  # Handle implied aggregates
            having_conditions.append({
                "column_or_expression": column_or_expression if column_or_expression.upper() != "COUNT" else "COUNT(*)",
                "operator": operator,
                "value": value
            })
        else:
            # Check if the column is part of the schema for WHERE clause
            is_where = any(
                column_or_expression in columns for columns in schema.values()
            )
            if is_where:
                where_conditions.append({
                    "column": column_or_expression,
                    "operator": operator,
                    "value": value
                })
            else:
                # If not found in schema, treat it as part of HAVING by default
                having_conditions.append({
                    "column_or_expression": column_or_expression,
                    "operator": operator,
                    "value": value
                })

    return {
        "where": where_conditions,
        "having": having_conditions
    }


# Extract conditions
q5_tok = preprocess_text(q5, constants)
group_tok5 = identify_group_by(q5_tok, SQL)
req = indentify_col_tables(group_tok5, SQL)
conditions = extract_conditions_with_clauses(group_tok5, req)
print(conditions)


['categories', 'products']
{'where': [], 'having': [{'column_or_expression': 'COUNT(*)', 'operator': '>', 'value': '5'}]}


In [498]:
group_tok5

['select', 'all', '>', '3', 'item']

## identify order by

In [None]:
def indentify_order_by(tokens, req):
    result = []
    
    # if "order by" appears in tokens
    i = 0
    while i < len(tokens) - 2:
        if tokens[i].lower() == "order" and tokens[i+1].lower() == "by":
            # print(f"{tokens[i]}_{tokens[i+1]}")
            for table, columns in req.items():
                # for columns in req[table].values():
                # for column in columns.keys():
                match = get_close_matches(tokens[i+2], columns)
                if match:
                    if "max" in tokens:
                        result.append(f"{tokens[i]} {tokens[i+1]} {match[0]} DESC")    
                    else:
                        result.append(f"{tokens[i]} {tokens[i+1]} {match[0]}")
            i += 3
        else:
            result.append(tokens[i])
            i += 1
    if i < len(tokens):  # Add any remaining token
        result.extend(tokens[i:])

    print(list(req.values())[0])
    
    # if just "max" appears in tokens
    for token in result:
        match = get_close_matches(token, list(req.values())[0])
        if match:
            print(match)
            if "max" in result:
                print(f"{match[0]} desc") # TODO: add this to template (to be created)
            else:
                print(f"{match[0]}") # TODO: add this to template (to be created)

    return list(dict.fromkeys(result)) # removing duplicates

In [424]:
print((q3))
print(preprocess_text(q3, constants))
print((q6))
print(preprocess_text(q6, constants))
print((q8))
print(preprocess_text(q8, constants))

Show the first name of customers who have placed orders, sorted by their last name.
triggered sorted by
['Show', 'the', 'first_name', 'of', 'customers', 'who', 'have', 'placed', 'orders', 'sorted by']
['select', 'first_name', 'of', 'customer', 'placed', 'order', 'order', 'by', 'last_name']
What are the top 5 most expensive products?
['select', 'limit', '5', 'max', 'expensive', 'product']
List the top 3 customers who have spent the most on 'Electronics' category products.
['select', 'limit', '3', 'customer', 'max', 'on', "'Electronics'", 'category', 'product']


In [460]:
temp_token = preprocess_text(q6, constants)
print(temp_token)
temp_nk_token = remove_keywords(temp_token, ALL_KEYWORDS)
group_tok = identify_group_by(temp_nk_token, SQL)
req = indentify_col_tables(group_tok, SQL)
indentify_order_by(temp_token, req)
# req

['select', 'limit', '5', 'max', 'priciest', 'product']
['products']
['price']
['price']
order by price desc


['select', 'limit', '5', 'max', 'priciest', 'product']

In [426]:
temp_token = preprocess_text(q3, constants)
temp_nk_token = remove_keywords(temp_token, ALL_KEYWORDS)
group_tok = identify_group_by(temp_nk_token, SQL)
req = indentify_col_tables(group_tok, SQL)
indentify_order_by(temp_token, req)

triggered sorted by
['Show', 'the', 'first_name', 'of', 'customers', 'who', 'have', 'placed', 'orders', 'sorted by']
['customers', 'orders']


['select',
 'first_name',
 'of',
 'customer',
 'placed',
 'order',
 'order by last_name']

In [433]:
q6 = "What are the top 5 most priciest products?"

In [443]:
get_close_matches('priciest', ['price'])

['price']

## Pipeline

preprocess -> remove keywords -> group by -> columns & table identification -> join -> order by -> conditions -> where | having

In [993]:
QUERY_TEMPLATE = {
    "SELECT": "",
    "FROM": -1,
    "WHERE": [], 
    'HAVING': [],
    "GROUP BY": -1,
    "ORDER BY": -1,
    "LIMIT": -1
}

CONSTANTS = constants

DB_SCHEMA = SQL

In [994]:
class QueryGenerator:
    def __init__(self):
        self.db_schema = DB_SCHEMA
        self.req_schema = dict()
        self.constants = CONSTANTS
        self.query_template = QUERY_TEMPLATE

    def replace_string(self, text, replace_dict):
        # Create a reverse mapping from values to keys
        value_to_key = {}
        for key, values in replace_dict.items():
            for value in values:
                value_to_key[value.lower()] = key.lower()
        
        # Create a regex pattern for matching words
        pattern = r'\b(' + '|'.join(re.escape(word) for word in value_to_key.keys()) + r')\b'
        
        # Function to replace matched words
        def replace_word(match):
            return value_to_key.get(match.group(0).lower(), match.group(0))
        
        # Perform the replacement
        replaced_text = re.sub(pattern, replace_word, text, flags=re.IGNORECASE)
    
        return replaced_text

    def replace_numbers(self, token):
        try:
            return w2n.word_to_num(token)
        except:
            return token
        
    def merge_before_token(self, tokens, merge_key):

        result = []
        i = 0

        while i < len(tokens) - 1:  # Changed to len(tokens) - 1
            if tokens[i].lower() not in self.constants['ALL_KEYWORDS'] and (tokens[i + 1].lower() == merge_key or tokens[i + 1].lower() == f"{merge_key}s"):
                print("condition trig")
                # print(f"{tokens[i]}_{tokens[i+1]}")
                result.append(f"{tokens[i]}_{tokens[i+1]}")
                i += 2
            else:
                result.append(tokens[i])
                i += 1
        if i < len(tokens):  # Add any remaining token
            result.append(tokens[i])

        print("merged", result)

        return result

    def preprocess_text(self, text):
        # Lowercase
        # text = text.lower() # change it to lower while comparisson
        
        # Remove punctuation
        # pattern = r'[^\w\s()*\d-]|(?<!\d)-(?!\d)'
        # text = re.sub(pattern, '', text)
        text = re.sub(r'\?', '', text)
        text = re.sub(r'\$', '', text)
        text = re.sub(r'\.$', '', text)
        text = re.sub(r',', '', text)
        pattern = r'\b(\d+)-(\d+)\b'
        text = re.sub(pattern, r'\1 to \2', text)


        tokens = text.split()

        # "name" condition
        # result = []
        # i = 0
        # while i < len(tokens) - 1:  # Changed to len(tokens) - 1
        #     if tokens[i].lower() not in self.constants['ALL_KEYWORDS'] and (tokens[i+1].lower() == "name" or tokens[i+1].lower() == "names"):
        #         # print(f"{tokens[i]}_{tokens[i+1]}")
        #         result.append(f"{tokens[i]}_{tokens[i+1]}")
        #         i += 2
        #     else:
        #         result.append(tokens[i])
        #         i += 1
        # if i < len(tokens):  # Add any remaining token
        #     result.append(tokens[i])

        result = self.merge_before_token(tokens, merge_key="name")

        result = self.merge_before_token(result, merge_key="date")

        i = 0
        result2 = []
        # group_by_keys = self.constants['sql_clauses']['GROUP BY']
        # group_by_keys = [key.split('by')[0].strip() for key in group_by_keys if key.split('by')[0] != '']
        # print("group keys", group_by_keys)

        order_by_keys = self.constants['sql_clauses']['ORDER BY']
        order_by_keys = [key.split('by')[0].strip() for key in order_by_keys if key.split('by')[0] != '']
        # print("order keys", order_by_keys)

        while i < len(result) - 1:
            if result[i+1].lower() =='by' and result[i].lower() in order_by_keys: #(result[i].lower() in group_by_keys):
                result2.append(f"{result[i]} {result[i+1]}")
                i += 2
            else:
                result2.append(result[i])
                i += 1
        if i < len(result):  # Add any remaining token
            result2.append(result[i])

        # print(result)
        text = ' '.join(result2)

        # substitute words
        text = self.replace_string(text, self.constants['sql_clauses'])
        text = self.replace_string(text, self.constants['aggregate_functions'])
        # from \d+ -> >=
        text = self.replace_string(text, self.constants['comparison_operators'])
        between_pattern = r'from \b(\w+|\d+)\b to \b(\w+|\d+)\b | \b(\w+|\d+)\b to \b(\w+|\d+)\b'
        text = re.sub(between_pattern, r" between(\1, \2)", text)
        text = self.replace_string(text, self.constants['logical_operators'])

        # Tokenization
        # tokens = word_tokenize(text)
        tokens = text.split()

        tokens = [str(self.replace_numbers(token)) for token in tokens]
        # print(tokens)

        # Remove stopwords
        #  load custom stop words
        tokens = [token for token in tokens if token.lower() not in self.constants['stop_words']]
        tokens = [token for token in tokens if token.lower() not in self.constants['stop_words']]
        tokens = [token for token in tokens if token.lower() not in self.constants['stop_words']]

        # Lemmatization (optional) // important for similarity
        tokens = [lemmatizer.lemmatize(token) for token in tokens]
        
        return tokens# ' '.join(tokens)
    
    def remove_keywords(self, tokens):
        token_without_key = [token for token in tokens if token not in self.constants['ALL_KEYWORDS']]

        return token_without_key
    
    def identify_group_by(self, tokens):
        result = []
        i = 0
        while i < len(tokens) - 2:  # Changed to len(tokens) - 1

            if tokens[i].lower() == "group" and tokens[i+1].lower() == "by":
                # print(f"{tokens[i]}_{tokens[i+1]}")
                for table, columns in self.db_schema.items():
                    # for column in self.db_schema[table].keys():
                    # for column in columns.keys():
                    match = get_close_matches(tokens[i+2], columns)
                    if match:
                        print(tokens[i+2], ":", match)     
                        self.query_template["GROUP BY"] = f"{match[0]}"                   
                        result.append(f"{tokens[i]} {tokens[i+1]} {match[0]}")
                        break
                i += 3
            else:
                result.append(tokens[i])
                i += 1
        if i < len(tokens):  # Add any remaining token
            result.extend(tokens[i:])

        # self.query_template["GROUP BY"] = -1
        print("gleice: ", result)
        return result
    
        # group by before this
    def indentify_table(self, tokens):
        
        detected_table = [[], [], []] # [[table names], [similarities], [tokens]]
        
        for i, token in enumerate(tokens):
            if token.lower() == "from":
                table_name = tokens[i + 1]
                if table_name in self.db_schema:
                    print(f"Table identified: {table_name}")
                    return table_name
            
            for table in self.db_schema.keys():
                similarity = fuzz.ratio(token, table.lower())
                if token in detected_table[2]:
                    # print(f"Table identified: {token} {table} {similarity}")
                    # print(detected_table)
                    i = detected_table[2].index(token)
                    if similarity >= detected_table[1][i]:
                        # detected_table[0] = table
                        detected_table[1][i] = similarity
                        detected_table[0][i] = table
                    # print(f"Table identified: {table} {similarity}")
                else:
                    if similarity >= 70: # initial threshold
                        detected_table[0].append(table)
                        detected_table[1].append(similarity)
                        detected_table[2].append(token)


        for i in range(len(detected_table[0])):
            if detected_table[1][i] > 85:
                tokens.remove(detected_table[2][i])    
            
        # print("dsfsdgd",detected_table[0])

        return detected_table[0]


    def indentify_col_tables(self, tokens):
        res = dict()

        identified_table = self.indentify_table(tokens)
        print("table",identified_table)
        if not tokens:
            res[identified_table[0]] = ['*']
            return res

        if identified_table:
            for table in identified_table:
                res[table] = dict()
                for token in tokens:
                    for column in self.db_schema[table].keys():
                        # for column in columns.keys():
                        if column=='pk' or column=='fk':
                            continue
                        else:
                            similarity = fuzz.ratio(token, column.lower())
                            if similarity >= 55:
                                if res[table].get(token):
                                    # print("case 1")
                                    # print(token, ":", column, similarity)
                                    old_similarity = list(res[table][token].values())[0]
                                    # print("sim", old_similarity)
                                    if similarity > old_similarity:
                                        res[table][token] = {column: similarity}
                                else:
                                    # print("case 2")
                                    # print(token, ":", column, similarity)
                                    res[table][token] = {column: similarity}
                                # print(res)
                                # if res.get(identified_table):
                                #     res[identified_table].append(column)
                                # else:
                                #     res[identified_table] = [column]
        else:
            for token in tokens:
                for table, columns in self.db_schema.items():
                    # for column in self.db_schema[table].keys():
                    # # for column in columns.keys():
                    #     if column=='pk' or column=='fk':
                    #         continue
                    #     else:
                    
                    cols = list(map(str.lower, list(columns.keys())))
                            
                    match = get_close_matches(token, cols)
                    # print("abracadabra: ", token, columns)
                    if match:
                        # print("match found:",match)
                        # print("*"*4, token, ":", match[0])
                        # if token in columns: # replace by fuzzy logic
                        # print(f"Column identified: {column} in table {table}")
                        if res.get(table):
                            res[table].add(match[0])
                        else:
                            res[table] = set(match)

        if identified_table:
            result = {}
            for table, token in res.items():
                cols = set()
                for col in token.values():
                    cols.update(col.keys())
                result[table] = cols

            result = {key: value for key, value in result.items() if value != set()}
            return result
        

        res = {key: value for key, value in res.items() if value != set()}    

        return res


    # JOIN
    def create_graph(self, directional=False):
        if directional:
            G = nx.DiGraph()
        else:
            G = nx.Graph()

        # Add nodes and edges based on the SQL dictionary
        for table, details in self.db_schema.items():
            G.add_node(table)  # Add table as a node
            for fk_table in details['fk']:  # Iterate through foreign keys
                G.add_edge(table, fk_table)  # Create an edge from current table to foreign key table

        # Get nodes and edges for verification
        nodes = list(G.nodes)
        edges = list(G.edges)

        # Output nodes and edges
        # print(nodes)
        # print(edges)

        return G
    
    def required_tables_graph(self, G, start, end, required_tables):
        def dijkstra(graph, start, end):
            distances = {node: float('infinity') for node in graph}
            distances[start] = 0
            pq = [(0, start)]
            previous = {node: None for node in graph}

            while pq:
                current_distance, current_node = heapq.heappop(pq)

                if current_node == end:
                    path = []
                    while current_node:
                        path.append(current_node)
                        current_node = previous[current_node]
                    return path[::-1], current_distance

                for neighbor in graph[current_node]:
                    distance = current_distance + 1  # All edges have weight of 1
                    if distance < distances[neighbor]:
                        distances[neighbor] = distance
                        previous[neighbor] = current_node
                        heapq.heappush(pq, (distance, neighbor))

            return None, float('infinity')

        required_tables = set(required_tables) - {start, end}
        best_path = None
        best_distance = float('infinity')

        def dfs(current_path, current_distance, remaining_required):
            nonlocal best_path, best_distance

            if not remaining_required:
                path, distance = dijkstra(G, current_path[-1], end)
                if path:
                    total_path = current_path + path[1:]
                    total_distance = current_distance + distance
                    if total_distance < best_distance:
                        best_path = total_path
                        best_distance = total_distance
                return

            for node in remaining_required:
                path, distance = dijkstra(G, current_path[-1], node)
                if path:
                    new_path = current_path + path[1:]
                    new_distance = current_distance + distance
                    new_remaining = remaining_required - {node}
                    dfs(new_path, new_distance, new_remaining)

        dfs([start], 0, required_tables)

        return best_path, best_distance
    
    def graph_sort(self, edges):
        counts = Counter(t[0] for t in edges)
        
        # Sort tuples based on the counts in descending order
        sorted_tuples = sorted(edges, key=lambda x: counts[x[0]], reverse=True)
        
        return sorted_tuples

    def remove_duplicate_keys(self, edges):

        filtered_result = dict()

        for edge in edges:

            # Extract the keys from the dictionaries
            table0 = self.req_schema[edge[0]]
            table1 = self.req_schema[edge[1]]

            # Create a new dictionary with only the edge[0] key
            # and its value as the set of keys that are unique to edge[0]
            filtered_result[edge[0]]= table0 - table1
            if table1 - table0:
                filtered_result[edge[1]]= table1 - table0


            # If the result is an empty set, we want to keep the original edge[0] value
            if not filtered_result[edge[0]]:
                filtered_result[edge[0]] = self.req_schema[edge[0]]

        print("updated schema:",filtered_result)

        # Get all unique keys across all tables
        # all_keys = set().union(*self.req_schema.values())
        
        # # Create a dictionary to store the count of each key
        # key_count = {key: sum(1 for table_keys in self.req_schema.values() if key in table_keys) for key in all_keys}
        
        # # Create the result dictionary
        # result = {}
        
        # for table, keys in self.req_schema.items():
        #     # Keep only the keys that are unique to this table
        #     unique_keys = {key for key in keys if key_count[key] == 1}
            
        #     # If there are no unique keys, keep all original keys for this table
        #     if not unique_keys:
        #         unique_keys = keys
            
        #     # Add the table and its unique keys to the result
        #     result[table] = unique_keys
        
        return filtered_result

    def join_clause(self):
        clause = ""

        # not required
        # max_cols = 0
        # prim_table = None
        # # change logic based on foreign key numbers
        # for table, cols in req_schema.items():
        #     if len(cols) > max_cols:
        #         max_cols = len(cols)
        #         prim_table = table
        
        # # pk_col = db_schema[prim_table]['pk']
        # print(prim_table)

        # create graph from Database schema
        db_graph = self.create_graph()
        db_dir_graph = self.create_graph(directional=True)

        # print(db_dir_graph)

        required_tables = list(self.req_schema.keys())

        print(required_tables)

        min_dist = float('inf')
        for st_table in required_tables:
            for end_table in required_tables:
                sub_graph, distance = self.required_tables_graph(db_graph, st_table, end_table, required_tables)
                if distance < min_dist:
                    join_graph = OrderedSet(sub_graph)
                    min_dist = distance

        print(join_graph)


        req_graph = []

        for edge in db_dir_graph.edges:
            # assumption binary relations between tables
            if edge[0] in join_graph and edge[1] in join_graph:
                req_graph.append(edge)

        req_graph = self.graph_sort(req_graph)
        print("required graph: ", req_graph)

        self.req_schema = self.remove_duplicate_keys(req_graph)

        if len(self.req_schema) >1 :

            clause += f"{req_graph[0][0]}\n"

            for i, edge in enumerate(req_graph):
                table1 = edge[0]
                table2 = edge[1]
                # print(edge)
                fk_col = self.db_schema[table1]['fk'][table2] # foreign key corresponding to 2nd table  primary key
                pk_col = self.db_schema[table2]['pk'][0] # primary key of 2nd table
                clause += f"JOIN {table2} ON {table1}.{fk_col}={table2}.{pk_col} \n"

            self.query_template['FROM'] = clause
        else:
            clause += f"{req_graph[0][0]}\n"
            self.query_template['FROM'] = clause

            
        return clause

    def aggregate_parser(self, tokens):

        res = []
        i = 0
        thresh = 50
        pot_col = ""
        while i< len(tokens):
            if tokens[i].upper() in list(self.constants['aggregate_functions'].keys()):
                for table in self.db_schema.keys():
                    print(table)
                    cols = list(self.db_schema[table].keys()) # too check with missing cols from prev functions
                    print(cols)
                    for col in cols:
                        similarity = fuzz.ratio(tokens[i+1], col)
                        if similarity>thresh:
                            thresh = similarity
                            pot_col = f"{table.lower()}.{col.lower()}"
                            print(thresh, pot_col)
                res.append(f"{tokens[i].upper()}({pot_col})")
                i+=2
                    # i+=1
            else:
                res.append(tokens[i])
                i+=1
            # i+=1
            
        print("agg_parser",res)

        return res

    # CONDITION - WHERE & HAVING    
    def extract_conditions(self, tokens):
        """
        Extracts WHERE and HAVING conditions from a natural language query based on keywords 'where' and 'having',
        or based on comparison operators (=, >, <, <=, >=).

        Parameters:
        - input_query (str): The natural language query.
        - req_schema (dict): A dictionary containing table names as keys and list of column names as values.

        Returns:
        - dict: A dictionary with separate lists of WHERE and HAVING conditions, and aggregate functions.
        """

        tokens_copy = tokens.copy()
        print(tokens_copy)
        # Define regex patterns for comparison operators and aggregate functions
        comparison_pattern = r'(=|>|<|>=|<=|!=)'  # To locate comparison operators
        # aggregate_functions = ["SUM", "AVG", "COUNT", "MAX", "MIN"]

        # Initialize results
        where_conditions = []
        having_conditions = []

        # Split input query into tokens for easier parsing
        # tokens = input_query.split()
        input_query = (' ').join(tokens_copy)
        print(input_query)

        # Detect keywords and their positions
        where_pos = input_query.lower().find("where")
        having_pos = input_query.lower().find("having")

        # Helper function to find nearest column and value around a comparison operator
        def find_column_and_value(tokens_copy, operator_index, req_schema):
            column = None
            value = None

            # Look before and after the operator for a column and a value
            if operator_index > 0:
                potential_column = tokens_copy[operator_index - 1]
                formatted_keys = [f'{column.lower()}' 
                        for table, columns in req_schema.items() 
                        for column in columns]
                print(req_schema)
                print("inside_col_cal_func", potential_column,formatted_keys)
                match = get_close_matches(potential_column.lower(), formatted_keys)
                print(match)
                if match:
                    column = potential_column

            if operator_index < len(tokens_copy) - 1:
                potential_value = tokens_copy[operator_index + 1]
                # Check if it's a quoted value or a number
                if re.match(r"^'[^']*'$|^\d+(\.\d+)?$|^(?:\d{2}[/-]\d{2}[/-]\d{4})$", potential_value):
                    value = potential_value.strip("'").strip('"')
                # print("%%%%%",value)

            return column, value

        res = self.aggregate_parser(tokens_copy)

        agg_flag = False

        for token in res:
            for agg in aggregate_functions.keys():
                if agg.lower() in token.lower():
                    agg_flag = True

        for i, token in enumerate(res):
            if re.match(comparison_pattern, token):  # If it's a comparison operator
                column, value = find_column_and_value(res, i, self.req_schema)
                print("col and val", column, value)

                if column and value:
                    condition = {
                        "column": column,
                        "operator": token,
                        "value": value
                    }
                    print("cond:", condition)
                    print("where:", where_pos)
                    print("having:", having_pos)

                    # Check if the condition belongs to WHERE or HAVING based on position
                    if where_pos != -1 and having_pos == -1: # iff "where" in tokens
                        where_conditions.append(condition)
                    elif having_pos != -1 and where_pos == -1: # iff "where" in tokens
                        having_conditions.append(condition)
                    elif agg_flag: # iff "aggregate_function" in tokens
                        having_conditions.append(condition)
                    else: # default
                        where_conditions.append(condition)

        # # Detect aggregate functions
        # for func in aggregate_functions:
        #     match = re.search(rf'\b{func}\s*\((.*?)\)', input_query, re.IGNORECASE)
        #     if match:
        #         aggregate_query = {
        #             "function": func.upper(),
        #             "column_or_expression": match.group(1).strip() or "*"
        #         }
        #         break  # Only detect the first aggregate function

        self.query_template['WHERE'] = where_conditions
        self.query_template['HAVING'] = having_conditions

        return res, {"where": where_conditions, "having": having_conditions}

    def indentify_order_by(self, tokens):
        result = []
        
        # if "order by" appears in tokens
        i = 0
        while i < len(tokens) - 2:
            if tokens[i].lower() == "order" and tokens[i+1].lower() == "by":
                # print(f"{tokens[i]}_{tokens[i+1]}")
                for table, columns in self.req_schema.items():
                    # for columns in self.req_schema[table].values():
                    # for column in columns.keys():
                    match = get_close_matches(tokens[i+2], columns)
                    if match:
                        if "max" in tokens:
                            self.query_template["ORDER BY"] = f"{table}.{match[0]} DESC"
                            result.append(f"{tokens[i]} {tokens[i+1]} {table}.{match[0]} DESC")    
                        else:
                            self.query_template["ORDER BY"] = f"{table}.{match[0]}"
                            result.append(f"{tokens[i]} {tokens[i+1]} {table}.{match[0]}")
                i += 3
            else:
                result.append(tokens[i])
                i += 1
        if i < len(tokens):  # Add any remaining token
            result.extend(tokens[i:])

        print(self.req_schema)

        # print(list(self.req_schema.values())[0])
        
        # if just "max" appears in tokens
        for token in result:
            for table, cols in self.req_schema.items():
                match = get_close_matches(token, cols)
                if match:
                    print(match)
                    if "max" in result:
                        self.query_template["ORDER BY"] = f"{table}.{match[0]} DESC"
                        print(f"{match[0]} desc") 
                    else:
                        self.query_template["ORDER BY"] = f"{table}.{match[0]}"
                        print(f"{match[0]}")

        return list(dict.fromkeys(result)) # removing duplicates
    
    # LIMIT
    def indentify_limit(self, tokens):
        limit = -1
        for i, token in enumerate(tokens):
            if token.lower() == "limit" or token.lower() == "limits":
                if tokens[i+1].isdigit():
                    limit = tokens[i+1]
                    self.query_template["LIMIT"] = tokens[i+1]
                    return
                else:
                    limit = 1
                    self.query_template["LIMIT"] = 1
        
        # self.query_template["LIMIT"] = -1

        return limit


    

    def select_validator(self, tokens):
        # making sure all the required columns are present in select clause
        
        # check group by and having condition
        self.query_template['SELECT'] = ""
        select_cols = set()

        agg_funcs = list(aggregate_functions.keys())

        for token in tokens:
            for agg in aggregate_functions:
                if agg in token:
                    select_cols.add(token)

        for table, cols in self.req_schema.items():
            for col in cols:
                select_cols.add(f"{table}.{col}")

        group_by_col = self.query_template['GROUP BY']
        if group_by_col:
            for table, cols in self.req_schema.items():
                if group_by_col in cols:
                    select_cols.add(f"{table}.{group_by_col}")
                    self.query_template['GROUP BY'] = f"{table}.{group_by_col}"

        
        having_col = self.query_template['HAVING']
        if group_by_col and having_col:
            for table, cols in self.req_schema.items():
                if group_by_col in cols:
                    select_cols.add(f"{table}.{group_by_col}")
        

        for cols in select_cols:
            self.query_template['SELECT']+=f"{cols}, "

        self.query_template['SELECT'] = self.query_template['SELECT'].strip(', ')

        return

    # GRAPH GENERATOR - seperate function/class
    # def graph_generator():
    #     pass

    # template generator - not required
    def query_template_generator(self, input) -> dict:

        # preprocess -> remove keywords -> group by -> columns & table identification 
        # -> join -> order by -> conditions -> where | having

        self.query_template = QUERY_TEMPLATE.copy()

        self.input = input
        print("******input",self.input)

        self.tokens = self.preprocess_text(input)
        print("******tokens",self.tokens)
        self.tokens_nk = self.remove_keywords(self.tokens)
        print("******tokens_nk",self.tokens_nk)

        self.tokens_group = self.identify_group_by(self.tokens_nk)
        print("******tokens_group",self.tokens_group)
        self.req_schema = self.indentify_col_tables(self.tokens_group)
        print("******req_schema",self.req_schema)

        if len(self.req_schema)>1:
            self.from_clause = self.join_clause()
            print("******",self.from_clause)
        else:
            for table in self.req_schema.keys():
                self.query_template['FROM'] = table

        self.agg_tokens, self.conditions = self.extract_conditions(self.tokens)
        print("******conditions",self.conditions)
        print("******agg_tokens",self.agg_tokens)


        self.order_by = self.indentify_order_by(self.tokens)
        print("******order_by",self.order_by)

        self.limit = self.indentify_limit(self.tokens)
        print("******limit",self.limit)

        self.select = self.select_validator(self.agg_tokens)


        return self.query_template

    # Driver function - Query generator
    # def sql_parser(self, parsed_dict: dict):
    #     query = ""

    #     return query



In [995]:
qg = QueryGenerator()

q = "Find the patients names, admissions roomnumber  and the insurance provider"

qg.query_template_generator(q)

******input Find the patients names, admissions roomnumber  and the insurance provider
condition trig
merged ['Find', 'the', 'patients_names', 'admissions', 'roomnumber', 'and', 'the', 'insurance', 'provider']
merged ['Find', 'the', 'patients_names', 'admissions', 'roomnumber', 'and', 'the', 'insurance', 'provider']
******tokens ['select', 'patients_names', 'admission', 'roomnumber', 'and', 'insurance', 'provider']
******tokens_nk ['patients_names', 'admission', 'roomnumber', 'insurance', 'provider']
gleice:  ['patients_names', 'admission', 'roomnumber', 'insurance', 'provider']
******tokens_group ['patients_names', 'admission', 'roomnumber', 'insurance', 'provider']
table ['patients', 'admissions', 'insurance']
******req_schema {'patients': {'patientname'}, 'admissions': {'patientid', 'roomnumber'}, 'insurance': {'insuranceprovider', 'patientid'}}
['patients', 'admissions', 'insurance']
OrderedSet(['patients', 'insurance', 'admissions'])
required graph:  [('admissions', 'patients'), (

{'SELECT': 'insurance.patientid, admissions.roomnumber, insurance.insuranceprovider, patients.patientname',
 'FROM': 'admissions\nJOIN patients ON admissions.patientid=patients.patientid \nJOIN insurance ON admissions.insuranceid=insurance.insuranceid \nJOIN patients ON insurance.patientid=patients.patientid \n',
 'WHERE': [],
 'HAVING': [],
 'GROUP BY': -1,
 'ORDER BY': 'insurance.insuranceprovider',
 'LIMIT': -1}

In [987]:
merge_key = "name"

"name" == merge_key or "name" == f"{merge_key}s"

True

In [913]:
get_close_matches("order", ["order_date"])

['order_date']

In [961]:
class SQLGenerator(QueryGenerator):
    def __init__(self):
        super().__init__()
    
    def sql_parser(self, input):
        query_template = super().query_template_generator(self, input)

        query = ""

        for key, values in query_template.items():
            if query_template[key] != -1 or query_template[key]:
                if key!="WHERE" and key != "HAVING":
                    print(key)
                    query += f"{key.upper()} {values}\n"
                    # print(query)
                else:

                    print("cond")
                    for cond in query_template[key]:
                        print((" ").join(cond.values()))
                        query += f"{key} {(" ").join(cond.values())} AND "
                    
                    query = query[:-3] # removing extra AND
                    query +=  "\n"

        query = query.strip()
        query += ";"

        return query


In [966]:
check_dict = {'SELECT': 'products.stock_quantity, order_items.quantity',
 'FROM': 'order_items\nJOIN products ON order_items.product_id=products.product_id \n',
 'WHERE': [{'column': 'quantity', 'operator': '>', 'value': '30'}],
 'HAVING': [],
 'GROUP BY': -1,
 'ORDER BY': 'products.stock_quantity DESC',
 'LIMIT': 3}

query = ""

for key, values in check_dict.items():
    if check_dict[key] != -1 or check_dict[key]:
        if key=="FROM":
            query += f"{key.upper()} {values}"
        elif key!="WHERE" and key != "HAVING":
            print(key)
            query += f"{key.upper()} {values}\n"
            # print(query)
        else:

            print("cond")
            for cond in check_dict[key]:
                print((" ").join(cond.values()))
                query += f"{key} {(" ").join(cond.values())} AND "
            
            query = query[:-3] # removing extra AND
            query +=  "\n"

query = query.strip()
query += ";"
print(query)

SELECT
cond
quantity > 30
cond
GROUP BY
ORDER BY
LIMIT
SELECT products.stock_quantity, order_items.quantity
FROM order_items
JOIN products ON order_items.product_id=products.product_id 
WHERE quantity > 30
GROUP BY -1
ORDER BY products.stock_quantity DESC
LIMIT 3;


In [None]:
class NoSQLGenerator(QueryGenerator):
    def __init__(self):
        super().__init__()
        

    def mongod_parser(self, input):
        query_template = super().query_template_generator(self, input=input)

        mongodb_query = "db"
        prim_table = query_template['FROM'].split()[0]

        conditions

        if query_template['GROUP BY']==-1:
            mongodb_query += f".{prim_table}"
            if "DISTINCT" in query_template['SELECT']:
                monogodb_query += f"distinct"


        



In [970]:
def sql_to_mongodb_query(sql_dict):
    """
    Converts a SQL-like dictionary to a MongoDB aggregation pipeline in dictionary format.

    Args:
        sql_dict (dict): SQL-like query dictionary.

    Returns:
        dict: MongoDB aggregation pipeline in dictionary format.
    """
    # Initialize MongoDB query components
    match = {}
    project = {}
    sort = {}
    group = None
    limit = None

    # Process WHERE clause into $match
    where_clauses = sql_dict.get('WHERE', [])
    for condition in where_clauses:
        column = condition['column']
        operator = condition['operator']
        value = condition['value']

        # Map SQL operators to MongoDB operators
        operator_map = {
            '=': '$eq',
            '>': '$gt',
            '<': '$lt',
            '>=': '$gte',
            '<=': '$lte',
            '<>': '$ne'
        }

        mongo_operator = operator_map.get(operator)
        if mongo_operator:
            match[column] = {mongo_operator: int(value) if value.isdigit() else value}

    # Handle SELECT clause as $project
    select_columns = sql_dict.get('SELECT', '').split(', ')
    for column in select_columns:
        project[column.strip()] = 1

    # Process GROUP BY into $group
    group_by = sql_dict.get('GROUP BY')
    if group_by != -1:  # Check if GROUP BY is specified
        group = {"_id": {col.strip(): f"${col.strip()}" for col in group_by.split(', ') if col.strip() != ''}}
        # Add SELECT columns to $group for aggregation
        for column in select_columns:
            column = column.strip()
            if column not in group["_id"]:
                group[column] = {"$first": f"${column}"}  # Default to $first for other fields

    # Process ORDER BY into $sort
    order_by = sql_dict.get('ORDER BY', '')
    if order_by:
        column, direction = order_by.split(' ')
        sort[column.strip()] = -1 if direction.strip().upper() == 'DESC' else 1

    # Handle LIMIT
    limit = sql_dict.get('LIMIT')

    # Construct MongoDB query dictionary
    query_dict = {}
    if match:
        query_dict['match'] = match
    if group:
        query_dict['group'] = group
    if project:
        query_dict['project'] = project
    if sort:
        query_dict['sort'] = sort
    if limit:
        query_dict['limit'] = int(limit)

    return query_dict


# Example usage
sql_dict = {
    'SELECT': 'products.stock_quantity, order_items.quantity',
    'FROM': 'order_items\nJOIN products ON order_items.product_id=products.product_id \n',
    'WHERE': [{'column': 'quantity', 'operator': '>', 'value': '30'}],
    'HAVING': [],
    'GROUP BY': 'products.stock_quantity',
    'ORDER BY': 'products.stock_quantity DESC',
    'LIMIT': 3
}

mongodb_query = sql_to_mongodb_query(sql_dict)
print(mongodb_query)


{'match': {'quantity': {'$gt': 30}}, 'group': {'_id': {'products.stock_quantity': '$products.stock_quantity'}, 'order_items.quantity': {'$first': '$order_items.quantity'}}, 'project': {'products.stock_quantity': 1, 'order_items.quantity': 1}, 'sort': {'products.stock_quantity': -1}, 'limit': 3}
