In [1]:
!pip install spacy




In [2]:
!python -m spacy download en_core_web_sm


Collecting en-core-web-sm==3.7.1
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.


 I used transformers with predefined patterns for generating query here

In [3]:
import spacy
from transformers import pipeline

nlp = spacy.load("en_core_web_sm")
optimizer = pipeline(task="text-generation", model="facebook/bart-large")

table_mapping = {
    "employee": "employee_table",
    "salary": "salary_column",
    "department": "department_column",
    "hire_date": "hire_date_column",
    "name": "name_column",
    "first_name": "first_name_column",
    "last_name": "last_name_column",
    "age": "age_column",
    "product": "product_table",
    "sales": "sales_table",
    "sales_date": "sales_date_column",
    "product_id": "product_id_column",
    "revenue": "revenue_column",
    "employee_id": "employee_id_column",
    "sales_amount": "sales_amount_column",
    "category": "category_column"
}

def generate_sql_from_text(text):
    doc = nlp(text.lower())

    try:
        if any(token.text in ["top", "highest-paid"] for token in doc):
            query = f"SELECT * FROM {table_mapping['employee']} ORDER BY {table_mapping['salary']} DESC LIMIT 5;"

        elif "youngest" in doc.text and "engineering" in doc.text:
            query = f"SELECT * FROM {table_mapping['employee']} WHERE {table_mapping['department']} = 'Engineering' ORDER BY {table_mapping['age']} ASC LIMIT 1;"

        elif "hired in the year 2019" in doc.text:
            query = f"SELECT * FROM {table_mapping['employee']} WHERE {table_mapping['hire_date']} >= '2019-01-01' AND {table_mapping['hire_date']} <= '2019-12-31';"

        elif "above the average salary" in doc.text:
            query = f"SELECT {table_mapping['name']} FROM {table_mapping['employee']} WHERE {table_mapping['salary']} > (SELECT AVG({table_mapping['salary']}) FROM {table_mapping['employee']});"

        elif "first and last names start with the same letter" in doc.text:
            query = f"SELECT {table_mapping['name']} FROM {table_mapping['employee']} WHERE SUBSTR({table_mapping['first_name']}, 1, 1) = SUBSTR({table_mapping['last_name']}, 1, 1);"

        elif "highest salary in each department" in doc.text:
            query = f"SELECT {table_mapping['department']}, MAX({table_mapping['salary']}) AS max_salary_employee FROM {table_mapping['employee']} GROUP BY {table_mapping['department']};"

        elif all(keyword in doc.text for keyword in ["total revenue", "product category", "past year"]):
            query = f"""
                SELECT {table_mapping['category']}, SUM({table_mapping['revenue']}) AS total_revenue
                FROM {table_mapping['sales']} s
                JOIN {table_mapping['product']} p ON s.{table_mapping['product_id']} = p.{table_mapping['product_id']}
                WHERE s.{table_mapping['sales_date']} >= '2023-01-01' AND s.{table_mapping['sales_date']} <= '2023-12-31'
                GROUP BY {table_mapping['category']};
            """

        elif all(keyword in doc.text for keyword in ["employees", "sales", "past quarter", "salary higher"]):
            query = f"""
                SELECT e.{table_mapping['name']}, e.{table_mapping['department']}
                FROM {table_mapping['employee']} e
                JOIN {table_mapping['sales']} s ON e.{table_mapping['employee_id']} = s.{table_mapping['employee_id']}
                WHERE s.{table_mapping['sales_date']} >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) AND e.{table_mapping['salary']} > (SELECT AVG({table_mapping['salary']}) FROM {table_mapping['employee']})
                GROUP BY e.{table_mapping['name']}, e.{table_mapping['department']};
            """

        elif "products that have never been sold" in doc.text:
            query = f"""
                SELECT *
                FROM {table_mapping['product']} p
                LEFT JOIN {table_mapping['sales']} s ON p.{table_mapping['product_id']} = s.{table_mapping['product_id']}
                WHERE s.{table_mapping['product_id']} IS NULL;
            """

        elif all(keyword in doc.text for keyword in ["highest sales", "each department"]):
            query = f"""
                SELECT {table_mapping['department']}, {table_mapping['employee_id']}, {table_mapping['sales_amount']}
                FROM (
                    SELECT e.{table_mapping['department']}, s.{table_mapping['employee_id']}, SUM(s.{table_mapping['sales_amount']}) AS {table_mapping['sales_amount']},
                           ROW_NUMBER() OVER (PARTITION BY e.{table_mapping['department']} ORDER BY SUM(s.{table_mapping['sales_amount']}) DESC) AS rank
                    FROM {table_mapping['sales']} s
                    JOIN {table_mapping['employee']} e ON s.{table_mapping['employee_id']} = e.{table_mapping['employee_id']}
                    GROUP BY e.{table_mapping['department']}, s.{table_mapping['employee_id']}
                ) AS ranked_sales
                WHERE rank <= 5;
            """

        else:
            query = "Sorry, I couldn't understand the query."

    except Exception as e:
        print(f"Error generating SQL query: {e}")
        query = "Sorry, I couldn't understand the query."

    return query

