In [1]:
import pandas as pd
from sqlalchemy import create_engine
import pymysql
import sys
import os
from tabulate import tabulate
import random
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.corpus import wordnet
from nltk.stem import WordNetLemmatizer
from nltk import pos_tag
import re

nltk.download('punkt')
nltk.download('punkt_tab')
nltk.download('stopwords')
nltk.download('averaged_perceptron_tagger_eng')



[nltk_data] Downloading package punkt to /home/ubuntu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/ubuntu/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package stopwords to /home/ubuntu/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /home/ubuntu/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


True

In [25]:
def upload_sql(connection):

    file_path = input("Enter the CSV filename to upload to SQL: ")

    if not os.path.isfile(file_path):
        print(f"Error: File '{file_path}' not found. Check for any typos")
        return

    table = os.path.splitext(os.path.basename(file_path))[0]
    #db = 'SQL_Datasets'
    user = 'root'
    password = 'Dsci-551'
    host = 'localhost'
    port = 3306

    connection = pymysql.connect(user = user, password = password, host = host, port = port)
    cursor = connection.cursor()

    cursor.execute("Show Databases")
    databases = [db[0] for db in cursor.fetchall()]

    print("\nAvailable Databases: ")
    for i, db in enumerate(databases, start = 1):
        print(f"{i}, {db}")


    while True:
        db_choice = input("\n Enter the name of database you want to upload the CSV to (or type 'New' to create a new database): ").strip()

        if db_choice in databases:
            db = db_choice
            cursor.execute(f"Use `{db}`")
            break
        elif db_choice.lower() == 'new':
            new_db = input("Enter the name of the new database: ").strip()
            cursor.execute(f"Create database if not exists `{new_db}`")
            cursor.execute(f"Use `{new_db}`")
            db = new_db
            break
        else:
            print("Did not recognize input. Please enter one of the existing databases of 'New'")
            
   

    cursor.execute(f"Show tables like '{table}'")
    table_exists = cursor.fetchone() is not None

    if table_exists:
        while True:
            replace = input(f"The table '{table}' already exists. Do you want to replace it? (yes/no): ").strip().lower()
            if replace == 'yes':
                break
            elif replace == 'no':
                print(f"Table '{table}' not replaced")
                connection.close()
                return
            else:
                print("Invalid input. Please enter 'yes' or 'no' only")

    engine = create_engine(f'mysql+pymysql://{user}:{password}@{host}:{port}/{db}')
    data = pd.read_csv(file_path, encoding = 'ISO-8859-1')
    data.to_sql(name = table, con = engine, if_exists = 'replace', index = False)

    print(f"'{file_path} has been successfully uploaded as '{table}' table in database '{db}'")
    cursor.close()
    connection.close()
    

In [26]:
def connect_to_SQL(selected_db = None):

    #db = 'SQL_Datasets'
    user = 'root'
    password = 'Dsci-551'
    host = 'localhost'
    port = 3306

    connection = pymysql.connect(user = user, password = password, host = host, port = port)
    cursor = connection.cursor()

    if not selected_db:

        cursor.execute("Show Databases")
        dbs = [db[0] for db in cursor.fetchall()]

        print("\nAvailable Databases: ")
        for i, db in enumerate(dbs, start = 1):
            print(f"{i}, {db}")


        while True:
            db_choice = input("\n Enter the name of database you want to use.: ").strip()

            if db_choice in dbs:
                selected_db = db_choice
                break
            else:
                print("Did not recognize input. Please enter one of the existing databases.\n")
                
    cursor.close()
    connection.close()
    connection = pymysql.connect(user = user, password = password, host = host, port = port, database = selected_db)

    print(f"Connected to database: {selected_db}")
    return connection, selected_db
            

   

   


In [27]:
def preprocess(user_input):

    tokens = re.findall(r'<=|>=|=|<|>|\w+[\w_/.\-]*', user_input)
    #tokens = word_tokenize(user_input, language='english')
    filtered_tokens = [token.lower() for token in tokens]
    #stop_words = set(stopwords.words('english'))
    #fixed_stop_words = stop_words - {'where', 'having', 'than'}
    #filtered_tokens = [word for word in tokens]
    tagged_tokens = pos_tag(filtered_tokens)
    #print(filtered_tokens)
    return filtered_tokens, tagged_tokens

