In [1]:
import sqlglot
import pandas as pd

In [2]:
path = "../queries/1a.sql"

In [3]:
path = "../../../spark-eval/benchmark/job/1a.sql"

In [4]:
path_folder = "../../../spark-eval/benchmark/job/"

In [11]:
with open(path, 'r') as f:
    query = f.read()
    print(query)

SELECT MIN(mc.note) AS production_note,
       MIN(t.title) AS movie_title,
       MIN(t.production_year) AS movie_year
FROM company_type AS ct,
     info_type AS it,
     movie_companies AS mc,
     movie_info_idx AS mi_idx,
     title AS t
WHERE ct.kind = 'production companies'
  AND it.info = 'top 250 rank'
  AND mc.note NOT LIKE '%(as Metro-Goldwyn-Mayer Pictures)%'
  AND (mc.note LIKE '%(co-production)%'
       OR mc.note LIKE '%(presents)%')
  AND ct.id = mc.company_type_id
  AND t.id = mc.movie_id
  AND t.id = mi_idx.movie_id
  AND mc.movie_id = mi_idx.movie_id
  AND it.id = mi_idx.info_type_id;




In [4]:
# Parse the SQL query
parsed_query = sqlglot.parse_one(query)

# Extract all tables from the query using sqlglot's expressions
table_names = [table.name for table in parsed_query.find_all(sqlglot.expressions.Table)]

# Print the extracted table names
print(table_names)


['company_type', 'info_type', 'movie_companies', 'movie_info_idx', 'title']


In [5]:
#List attribtues of table
# Parse the SQL query
parsed_query = sqlglot.parse_one(query)

# Dictionary to hold table names and their corresponding attributes
table_columns = {}

# Loop through all columns in the SELECT clause
for column in parsed_query.find_all(sqlglot.expressions.Column):
    # Get the table and column names
    table_name = column.table
    column_name = column.name
    
    # Add the column name under the corresponding table name in the dictionary
    if table_name:
        if table_name not in table_columns:
            table_columns[table_name] = []
        table_columns[table_name].append(column_name)

# Print the table names and their corresponding attributes
for table, columns in table_columns.items():
    print(f"Table: {table}, Columns: {columns}")


Table: mc, Columns: ['note', 'movie_id', 'movie_id', 'company_type_id', 'note', 'note', 'note']
Table: t, Columns: ['title', 'production_year', 'id', 'id']
Table: it, Columns: ['id', 'info']
Table: mi_idx, Columns: ['info_type_id', 'movie_id', 'movie_id']
Table: ct, Columns: ['id', 'kind']


In [6]:
#Find Aggregate Attributes
# Parse the SQL query
parsed_query = sqlglot.parse_one(query)

# List to hold the aggregation columns
aggregated_columns = []

# Loop through all aggregation functions in the query
for agg_func in parsed_query.find_all(sqlglot.expressions.AggFunc):
    # Get the column being aggregated
    for column in agg_func.find_all(sqlglot.expressions.Column):
        table_name = column.table
        column_name = column.name
        aggregated_columns.append((table_name, column_name))

# Print the aggregated columns
for table, column in aggregated_columns:
    print(f"Table: {table}, Column in Aggregation: {column}")


Table: mc, Column in Aggregation: note
Table: t, Column in Aggregation: title
Table: t, Column in Aggregation: production_year


In [None]:
#Find GROUP Attributes

# Parse the SQL query
parsed_query = sqlglot.parse_one(query)

# List to hold the aggregation columns
grouped_columns = []

# Loop through all grouped functions in the query
for agg_func in parsed_query.find_all(sqlglot.expressions.Group):
    # Get the column being grouped
    for column in agg_func.find_all(sqlglot.expressions.Column):
        table_name = column.table
        column_name = column.name
        grouped_columns.append((table_name, column_name))

# Print the grouped columns
for table, column in grouped_columns:
    print(f"Table: {table}, Column in Aggregation: {column}")

In [24]:
def getNumberGroupAttributes(query):
    parsed_query = sqlglot.parse_one(query)
    count_group = 0
    
    for agg_func in parsed_query.find_all(sqlglot.expressions.Group):
        # Get the column being grouped
        for column in agg_func.find_all(sqlglot.expressions.Column):
            count_group+=1

    return count_group


def getNumberTables(query):
    parsed_query = sqlglot.parse_one(query)
    table_names = [table.name for table in parsed_query.find_all(sqlglot.expressions.Table)]
    
    return len(table_names)



import os
query_info = {} #qID, #tables, #groupingAttributes
idx = 0
# Open one of the files,
for data_file in sorted(os.listdir(path_folder)):
    if data_file.endswith(".sql"):
        q_path = path_folder + data_file
        with open(q_path, 'r') as f:
            query = f.read()
           
            num_tables = getNumberTables(query)
            num_group = getNumberGroupAttributes(query)
            query_info[idx] = (data_file, num_tables, num_group)
    idx+=1
    if(idx == 5): break

df = pd.DataFrame(query_info, index=['qID', 'c_tables', 'c_groupingAttributes'])
df.T #transpose matrix
print(df)
print(df["c_tables"].mean())
print(df["c_groupingAttributes"].mean())

       qID c_tables c_groupingAttributes
1  10a.sql        7                    0
2  10b.sql        7                    0
3  10c.sql        7                    0
4  11a.sql        8                    0
7.25
0.0
