In [60]:
import pymongo
import pymysql
import pandas as pd
import nltk
from nltk import pos_tag, word_tokenize
from generate_sql_examples_final import get_mysql_metadata
from mongo_examples_testing import get_mongodb_metadata
import re
from tabulate import tabulate
from input_process import match_query_pattern

In [61]:
example_queries = {
    (0, "SELECT"): "Show me all the records in the course table",  # Select
    (1, "UPLOAD"): "Upload this data '../data/sql_data/students.csv'",  # Upload
    (2, "EXAMPLE", "SELECT", "SQL"): "Show me examples of sql queries",  # Example
    (3, "EXAMPLE", "SELECT", "GROUP BY", "SQL"): "Show me examples of sql queries using group by",  # Example group by
    (4, "SELECT", "WHERE"): "Get the details of students who work in the Calculus department",  # Where
    (5, "SELECT", "AGGREGATE"): "How many employees are there in the company",  # Select count(*)
    (6, "SELECT", "WHERE"): "Show me the names and salaries of employees earning more than $50000",  # Where
    (7, "SELECT", "ORDER BY"): "List all employees, sorted by their hire date in descending order",  # Order by
    (8, "AGGREGATE", "WHERE"): "Find the average salary of employees in the Engineering department",  # Select Avg() where
    (9, "SELECT", "GROUP BY"): "How many employees are there in each department",  # Group by
    (10, "SELECT", "HAVING"): "Show me the departments where the total salaries of employees exceed $100000",  # Having
    (11, "SELECT", "HAVING"): "List the products where the average price is less than $50",  # Having
    (12, "SELECT", "ORDER BY", "LIMIT"): "What are the top 10 highest paid employees",  # Order by limit
    (13, "SELECT", "HAVING"): "Which customers have placed more than 5 orders",  # Having
    (14, "SELECT", "GROUP BY"): "Find the total revenue generated by each product category",  # Group by
    (15, "SELECT", "WHERE", "ORDER BY"): "List all orders placed in the last 30 days, sorted by order date",  # Where Order by
    (16, "SELECT", "JOIN"): "Join the employees table with the departments table to find department names",  # Join
    (17, "SELECT", "JOIN"): "List all customers and their orders",  # Join
    (18, "SELECT", "GROUP BY"): "How many products are in stock for each supplier",  # Group by
    (19, "SELECT", "WHERE"): "Show me the products where the stock quantity is less than 10",  # Where
    (20, "UPLOAD"): "Upload the file '../data/new_products.json' into the database",  # Upload
    (21, "GROUP BY", "AGGREGATE"): "Find the maximum salary in each department",  # Group by Max()
    (22, "SELECT", "GROUP BY"): "List employees grouped by their job titles",  # Group by
    (23, "SELECT", "WHERE"): "Show me all orders where the total exceeds $1000",  # Where
    (24, "SELECT", "ORDER BY"): "List all customers sorted by their last purchase date",  # Order by
    (25, "EXAMPLE", "SELECT", "SQL"): "Show me examples of sql queries for finding duplicates",  # Example
    (26, "AGGREGATE", "WHERE"): "How many employees were hired in the last year",  # Where Count()
    (27, "GROUP BY", "AGGREGATE"): "Find the minimum price of products in each category",  # Group by Min()
    (28, "SELECT", "JOIN"): "Join the orders table with the customers table to get customer details",  # Join
    (29, "SELECT", "WHERE"): "List all records in the products table where price is greater than $20",  # Where
    (30, "SELECT", "ORDER BY", "LIMIT"): "Show me the names of employees earning the top 5 highest salaries",  # Order by limit
    (31, "GROUP BY", "AGGREGATE"): "Find the sum of sales for each region",  # Group by Sum()
    (32, "SELECT", "ORDER BY"): "List all products sorted by their price in ascending order",  # Order by
    (33, "SELECT", "EXAMPLE", "MONGODB"): "Give me examples of nosql queries.",  # Order by
}

login_info = {
    'endpoint': "localhost",
    'username': "root",
    'password': "MySQLDBP455",
    'sql_database_name': "chatdb",
    'mongo_username': 'mdmolnar',
    'mongo_password': 'AtM0nG0d1452',
    'mongo_database_name': "ChatDB"
}