In [28]:
def column_types(cursor, table):
    cursor.execute(f"Show Columns From `{table}`")
    
    columns_type = cursor.fetchall()
    columns = [col[0].strip() for col in columns_type]
    numeric_columns = [col[0].strip() for col in columns_type if col[1].strip() in ['int', 'float', 'double', 'decimal', 'smallint', 'tinyint', 'bigint']]
    categorical_columns = [col[0].strip() for col in columns_type if col[1].strip() in ['varchar', 'char', 'text', 'enum', 'nchar', 'nvarchar', 'ntext']]

    return columns, numeric_columns, categorical_columns

In [29]:
def get_column_matches(user_input, cursor):

    filtered_tokens, tagged_tokens = preprocess(user_input)
    cursor.execute("Show Tables")
    tables = [row[0] for row in cursor.fetchall()]
    all_columns = {}

    for table in tables:
        columns, numeric_columns, categorical_columns = column_types(cursor, table)
        all_columns[table] = columns

    matched_columns = {}
    for table, columns in all_columns.items():
        matching_columns = [col for col in columns if col in filtered_tokens]
        if matching_columns:
            matched_columns[table] = matching_columns
    return matched_columns

    

In [30]:
def get_table(user_input, cursor):

    matched_columns = get_column_matches(user_input, cursor)
    if matched_columns:
        for table, columns in matched_columns.items():
            columns, numeric_columns, categorical_columns = column_types(cursor, table)
            return table, columns, numeric_columns, categorical_columns

    cursor.execute("Show Tables")
    tables = [table[0] for table in cursor.fetchall()]

    for table in tables:
        if table.lower() in user_input.lower():
            columns, numeric_columns, categorical_columns = column_types(cursor, table)
            return table, columns, numeric_columns, categorical_columns
            
    return None, None, None, None

In [31]:
def sql_sample_data(connection):

    if not connection.open:
        connection.ping(reconnect = True)
        
    cursor = connection.cursor()


    cursor.execute("Show Tables")
    tables = [table[0] for table in cursor.fetchall()]

    if not tables:
        print(f"No table founds in '{selected_db}'. Please upload some.")
        connection.close()
        return
    
    print(f"Tables and attributes in '{selected_db}':")

    table_columns = {}
    for table in tables:
        cursor.execute(f"Show columns from `{table}`")
        columns = [col[0] for col in cursor.fetchall()]
        table_columns[table] = columns
        print(f"\n {table}")
        print(f"Attributes:", ", ".join(columns))
        print()



    while True:
        selected_table = input("\nEnter name of table you want to view sample data from: ").strip()

        if selected_table in table_columns:
            break

        else:
            print("Table input not recognized. Please enter a table name from above")

    cursor.execute(f"Select * from `{selected_table}` limit 5")
    sample = cursor.fetchall()

    if sample:
        df = pd.DataFrame(sample, columns = table_columns[selected_table])
        print(f"\n Sample data from '{selected_table}': ")
        print(tabulate(df, headers = 'keys', tablefmt = 'psql', showindex = False))

    else:
        print(f"No data found in '{table}'")

    cursor.close()
    
    

