In [None]:
import os
from config import query_directory, TABLES, DB_CONFIG, connect_to_db
from typing import List, Dict, Tuple
from file_utils import read_queries_from_directory, write_query_results_to_file
import pandas as pd
from tqdm import tqdm
from sql_parser import parse_sql
from query_template import get_query_template
from collections import defaultdict
import networkx as nx
import pandas as pd
from sql_parser import parse_sql
import matplotlib.pyplot as plt
from sql_parser import parse_sql

from pathlib import Path
try:
    # Works when running from a .py file
    REPO_ROOT = Path(__file__).resolve()
except NameError:
    # Fallback for notebooks — use current working directory
    REPO_ROOT = Path.cwd().resolve()
while REPO_ROOT.name != "Learned-Optimizers-Benchmarking-Suite" and REPO_ROOT.parent != REPO_ROOT:
    REPO_ROOT = REPO_ROOT.parent

job = f"{REPO_ROOT}/workloads/imdb_pg_dataset/job"
job_dynamic = f"{REPO_ROOT}/workloads/imdb_pg_dataset/job_dynamic"
job_extended = f"{REPO_ROOT}/workloads/imdb_pg_dataset/job_extended"

In [2]:
from query_template import get_query_template_no_correl
from config import connect_to_db
from selectivity import DatabaseCache
import os
import re
from query_template import SQLInfoExtractor
import pglast

def build_query_profile(parsed_query: dict, query_name: str, query_info: dict, query: str, alias_mapping: Dict[str, str] = None) -> dict:
    sql_operator_pattern = re.compile(
        r'(>=|<=|!=|=|>|<|LIKE|like|IN|BETWEEN|IS\s+NOT\s+NULL|IS\s+NULL|NOT\s*(=|LIKE|like|IN|BETWEEN))'
    )

    node = pglast.parse_sql(query)
    extractor = SQLInfoExtractor()
    extractor(node)
    
    info = extractor.info

    selected_columns = []
    for column in parsed_query.get('select_columns', []):
        if '.' in column:
            table, column_name = column.split('.', 1)
            if alias_mapping and table in alias_mapping:
                table = alias_mapping[table]
            selected_columns.append((table, column_name))

    # Process join enumeration
    join_enumeration = []
    for i, join in enumerate(query_info.get('join_details', []), 1):
        join_entry = {
            'join_number': f'join{i}',
            'tables': join['tables'],
            'condition': join['condition'],
            'actual_rows': None
        }
        if 'execution_data' in join:
            join_entry['actual_rows'] = join['execution_data'].get('actual_rows')
        join_enumeration.append(join_entry)

    profile = {
        'query_name': query_name,
        'num_joins': len(parsed_query.get('joins', [])),
        'num_tables': len(set(parsed_query.get('from_tables', []))),
        'selected_columns': sorted(selected_columns),
        'num_columns_selected': len(parsed_query.get('select_columns', [])),
        'num_predicates': len(parsed_query.get('filters', [])),
        'low_selectivity_predicates': 0,
        'high_selectivity_predicates': 0,
        'tables': sorted(info.get('tables', [])),
        'predicates': sorted(info.get('predicates', [])),
        'joins': parsed_query.get('joins', []),
        'join_details': query_info.get('join_details', []),
        'join_enumeration': join_enumeration
    }

    for table, filters in query_info.get('filters', {}).items():
        for column, value in filters.items():
            if isinstance(value, float):  # valid selectivity
                selectivity = value
                if selectivity <= 0.05:
                    profile['low_selectivity_predicates'] += 1
                else:
                    profile['high_selectivity_predicates'] += 1
    
    return profile

def process_queries_in_directory(directory: str, db_cache: DatabaseCache):
    results = []
    sql_files = [f for f in os.listdir(directory) if f.endswith(".sql")]
    query_profiles = []
    # sql_files = ['31c.sql']
    for filename in tqdm(sql_files, desc="Processing SQL files", unit="file"):
        filepath = os.path.join(directory, filename)
        try:
            with open(filepath, "r") as file:
                query = file.read().strip()
                if query:
                    parsed_query = parse_sql(query, split_parentheses=True)
                    aliases = parsed_query.get('aliases', {})
                    query_info = get_query_template_no_correl(parsed_query, db_cache, alias_mapping=aliases, original_query=query)
                    profile = build_query_profile(parsed_query, filename, query_info, query, alias_mapping=aliases)
                    query_profiles.append(profile)
                    results.append((filename, query_info))
        except IOError as e:
            print(f"Error reading file {filename}: {e}")
    
    return results, query_profiles

In [3]:
conn = connect_to_db()
db_cache = DatabaseCache(conn)
db_cache.preload_all_tables()
conn.close()

