In [4]:
import os, re
import pandas as pd
import numpy as np
import sqlglot
import psycopg2
from sqlglot import parse_one, exp
from utils import get_files_absolute_path_from_dir, connect_postgres, sql_to_dataframe

In [5]:
path = os.path.join(os.getcwd(), 'all_queries', 'updated_queries')
files = os.listdir(path)
print(path)
queries = get_files_absolute_path_from_dir(path)

D:\BDMA\Data Warehouses\tpcds-benchmark\all_queries\updated_queries
Total files: 99
First few files...
['D:/BDMA/Data Warehouses/tpcds-benchmark/all_queries/updated_queries/query-01.sql', 'D:/BDMA/Data Warehouses/tpcds-benchmark/all_queries/updated_queries/query-02.sql', 'D:/BDMA/Data Warehouses/tpcds-benchmark/all_queries/updated_queries/query-03.sql', 'D:/BDMA/Data Warehouses/tpcds-benchmark/all_queries/updated_queries/query-04.sql', 'D:/BDMA/Data Warehouses/tpcds-benchmark/all_queries/updated_queries/query-05.sql']


In [6]:
conditions = []
def extract_conditions(expression):
    if isinstance(expression, sqlglot.exp.And) or isinstance(expression, sqlglot.exp.Or):
        extract_conditions(expression.args['this'])
        extract_conditions(expression.args['expression'])
    elif isinstance(expression, sqlglot.exp.Exists) or isinstance(expression, sqlglot.exp.Subquery):
        return
    elif isinstance(expression, sqlglot.exp.Paren):
            extract_conditions(expression.this)
    else:
        conditions.append(expression.sql())

def remove_subquery(condition):
    idx = condition.find('SELECT')
    if idx != -1:
        return condition[:idx]
    return condition


def remove_case_when(condition):
    if "CASE" in condition:
        return 0
    return 1


for query in queries:
    with open(query, "r") as f:
        all_lines = []
        for line in f:
            cleaned_line = line.strip()
            if cleaned_line != '':
                all_lines.append(cleaned_line)
                
        one_line_query = ' '.join(all_lines)
        parsed = sqlglot.parse_one(one_line_query, dialect="postgres")
        wheres = parsed.find_all(exp.Where)
        for where in wheres:
            extract_conditions(where.this)
    


conditions = list(map(remove_subquery, filter(remove_case_when ,conditions)))
    
print(len(conditions))

1327


In [7]:
cur = connect_postgres("tpcds")

PostgreSQL server information
{'user': 'postgres', 'channel_binding': 'prefer', 'dbname': 'tpcds', 'host': 'localhost', 'port': '5432', 'options': '', 'sslmode': 'prefer', 'sslcompression': '0', 'sslcertmode': 'allow', 'sslsni': '1', 'ssl_min_protocol_version': 'TLSv1.2', 'gssencmode': 'disable', 'krbsrvname': 'postgres', 'gssdelegation': '0', 'target_session_attrs': 'any', 'load_balance_hosts': 'disable'} 

You are connected to -  ('PostgreSQL 16.4, compiled by Visual C++ build 1940, 64-bit',) 



In [8]:
df = sql_to_dataframe(cur, "SELECT attname from PG_STATS WHERE schemaname = 'public';", ["column"])

In [9]:
no_operations = {}

for column in list(df['column']):
    no_operations[column] = {'e' : 0, 'c' : 0}
    for condition in conditions:
        if column in condition:
            if re.match(r".*(<|>|\bbetween\b).*", condition, flags=re.IGNORECASE):
                no_operations[column]['c'] += 1
            elif re.match(r".*(=|\bIN\b|\bIS\b|\bLIKE\b).*", condition, flags=re.IGNORECASE):
                no_operations[column]['e'] += 1
            

print(no_operations)

