In [12]:
import nltk
from nltk.tokenize import word_tokenize
from nltk.tag import pos_tag

In [19]:
nltk.download("punkt")
nltk.download("punkt_tab")
nltk.download("averaged_perceptron_tagger")
nltk.download("averaged_perceptron_tagger_eng")

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Ali\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\Ali\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     C:\Users\Ali\AppData\Roaming\nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     C:\Users\Ali\AppData\Roaming\nltk_data...
[nltk_data]   Unzipping taggers\averaged_perceptron_tagger_eng.zip.


True

In [None]:
class NLQtoSQL:
    def __init__(self):
        self.table_map = {
            "employees": ["employee", "employees", "worker", "workers", "staff"],
            "orders": ["order", "orders"],
            "products": ["product", "products", "item", "items"]
        }

        self.aggregation_map = {
            "count": ["how many", "count", "number of"],
            "sum": ["total", "sum"]
        }

    def detect_table(self, tokens):
        tokens_l = [t.lower() for t in tokens]
        for table, keywords in self.table_map.items():
            for kw in keywords:
                if kw in tokens_l:
                    return table
        return None

    def detect_aggregation(self, text):
        text_l = text.lower()
        for agg, keys in self.aggregation_map.items():
            for k in keys:
                if k in text_l:
                    return agg
        return None

    def detect_conditions(self, tokens):
        conds = []
        tokens_l = [t.lower() for t in tokens]
        if "where" in tokens_l:
            idx = tokens_l.index("where")
            condition = " ".join(tokens[idx+1:])
            conds.append(condition)
        return conds

    def build_sql(self, text):
        tokens = word_tokenize(text)
        tagged = pos_tag(tokens)

        table = self.detect_table(tokens)
        agg = self.detect_aggregation(text)
        conds = self.detect_conditions(tokens)

        if table is None:
            raise ValueError("Could not detect table from input.")

        # SELECT part
        if agg == "count":
            select_clause = "SELECT COUNT(*)"
        elif agg == "sum":
            select_clause = "SELECT SUM(*)"
        else:
            select_clause = "SELECT *"

        query = f"{select_clause} FROM {table}"

        if conds:
            query += f" WHERE {conds[0]}"

        return query + ";"

In [21]:
if __name__ == "__main__":
    parser = NLQtoSQL()
    user_question = "How many employees are there where age > 30?"
    sql = parser.build_sql(user_question)
    print(sql)

SELECT COUNT(*) FROM employees WHERE age > 30 ?;
