## Imports

In [None]:
import json
import spacy
from spacy.tokens import DocBin

import re
import nltk
from nltk.tokenize import word_tokenize
# from nltk.corpus import stopwords
# from nltk.stem import PorterStemmer
from nltk.stem import WordNetLemmatizer
import pandas as pd
from difflib import get_close_matches
from thefuzz import fuzz
import networkx as nx
import heapq
from collections import defaultdict
from ordered_set import OrderedSet
from collections import Counter
from word2number import w2n

In [2]:
nltk.download('punkt')
nltk.download('punkt_tab')
# nltk.download('stopwords')
nltk.download('wordnet')

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/vedanttibrewal/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /Users/vedanttibrewal/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/vedanttibrewal/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [3]:
lemmatizer = WordNetLemmatizer()

## Sample questions

In [4]:
SQL = {
    "products": {
        "pk": ["product_id"],
        "fk": {"categories": "category_id"},
        "product_id": "INT",
        "product_name": "VARCHAR(255)",
        "description": "TEXT",
        "price": "DECIMAL(10, 2)",
        "stock_quantity": "INT",
        "category_id": "INT"
    },
    "categories": {
        "pk": ["category_id"],
        "fk": {},
        "category_id": "INT",
        "category_name": "VARCHAR(100)",
        "description": "TEXT"
    },
    "customers": {
        "pk": ["customer_id"],
        "fk": {},
        "customer_id": "INT",
        "first_name": "VARCHAR(50)",
        "last_name": "VARCHAR(50)",
        "email": "VARCHAR(100)",
        "phone_number": "VARCHAR(20)",
        "address": "TEXT"
    },
    "orders": {
        "pk": ["order_id"],
        "fk": {
            "customers": "customer_id"
        },
        "order_id": "INT",
        "customer_id": "INT",
        "order_date": "DATE",
        "total_amount": "DECIMAL(10, 2)",
        "status": "VARCHAR(20)"
    },
    "order_items": {
        "pk": ["order_item_id"],
        "fk": {
            "orders": "order_id",
            "products": "product_id"
        },
        "order_item_id":"INT",
        "order_id": "INT",
        "product_id": "INT",
        "quantity": "INT",
        "unit_price": "DECIMAL(10, 2)"
    }
}

req = {
    "products": [
        "product_id",
        "product_name",
        "price"
    ],
    "customers": [
        "customer_id",
        "first_name",
        "last_name",
        "email"
    ],
    "orders": [
        "order_id",
        "customer_id",
        "order_date",
        "total_amount"
    ]
}

In [33]:
# Basic COUNT query
q1 = "What is the total number of customers?"
sql1 = "SELECT COUNT(*) FROM customers;"

# Simple SELECT with WHERE clause
q2 = "List all products names with a price greater than $50."
sql2 = "SELECT product_name, price FROM products WHERE price > 50;"

# JOIN operation with ORDER BY
q3 = "Show the first name of customers who have placed orders, sorted by their last name."
sql3 = """
SELECT DISTINCT c.first_name, c.last_name 
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id
ORDER BY c.last_name;
"""

# Aggregation with GROUP BY
q4 = "What is the total amount of orders for each customer?"
sql4 = """
SELECT c.customer_id, c.first_name, c.last_name, SUM(o.total_amount) as total_spent
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id
GROUP BY c.customer_id, c.first_name, c.last_name;
"""

# HAVING clause
q5 = "Which categories have more than 5 products?"
sql5 = """
SELECT c.category_name, COUNT(p.product_id) as product_count
FROM categories c
JOIN products p ON c.category_id = p.category_id
GROUP BY c.category_name
HAVING COUNT(p.product_id) > 5;
"""

# LIMIT clause
q6 = "What are the top 5 most expensive products?"
sql6 = """
SELECT product_name, price
FROM products
ORDER BY price DESC
LIMIT 5;
"""