In [32]:
def gen_sample_queries(connection, num_queries = 1, random_queries = True):

    
    if not connection.open:
        connection.ping(reconnect = True)
        


    cursor = connection.cursor()

    cursor.execute("Show Tables")
    tables = [table[0] for table in cursor.fetchall()]


    if not tables:
        print("No tables in this database")
        connection.close()
        return []

    if random_queries:
        user_input = input("What kind of query would you like to see or would you like ChatDB to generate one?").strip().lower()

    else:
        user_input = input("Ask a question: ").strip()
        
    table_choice = None
    while table_choice is None:
        if random_queries:
            table_choice = random.choice(tables)
            columns, numeric_columns, categorical_columns = column_types(cursor, table_choice)
        else:
            table_choice, columns, numeric_columns, categorical_columns = get_table(user_input, cursor)
            if table_choice is None:
                print("Unable to find table based on user input")
                user_input = input("Rewrite your question: ")
                continue

    
    columns, numeric_columns, categorical_columns = column_types(cursor, table_choice)
    #print(numeric_columns)
    #print(categorical_columns)
      
    #if 'join' in user_input:
    #    query_choice = 'join'


        
    queries = ['having', 'group', 'max', 'min', 'sum', 'avg', 'count', 'where', 'order', 'select']
    conditions = ['>', '<', '=', '<=', '>=']
    condition_text = {'>': 'greater than', '<': 'less than', '=': 'equal to', '<=': 'less than or equal to', '>=': 'greater than or equal to'}
    agg_functions = ['AVG', 'SUM', 'MAX', 'MIN', 'COUNT']
    agg_text = {'AVG': 'average', 'SUM': 'sum', 'MAX': 'maximum', 'MIN': 'minimum', 'COUNT': 'count'}
    limit_text = {'ASC': 'Bottom', 'DESC': 'Top'}
    keyword_mapping = {'less than or equal to': '<=', 'greater than or equal to': '>=', 'greater than': '>', 'less than': '<', 
                       'equal to': '=', 'is': '=', 'are': '=', 'average': 'AVG', 'maximum': 'MAX', 'minimum': 'MIN', 'total': 'SUM',
                      'find': 'select', 'show': 'select', 'grouped': 'group'}
    

 

    def preprocess_keywords(tokens, mapping):

        processed_tokens = []
        i = 0
        while i < len(tokens):
            match = False
            for phrase, value in mapping.items():
                phrase_tokens = phrase.split()
                if tokens[i:i + len(phrase_tokens)] == phrase_tokens:
                    processed_tokens.append(value)
                    i += len(phrase_tokens)
                    match = True
                    break
    
            if not match:
                processed_tokens.append(mapping.get(tokens[i], tokens[i]))
                i += 1
        #print(processed_tokens)
        return processed_tokens
    
   
    def condition_value(cursor, agg_col, table_choice): 
        condition = random.choice(conditions)
        cursor.execute(f"Select MIN(`{agg_col}`), MAX(`{agg_col}`) from `{table_choice}`")
        
        min_value, max_value = cursor.fetchone()
        
        if min_value is None or max_value is None:
            print(f"Error: Value was None")
            return condition, 0
            
        value = random.uniform(min_value, max_value)

        if min_value == int(min_value) and max_value == int(max_value):
            value = round(value)

        return condition, value

    filtered_tokens, tagged_tokens = preprocess(user_input)
    filtered_tokens = preprocess_keywords(filtered_tokens, keyword_mapping)
    sample_queries = []

    specific_query = [query for query in queries if query in filtered_tokens]
    
    if random_queries:

        if 'any' in filtered_tokens or not any(query in filtered_tokens for query in queries):
            print("Generating any query. You selected 'any' or had no specific query identified.")
            selected_queries = random.sample(queries, k = num_queries)
        else:
            selected_queries = random.choices(specific_query, k = num_queries)

    else:
        if specific_query:
            selected_queries = specific_query[:num_queries]
        else:
            while not specific_query:
                print("Unable to idenfiy query based on user input\n")
                user_input = input("Rewrite your question: ")
                filtered_tokens, tagged_tokens = preprocess(user_input)
                specific_query = [query for query in queries if query in filtered_tokens]
            selected_queries = specific_query[:num_queries]
   
    for query in selected_queries:

        if 'where' in specific_query and 'having' in specific_query and numeric_columns and categorical_columns:
            where_condition, where_value = None, None
            having_condition, having_value = None, None
            group_col, agg_function = None, None

            if random_queries:
                where_col = random.choice(categorical_columns + numeric_columns)
                if where_col in numeric_columns:
                    where_condition, where_value = condition_value(cursor, where_col, table_choice)
                elif where_col in categorical_columns:
                    cursor.execute(f"Select Distinct `{where_col}` from `{table_choice}` limit 20;")
                    unique_values = [row[0] for row in cursor.fetchall()]
                    where_value = random.choice(unique_values) if unique_values else 'unknown'
                    where_condition = '='

                having_col = random.choice(numeric_columns)
                group_col = random.choice(categorical_columns)
                agg_function = random.choice(agg_functions)
                having_condition, having_value = condition_value(cursor, having_col, table_choice)

            else:
                where_index = filtered_tokens.index('where') if 'where' in filtered_tokens else None
                group_index = filtered_tokens.index('group') if 'group' in filtered_tokens else None  
                having_index = filtered_tokens.index('having') if 'having' in filtered_tokens else None
                if where_index is not None:
                    if group_index is not None and having_index is not None:
                        if where_index < group_index < having_index:
                            where_end = group_index
                            group_end = having_index
                            having_end = len(filtered_tokens)
                            

                        elif where_index < having_index < group_index:
                            where_end = having_index
                            having_end = group_index
                            group_end = len(filtered_tokens)

                        elif group_index < having_index < where_index:
                            group_end = having_index
                            having_end = where_index
                            where_end = len(filtered_tokens)


        

                    where_tokens = filtered_tokens[where_index+1: where_end]
                    group_tokens = filtered_tokens[group_index+1: group_end]
                    having_tokens = filtered_tokens[having_index+1: having_end]

                    where_col = next((col for col in columns if col in where_tokens), None)
                    if where_col:
                        if where_col in numeric_columns:
                            where_condition = next((cond for cond in conditions if cond in where_tokens), None)
                            where_value = next((t for t in where_tokens if t.replace('.', '', 1).isdigit()), None)
                        elif where_col in categorical_columns:
                            where_condition = '='
                            if '=' in filtered_tokens:
                                equal_index = filtered_tokens.index('=')
                                if equal_index + 1 < len(filtered_tokens):
                                    where_valuevalue = filtered_tokens[equal_index + 1]
                                else:
                                    where_value = None
                            else:
                                value = None

                    group_col = next((col for col in categorical_columns if col in group_tokens))
                    having_col = next((col for col in numeric_columns if col in having_tokens), None)
                    agg_function = next((func for func in agg_functions if func in having_tokens), None)
                    having_condition = next((cond for cond in conditions if cond in having_tokens), None)
                    having_value = next((t for t in having_tokens if t.replace('.', '', 1).isdigit()), None)

                print(f"where_col: {where_col} where_condition: {where_condition} where_value: {where_value} group_col: {group_col} having_col: {having_col} agg_function: {agg_function} having_condition: {having_condition} having_values: {having_value}")

            if where_col and where_condition and where_value and group_col and having_col and agg_function and having_condition and having_value:

                if where_col in categorical_columns:
                    where_value = f"'{where_value}'"

                if having_col in categorical_columns:
                    having_value = f"'{having_value}'"
                    
                query = (f"Select `{group_col}`, {agg_function}(`{having_col}`) as `{agg_text[agg_function]}_{having_col}` "
                        f"From `{table_choice}` where `{where_col}` {where_condition} {where_value} "
                        f"Group by `{group_col}` having {agg_function}(`{having_col}`) {having_condition} {having_value};")

                nl = (f"{agg_text[agg_function]} of {having_col} in {table_choice} grouped by {group_col} "
                        f"having {agg_text[agg_function]} of {having_col} {condition_text[having_condition]} {having_value} "
                        f"where {where_col} is {condition_text[where_condition]} {where_value}")
                sample_queries.append(f"NL Description: {nl}\n"
                                      f"\n"
                                      f"Query: {query}")
               

            else:
                print("Failed to recognize column, condition, or value for Where and having clause.")

                  
 
        # having
        elif query == 'having' and numeric_columns and categorical_columns:
            condition = None
            value = None
            if random_queries:
                agg_col = random.choice(numeric_columns)
                group_col = random.choice(categorical_columns)
                agg_function = random.choice(agg_functions)
                condition, value = condition_value(cursor, agg_col, table_choice)
            else:
                agg_col = next((col for col in numeric_columns if col in filtered_tokens), None)
                group_col = next((col for col in categorical_columns if col in filtered_tokens), None)
                agg_function = next((token.upper() for token in filtered_tokens if token.upper() in agg_functions), None)
                condition = next((cond for cond in conditions if cond in filtered_tokens), None)    
                value = next((t for t in filtered_tokens if t.replace('.', '', 1).isdigit()), None)

            #print(agg_col, " ", group_col, " ", agg_function, " ", condition, " ", value)

            if agg_col and group_col and agg_function and condition and value:
                query = f"Select `{group_col}`, {agg_function}(`{agg_col}`) as `{agg_text[agg_function]}_{agg_col}` from `{table_choice}` group by `{group_col}` having {agg_function}(`{agg_col}`) {condition} {value};"
                nl = f"{agg_text[agg_function]} {agg_col} in {table_choice} group by {group_col} having {agg_text[agg_function]}_{agg_col} {condition_text[condition]} {value}"
                sample_queries.append(f"NL Description: {nl}\n"
                                      f"\n"
                                      f"Query: {query}")
    
            else:
                print("Failed to recognize columns or function for having clause")

        # group by and aggregate functions combined
        elif any(token.upper() in ['GROUP'] + agg_functions for token in filtered_tokens) and numeric_columns and categorical_columns:
            condition = None
            value = None
            if random_queries:
                agg_function = next((token.upper() for token in filtered_tokens if token.upper() in agg_functions), random.choice(agg_functions))
                if agg_function == 'COUNT':
                    agg_col = random.choice(numeric_columns+categorical_columns)
                    if agg_col in numeric_columns:
                        condition, value = condition_value(cursor, agg_col, table_choice)
                else:
                    agg_col = random.choice(numeric_columns)
                group_col = random.choice(categorical_columns)

            else:
                agg_function = next((token.upper() for token in filtered_tokens if token.upper() in agg_functions), None)
                if 'group' in filtered_tokens:
                    group_index = filtered_tokens.index('group') 
                    agg_cols = [col for col in filtered_tokens[:group_index] if col in numeric_columns+categorical_columns]
                    group_cols = [col for col in filtered_tokens[group_index + 1:] if col in categorical_columns]
                    agg_col = agg_cols[0] if agg_cols else None
                    group_col = group_cols[0] if group_cols else None
                else:
                    group_col = None
                    agg_col = next((col for col in filtered_tokens if col in numeric_columns), None)
                if agg_function == 'COUNT':
                    condition = next((cond for cond in conditions if cond in filtered_tokens), None)  
                    value = next((t for t in filtered_tokens if t.replace('.', '', 1).isdigit()), None)
               
         
            #print(agg_col," ", agg_function, " ", group_col,  " ", condition, " ",value)

            if agg_function == 'COUNT' and agg_col in numeric_columns and group_col:
                query = f"Select `{group_col}`, COUNT(`{agg_col}`) as count_`{agg_col}` from `{table_choice}` where `{agg_col}` {condition} {value} group by `{group_col}`;"
                nl = f"Count of {agg_col} in {table_choice} group by {group_col} where {agg_col} is {condition_text[condition]} {value}"
                sample_queries.append(f"NL Description: {nl}\n"
                                      f"\n"
                                      f"Query: {query}")
        

            elif (agg_function == 'COUNT' and agg_col in categorical_columns) or (agg_function!= 'COUNT' and group_col and agg_col):
                query = f"Select {group_col}, {agg_function}({agg_col}) as {agg_function}_{agg_col} from {table_choice} group by {group_col};"
                nl = f"{agg_text[agg_function]} of {agg_col} group by {group_col}"
                sample_queries.append(f"NL Description: {nl}\n"
                                      f"\n"
                                      f"Query: {query}")
              

            elif agg_function in ['SUM', 'AVG', 'MAX', 'MIN'] and agg_col in numeric_columns and not group_col:
                query = f"Select {agg_function}({agg_col}) as {agg_function}_{agg_col} from {table_choice};"
                nl = f"{agg_text[agg_function]} of {agg_col} in {table_choice}"
                sample_queries.append(f"NL Description: {nl}\n"
                                      f"\n"
                                      f"Query: {query}")
                
            else:
                print("Failed to recognize columns or function for GROUP or Aggregate clause")

               # where
        elif query == 'where':
            condition = None
            value = None
            if random_queries:
                agg_col = random.choice(columns)
                if agg_col in numeric_columns:
                    condition, value = condition_value(cursor, agg_col, table_choice)
                elif agg_col in categorical_columns:
                    cursor.execute(f"Select Distinct `{agg_col}` from `{table_choice}` limit 20;")
                    unique_values = [row[0] for row in cursor.fetchall()]
                    if unique_values:
                        value = random.choice(unique_values)
                    else:
                        value = 'unknown'

                    condition = '='
                columns_used = ', '.join(random.sample(columns, k = min(2, len(columns))))

            else:
                if 'where' in filtered_tokens:
                    where_index = filtered_tokens.index('where')
                    columns_before = filtered_tokens[:where_index]
                    columns_after = filtered_tokens[where_index + 1:]
                    columns_used = ', '.join(col for col in columns_before if col in numeric_columns + categorical_columns)
                    agg_col = next((col for col in columns_after if col in filtered_tokens), None)
            
                    if agg_col in numeric_columns:
                        condition = next((cond for cond in conditions if cond in filtered_tokens), None)   
                        value = next((t for t in filtered_tokens if t.replace('.', '', 1).isdigit()), None)
                        

                    elif agg_col in categorical_columns:
                        condition = '='
                   
                        if '=' in filtered_tokens:
                            equal_index = filtered_tokens.index('=')
                            if equal_index + 1 < len(filtered_tokens):
                                value = filtered_tokens[equal_index + 1]
                            else:
                                value = None
                        else:
                            value = None
                print(f"Condition: {condition}, Value: {value}, agg_col: {agg_col}, columns_used: {columns_used}")
    
            if agg_col and condition and value:

                if agg_col in categorical_columns:
                    value = f"'{value}'"
                    
                query = f"Select {columns_used} from {table_choice} where {agg_col} {condition} {value};"
                nl = f"Select {columns_used} from {table_choice} where {agg_col} is {condition_text[condition]} {value}"
                sample_queries.append(f"NL Description: {nl}\n"
                                      f"\n"
                                      f"Query: {query}")
                

            else:
                print("Failed to recognize column, condition or value for WHERE clause")


    # order by    
        elif query == 'order' and numeric_columns:
            if random_queries:
                columns_used = ', '.join(random.sample(columns, k = min(2, len(columns))))
                order_col = random.choice(numeric_columns)
                value = random.randint(1, 5)
                order_type = random.choice(['ASC', 'DESC'])
            else:
                if 'order' in filtered_tokens:
                    order_index = filtered_tokens.index('order')
                    columns_part = filtered_tokens[:order_index]
                    order_part = filtered_tokens[order_index:]

                    columns_used = ', '.join([col for col in numeric_columns+categorical_columns if col in columns_part])

                    order_col = next((col for col in order_part if col in numeric_columns), None)
                else:
                    columns_used = None
                    order_col = None
                    
                value = next((int(t) for t in filtered_tokens if t.isdigit()), None)
               
                if any(token in filtered_tokens for token in ['bottom', 'ascending', 'asc']):
                    order_type = 'ASC'
                    
                if any(token in filtered_tokens for token in ['top', 'descending', 'desc']):
                    order_type = 'DESC'

                #print(columns_used, " ", order_col, " ", value, " ", order_type)
                    
            query = f"Select {columns_used} from {table_choice} order by {order_col} {order_type} LIMIT {value};"
            nl = f"{limit_text[order_type]} {value} records ordered by {order_col}"
            sample_queries.append(f"NL Description: {nl}\n"
                                      f"\n"
                                      f"Query: {query}")

         

        elif query == 'select':
            if random_queries:
                columns_used = ', '.join(random.sample(columns, k = min(4, len(columns))))
         
            else:
                if 'all' in filtered_tokens or all(col in filtered_tokens for col in columns):
                    columns_used = '*'
                else:
                    filtered_columns = [col for col in columns if col in filtered_tokens]
                    if filtered_columns:
                        columns_used = ', '.join(filtered_columns)
                    else:
                        columns_used = '*'

            query = f"Select {columns_used} from {table_choice};"
            nl = f"Select {columns_used} from {table_choice}"
            sample_queries.append(f"NL Description: {nl}\n"
                                      f"\n"
                                      f"Query: {query}")
           



            
    if random_queries:
        if sample_queries:
            print("\nSample Queries\n")
            for i, query in enumerate(sample_queries, 1):
                print(f"Sample Query {i} \n {query}\n")
             
                
    elif not random_queries:
        if sample_queries:
            print("\nQuery")
            for i, query in enumerate(sample_queries, 1):
                print(f"Query: {query}\n")
               
    else:
        print("No Queries could be generated")
            



    return sample_queries
    

    