def suggest_query_optimizations(query):
    try:
        optimization_suggestion = "No specific optimizations suggested for this query pattern."

        if "WHERE" in query:
            if "GROUP BY" not in query:
                optimization_suggestion = "Consider adding indexes on columns used in WHERE clauses for better performance."
            elif "GROUP BY" in query and "ORDER BY" not in query:
                optimization_suggestion = "Ensure columns in the GROUP BY clause are indexed for efficient grouping."
        elif "GROUP BY" in query:
            optimization_suggestion = "Ensure columns in the GROUP BY clause are properly indexed for performance."
        elif "JOIN" in query:
            optimization_suggestion = "Consider adding indexes on columns used in JOIN conditions for faster joins."
        elif "ORDER BY" in query:
            optimization_suggestion = "Ensure columns used in ORDER BY clause are indexed for sorting optimization."

        return optimization_suggestion

    except Exception as e:
        print(f"Error optimizing query: {e}")
        return "No specific optimizations suggested for this query pattern."

def process_questions_from_file(filename):
    try:
        with open(filename, 'r') as file:
            questions = file.readlines()

        results = []
        for question in questions:
            question = question.strip()

            generated_query = generate_sql_from_text(question)
            optimization_suggestion = suggest_query_optimizations(generated_query)

            results.append({
                "Question": question,
                "Generated SQL Query": generated_query,
                "Optimization Suggestion": optimization_suggestion
            })

        return results

    except FileNotFoundError:
        print(f"Error: File '{filename}' not found.")
        return None

def print_results(results):
    if results:
        for result in results:
            print("=" * 50)
            print(f"Question: {result['Question']}")
            print("=" * 50)
            print("Generated SQL Query:")
            print("=" * 50)
            print(f"{result['Generated SQL Query']}")
            print("=" * 50)
            print("Optimization Suggestion:")
            print("=" * 50)
            print(f"{result['Optimization Suggestion']}")
            print("=" * 50)
    else:
        print("No results to display.")

def take_user_input_and_generate_sql():
    while True:
        print("\nEnter your question in natural language to generate an SQL query or type 'questions' to read from questions.txt, 'exit' to quit :")
        user_input = input("> ")

        if user_input.lower() == "exit":
            break
        elif user_input.lower() == "questions":
            filename = "/content/questions.txt"
            results = process_questions_from_file(filename)
            print_results(results)
        else:
            generated_query = generate_sql_from_text(user_input)

            print("\nGenerated SQL Query:")
            print("=" * 50)
            print(f"Question: {user_input}")
            print("=" * 50)
            print(f"{generated_query}")
            print("=" * 50)

            optimization_suggestion = suggest_query_optimizations(generated_query)
            print("Optimization Suggestion:")
            print("=" * 50)
            print(f"{optimization_suggestion}")
            print("=" * 50)

def main():
    take_user_input_and_generate_sql()

if __name__ == "__main__":
    main()


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]


Enter your question in natural language to generate an SQL query or type 'questions' to read from questions.txt, 'exit' to quit :
> questions
Question: Show me the top 5 highest-paid employees.
Generated SQL Query:
SELECT * FROM employee_table ORDER BY salary_column DESC LIMIT 5;
Optimization Suggestion:
Ensure columns used in ORDER BY clause are indexed for sorting optimization.
Question: Who is the youngest employee in the Engineering department?
Generated SQL Query:
SELECT * FROM employee_table WHERE department_column = 'Engineering' ORDER BY age_column ASC LIMIT 1;
Optimization Suggestion:
Consider adding indexes on columns used in WHERE clauses for better performance.
Question: Find all employees who were hired in the year 2019.
Generated SQL Query:
SELECT * FROM employee_table WHERE hire_date_column >= '2019-01-01' AND hire_date_column <= '2019-12-31';
Optimization Suggestion:
Consider adding indexes on columns used in WHERE clauses for better performance.
Question: Show the nam