# Aggregation in WHERE clause
q7 = "Find all orders with more than 3 items."
sql7 = """
SELECT o.order_id, COUNT(oi.product_id) as item_count
FROM orders o
JOIN order_items oi ON o.order_id = oi.order_id
GROUP BY o.order_id
HAVING COUNT(oi.product_id) > 3;
"""

# Multiple JOINs
q8 = "List the top 3 customers who have spent the most on 'Electronics' category products."
sql8 = """
SELECT c.customer_id, c.first_name, c.last_name, SUM(oi.quantity * oi.unit_price) as total_spent
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id
JOIN order_items oi ON o.order_id = oi.order_id
JOIN products p ON oi.product_id = p.product_id
JOIN categories cat ON p.category_id = cat.category_id
WHERE cat.category_name = 'Electronics'
GROUP BY c.customer_id, c.first_name, c.last_name
ORDER BY total_spent DESC
LIMIT 3;
"""

## Constants

In [6]:
stop_words = [
    'a', 'an', 'the',
    'i', 'me', 'my', 'mine', 'myself', 'you', 'your', 'yours', 'yourself', 'he', 'him', 'his', 'himself', 'she', 'her', 'hers', 'herself', 'it', 'its', 'itself', 'we', 'our', 'ours', 'ourselves', 'they', 'them', 'their', 'theirs', 'themselves',
    'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
    'do', 'does', 'did', 'have', 'has', 'had', 'can', 'could', 'shall', 'should', 'will', 'would', 'may', 'might', 'must',
    'about', 'across', 'after', 'against', 'along', 'around', 'at', 'before', 'behind', 'below', 'beneath', 'beside', 'during', 'into', 'near', 'off', 'out', 'through', 'toward', 'under', 'up', 'with',
    'if', 'then', 'because', 'as', 'until', 'while',
    'this', 'that', 'these', 'those', 'such', 'what', 'which', 'whose', 'whoever', 'whatever', 'whichever', 'whomever', 'either', 'neither', 'both',
    'very', 'really', 'always', 'never', 'too', 'already', 'often', 'sometimes', 'rarely', 'seldom', 'again', 'further', 'then', 'once', 'here', 'there', 'where', 'why', 'how',
    # 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'second', 'third', 'fourth', 'fifth', # find a way to implement it in limit
    'few', 'little', 'much', 'enough',
    'yes', 'no', 'not', 'okay', 'ok', 'right', 'sure', 'well', 'uh', 'um', 'oh', 'eh', 'hmm', 'just', 'ever', 'yet', 'etc', 'perhaps', 'maybe', 'list'
]

In [47]:


aggregate_functions = {
    'COUNT': ['count', 'number of', 'quantity of', 'total number of', 'tally', 'enumerate', 'how many'],
    'SUM': ['sum', 'total', 'aggregate', 'combined', 'add up', 'overall', 'cumulative'],
    'AVG': ['average', 'mean', 'typical', 'median', 'expected value', 'norm', 'central tendency'],
    'MAX': ['maximum', 'highest', 'most', 'top', 'peak', 'greatest', 'largest', 'biggest', 'uppermost'],
    'MIN': ['minimum', 'lowest', 'least', 'bottom', 'smallest', 'tiniest', 'least', 'floor'],
    'DISTINCT': ['unique', 'different', 'distinct', 'individual', 'separate', 'non-duplicate', 'exclusive'],
    'GROUP_CONCAT': ['concatenate', 'combine strings', 'join', 'merge text', 'string aggregation', 'text combination'],
    # 'FIRST': ['first', 'initial', 'earliest', 'primary', 'leading', 'opening', 'foremost'], # limit implementation along with number
    # 'LAST': ['last', 'final', 'latest', 'ultimate', 'concluding', 'terminal', 'closing']
}