In [None]:
if __name__ == "__main__":

    print("Welcome to ChatDB\n")
    selected_db = None
    connection = None
    while True:

        print("Please select a numbered option: \n")
        print("1. Select or Change Database\n")
        print("2. Upload Data\n")
        print("3. View Table Attributes and Sample Data\n")
        print("4. Generate Sample Queries\n")
        print("5. Answer Natural Language Questions\n")
        print("6. End Program\n")

        choice = input("Enter your choice ('1', '2', '3', '4', '5', '6'): ").strip()

        if choice.strip() == '1':
            connection, selected_db = connect_to_SQL()
        
        elif choice.strip() == '2':
            if selected_db:
                upload_sql(connection)
            else:
                print("Please select a database using option '1'")

        elif choice.strip() == '3':
            if selected_db:
                sql_sample_data(connection)
            else:
                print("Please select a database using option '1'")
            

        elif choice.strip() == '4':
            if selected_db:
                gen_sample_queries(connection, num_queries = 3, random_queries = True)
            else:
                print("Please select a database using option '1'")

        elif choice.strip() == '5':
            if selected_db:
                gen_sample_queries(connection,num_queries = 1, random_queries = False)
            else:
                print("Please select a database using option '1'")


        elif choice.strip() == '6':
            break
        
        else:
            print("Invalid choice. Please Enter one of the options in quotes above")

    if connection:
        connection.close()

Welcome to ChatDB

Please select a numbered option: 

1. Select or Change Database

2. Upload Data

3. View Table Attributes and Sample Data

4. Generate Sample Queries

5. Answer Natural Language Questions

6. End Program