In [140]:
aggregate_words = {
    'MIN': ['minimum', 'smallest', 'lowest', 'least', 'min'],
    'MAX': ['maximum', 'largest', 'highest', 'greatest', 'max'],
    'AVG': ['average', 'mean', 'avg'],
    'COUNT': ['count', 'number', 'total', 'many'],
    'SUM': ['sum', 'total', 'add', 'combined']
}
import string
group_words = ["group by", "aggregate", "group", "grouping", "each", "total", "categorize",
                 "partition", "classify", "segment", "cluster", "bucket", 'grouped']

where_words = {
    'LESS_THAN': ['less', 'fewer', 'below', 'under', 'lower than', 
                  'smaller than', 'not exceeding', 'underneath', '<'],
    'GREATER_THAN': ['more', 'greater', 'above', 'over', 'higher than', 
                     'exceeds', 'bigger than', 'larger than', '>'],
    'EQUAL_TO': ['equal', 'exactly', 'equals', 'same as', '=',
                 'identical to', 'matches', 'equivalent to'],
    'NOT_EQUAL_TO': ['not equal', 'different', 'not the same as', 
                     'does not equal', 'unequal', 'not identical', '!=']
}

order_words = ["order by", "sort", "sort by", "ordered", "order", "ascending", "descending", "rank", "arrange", "prioritize",
               "sequence", "order", "hierarchy", "top", "bottom", "sorted", 'biggest', 'smallest']

desc_words = ['descending', 'desc', 'biggest to smallest', 'reverse', 'biggest']

all_cols = ['all columns', 'every column', 'each column']

limit_words = ['top', 'bottom', 'highest', 'lowest', 'biggest', 'smallest', 'limit', 'only', 'list']