comparison_operators = {
    '>': ['greater than', 'more than', 'above', 'over', 'exceeding', 'surpassing', 'beyond', 'higher than', 'in excess of'],
    '<': ['less than', 'fewer than', 'below', 'under', 'beneath', 'lower than', 'not as much as', 'smaller than'],
    '=': ['equal to', 'same as', 'identical to', 'matching', 'equivalent to', 'corresponds to', 'is', 'for'],
    '!=': ['not equal to', 'different from', 'excluding', 'not the same as', 'dissimilar to', 'unlike', 'other than'],
    '>=': ['greater than or equal to', 'at least', 'no less than', 'minimum of', 'not below', ' starting from'],
    '<=': ['less than or equal to', 'at most', 'no more than', 'maximum of', 'not above', 'up to'],
    'BETWEEN': ['between', 'in the range of', 'within the bounds of', 'inside the limits of'],
    'IN': ['in', 'within', 'among', 'included in', 'part of', 'contained in', 'one of'],
    'NOT IN': ['not in', 'outside of', 'excluded from', 'not among', 'not part of', 'not contained in'],
    'LIKE': ['like', 'similar to', 'resembling', 'matching pattern', 'corresponding to'],
    'NOT LIKE': ['not like', 'dissimilar to', 'unlike', 'not matching pattern', 'different from pattern']
}

logical_operators = {
    'AND': ['and', 'also', 'as well as', 'in addition to', 'plus', 'together with', 'along with', 'including'],
    'OR': ['or', 'alternatively', 'either', 'otherwise', 'else', 'and/or'],
    'NOT': ['not', 'except', 'excluding', 'other than', 'but not', 'save for', 'apart from']
}

In [8]:
data_types = {
    'INTEGER': ['integer', 'int', 'whole number', 'numeric'],
    'FLOAT': ['float', 'decimal', 'real number', 'fractional number'],
    'VARCHAR': ['string', 'text', 'characters', 'alphanumeric'],
    'DATE': ['date', 'calendar day', 'day'],
    'TIMESTAMP': ['timestamp', 'date and time', 'moment', 'point in time'],
    'BOOLEAN': ['boolean', 'true/false', 'yes/no', 'binary']
}

In [39]:
# order matters
sql_clauses = {
    'FROM': ['from', 'in', 'out of', 'sourced from', 'derived from', 'based on'],
    'WHERE': ['where', 'for which', 'that have', 'meeting the condition', 'satisfying', 'fulfilling'],
    'GROUP BY': ['group by', 'categorize by', 'classify by', 'organize by', 'arrange by', 'cluster by', 'for each', 'broken down by', 'per', 'by'], # check about 'by'
    'HAVING': ['having', 'with the condition', 'subject to', 'meeting the criteria'],
    'ORDER BY': ['order by', 'sort by', 'arrange by', 'rank by', 'sequence by'],
    'LIMIT': ['limit', 'top', 'first', 'restrict to', 'cap at', 'only show'],
    'JOIN': ['join', 'combine', 'merge', 'connect', 'link', 'associate'],
    'UNION': ['union', 'combine', 'merge', 'incorporate', 'consolidate', 'unite'],
    'INTERSECT': ['intersect', 'in common', 'shared by', 'mutual', 'overlapping'],
    'EXCEPT': ['except', 'subtract', 'exclude', 'remove', 'omit', 'leave out'],
    'SELECT': ['show', 'list all', 'list', 'give', 'return', 'fetch', 'retrieve', 'get', 'find', 'which', 'what is the', 'what is', 'what are', 'what'], # 'display
}

In [38]:
ALL_KEYWORDS = set(list(aggregate_functions.keys())) | (set(list(logical_operators.keys()))) | (set(sql_clauses.keys())) | (set(comparison_operators.keys()))
ALL_KEYWORDS = list(map(str.lower, ALL_KEYWORDS))

In [66]:
ALL_KEYWORDS

['not',
 '<',
 'count',
 'having',
 'avg',
 'and',
 'join',
 'max',
 'select',
 'not like',
 'min',
 'union',
 'like',
 '>=',
 'order by',
 'limit',
 'from',
 'not in',
 'distinct',
 'except',
 'or',
 '=',
 'between',
 'last',
 '!=',
 'intersect',
 'group by',
 'first',
 'where',
 '>',
 'in',
 'sum',
 'group_concat',
 '<=']