I tried connecting a db with it

In [4]:
!pip install mysql-connector-python

Collecting mysql-connector-python
  Downloading mysql_connector_python-8.4.0-cp310-cp310-manylinux_2_17_x86_64.whl (19.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m19.4/19.4 MB[0m [31m29.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mysql-connector-python
Successfully installed mysql-connector-python-8.4.0


My schema

In [23]:
import sqlite3

def create_database():
    conn = sqlite3.connect('database.db', timeout=10)
    cursor = conn.cursor()
    cursor.execute("PRAGMA journal_mode=WAL;")
    cursor.execute('DROP TABLE IF EXISTS employees;')

    cursor.execute('''
    CREATE TABLE employees (
        EmployeeID INTEGER PRIMARY KEY,
        FirstName TEXT,
        LastName TEXT,
        Age INTEGER,
        Department TEXT,
        Position TEXT,
        Salary INTEGER,
        HireDate TEXT,
        ManagerID INTEGER
    )
    ''')

    employees_data = [
        (1, "John", "Smith", 28, "Sales", "Manager", 85000, "2015-03-01", 21),
        (2, "Jane", "Doe", 34, "Engineering", "Developer", 95000, "2016-05-23", 6),
        (3, "Emily", "Johnson", 29, "HR", "Generalist", 70000, "2017-08-14", 7),
        (4, "Michael", "Brown", 45, "Marketing", "Coordinator", 80000, "2014-11-11", 20),
        (5, "Sarah", "Williams", 31, "Sales", "Consultant", 85000, "2013-07-25", 1),
        (6, "David", "Jones", 38, "Engineering", "Manager", 95000, "2019-09-17", 1),
        (7, "Laura", "Garcia", 26, "HR", "Manager", 70000, "2018-12-02", 21),
        (8, "James", "Miller", 39, "Marketing", "SEO", 80000, "2020-01-10", 20),
        (9, "Anna", "Davis", 27, "Sales", "Associate", 85000, "2015-03-01", 1),
        (10, "Robert", "Rodriguez", 41, "Engineering", "QA", 95000, "2016-05-23", 6),
        (11, "Linda", "Martinez", 33, "HR", "Coordinator", 70000, "2017-08-14", 7),
        (12, "William", "Hernandez", 30, "Marketing", "Analyst", 80000, "2014-11-11", 20),
        (13, "Elizabeth", "Lopez", 36, "Sales", "Analyst", 85000, "2013-07-25", 1),
        (14, "Richard", "Gonzalez", 42, "Engineering", "DevOps", 95000, "2019-09-17", 6),
        (15, "Jessica", "Wilson", 32, "HR", "Analyst", 70000, "2018-12-02", 7),
        (16, "Joseph", "Anderson", 37, "Marketing", "Associate", 80000, "2020-01-10", 20),
        (17, "Karen", "Thomas", 29, "Sales", "Coordinator", 85000, "2015-03-01", 1),
        (18, "Thomas", "Taylor", 35, "Engineering", "Technical support", 95000, "2016-05-23", 6),
        (19, "Nancy", "Moore", 40, "HR", "Recruiter", 70000, "2017-08-14", 7),
        (20, "Charles", "Jackson", 43, "Marketing", "Manager", 80000, "2014-11-11", 21),
        (21, "Alex", "Johnson", 50, "Management", "CEO", 200000, "2010-01-01", None)
    ]

    cursor.executemany('''
    INSERT INTO employees (EmployeeID, FirstName, LastName, Age, Department, Position, Salary, HireDate, ManagerID)
    VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
    ''', employees_data)

    conn.commit()
    conn.close()
    print("Database created and populated successfully.")

if __name__ == "__main__":
    create_database()


Database created and populated successfully.


In [26]:
import sqlite3
import re

def connect_database(db_file='/content/database.db'):
    """ Connect to SQLite database """
    try:
        conn = sqlite3.connect(db_file)
        print(f"Connected to database: {db_file}")
        return conn
    except sqlite3.Error as e:
        print(f"Error connecting to SQLite database: {e}")
        return None

def execute_query(db_connection, query):
    """ Execute SQL query and fetch results """
    cursor = db_connection.cursor()
    try:
        cursor.execute(query)
        results = cursor.fetchall()
        return results
    except sqlite3.Error as e:
        print(f"Error executing query: {query}")
        print(f"SQLite error: {e}")
        return None

def read_input(file_path):
    """ Read questions from a file """
    try:
        with open(file_path, 'r') as file:
            questions = file.readlines()
        return [q.strip() for q in questions]
    except IOError as e:
        print(f"Error reading file: {file_path}")
        print(f"IOError: {e}")
        return []

def preprocess_query(sql_query):
    """ Preprocess SQL query """
    sql_query = re.sub(r"YEAR\((\w+)\)", r"substr(\1, 1, 4)", sql_query)
    sql_query = re.sub(r"(\w+)\s*>=\s*'(\d{4})-01-01'\s*AND\s*\1\s*<\s*'(\d{4})-01-01'", r"substr(\1, 1, 4) = '\2'", sql_query)
    return sql_query

def generate_sql_from_text(text):
    """ Generate SQL query based on input text """
    try:
        if any(token.text in ["top", "highest-paid"] for token in nlp(text.lower())):
            query = f"SELECT * FROM {table_mapping['employee']} ORDER BY {table_mapping['salary']} DESC LIMIT 5;"

        elif "youngest" in text.lower() and "engineering" in text.lower():
            query = f"SELECT * FROM {table_mapping['employee']} WHERE {table_mapping['department']} = 'Engineering' ORDER BY {table_mapping['age']} ASC LIMIT 1;"

        elif "hired in the year 2019" in text.lower():
            query = f"SELECT * FROM {table_mapping['employee']} WHERE {table_mapping['hire_date']} >= '2019-01-01' AND {table_mapping['hire_date']} <= '2019-12-31';"

        elif "above the average salary" in text.lower():
            query = f"SELECT {table_mapping['name']} FROM {table_mapping['employee']} WHERE {table_mapping['salary']} > (SELECT AVG({table_mapping['salary']}) FROM {table_mapping['employee']});"

        elif "first and last names start with the same letter" in text.lower():
            query = f"SELECT {table_mapping['name']} FROM {table_mapping['employee']} WHERE SUBSTR({table_mapping['first_name']}, 1, 1) = SUBSTR({table_mapping['last_name']}, 1, 1);"

        elif "highest salary in each department" in text.lower():
            query = f"SELECT {table_mapping['department']}, MAX({table_mapping['salary']}) AS max_salary_employee FROM {table_mapping['employee']} GROUP BY {table_mapping['department']};"

        elif all(keyword in text.lower() for keyword in ["total revenue", "product category", "past year"]):
            query = f"""
                SELECT {table_mapping['category']}, SUM({table_mapping['revenue']}) AS total_revenue
                FROM {table_mapping['sales']} s
                JOIN {table_mapping['product']} p ON s.{table_mapping['product_id']} = p.{table_mapping['product_id']}
                WHERE s.{table_mapping['sales_date']} >= '2023-01-01' AND s.{table_mapping['sales_date']} <= '2023-12-31'
                GROUP BY {table_mapping['category']};
            """

        elif all(keyword in text.lower() for keyword in ["employees", "sales", "past quarter", "salary higher"]):
            query = f"""
                SELECT e.{table_mapping['name']}, e.{table_mapping['department']}
                FROM {table_mapping['employee']} e
                JOIN {table_mapping['sales']} s ON e.{table_mapping['employee_id']} = s.{table_mapping['employee_id']}
                WHERE s.{table_mapping['sales_date']} >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) AND e.{table_mapping['salary']} > (SELECT AVG({table_mapping['salary']}) FROM {table_mapping['employee']})
                GROUP BY e.{table_mapping['name']}, e.{table_mapping['department']};
            """

        elif "products that have never been sold" in text.lower():
            query = f"""
                SELECT *
                FROM {table_mapping['product']} p
                LEFT JOIN {table_mapping['sales']} s ON p.{table_mapping['product_id']} = s.{table_mapping['product_id']}
                WHERE s.{table_mapping['product_id']} IS NULL;
            """

        elif all(keyword in text.lower() for keyword in ["highest sales", "each department"]):
            query = f"""
                SELECT {table_mapping['department']}, {table_mapping['employee_id']}, {table_mapping['sales_amount']}
                FROM (
                    SELECT e.{table_mapping['department']}, s.{table_mapping['employee_id']}, SUM(s.{table_mapping['sales_amount']}) AS {table_mapping['sales_amount']},
                           ROW_NUMBER() OVER (PARTITION BY e.{table_mapping['department']} ORDER BY SUM(s.{table_mapping['sales_amount']}) DESC) AS rank
                    FROM {table_mapping['sales']} s
                    JOIN {table_mapping['employee']} e ON s.{table_mapping['employee_id']} = e.{table_mapping['employee_id']}
                    GROUP BY e.{table_mapping['department']}, s.{table_mapping['employee_id']}
                ) AS ranked_sales
                WHERE rank <= 5;
            """

        else:
            query = "Sorry, I couldn't understand the query."

    except Exception as e:
        print(f"Error generating SQL query: {e}")
        query = "Sorry, I couldn't understand the query."

    return query

def suggest_query_optimizations(query):
    """ Suggest optimizations for SQL query """
    try:
        optimization_suggestion = "No specific optimizations suggested for this query pattern."

        if "WHERE" in query:
            if "GROUP BY" not in query:
                optimization_suggestion = "Consider adding indexes on columns used in WHERE clauses for better performance."
            elif "GROUP BY" in query and "ORDER BY" not in query:
                optimization_suggestion = "Ensure columns in the GROUP BY clause are indexed for efficient grouping."
        elif "GROUP BY" in query:
            optimization_suggestion = "Ensure columns in the GROUP BY clause are properly indexed for performance."
        elif "JOIN" in query:
            optimization_suggestion = "Consider adding indexes on columns used in JOIN conditions for faster joins."
        elif "ORDER BY" in query:
            optimization_suggestion = "Ensure columns used in ORDER BY clause are indexed for sorting optimization."

        return optimization_suggestion

    except Exception as e:
        print(f"Error optimizing query: {e}")
        return "No specific optimizations suggested for this query pattern."

def display_results(file_number, question, sql_query, results, suggestions):
    """ Display query results and optimization suggestions """
    print("========================================")
    print(f"              Question {file_number}              ")
    print("========================================\n")
    print(f"Question:\n{question}\n")
    print("Generated SQL Query:\n")
    print(f"{sql_query}\n")
    print("Query Results:\n")
    if isinstance(results, list):
        if results:
            for row in results:
                print(row)
        else:
            print("No results found.")
    else:
        print(f"{results}")
    print("\nOptimization Suggestions:")
    print(f"{suggestions}")
    print("\n" + "="*40 + "\n")

def main():
    while True:
        print("\nEnter 'questions' to read questions from a file or enter a sentence or 'exit' to quit:")
        user_input = input("> ")

        if user_input.lower() == "exit":
            break
        elif user_input.lower() == "questions":
            input_file = '/content/questions.txt'  # Replace with your file path
            questions = read_input(input_file)
        else:
            questions = [user_input]

        db_connection = connect_database()

        if db_connection:
            for i, question in enumerate(questions, start=1):

                sql_query = generate_sql_from_text(question)
                sql_query = preprocess_query(sql_query)

                results = execute_query(db_connection, sql_query)
                suggestions = suggest_query_optimizations(sql_query)

                display_results(i, question, sql_query, results, suggestions)

                print("-" * 80)

            db_connection.close()
        else:
            print("Database connection not established. Exiting...")
            break

if __name__ == "__main__":
    main()




Enter 'questions' to read questions from a file or enter a sentence or 'exit' to quit:
> questions
Connected to database: /content/database.db
              Question 1              

Question:
Show me the top 5 highest-paid employees.

Generated SQL Query:

SELECT * FROM employees ORDER BY salary DESC LIMIT 5;

Query Results:

(21, 'Alex', 'Johnson', 50, 'Management', 'CEO', 200000, '2010-01-01', None)
(2, 'Jane', 'Doe', 34, 'Engineering', 'Developer', 95000, '2016-05-23', 6)
(6, 'David', 'Jones', 38, 'Engineering', 'Manager', 95000, '2019-09-17', 1)
(10, 'Robert', 'Rodriguez', 41, 'Engineering', 'QA', 95000, '2016-05-23', 6)
(14, 'Richard', 'Gonzalez', 42, 'Engineering', 'DevOps', 95000, '2019-09-17', 6)

Optimization Suggestions:
Ensure columns used in ORDER BY clause are indexed for sorting optimization.


--------------------------------------------------------------------------------
              Question 2              

Question:
Who is the youngest employee in the Engineering