Loaded 901,343 rows from aka_name
Loaded 361,472 rows from aka_title
Loaded 36,244,344 rows from cast_info
Loaded 3,140,339 rows from char_name
Loaded 4 rows from comp_cast_type
Loaded 234,997 rows from company_name
Loaded 4 rows from company_type
Loaded 135,086 rows from complete_cast
Loaded 113 rows from info_type
Loaded 134,170 rows from keyword
Loaded 7 rows from kind_type
Loaded 18 rows from link_type
Loaded 2,609,129 rows from movie_companies
Loaded 14,835,720 rows from movie_info
Loaded 1,380,035 rows from movie_info_idx
Loaded 4,523,930 rows from movie_keyword
Loaded 29,997 rows from movie_link
Loaded 4,167,491 rows from name
Loaded 2,963,664 rows from person_info
Loaded 12 rows from role_type
Loaded 2,528,312 rows from title


In [4]:
results, query_profiles_job = process_queries_in_directory(query_directory, db_cache)
job_original_profiles = pd.DataFrame(query_profiles_job)

  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = p

In [5]:
results, query_profiles_job = process_queries_in_directory(query_directory_variation, db_cache)
job_modified_profiles = pd.DataFrame(query_profiles_job)

  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = p

In [6]:
from typing import List, Dict, Any
from collections import defaultdict
from distribution import MetadataDistribution

TYPES = ["tables", "predicates", "selected_columns", "join_enumeration"]

def build_workload_distributions(query_profiles: List[Dict[str, Any]], use_predefined_bins: bool = False, config: Dict = None):
    """
    Build MetadataDistribution for each feature type from a list of query profiles.
    """
    workload_data = defaultdict(list)

    for profile in query_profiles:
        for t in TYPES:
            values = profile.get(t, [])
            # Store the values as lists (ensure hashability)
            workload_data[t].append(values)

    dists = {}

    for t in TYPES:
        data = workload_data[t]
        if use_predefined_bins and config and "bin_values" in config and t in config["bin_values"]:
            slots = config["bin_values"][t]
            dists[t] = MetadataDistribution(t, data, slots=slots)
        else:
            dists[t] = MetadataDistribution(t, data)

    return dists

# Build distributions for both workloads
dists_job_original = build_workload_distributions(job_original_profiles.to_dict(orient='records'))
dists_job_modified = build_workload_distributions(job_modified_profiles.to_dict(orient='records'))

MetadataDistribution: tables with 113 items
MetadataDistribution: predicates with 113 items
MetadataDistribution: selected_columns with 113 items
MetadataDistribution: join_enumeration with 113 items
MetadataDistribution: tables with 113 items
MetadataDistribution: predicates with 113 items
MetadataDistribution: selected_columns with 113 items
MetadataDistribution: join_enumeration with 113 items


In [7]:
from collections import Counter
from scipy.spatial.distance import jensenshannon
import numpy as np

def compute_workload_distance_from_dists(dists1: Dict[str, MetadataDistribution], dists2: Dict[str, MetadataDistribution]) -> float:
    distances = []

    for key in TYPES:
        dist1_obj = dists1.get(key)
        dist2_obj = dists2.get(key)
        
        if dist1_obj is None or dist2_obj is None:
            continue

        dist1 = dist1_obj.get()
        dist2 = dist2_obj.get()

        if type == "selected_columns":
            print(dist2)
            
        import json
        def make_hashable_key(k):
            return json.dumps(k, sort_keys=True)
        
        dist1_serialized = pd.Series(
            {make_hashable_key(k): v for k, v in dist1.items()}
        )
        dist2_serialized = pd.Series(
            {make_hashable_key(k): v for k, v in dist2.items()}
        )

        all_bins = dist1_serialized.index.union(dist2_serialized.index)

        p = np.array([dist1_serialized.get(k, 0.0) for k in all_bins])
        q = np.array([dist2_serialized.get(k, 0.0) for k in all_bins])
        p = p / (p.sum() if p.sum() != 0 else 1)
        q = q / (q.sum() if q.sum() != 0 else 1)
        jsd = jensenshannon(p, q)
        print(f"Jensen-Shannon Divergence for {key}: {jsd}")
        distances.append(jsd)

    return float(np.mean(distances)) if distances else 0.0

distance = compute_workload_distance_from_dists(dists_job_original, dists_job_modified)
print("Workload distance (Jensen-Shannon Divergence):", distance)

Jensen-Shannon Divergence for tables: 0.0
Jensen-Shannon Divergence for predicates: 0.0
Jensen-Shannon Divergence for selected_columns: 0.15664029935700752
Jensen-Shannon Divergence for join_enumeration: 0.0
Workload distance (Jensen-Shannon Divergence): 0.03916007483925188