In [41]:
constants = {"ALL_KEYWORDS": ALL_KEYWORDS,
             "sql_clauses": sql_clauses,
             "data_types": data_types,
             "logical_operators": logical_operators,
             "comparison_operators": comparison_operators,
             "aggregate_functions": aggregate_functions,
             "stop_words": stop_words}

## Preprocess

In [None]:
def replace_string(text, replace_dict):
    # Create a reverse mapping from values to keys
    value_to_key = {}
    for key, values in replace_dict.items():
        for value in values:
            value_to_key[value.lower()] = key.lower()
    
    # Create a regex pattern for matching words
    pattern = r'\b(' + '|'.join(re.escape(word) for word in value_to_key.keys()) + r')\b'
    
    # Function to replace matched words
    def replace_word(match):
        return value_to_key.get(match.group(0).lower(), match.group(0))
    
    # Perform the replacement
    replaced_text = re.sub(pattern, replace_word, text, flags=re.IGNORECASE)
    
    return replaced_text

def replace_numbers(token):
    try:
        return w2n.word_to_num(token)
    except:
        return token

def preprocess_text(text, constants):
    # Lowercase
    # text = text.lower() # change it to lower while comparisson
    
    # Remove punctuation
    # pattern = r'[^\w\s()*\d-]|(?<!\d)-(?!\d)'
    # text = re.sub(pattern, '', text)
    text = re.sub(r'\?', '', text)
    text = re.sub(r'\$', '', text)
    text = re.sub(r'.$', '', text)
    text = re.sub(r',', '', text)
    pattern = r'\b(\d+)-(\d+)\b'
    text = re.sub(pattern, r'\1 to \2', text)

    # substitute words
    text = replace_string(text, sql_clauses)
    text = replace_string(text, aggregate_functions)
    # from \d+ -> >=
    text = replace_string(text, comparison_operators)
    between_pattern = r'from \b(\w+|\d+)\b to \b(\w+|\d+)\b | \b(\w+|\d+)\b to \b(\w+|\d+)\b'
    text = re.sub(between_pattern, r" between(\1, \2)", text)
    text = replace_string(text, logical_operators)

    

    # Tokenization
    # tokens = word_tokenize(text)
    tokens = text.split()

    # "name" condition
    result = []
    i = 0
    while i < len(tokens) - 1:  # Changed to len(tokens) - 1
        if tokens[i].lower() not in constants['ALL_KEYWORDS'] and (tokens[i+1].lower() == "name" or tokens[i+1].lower() == "names"):
            # print(f"{tokens[i]}_{tokens[i+1]}")
            result.append(f"{tokens[i]}_{tokens[i+1]}")
            i += 2
        else:
            result.append(tokens[i])
            i += 1
    if i < len(tokens):  # Add any remaining token
        result.append(tokens[i])

    # print(result)
    tokens = result

    tokens = [str(replace_numbers(token)) for token in tokens]
    print(tokens)

    # Remove stopwords
    #  load custom stop words
    tokens = [token for token in tokens if token.lower() not in stop_words]

    # Lemmatization (optional) // important for similarity
    tokens = [lemmatizer.lemmatize(token) for token in tokens]
    
    return tokens# ' '.join(tokens)

In [251]:
q1_tok = preprocess_text(q1, constants)
q2_tok = preprocess_text(q2, constants)
q4_tok = preprocess_text(q4, constants)
print(q1)
print(q2)
print(q4)
q1_tok, q2_tok, q4_tok

['select', 'count', 'customer']
['select', 'products_names', 'with', 'a', 'price', '>', '50']
['select', 'sum', 'amount', 'of', 'orders', 'group', 'by', 'custome']
What is the total number of customers?
List all products names with a price greater than $50.
What is the total amount of orders for each customer?


(['select', 'count', 'customer'],
 ['select', 'products_names', 'price', '>', '50'],
 ['select', 'sum', 'amount', 'of', 'order', 'group', 'by', 'custome'])

## Identify the keywords