def execute_sql_query(user_input, keywords, login_info):
    exclude = string.punctuation.translate(str.maketrans('', '', "_()"))
    user_list = user_input.translate(str.maketrans('', '', exclude)).split()

    print(user_list)
    
    mdata = get_mysql_metadata(login_info)

    tables = [table for table in mdata.keys()]

    assoc_tables = []

    for word in user_list:
        if word in tables:
            assoc_tables.append(word)
    print('Generating query for table:', assoc_tables)

    assoc_cols = []
    col_idx_mdata = []
    for i in range(len(assoc_tables)):
        table_columns = []
        for j in range(len(mdata[assoc_tables[i]])):
            table_columns.append(mdata[assoc_tables[i]][j]['name'])

        print('Columns in associated table:', table_columns)
        for word in user_list:
            if word in table_columns:
                assoc_cols.append(word)
                idx = table_columns.index(word)
                col_idx_mdata.append(idx)

    unique_vals = []
    for i in assoc_tables:
        for j in col_idx_mdata:
            unique_vals += mdata[i][j]['unique_values']

    unique_vals = [str(item) if isinstance(item, (int, float)) else item for item in unique_vals]

    for i, word in enumerate(user_list):
        try:
            bigram = word + f' {user_list[i+1]}'
        except:
            pass
        if bigram in all_cols:
            assoc_cols.append('*')
    if len(assoc_cols) == 0:
        assoc_cols.append('*')
    print('Creating query with columns:', assoc_cols)

    query = ''
    from_in_query = False

    if 'SELECT' in keywords:
        if '*' in assoc_cols:
            query += 'SELECT * '
        else:
            query += 'SELECT ' + ', '.join(assoc_cols) + ' '
    
    elif 'AGGREGATE' in keywords:
        processes = []

        for key, value in aggregate_words.items():
            for i, word in enumerate(user_list):
                if word.lower() in value:
                    processes.append(key)
                    if user_list[i+1] in assoc_cols:
                        processes.append(user_list[i+1])
                    elif user_list[i+2] in assoc_cols:
                        processes.append(user_list[i+2])


        if len(processes)%2 ==1:
            processes.append('*')
        
        query += 'SELECT '

        for key, values in aggregate_words.items():
            if key in processes:
                idx = processes.index(key)

                if idx == len(processes)-2:
                    query += f'{key}({processes[idx+1]})'
                else:
                    query += f'{key}({processes[idx+1]}), '
                assoc_cols.append(f'{key}({processes[idx+1]})')

        query += ' '

    if not from_in_query:
        query += 'FROM ' + ', '.join(assoc_tables) + ' '
        from_in_query = True

    if 'JOIN' in keywords:
        pass

    if 'WHERE' in keywords:
        processes = []

        for key, values in where_words.items():
            for i, word in enumerate(user_list):
                if i < len(user_list)-1:
                    bigram = word + ' ' + user_list[i+1]
                else:
                    bigram = word
        
                if word.lower() in values:
                    if user_list[i-1] in assoc_cols:
                        processes.append(user_list[i-1])
                    elif user_list[i-2] in assoc_cols:
                        processes.append(user_list[i-2])
                    processes.append(key)
                    if user_list[i+1] in unique_vals:
                        processes.append(user_list[i+1])
                    elif user_list[i+2] in unique_vals:
                        processes.append(user_list[i+2])
                    elif f'{user_list[i+1]} {user_list[i+2]}' in unique_vals:
                        processes.append(f'{user_list[i+1]} {user_list[i+2]}')
                    elif f'{user_list[i+2]} {user_list[i+3]}' in unique_vals:
                        processes.append(f'{user_list[i+2]} {user_list[i+3]}')
                                       
                elif bigram.lower() in values:
                    if user_list[i-1] in assoc_cols:
                        processes.append(user_list[i-1])
                    elif user_list[i-2] in assoc_cols:
                        processes.append(user_list[i-2])
                    processes.append(key)
                    if user_list[i+2] in unique_vals:
                        processes.append(user_list[i+2])
                    elif user_list[i+3] in unique_vals:
                        processes.append(user_list[i+3])
                    elif f'{user_list[i+2]} {user_list[i+3]}' in unique_vals:
                        processes.append(f'{user_list[i+2]} {user_list[i+3]}')
                    elif f'{user_list[i+3]} {user_list[i+4]}' in unique_vals:
                        processes.append(f'{user_list[i+3]} {user_list[i+4]}')


        query += 'WHERE '
        if len(processes) > 3:
            for idx in range(1, len(processes)//3):
                processes.insert(idx*3, 'AND')
        
        print("WHERE:", processes)

        for process_idx in range(0, len(processes), 4):
            try:
                processes[process_idx+2] = int(processes[process_idx+2])
                if processes[process_idx+1] == 'GREATER_THAN':
                    query += f'{processes[process_idx]} > {processes[process_idx+2]}'
                elif processes[process_idx+1] == 'LESS_THAN':
                    query += f'{processes[process_idx]} < {processes[process_idx+2]}'
                elif processes[process_idx+1] == 'EQUAL_TO':
                    query += f'{processes[process_idx]} = {processes[process_idx+2]}'
                elif processes[process_idx+1] == 'NOT_EQUAL_TO':
                    query += f'{processes[process_idx]} != {processes[process_idx+2]}'
            except:
                if processes[process_idx+1] == 'GREATER_THAN':
                    query += f'{processes[process_idx]} > {processes[process_idx+2]}'
                elif processes[process_idx+1] == 'LESS_THAN':
                    query += f'{processes[process_idx]} < {processes[process_idx+2]}'
                elif processes[process_idx+1] == 'EQUAL_TO':
                    query += f"{processes[process_idx]} = \'{processes[process_idx+2]}\'"
                elif processes[process_idx+1] == 'NOT_EQUAL_TO':
                    query += f"{processes[process_idx]} != \'{processes[process_idx+2]}\'"
            try:
                if processes[process_idx+3] == 'AND':
                    query += ' AND '
            except:
                pass
        
        query += ' '

    if 'GROUP BY' in keywords:
        group_by = []

        for i, word in enumerate(user_list):
            try:
                bigram = word + f' {user_list[i+1]}'
            except:
                pass
            if bigram in group_words:
                if user_list[i+1] in assoc_cols:
                    group_by.append(user_list[i+1])
                elif user_list[i+2] in assoc_cols:
                    group_by.append(user_list[i+2])
                elif user_list[i+3] in assoc_cols:
                    group_by.append(user_list[i+3])
            elif word in group_words:
                if user_list[i+1] in assoc_cols:
                    group_by.append(user_list[i+1])
                elif user_list[i+2] in assoc_cols:
                    group_by.append(user_list[i+2])
                elif user_list[i+3] in assoc_cols:
                    group_by.append(user_list[i+3])
        print('gb:', group_by)

        for i, col in enumerate(group_by):
            if i == 0:
                query += f'GROUP BY {col}'
            else:
                query += f', {col}'
        query += ' '
        query = query.replace("SELECT", f"SELECT {col},")
    
    if 'HAVING' in keywords:
        pass
    
    if 'ORDER BY' in keywords:
        asc = 'ASC'
        order = []

        for i, word in enumerate(user_list):
            try:
                bigram = word + f' {user_list[i+1]}'
            except:
                pass
            try:
                if bigram in order_words:
                    if user_list[i+2] in assoc_cols:
                        order.append(user_list[i+2])
                    elif user_list[i+3] in assoc_cols:
                        order.append(user_list[i+3])
                elif word in order_words:
                    if user_list[i+1] in assoc_cols:
                        order.append(user_list[i+1])
                    elif user_list[i+2] in assoc_cols:
                        order.append(user_list[i+2])
                    elif user_list[i+3] in assoc_cols:
                        order.append(user_list[i+3])
            except:
                pass
                
        for i, word in enumerate(user_list):
            try:
                bigram = word.lower() + f' {user_list[i+1]}'.lower()
            except:
                pass
            try:
                trigram = bigram + f' {user_list[i+2]}'
            except:
                pass
            if word.lower() in desc_words or bigram in desc_words or trigram in desc_words:
                asc = 'DESC'

        query += f'ORDER BY ' + ', '.join(order) + f' {asc} '
    
    
    if 'LIMIT' in keywords:
        query += 'LIMIT '
        for i, word in enumerate(user_list):
            if word in limit_words:
                try:
                    query += f'{int(user_list[i+1])}'
                except:
                    pass
                try:
                    query += f'{int(user_list[i+2])}'
                except:
                    pass
                try:
                    query += f'{int(user_list[i+3])}'
                except:
                    pass
                try:
                    query += f'{int(user_list[i-1])}'
                except:
                    pass
    
    query += ';'
    
    connection = pymysql.connect(
        host = login_info['endpoint'],
        user = login_info['username'],
        password = login_info['password'],
        database = login_info['sql_database_name']
    )

    cursor = connection.cursor()
    print("Executing query:", query)
    try:
        cursor.execute(query)
        columns = [desc[0] for desc in cursor.description]
        rows = cursor.fetchall()

        print(tabulate(rows, headers=columns, tablefmt="pretty"))

        cursor.close()
        connection.close()
    except:
        print('Error executing query')

    return

    # print(query)

user_input = 'Show me how many rows where review is greater than 1 from generalinfo'  # work on this
user_input = 'Show me how many rows in generalinfo'
user_input =  "How many restaurants have each review from generalinfo"
# user_input = 'Show me how many rows where course_id is greater than 104 from course, sort by course_id'
user_input = 'show me the 5 biggest street_num, filter to where city is san francisco and use the location table'
user_input = 'show me all columns where food_type is burgers and city is alameda from generalinfo'
user_input = "show me average review where label is baskin robbins from generalinfo"
user_input = 'Show me average review where food_type is ice cream group by label from generalinfo limit 5'
user_input = 'show me columns group by label from generalinfo'
user_input = 'show me the biggest 5 food types by average review of restaurants from each food_type in generalinfo rank them by AVG(review)'
user_input = 'show me all columns where food_type is burgers and city is alameda from generalinfo'
print(user_input)
keywords = match_query_pattern(user_input)
print('kywds', keywords)

execute_sql_query(user_input, keywords, login_info)

show me all columns where food_type is burgers and city is alameda from generalinfo
kywds ['WHERE', 'SELECT']
['show', 'me', 'all', 'columns', 'where', 'food_type', 'is', 'burgers', 'and', 'city', 'is', 'alameda', 'from', 'generalinfo']
Generating query for table: ['generalinfo']
Columns in associated table: ['id_restaurant', 'label', 'food_type', 'city', 'review']
Creating query with columns: ['food_type', 'city', '*']
WHERE: []
Executing query: SELECT * FROM generalinfo WHERE  ;
Error executing query


In [None]:
import numpy as np

def sql_or_nosql(user_input, login_info):
    decision = []

    user_list = user_input.replace(",", "").split()

    sql_mdata = get_mysql_metadata(login_info)
    mongo_mdata = get_mongodb_metadata(login_info)

    for word in user_list:
        if word in sql_mdata.keys():
            decision.append(1)
        elif word in mongo_mdata.keys():
            decision.append(0)

    if np.mean(decision) > 0.5:
        return 'SQL'
    elif len(decision) > 0:
        return'MONGODB'
    else:
        return 'UNDEFINED'

user_input = 'show me the 5 biggest street_num, filter to where city is san francisco and use the location table'

sql_or_nosql(user_input, login_info)

'SQL'

In [None]:
# Show ten rows command for sql

def show_table(n_rows, table, login_info):
    try:
        # Establish a database connection
        connection = pymysql.connect(
            host=login_info['endpoint'],
            user=login_info['username'],
            password=login_info['password'],
            db=login_info['sql_database_name']
        )
        cursor = connection.cursor()

        # Construct the SQL query
        query = f"SELECT * FROM {table} LIMIT {n_rows};"

        # Execute the query
        cursor.execute(query)
        rows = cursor.fetchall()
        column_names = [desc[0] for desc in cursor.description]

        # Display the data
        if rows:
            print(f"\nShowing {len(rows)} rows from table \'{table}\':\n")

            # Attempt to use pandas for a pretty table display
            try:
                df = pd.DataFrame(rows, columns=column_names)
                print(df)
            except ImportError:
                # Fallback to line-by-line print
                print(f"{' | '.join(column_names)}")
                print("-" * 80)
                for row in rows:
                    print(" | ".join(map(str, row)))
        else:
            print(f"No data found in table `{table}`.")
        
        connection.close()

    except pymysql.Error as e:
        print(f"An error occurred: {e}")
        return 
    
    return

show_table(5, 'courses', login_info)


Showing 5 rows from table 'courses':

   CourseID                CourseName  InstructorID InstructorName  \
0       101           Data Structures             2      Dr. Brown   
1       102                  Calculus             3      Dr. Smith   
2       103          Database Systems             2      Dr. Brown   
3       104            Linear Algebra             3      Dr. Smith   
4       105  Introduction to Business             4      Dr. White   

   CreditHours  
0            3  
1            4  
2            3  
3            3  
4            3  


In [None]:
## Show n rows for mongodb

import pymongo
import pandas as pd
pd.set_option('display.max_columns', None)

def show_collection(n_docs, collection_name, login_info):

    try:

        mongo_username = login_info['mongo_username']
        mongo_password = login_info['mongo_password']

        connection_string = f'mongodb+srv://{mongo_username}:{mongo_password}@cluster0.tgu2d.mongodb.net/'

        # Establish a MongoDB connection
        client = pymongo.MongoClient(connection_string)
        db = client['ChatDB']
        collection = db[collection_name]

        # Fetch the documents
        documents = list(collection.find().limit(n_docs))

        # Display the documents
        if documents:
            print(f"\nShowing {len(documents)} documents from collection `{collection_name}`:")

            for i, doc in enumerate(documents, start=1):
                print(f"{i}: {doc}")
        else:
            print(f"No documents found in collection `{collection_name}`.")

        # Close the connection
        client.close()

    except Exception as e:
        print(f"An error occurred: {e}")

show_collection(5, 'products', login_info)

No documents found in collection `products`.


In [None]:
mdata = get_mongodb_metadata(login_info)
print(mdata.keys())
for k, v in mdata.items():
    print(k, v[0])

dict_keys(['UW_std_course', 'UW_std_advisedBY', 'UW_std_taughtBy', 'UW_std_person'])
UW_std_course {'name': '_id', 'type': 'ObjectId', 'primary_key': True, 'unique_values': []}
UW_std_advisedBY {'name': '_id', 'type': 'ObjectId', 'primary_key': True, 'unique_values': []}
UW_std_taughtBy {'name': '_id', 'type': 'ObjectId', 'primary_key': True, 'unique_values': []}
UW_std_person {'name': '_id', 'type': 'ObjectId', 'primary_key': True, 'unique_values': []}


In [None]:
assoc_tables = ['UW_std_course']
user_list = user_input.replace(',', '').split()
assoc_cols = []
col_idx_mdata = []
for i in range(len(assoc_tables)):
    table_columns = []
    for j in range(len(mdata[assoc_tables[i]])):
        table_columns.append(mdata[assoc_tables[i]][j]['name'])

    print('tcols:', table_columns)
    for word in user_list:
        if word in table_columns:
            assoc_cols.append(word)
            idx = table_columns.index(word)
            col_idx_mdata.append(idx)


tcols: ['_id', 'course_id', 'courseLevel']


In [119]:
from pymongo import MongoClient
from bson.son import SON

def execute_mongo_query(user_input, keywords, login_info):
    user_list = user_input.translate(str.maketrans('', '', string.punctuation.replace('_', ''))).split()

    # Connect to MongoDB
    connection_string = f'mongodb+srv://{login_info['mongo_username']}:{login_info['mongo_password']}@cluster0.tgu2d.mongodb.net/'
    mdata = get_mongodb_metadata(login_info)
    client = pymongo.MongoClient(connection_string)
    db = client[login_info['mongo_database_name']]

    collections = db.list_collection_names()

    assoc_collections = [word for word in user_list if word in collections]
    if not assoc_collections:
        print("No matching collections found.")
        return

    assoc_cols = []
    col_idx_mdata = []
    for i in range(len(assoc_collections)):
        table_columns = []
        for j in range(len(mdata[assoc_collections[i]])):
            table_columns.append(mdata[assoc_collections[i]][j]['name'])

        print('tcols:', table_columns)
        for word in user_list:
            if word in table_columns:
                assoc_cols.append(word)
                idx = table_columns.index(word)
                col_idx_mdata.append(idx)

    unique_vals = []
    for i in assoc_collections:
        for j in col_idx_mdata:
            unique_vals += mdata[i][j]['unique_values']

    unique_vals = [str(item) if isinstance(item, (int, float)) else item for item in unique_vals]

    collection = db[assoc_collections[0]]

    pipeline = []
    print(assoc_collections)
    # Handle WHERE conditions (filters)
    if 'WHERE' in keywords:
        match_stage = {}
        for key, values in where_words.items():
            for i, word in enumerate(user_list):
                try:
                    bigram = word + f' {user_list[i+1]}'
                except:
                    pass
                if word.lower() in values:
                    if user_list[i-1] in assoc_cols:
                        field = user_list[i - 1]
                    elif user_list[i-2] in assoc_cols:
                        field = user_list[i - 2]
                    if user_list[i+1] in unique_vals:
                        value = user_list[i + 1]
                    elif user_list[i+2] in unique_vals:
                        value = user_list[i + 2]
                    elif user_list[i+3] in unique_vals:
                        value = user_list[i + 3]
                    elif f'{user_list[i+1]} {user_list[i+2]}' in unique_vals:
                        value = f'{user_list[i+1]} {user_list[i+2]}'
                    elif f'{user_list[i+2]} {user_list[i+3]}' in unique_vals:
                        value = f'{user_list[i+2]} {user_list[i+3]}'
                    else:
                        value = 100
                    if key == 'LESS_THAN':
                        match_stage[field] = {"$lt": value}
                    elif key == 'GREATER_THAN':
                        match_stage[field] = {"$gt": value}
                    elif key == 'EQUAL_TO':
                        match_stage[field] = value
                    elif key == 'NOT_EQUAL_TO':
                        match_stage[field] = {"$ne": value}
        if match_stage:
            pipeline.append({"$match": match_stage})
        print("Match", match_stage)
    
    if 'SELECT' in keywords or 'AGGREGATE' in keywords:
        for i, word in enumerate(user_list):
            try:
                bigram = word + f' {user_list[i+1]}'
            except:
                pass
            if bigram in all_cols:
                projection = {"_id": 0}
                break
            elif assoc_cols:
                projection = {col: 1 for col in assoc_cols}
        pipeline.append({"$project": projection})   

    # Handle GROUP BY
    if 'GROUP BY' in keywords:
        group_by_field = None
        for i, word in enumerate(user_list):
            try:
                bigram = word + f' {user_list[i+1]}'
            except:
                pass
            if bigram in group_words:
                if user_list[i+1] in assoc_cols:
                    group_by_field = user_list[i+1]
                elif user_list[i+2] in assoc_cols:
                    group_by_field = user_list[i+2]
                elif user_list[i+3] in assoc_cols:
                    group_by_field = user_list[i+3]
            elif word in group_words:
                if user_list[i+1] in assoc_cols:
                    group_by_field = user_list[i+1]
                elif user_list[i+2] in assoc_cols:
                    group_by_field = user_list[i+2]
                elif user_list[i+3] in assoc_cols:
                    group_by_field = user_list[i+3]
        print("GB:", group_by_field)
        processes = []

        if 'AGGREGATE' in keywords:
            for key, values in aggregate_words.items():
                for i, word in enumerate(user_list):
                    if word.lower() in values:
                        processes.append(key)
                        if user_list[i+1] in assoc_cols:
                            processes.append(f'${user_list[i+1]}')
                        elif user_list[i+2] in assoc_cols:
                            processes.append(f'${user_list[i+2]}')

        group_stage = {
            "$group": {
                "_id": f"${group_by_field}"
            }
        }

        processes = ['SUM' if item == 'COUNT' else item for item in processes]
        if len(processes) % 2 == 1:
            processes.append(1)
        print("AGG:", processes)

        for i, item in enumerate(processes):
            if item in aggregate_words.keys():
                try:
                    alias = f'{item.lower()}_{processes[i+1].replace('$', '')}'
                except:
                    alias = f'{item.lower()}_{processes[i+1]}'
                try:
                    int(processes[i+1])
                    group_stage['$group'][alias] = {f'${item.lower()}': processes[i+1]}
                except:
                    group_stage['$group'][alias] = {f'${item.lower()}': f'{processes[i+1]}'}

        pipeline.append(group_stage)
        if not any('$project' in stage for stage in pipeline):
            pipeline = [stage for stage in pipeline if '$project' not in stage]
    
    elif 'AGGREGATE' in keywords:
        processes = []

        for key, value in aggregate_words.items():
            for i, word in enumerate(user_list):
                if word in value:
                    processes.append(key)
                    if user_list[i+1] in assoc_cols:
                        processes.append(user_list[i+1])
                    elif user_list[i+2] in assoc_cols:
                        processes.append(user_list[i+2])
        
        processes = ['SUM' if item == 'COUNT' else item for item in processes]
        group_stage = {
            "$group": {
                "_id": None
            }
        }
        if len(processes) % 2 == 1:
            processes.append(1)
        print("AGG:", processes)

        for i, item in enumerate(processes):
            if item in aggregate_words.keys():
                alias = f'{item.lower()}_{processes[i+1]}'
                group_stage['$group'][alias] = {f'${item.lower()}': f'${processes[i+1]}'}

        pipeline.append(group_stage)
        if not any('$project' in stage for stage in pipeline):
            pipeline = [stage for stage in pipeline if '$project' not in stage]

    # Handle ORDER BY
    if 'ORDER BY' in keywords:
        sort_order = 1  # Default to ascending
        for i, word in enumerate(user_list):
            try:
                bigram = word + f' {user_list[i+1]}'
            except:
                pass
            if word in desc_words or bigram in desc_words:
                sort_order = -1
            if bigram in order_words:
                if user_list[i+2] in assoc_cols:
                    sort_field = user_list[i+2]
                elif user_list[i+3] in assoc_cols:
                    sort_field = user_list[i+3]
            elif word in order_words:
                if user_list[i+1] in assoc_cols:
                    sort_field = user_list[i+1]
                elif user_list[i+2] in assoc_cols:
                    sort_field = user_list[i+2]
                elif user_list[i+3] in assoc_cols:
                    sort_field = user_list[i+3]

        pipeline.append({"$sort": {sort_field: sort_order}})
    
    # Handle LIMIT
    if 'LIMIT' in keywords:
        for i, word in enumerate(user_list):
            if word in limit_words:
                try:
                    pipeline.append({"$limit": int(user_list[i+1])})
                except:
                    pass
                try:
                    pipeline.append({"$limit": int(user_list[i+2])})
                except:
                    pass
                try:
                    pipeline.append({"$limit": int(user_list[i+3])})
                except:
                    pass
                try:
                    pipeline.append({"$limit": int(user_list[i-1])})
                except:
                    if not any([True for stage in pipeline if '$limit' in stage]):
                        pipeline.append({"$limit": 5})
                


    print("Executing MongoDB pipeline:", pipeline)
    try:
        result = list(collection.aggregate(pipeline))
        headers = result[0].keys()
        data = [row.values() for row in result]
        print(tabulate(data, headers=headers, tablefmt="pretty"))
    except Exception as e:
        print("Error executing query:", e)

    client.close()

user_input = 'Show all columns where the course_id is greater than 100 from UW_std_course only show 6'
# user_input = 'Show all columns where inPhase is equal to Post_Quals from UW_std_person, only show 6 and by descending _id'
# user_input = 'Show me the average p_id for each inPhase from UW_std_person and sort it by _id'
user_input = 'show me the average p_id from UW_std_person where student equals 1'
user_input = 'show me every column from UW_std_person'
user_input = 'show me average p_id, maximum p_id, minimum p_id, and count of all rows from UW_std_person grouped by hasPosition'
keywords = match_query_pattern(user_input)
print(keywords, user_input)

execute_mongo_query(user_input, keywords, login_info)

['GROUP BY', 'AGGREGATE'] show me average p_id, maximum p_id, minimum p_id, and count of all rows from UW_std_person grouped by hasPosition
tcols: ['_id', 'p_id', 'professor', 'student', 'hasPosition', 'inPhase', 'yearsInProgram']
['UW_std_person']
GB: hasPosition
AGG: ['MIN', '$p_id', 'MAX', '$p_id', 'AVG', '$p_id', 'SUM', 1]
Executing MongoDB pipeline: [{'$project': {'p_id': 1, 'hasPosition': 1}}, {'$group': {'_id': '$hasPosition', 'min_p_id': {'$min': '$p_id'}, 'max_p_id': {'$max': '$p_id'}, 'avg_p_id': {'$avg': '$p_id'}, 'sum_1': {'$sum': 1}}}]
+-------------+----------+----------+--------------------+-------+
|     _id     | min_p_id | max_p_id |      avg_p_id      | sum_1 |
+-------------+----------+----------+--------------------+-------+
|   Faculty   |    5     |   415    |      215.025       |  40   |
|      0      |    3     |   435    | 223.89823008849558 |  226  |
| Faculty_adj |    7     |   349    |       145.0        |   6   |
| Faculty_aff |   103    |   293    | 229.3

In [117]:
connection_string = f'mongodb+srv://{login_info['mongo_username']}:{login_info['mongo_password']}@cluster0.tgu2d.mongodb.net/'
mdata = get_mongodb_metadata(login_info)
client = pymongo.MongoClient(connection_string)
db = client[login_info['mongo_database_name']]
collection = db['UW_std_person']
pipeline =  [{'$project': {'p_id': 1, 'hasPosition': 1}}, {'$group': {'_id': '$hasPosition', 'min_p_id': {'$min': '$p_id'}, 'max_p_id': {'$max': '$p_id'}, 'avg_p_id': {'$avg': '$p_id'}, 'sum_1': {'$sum': 1}}}]
result = list(collection.aggregate(pipeline))
headers = result[0].keys()
data = [row.values() for row in result]
print(tabulate(data, headers=headers, tablefmt="pretty"))


+-------------+----------+----------+--------------------+-------+
|     _id     | min_p_id | max_p_id |      avg_p_id      | sum_1 |
+-------------+----------+----------+--------------------+-------+
| Faculty_aff |   103    |   293    | 229.33333333333334 |   3   |
| Faculty_adj |    7     |   349    |       145.0        |   6   |
| Faculty_eme |    22    |   375    | 231.33333333333334 |   3   |
|      0      |    3     |   435    | 223.89823008849558 |  226  |
|   Faculty   |    5     |   415    |      215.025       |  40   |
+-------------+----------+----------+--------------------+-------+


In [None]:
queries = ['Find all students who have yearsInProgram equal to Year_12 years from UW_std_ person',
           'Find all faculty members from UW_std_person',
           'Find all students who have completed their qualifying exams from UW_std_person',
           'Find all students who are currently in their first year of the program',
           'Find all records from person where their inPhase is Post_Generals ',
           'Find all students who have a faculty position (either "Faculty_aff", "Faculty_eme", etc.), and are in their third year',
           'Find all students who are either in the "Post_Generals" or "Post_Quals" phase',
           'Find all students who are in the program for less than 5 years',
           'Find all students who are not in a faculty position',
           "Find all students who have the 'student' field marked as '1' and have been in the program for 8 years or more",
           'Find all students who have been in the program for more than 5 years',
           "Find all students or faculty members in a specific phase, say 'Post_Quals', who have been in the program for more than 10 years",
           "Find all records where the person is either a student or a professor",
           "Find all students who have a faculty position ('Faculty' type) and have been in the program for more than 5 years",
           "Find all students who are in their last year of the program ('Year_12')",
           "Find all students who are not in any faculty position ('hasPosition' field is '0' or 'Faculty_aff', etc.)",
           'Find all students who are in the "Pre_Quals" phase and have a faculty position']


# for input in queries:
#     input += ' limit 5'
#     keywords = match_query_pattern(input)
#     print(keywords, input)
#     print(execute_mongo_query(input, keywords, login_info))
#     print(execute_sql_query(input, keywords, login_info))