{'ca_address_sk': {'e': 46, 'c': 0}, 'ca_address_id': {'e': 0, 'c': 0}, 'ca_street_number': {'e': 0, 'c': 0}, 'ca_street_name': {'e': 0, 'c': 0}, 'ca_street_type': {'e': 0, 'c': 0}, 'ca_suite_number': {'e': 0, 'c': 0}, 'ca_city': {'e': 1, 'c': 2}, 'ca_county': {'e': 7, 'c': 0}, 'ca_state': {'e': 18, 'c': 0}, 'ca_zip': {'e': 5, 'c': 1}, 'ca_country': {'e': 9, 'c': 1}, 'ca_gmt_offset': {'e': 12, 'c': 0}, 'ca_location_type': {'e': 0, 'c': 0}, 'cd_demo_sk': {'e': 21, 'c': 0}, 'cd_gender': {'e': 4, 'c': 0}, 'cd_marital_status': {'e': 18, 'c': 1}, 'cd_education_status': {'e': 18, 'c': 0}, 'cd_purchase_estimate': {'e': 0, 'c': 0}, 'cd_credit_rating': {'e': 0, 'c': 0}, 'cd_dep_count': {'e': 0, 'c': 0}, 'cd_dep_employed_count': {'e': 0, 'c': 0}, 'cd_dep_college_count': {'e': 0, 'c': 0}, 'd_date_sk': {'e': 155, 'c': 0}, 'd_date_id': {'e': 0, 'c': 0}, 'd_date': {'e': 167, 'c': 27}, 'd_month_seq': {'e': 3, 'c': 22}, 'd_week_seq': {'e': 13, 'c': 0}, 'd_quarter_seq': {'e': 0, 'c': 0}, 'd_year': {'e'

In [10]:
analytics = pd.DataFrame([
    {'attname': key, 'equalities': value['e'], 'comparisons': value['c']}
    for key, value in no_operations.items()
])
analytics

Unnamed: 0,attname,equalities,comparisons
0,ca_address_sk,46,0
1,ca_address_id,0,0
2,ca_street_number,0,0
3,ca_street_name,0,0
4,ca_street_type,0,0
...,...,...,...
308,ss_ext_tax,0,0
309,ss_coupon_amt,0,6
310,ss_net_paid,0,1
311,ss_net_paid_inc_tax,0,0


In [11]:
# Set the threshold parameter
threshold = 5

In [12]:
# decide which columns need indexes and what type

def decide_index_type(row, threshold):
    e = row['equalities']
    c = row['comparisons']
    if e > threshold and c > threshold:
        return 'btree'
    elif e == 0 and c > threshold:
        return 'btree'
    elif e > threshold and c == 0:
        return 'hash'
    else:
        return None

analytics['index_type'] = analytics.apply(lambda row: decide_index_type(row, threshold), axis=1)

# columns that need indexing
index_candidates = analytics[analytics['index_type'].notnull()]
index_candidates.reset_index(drop=True, inplace=True)
index_candidates

Unnamed: 0,attname,equalities,comparisons,index_type
0,ca_address_sk,46,0,hash
1,ca_county,7,0,hash
2,ca_state,18,0,hash
3,ca_gmt_offset,12,0,hash
4,cd_demo_sk,21,0,hash
5,cd_education_status,18,0,hash
6,d_date_sk,155,0,hash
7,d_date,167,27,btree
8,d_week_seq,13,0,hash
9,d_year,95,6,btree


In [13]:
# list of columns to find tables for
columns_to_find = list(index_candidates['attname'])

columns_in_clause = "', '".join(columns_to_find)

query = f"""
SELECT table_name, column_name
FROM information_schema.columns
WHERE table_schema = 'public'
AND column_name IN ('{columns_in_clause}');
"""

cur.execute(query)
results = cur.fetchall()

tables_columns = pd.DataFrame(results, columns=['table_name', 'attname'])
tables_columns

Unnamed: 0,table_name,attname
0,web_sales,ws_order_number
1,date_dim,d_date
2,date_dim,d_week_seq
3,date_dim,d_year
4,date_dim,d_moy
5,item,i_item_sk
6,store_sales,ss_ticket_number
7,store_sales,ss_quantity
8,store_sales,ss_wholesale_cost
9,store_sales,ss_list_price


In [14]:
index_candidates = index_candidates[['attname', 'equalities', 'comparisons', 'index_type']]
tables_columns = tables_columns[['table_name', 'attname']]

# merge with suffixes
index_candidates = index_candidates.merge(tables_columns, on='attname', how='inner', suffixes=('', '_from_tables'))
index_candidates

Unnamed: 0,attname,equalities,comparisons,index_type,table_name
0,ca_address_sk,46,0,hash,customer_address
1,ca_county,7,0,hash,customer_address
2,ca_state,18,0,hash,customer_address
3,ca_gmt_offset,12,0,hash,customer_address
4,cd_demo_sk,21,0,hash,customer_demographics
5,cd_education_status,18,0,hash,customer_demographics
6,d_date_sk,155,0,hash,date_dim
7,d_date,167,27,btree,date_dim
8,d_week_seq,13,0,hash,date_dim
9,d_year,95,6,btree,date_dim


In [15]:
# n_distinct from pg_stats for all columns
query = """
SELECT tablename, attname, n_distinct
FROM pg_stats
WHERE schemaname = 'public';
"""

cur.execute(query)
results = cur.fetchall()

# df with n_distinct
pg_stats_df = pd.DataFrame(results, columns=['table_name', 'attname', 'n_distinct'])

# merge with index_candidates
index_candidates = index_candidates.merge(pg_stats_df, on=['table_name', 'attname'], how='left')
index_candidates

Unnamed: 0,attname,equalities,comparisons,index_type,table_name,n_distinct
0,ca_address_sk,46,0,hash,customer_address,-1.0
1,ca_county,7,0,hash,customer_address,1846.0
2,ca_state,18,0,hash,customer_address,51.0
3,ca_gmt_offset,12,0,hash,customer_address,6.0
4,cd_demo_sk,21,0,hash,customer_demographics,-1.0
5,cd_education_status,18,0,hash,customer_demographics,7.0
6,d_date_sk,155,0,hash,date_dim,-1.0
7,d_date,167,27,btree,date_dim,-1.0
8,d_week_seq,13,0,hash,date_dim,-0.142863
9,d_year,95,6,btree,date_dim,201.0


In [16]:
# estimated row counts from pg_class
table_names = index_candidates['table_name'].unique()
table_names_list = "', '".join(table_names)

query = f"""
SELECT relname AS table_name, reltuples AS row_count
FROM pg_class
WHERE relname IN ('{table_names_list}');
"""

cur.execute(query)
results = cur.fetchall()
row_counts_df = pd.DataFrame(results, columns=['table_name', 'row_count'])

# merge with index_candidates
index_candidates = index_candidates.merge(row_counts_df, on='table_name', how='left')
index_candidates

Unnamed: 0,attname,equalities,comparisons,index_type,table_name,n_distinct,row_count
0,ca_address_sk,46,0,hash,customer_address,-1.0,50000.0
1,ca_county,7,0,hash,customer_address,1846.0,50000.0
2,ca_state,18,0,hash,customer_address,51.0,50000.0
3,ca_gmt_offset,12,0,hash,customer_address,6.0,50000.0
4,cd_demo_sk,21,0,hash,customer_demographics,-1.0,1920800.0
5,cd_education_status,18,0,hash,customer_demographics,7.0,1920800.0
6,d_date_sk,155,0,hash,date_dim,-1.0,73049.0
7,d_date,167,27,btree,date_dim,-1.0,73049.0
8,d_week_seq,13,0,hash,date_dim,-0.142863,73049.0
9,d_year,95,6,btree,date_dim,201.0,73049.0


In [17]:
# estimated number of distinct values (based on n_distinct and row_count)
def compute_estimated_distinct(row):
    n_distinct = row['n_distinct']
    if n_distinct >= 0:
        return n_distinct
    else:
        # negative n_distinct -> a fraction of row_count
        return abs(n_distinct) * row['row_count']

index_candidates['estimated_distinct'] = index_candidates.apply(compute_estimated_distinct, axis=1)
index_candidates

Unnamed: 0,attname,equalities,comparisons,index_type,table_name,n_distinct,row_count,estimated_distinct
0,ca_address_sk,46,0,hash,customer_address,-1.0,50000.0,50000.0
1,ca_county,7,0,hash,customer_address,1846.0,50000.0,1846.0
2,ca_state,18,0,hash,customer_address,51.0,50000.0,51.0
3,ca_gmt_offset,12,0,hash,customer_address,6.0,50000.0,6.0
4,cd_demo_sk,21,0,hash,customer_demographics,-1.0,1920800.0,1920800.0
5,cd_education_status,18,0,hash,customer_demographics,7.0,1920800.0,7.0
6,d_date_sk,155,0,hash,date_dim,-1.0,73049.0,73049.0
7,d_date,167,27,btree,date_dim,-1.0,73049.0,73049.0
8,d_week_seq,13,0,hash,date_dim,-0.142863,73049.0,10436.0
9,d_year,95,6,btree,date_dim,201.0,73049.0,201.0


In [18]:
# filter out columns -> this can be adjusted
index_candidates = index_candidates[index_candidates['estimated_distinct'] > 1000]
index_candidates.reset_index(drop=True, inplace=True)
index_candidates

Unnamed: 0,attname,equalities,comparisons,index_type,table_name,n_distinct,row_count,estimated_distinct
0,ca_address_sk,46,0,hash,customer_address,-1.0,50000.0,50000.0
1,ca_county,7,0,hash,customer_address,1846.0,50000.0,1846.0
2,cd_demo_sk,21,0,hash,customer_demographics,-1.0,1920800.0,1920800.0
3,d_date_sk,155,0,hash,date_dim,-1.0,73049.0,73049.0
4,d_date,167,27,btree,date_dim,-1.0,73049.0,73049.0
5,d_week_seq,13,0,hash,date_dim,-0.142863,73049.0,10436.0
6,t_time_sk,14,0,hash,time_dim,-1.0,86400.0,86400.0
7,i_item_sk,76,0,hash,item,-1.0,18000.0,18000.0
8,i_item_id,7,0,hash,item,-0.5,18000.0,9000.0
9,i_current_price,0,10,btree,item,-0.149333,18000.0,2688.0


In [19]:
# create indexes
sql_commands = []

for idx, row in index_candidates.iterrows():
    table_name = row['table_name']
    column_name = row['attname']
    index_type = row['index_type']
    index_name = f'idx_{table_name}_{column_name}_{index_type}'
    # index_name does not exceed 63 characters
    index_name = index_name[:63]
    sql = f'CREATE INDEX {index_name} ON {table_name} USING {index_type} ({column_name});'
    sql_commands.append(sql)

for sql in sql_commands:
    print(sql)

CREATE INDEX idx_customer_address_ca_address_sk_hash ON customer_address USING hash (ca_address_sk);
CREATE INDEX idx_customer_address_ca_county_hash ON customer_address USING hash (ca_county);
CREATE INDEX idx_customer_demographics_cd_demo_sk_hash ON customer_demographics USING hash (cd_demo_sk);
CREATE INDEX idx_date_dim_d_date_sk_hash ON date_dim USING hash (d_date_sk);
CREATE INDEX idx_date_dim_d_date_btree ON date_dim USING btree (d_date);
CREATE INDEX idx_date_dim_d_week_seq_hash ON date_dim USING hash (d_week_seq);
CREATE INDEX idx_time_dim_t_time_sk_hash ON time_dim USING hash (t_time_sk);
CREATE INDEX idx_item_i_item_sk_hash ON item USING hash (i_item_sk);
CREATE INDEX idx_item_i_item_id_hash ON item USING hash (i_item_id);
CREATE INDEX idx_item_i_current_price_btree ON item USING btree (i_current_price);
CREATE INDEX idx_customer_c_customer_sk_hash ON customer USING hash (c_customer_sk);
CREATE INDEX idx_customer_c_current_cdemo_sk_hash ON customer USING hash (c_current_cdemo

In [20]:
output_dir = 'index_setup'
output_file = os.path.join(output_dir, 'generated_indexes.sql')

# write SQL commands to the output file
with open(output_file, 'w') as f:
    for sql in sql_commands:
        f.write(sql + '\n')

print(f'SQL commands written to {output_file}')

SQL commands written to index_setup\generated_indexes.sql


In [22]:
cur.close()