In [86]:
def remove_keywords(tokens, keywords):
    token_without_key = [token for token in tokens if token not in keywords]

    return token_without_key


In [261]:
q1_nk_tok = remove_keywords(q1_tok, ALL_KEYWORDS)
q2_nk_tok = remove_keywords(q2_tok, ALL_KEYWORDS)
q4_nk_tok = remove_keywords(q4_tok, ALL_KEYWORDS)
q1_nk_tok, q2_nk_tok, q4_nk_tok

(['customer'],
 ['products_names', 'price', '50'],
 ['amount', 'of', 'order', 'group', 'by', 'custome'])

## Identify Group By

In [275]:
def identify_group_by(tokens, schema):
    result = []
    i = 0
    while i < len(tokens) - 2:  # Changed to len(tokens) - 1

        if tokens[i].lower() == "group" and tokens[i+1].lower() == "by":
            # print(f"{tokens[i]}_{tokens[i+1]}")
            for table, columns in schema.items():
                for column in schema[table].keys():
                # for column in columns.keys():
                    if column=='pk' or column=='fk':
                        continue
                    else:
                        match = get_close_matches(tokens[i+2], columns)
                        if match:
                            # print(tokens[i+2], ":", match)                        
                            result.append(f"{tokens[i]} {tokens[i+1]} {match[0]}")
            i += 3
        else:
            result.append(tokens[i])
            i += 1
    if i < len(tokens):  # Add any remaining token
        result.extend(tokens[i:])
    
    return list(dict.fromkeys(result)) # removing duplicates


In [306]:
print(q1)
q1_nk_tok = remove_keywords(q1_tok, ALL_KEYWORDS)
print(q1_nk_tok)
group_tok1 = identify_group_by(q1_nk_tok, SQL)
group_tok1

What is the total number of customers?
['customer']


['customer']

In [278]:
print(q2)
q2_nk_tok = remove_keywords(q2_tok, ALL_KEYWORDS)
print(q2_nk_tok)
group_tok2 = identify_group_by(q2_nk_tok, SQL)
group_tok2

List all products names with a price greater than $50.
['products_names', 'price', '50']


['products_names', 'price', '50']

In [298]:
fuzz.ratio("products_names", "products")

73

In [277]:
print(q4)
q4_nk_tok = remove_keywords(q4_tok, ALL_KEYWORDS)
print(q4_nk_tok)
group_tok4 = identify_group_by(q4_nk_tok, SQL)
group_tok4

What is the total amount of orders for each customer?
['amount', 'of', 'order', 'group', 'by', 'custome']


['amount', 'of', 'order', 'group by customer_id']

## Identify Columns and Tables

In [316]:
# group by before this
def indentify_table(tokens, schema):
    
    detected_table = [[], [], []] # [[table names], [similarities], [tokens]]
    
    for i, token in enumerate(tokens):
        if token.lower() == "from":
            table_name = tokens[i + 1]
            if table_name in schema:
                print(f"Table identified: {table_name}")
                return table_name
        
        for table in schema.keys():
            similarity = fuzz.ratio(token, table)
            if token in detected_table[2]:
                # print(f"Table identified: {token} {table} {similarity}")
                # print(detected_table)
                i = detected_table[2].index(token)
                if similarity >= detected_table[1][i]:
                    # detected_table[0] = table
                    detected_table[1][i] = similarity
                    detected_table[0][i] = table
                # print(f"Table identified: {table} {similarity}")
            else:
                if similarity >= 70: # initial threshold
                    detected_table[0].append(table)
                    detected_table[1].append(similarity)
                    detected_table[2].append(token)


    for i in range(len(detected_table[0])):
        if detected_table[1][i] > 85:
            tokens.remove(detected_table[2][i])    
        
    # print("dsfsdgd",detected_table[0])

    return detected_table[0]


def indentify_col_tables(tokens, schema):
    res = dict()

    identified_table = indentify_table(tokens, schema)
    print(identified_table)
    if not tokens:
        res[identified_table[0]] = ['*']
        return res

    if identified_table:
        for table in identified_table:
            res[table] = dict()
            for token in tokens:
                for column in schema[table].keys():
                    # for column in columns.keys():
                    if column=='pk' or column=='fk':
                        continue
                    else:
                        similarity = fuzz.ratio(token, column)
                        if similarity >= 50:
                            if res[table].get(token):
                                # print("case 1")
                                # print(token, ":", column, similarity)
                                old_similarity = list(res[table][token].values())[0]
                                # print("sim", old_similarity)
                                if similarity > old_similarity:
                                    res[table][token] = {column: similarity}
                            else:
                                # print("case 2")
                                # print(token, ":", column, similarity)
                                res[table][token] = {column: similarity}
                            # print(res)
                            # if res.get(identified_table):
                            #     res[identified_table].append(column)
                            # else:
                            #     res[identified_table] = [column]
    else:
        for token in tokens:
            for table, columns in schema.items():
                for column in schema[table].keys():
                # for column in columns.keys():
                    if column=='pk' or column=='fk':
                        continue
                    else:
                        match = get_close_matches(token, columns)
                        if match:
                            print("*"*4, token, ":", match[0])
                            # if token in columns: # replace by fuzzy logic
                            # print(f"Column identified: {column} in table {table}")
                            if res.get(table):
                                res[table].append(match[0])
                            else:
                                res[table] = [match[0]]

    if identified_table:
        result = {}
        for table, token in res.items():
            cols = []
            for col in token.values():
                cols.extend(col.keys())
            result[table] = cols

        result = {key: value for key, value in result.items() if value != []}
        return result
    

    res = {key: value for key, value in res.items() if value != []}    

    return res



In [285]:
fuzz.ratio("unit_price", "price")

67

In [301]:
print(q2)
print(q2_nk_tok)
q2_nk_tok = remove_keywords(q2_tok, ALL_KEYWORDS)
req = indentify_col_tables(group_tok2, SQL)
req

List all products names with a price greater than $50.
['products_names', 'price', '50']
['products']
['products']


{'products': ['product_name', 'price']}

In [315]:
print(q1)
print(q1_nk_tok)
group_tok1 = identify_group_by(q1_nk_tok, SQL)
req = indentify_col_tables(group_tok1, SQL)
req

What is the total number of customers?
['customer']
dsfsdgd ['customers']
['customers']


{'customers': ['*']}

In [322]:
print(q4)
q4_nk_tok = remove_keywords(q4_tok, ALL_KEYWORDS)
print(q4_nk_tok)

group_tok4 = identify_group_by(q4_nk_tok, SQL)
req = indentify_col_tables(group_tok4, SQL)
req

What is the total amount of orders for each customer?
['amount', 'of', 'order', 'group', 'by', 'custome']
['orders']


{'orders': ['total_amount', 'customer_id']}

In [319]:
# indentify_table(q1_nk_tok, SQL)
fuzz.ratio("group by customer_id", "customer_id")

71

## JOIN

In [22]:
def create_graph(db_schema, directional=False):
    if directional:
        G = nx.DiGraph()
    else:
        G = nx.Graph()

    # Add nodes and edges based on the SQL dictionary
    for table, details in db_schema.items():
        G.add_node(table)  # Add table as a node
        for fk_table in details['fk']:  # Iterate through foreign keys
            G.add_edge(table, fk_table)  # Create an edge from current table to foreign key table

    # Get nodes and edges for verification
    nodes = list(G.nodes)
    edges = list(G.edges)

    # Output nodes and edges
    # print(nodes)
    # print(edges)

    return G

In [23]:
def required_tables_graph(G, start, end, required_tables):
    def dijkstra(graph, start, end):
        distances = {node: float('infinity') for node in graph}
        distances[start] = 0
        pq = [(0, start)]
        previous = {node: None for node in graph}

        while pq:
            current_distance, current_node = heapq.heappop(pq)

            if current_node == end:
                path = []
                while current_node:
                    path.append(current_node)
                    current_node = previous[current_node]
                return path[::-1], current_distance

            for neighbor in graph[current_node]:
                distance = current_distance + 1  # All edges have weight of 1
                if distance < distances[neighbor]:
                    distances[neighbor] = distance
                    previous[neighbor] = current_node
                    heapq.heappush(pq, (distance, neighbor))

        return None, float('infinity')

    required_tables = set(required_tables) - {start, end}
    best_path = None
    best_distance = float('infinity')

    def dfs(current_path, current_distance, remaining_required):
        nonlocal best_path, best_distance

        if not remaining_required:
            path, distance = dijkstra(G, current_path[-1], end)
            if path:
                total_path = current_path + path[1:]
                total_distance = current_distance + distance
                if total_distance < best_distance:
                    best_path = total_path
                    best_distance = total_distance
            return

        for node in remaining_required:
            path, distance = dijkstra(G, current_path[-1], node)
            if path:
                new_path = current_path + path[1:]
                new_distance = current_distance + distance
                new_remaining = remaining_required - {node}
                dfs(new_path, new_distance, new_remaining)

    dfs([start], 0, required_tables)

    return best_path, best_distance

# G = create_graph(SQL, directional=False)
# # Set start, end, and required nodes
# required_tables = ['categories', 'products', 'orders']
# start = required_tables[0]
# end = required_tables[2]

# # Find the shortest path
# required_tables, distance = required_tables_graph(G, start, end, required_tables)

# print(f"Shortest path: {' -> '.join(required_tables) if required_tables else 'No path found'}")
# # Shortest path: categories -> products -> order_items -> orders

# print(f"Total distance: {distance}")
# # Total distance: 3

In [24]:
G = create_graph(SQL, directional=True)
list(G.edges)

[('products', 'categories'),
 ('orders', 'customers'),
 ('order_items', 'orders'),
 ('order_items', 'products')]

In [25]:
def graph_sort(edges):
    counts = Counter(t[0] for t in edges)
    
    # Sort tuples based on the counts in descending order
    sorted_tuples = sorted(edges, key=lambda x: counts[x[0]], reverse=True)
    
    return sorted_tuples


In [155]:
graph_sort(G.edges)

[('order_items', 'orders'),
 ('order_items', 'products'),
 ('products', 'categories'),
 ('orders', 'customers')]

In [26]:
def join_clause(req_schema: dict, db_schema: dict):
    clause = ""

    # not required
    # max_cols = 0
    # prim_table = None
    # # change logic based on foreign key numbers
    # for table, cols in req_schema.items():
    #     if len(cols) > max_cols:
    #         max_cols = len(cols)
    #         prim_table = table
    
    # # pk_col = db_schema[prim_table]['pk']
    # print(prim_table)

    # create graph from Database schema
    db_graph = create_graph(db_schema)
    db_dir_graph = create_graph(db_schema, directional=True)

    # print(db_dir_graph)

    required_tables = list(req_schema.keys())

    print(required_tables)

    min_dist = float('inf')
    for st_table in required_tables:
        for end_table in required_tables:
            sub_graph, distance = required_tables_graph(db_graph, st_table, end_table, required_tables)
            if distance < min_dist:
                join_graph = OrderedSet(sub_graph)
                min_dist = distance

    print(join_graph)

    req_graph = []

    for edge in db_dir_graph.edges:
        # assumption binary relations between tables
        if edge[0] in join_graph and edge[1] in join_graph:
            req_graph.append(edge)

    req_graph = graph_sort(req_graph)
    print("required graph: ", req_graph)

    clause += f"FROM {req_graph[0][0]}\n"

    for i, edge in enumerate(req_graph):
        table1 = edge[0]
        table2 = edge[1]
        print(edge)
        # if table1 in db_schema[table2]['fk']:
        #     # print("1")
        #     fk_col = db_schema[table2]['fk'][table1] # foreign key corresponding to 2nd table  primary key
        #     pk_col = db_schema[table1]['pk'][0] # primary key of 2nd table
        #     clause += f"JOIN {table1} ON {table2}.{fk_col}={table1}.{pk_col} \n"
        # else:
        fk_col = db_schema[table1]['fk'][table2] # foreign key corresponding to 2nd table  primary key
        pk_col = db_schema[table2]['pk'][0] # primary key of 2nd table
        clause += f"JOIN {table2} ON {table1}.{fk_col}={table2}.{pk_col} \n"
        
    return clause

In [233]:
req

{'orders': ['total_amount'], 'customers': []}

In [234]:
if len(req.keys()) > 1:
    print(join_clause(req, SQL))
else:
    print("No Join required")

['orders', 'customers']
OrderedSet(['orders', 'customers'])
required graph:  [('orders', 'customers')]
('orders', 'customers')
FROM orders
JOIN customers ON orders.customer_id=customers.customer_id 



In [174]:
SQL2 = {
    "books": {
        "pk": ["book_id"],
        "fk": {"authors": "author_id"},
        "book_id": "INT",
        "title": "VARCHAR(255)",
        "author_id": "INT",
        "isbn": "VARCHAR(13)",
        "publication_year": "INT",
        "price": "DECIMAL(10, 2)",
        "stock_quantity": "INT"
    },
    "authors": {
        "pk": ["author_id"],
        "fk": {},
        "author_id": "INT",
        "first_name": "VARCHAR(50)",
        "last_name": "VARCHAR(50)",
        "birth_date": "DATE",
        "nationality": "VARCHAR(50)",
        "biography": "TEXT"
    }
}

req2 = {
    "books": [
        "book_id",
        "title",
        "stock_quantity"
    ],
    "authors": [
        "author_id",
        "first_name",
        "last_name",
    ]
}

print(join_clause(req2, SQL2))

['books', 'authors']
OrderedSet(['books', 'authors'])
required graph:  [('books', 'authors')]
('books', 'authors')
FROM books
JOIN authors ON books.author_id=authors.author_id 



## Identify Condition

In [82]:
print(q1)
print(preprocess_text(q1, constants))
print((q2))
print(preprocess_text(q2, constants))
print((q3))
print(preprocess_text(q3, constants))
print((q4))
print(preprocess_text(q4, constants))
print((q5))
print(preprocess_text(q5, constants))
print((q6))
print(preprocess_text(q6, constants))
print((q7))
print(preprocess_text(q7, constants))
print((q8))
print(preprocess_text(q8, constants))

What is the total number of customers?
['select', 'count', 'customer']
List all products names with a price greater than $50.
products_names
['select', 'products_names', 'price', '>', '50']
Show the first name of customers who have placed orders, sorted by their last name.
['select', 'limit', 'name', 'of', 'customer', 'who', 'placed', 'order', 'sorted', 'group', 'by', 'last', 'name']
What is the total amount of orders for each customer?
['select', 'sum', 'amount', 'of', 'order', 'group', 'by', 'custome']
Which categories have more than 5 products?
['select', 'category', '>', '5', 'product']
What are the top 5 most expensive products?
['select', 'limit', '5', 'max', 'expensive', 'product']
Find all orders with more than 3 items.
['select', 'all', 'order', '>', '3', 'item']
List the top 3 customers who have spent the most on 'Electronics' category products.
['select', 'limit', '3', 'customer', 'who', 'spent', 'max', 'on', "'Electronics'", 'category', 'product']


In [60]:
# import re

# difference between where and having
def identify_where_condition(tokens, req):
    conditions = []
    for i, token in enumerate(tokens):
        if token.lower() == "where":
            conditions.append(" ".join(tokens[i+1:])) # refine with 
            break

    print(f"Conditions: {conditions}")
    # Example Output: Conditions: ["age > 30"]


def indentify_having_condition(tokens):
    # need to check for aggregate functions in tokens
    conditions = []
    for i, token in enumerate(tokens):
        if token.lower() == "having":
            conditions.append(" ".join(tokens[i+1:])) # refine with 
            break


## identify 