In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib.lines import Line2D

from classes.classes import ACT_CARD_ALL_MODEL_CONFIGS
from classes.classes import QueryFormerModelConfig, E2EModelConfig

from classes.classes import MODEL_CONFIGS
from classes.classes import ScaledPostgresModelConfig
from classes.paths import LocalPaths
from classes.workloads import EvalWorkloads
from classes.workloads import JoinOrderEvalWorkload
from cross_db_benchmark.datasets.datasets import Database
from evaluation.plots.evaluation_metrics import MissedPlansFraction, MaxOverestimation, MaxUnderestimation
from evaluation.plots.evaluation_metrics import QError
from evaluation.plots.evaluation_metrics import SelectedRuntime
from evaluation.plots.evaluation_metrics import SpearmanCorrelation
from evaluation.plots.utils import get_model_results, draw_predictions, draw_metric

STANDARD_HATCHES = ['//', '\\\\', None,  '+++', '..', '--', 'xx', 'o', '///']
sns.set_theme(style="whitegrid", font_scale=1.8)
fontsize = 14

----
## 2. Join Order Examples

In [None]:
#General Settings
path = LocalPaths().data / "plots" / "join_order_examples.pdf"
mosaic = """AAAABBEE
            AAAACCFF
            AAAADDGG"""
folder = "join_order_full"
grid_spec = {'height_ratios': [1,1,1], 'wspace': 3, 'hspace': 0.3}
hatches = STANDARD_HATCHES * 6

# Target workload
workloads = [JoinOrderEvalWorkload(database=Database("imdb"), folder=folder, wl_name="job_light_33", num_tables=4)]

title =("SELECT COUNT(*) FROM title t JOIN movie_info mi ON t.id = mi.movie_id JOIN movie_info_idx mii ON t.id = mii.movie_id  \n JOIN movie_companies mc ON t.id = mc.movie_id "
        "WHERE mii.info_type_id = 101 AND mi.info_type_id = 3  \n AND t.production_year > 2005  AND t.production_year < 2008 AND mc.company_type_id = 2;")

# Create plot
fig = plt.figure(figsize=(11 * len(workloads) , 3), dpi=100)
figures = fig.subfigures(nrows=1, ncols=len(workloads), wspace=-0.1, hspace=-0.05)

for idx, (workload, figure) in enumerate(zip(workloads, [figures])):
    subplots = figure.subplot_mosaic(mosaic, gridspec_kw=grid_spec).values()
    [prediction_ax, q_error_ax, runtime_ax, missed_plan_ax, spearmans_ax, underest_ax, overest_ax] = subplots
    
    results = get_model_results(workload, MODEL_CONFIGS)
    results = results.sort_values(by='label')

    draw_predictions(workload, results, MODEL_CONFIGS, prediction_ax, fontsize)
    r = draw_metric(results, MODEL_CONFIGS, q_error_ax, QError(), fontsize)
    draw_metric(results, MODEL_CONFIGS, spearmans_ax, SpearmanCorrelation(), fontsize)
    draw_metric(results, MODEL_CONFIGS, missed_plan_ax, MissedPlansFraction(), fontsize)
    draw_metric(results, MODEL_CONFIGS, overest_ax, MaxOverestimation(), fontsize)
    draw_metric(results, MODEL_CONFIGS, underest_ax, MaxUnderestimation(), fontsize)
    draw_metric(results, MODEL_CONFIGS, runtime_ax, SelectedRuntime(display_name="Selected\nRuntime(s)"), fontsize)
    
    for plot in [q_error_ax, runtime_ax, missed_plan_ax, spearmans_ax, underest_ax, overest_ax]:
        # add hatches to handles
        for i, patch in enumerate(plot.patches):
            patch.set_hatch(hatches[i])

    # Configure runtime prediction plot
    prediction_ax.set_xlabel('Join Enumeration', fontsize=fontsize)
    prediction_ax.annotate(xy=(0.1, 0.03), 
                           text='Most LCMs fail in ranking join orders',                         
                           xycoords='axes fraction',
                           fontsize=fontsize)
    
    # Configure q-error plot
    q_error_ax.set_ylim(1, 3)
    q_error_ax.set_yticks([1, 2, 3], labels=[1, 2, 3], fontsize=fontsize)
    q_error_ax.minorticks_off()
    
    # Configure spearman plot
    runtime_ax.set_ylim(0, 6)
    runtime_ax.axhline(y=results['label'].min(), linestyle='-', color='black', linewidth=2, zorder=100)
    runtime_ax.annotate(text='Optimal\nRuntime',
                        xy=(1.05, 0.21),
                        xycoords='axes fraction',
                        fontsize=fontsize * 0.75,
                        ha='left',
                        va='bottom')
    
    # Configure underestimation plot
    underest_ax.set_yticks([1, 3, 5], labels=[1, 3, 5], fontsize=fontsize)
    underest_ax.minorticks_off()
    
    # Configure overestimation plot
    overest_ax.set_yticks([1, 5, 10], labels=[1, 5, 10], fontsize=fontsize)
    overest_ax.minorticks_off()

    # Add letters to the subplots
    for plot, letter in zip(list(subplots), "ABCDEFG"):
        plot.annotate(
            letter,
            xy=(0.05, 0.95), 
            xycoords='axes fraction', 
            fontsize=9, 
            ha='center', 
            va='center', 
            bbox=dict(boxstyle='circle,pad=0.2', edgecolor='black', facecolor='white'))
    
    figure.suptitle(title, fontsize=fontsize * 0.7, fontproperties={'family': 'monospace'}, x=0.4, y=1.05, horizontalalignment='center')
  
# Create common legend    
legend_handles = [mpatches.Patch(color=model_config.color(), label=model_config.name.DISPLAY_NAME) for model_config in MODEL_CONFIGS]
for p in legend_handles:
    p.set_edgecolor('black')
legend_handles.insert(0, Line2D([0], [0], color='black', lw=1, linestyle='-', label='Real Runtime'))

legend = q_error_ax.legend(handles=legend_handles,
                           fontsize=fontsize,
                           ncol=1,
                           loc='center left',
                           bbox_to_anchor=(-5.1, -0.8),
                           labelspacing=0.2,
                           edgecolor='white')


for line in legend.get_lines():
    line.set_linewidth(6.0)

for i, patch in enumerate(legend.get_patches()):
    patch.set_hatch(hatches[i])

fig.align_labels()
plt.savefig(path, bbox_inches='tight')

In [None]:
metric = SpearmanCorrelation()
model_configs = [ScaledPostgresModelConfig(), QueryFormerModelConfig(), E2EModelConfig()]
extract = pd.DataFrame({
        model.name.DISPLAY_NAME: metric.evaluate_metric(preds=results[results["model"] == model.name.DISPLAY_NAME]["prediction"],
                                                        labels=results[results["model"] == model.name.DISPLAY_NAME]['label']) for model in model_configs},
        index=[metric.metric_name]).T
extract

----
## 3. Join Order (Full Benchmark)

In [None]:
spearman_df = pd.DataFrame()
runtime_df = pd.DataFrame()
missed_plans_df = pd.DataFrame()
overest_df  = pd.DataFrame()
underest_df = pd.DataFrame()
minimal_runtimes = []
blank_row = pd.DataFrame({'Index': ['empty']})

model_configs = ACT_CARD_ALL_MODEL_CONFIGS

for workload in EvalWorkloads.FullJoinOrder.imdb:
    results = get_model_results(workload, model_configs)
    for model in model_configs:
        model_results = results[results["model"] == model.name.DISPLAY_NAME]
        spearman_corr = SpearmanCorrelation().evaluate_metric(preds=model_results["prediction"], 
                                                              labels=model_results['label'])
        spearman_df.loc[model.name.DISPLAY_NAME, workload.get_workload_name()] = spearman_corr
        
        selected_runtime = SelectedRuntime().evaluate_metric(preds=model_results["prediction"], labels=model_results['label'])
        runtime_df.loc[model.name.DISPLAY_NAME, workload.get_workload_name()] = selected_runtime
        
        missed_plans = MissedPlansFraction().evaluate_metric(preds=model_results["prediction"], labels=model_results['label'])
        missed_plans_df.loc[model.name.DISPLAY_NAME, workload.get_workload_name()] = missed_plans
        
        overest = MaxOverestimation().evaluate_metric(preds=model_results["prediction"], labels=model_results['label'])
        overest_df.loc[model.name.DISPLAY_NAME, workload.get_workload_name()] = overest

        underest_ax = MaxUnderestimation().evaluate_metric(preds=model_results["prediction"], labels=model_results['label'])
        underest_df.loc[model.name.DISPLAY_NAME, workload.get_workload_name()] = underest_ax
        
    minimal_runtimes.append(model_results["label"].min())   

spearman_df = pd.concat([spearman_df.iloc[:9], blank_row, spearman_df.iloc[9:]])
spearman_df.iloc[9] = 0

runtime_df = pd.concat([runtime_df.iloc[:9], blank_row, runtime_df.iloc[9:]])
runtime_df.iloc[9] = 0

missed_plans_df = pd.concat([missed_plans_df.iloc[:9], blank_row, missed_plans_df.iloc[9:]])
missed_plans_df.iloc[9] = 0

overest_df = pd.concat([overest_df.iloc[:9], blank_row, overest_df.iloc[9:]])
overest_df.iloc[9] = 0

underest_df = pd.concat([underest_df.iloc[:9], blank_row, underest_df.iloc[9:]])
underest_df.iloc[9] = 0

In [None]:
path = LocalPaths().data / "plots" / "join_order_full.pdf"
hatches = STANDARD_HATCHES + ['xx', '--||', '...', 'xx', '---']
hatches = hatches * 2

# Create a color palette based on model names
color_mapping = {model.name.DISPLAY_NAME: model.color() for model in ACT_CARD_ALL_MODEL_CONFIGS}
color_mapping['0'] = 'white'  # Set the color for the blank category

# Create the boxplot and barplot
fig, (ax2, ax3, ax1, ax4, ax5) = plt.subplots(1, 5, figsize=(16, 3))

sns.boxplot(data=spearman_df.T, palette=color_mapping, ax=ax1, width=1, medianprops={"color": "black", "linewidth": 2}, notch=True)
#ax1.set_title('Spearman Correlation over JOB-Light Permutations')
ax1.set_xlabel('')
ax1.set_title('Spearman Correlation', fontsize=fontsize)
ax1.xaxis.set_ticklabels([])
ax1.set_ylim(-1, 1)
ax1.tick_params(axis='y', which='major', pad=0, labelsize=fontsize)

sns.barplot(data=runtime_df.sum(axis=1).reset_index(), x='index', hue='index', palette=color_mapping, y=0, ax=ax2, edgecolor='black', width=1)
ax2.set_xlabel('')
ax2.set_title('Total Runtime (s)', fontsize=fontsize)
ax2.xaxis.set_ticklabels([])
ax2.set_ylabel("")
ax2.axhline(y=sum(minimal_runtimes), color='black', linestyle='-', linewidth=2)
ax2.tick_params(axis='y', which='major', pad=0, labelsize=fontsize)

sns.boxplot(data=missed_plans_df.T, palette=color_mapping, ax=ax3, width=1, medianprops={"color": "black", "linewidth": 2}, notch=True)
ax3.set_xlabel('')
ax3.set_title('Surpassed Plans (%)', fontsize=fontsize)
ax3.xaxis.set_ticklabels([])
ax3.set_ylim(0, 100)
ax3.tick_params(axis='y', which='major', pad=0,  labelsize=fontsize)

sns.boxplot(data=underest_df.T, palette=color_mapping, ax=ax4, width=1, medianprops={"color": "black", "linewidth": 2}, notch=True)
ax4.set_xlabel('')
ax4.set_title('Underestimation', fontsize=fontsize)
ax4.xaxis.set_ticklabels([])
ax4.set_ylim(0.9, 25)
ax4.set_yscale('log')
ax4.tick_params(axis='y', which='major', pad=0,  labelsize=fontsize)

sns.boxplot(data=overest_df.T, palette=color_mapping, ax=ax5, width=1, medianprops={"color": "black", "linewidth": 2}, notch=True)
ax5.set_xlabel('')
ax5.set_title('Overestimation', fontsize=fontsize)
ax5.xaxis.set_ticklabels([])
ax5.set_ylim(0.9, 25)
ax5.set_yscale('log')
ax5.tick_params(axis='y', which='major', pad=0,  labelsize=fontsize)

for ax in (ax1, ax2, ax3, ax4, ax5):
    ax.annotate('Act. Card.', xy=(0.88, -0.06), xycoords='axes fraction', fontsize=fontsize * 0.9, ha='center', va='center', color='black')
    ax.annotate('Est. Card.', xy=(0.3, -0.06), xycoords='axes fraction', fontsize=fontsize * 0.9, ha='center', va='center', color='black')
        
    for i, patch in enumerate(ax.patches):
        patch.set_hatch(hatches[i])
        patch.set_edgecolor('black')

# Create legend patches
legend_patches = [mpatches.Patch(color=model_config.color(), label=model_config.name.DISPLAY_NAME) for model_config in model_configs]
for p in legend_patches:
    p.set_edgecolor('black')
legend_patches.insert(0, Line2D([0], [0], color='black', lw=4, linestyle='-', label='Optimal Runtime'))
legend_patches.insert(10, plt.Line2D([], [], linewidth=0))
legend_patches.insert(3, plt.Line2D([], [], linewidth=0))

handles = ax2.legend(handles=legend_patches, 
           loc='center right', 
           bbox_to_anchor=(-0.2, 0.5), 
           edgecolor='white', 
           labelspacing=0.05, 
           fontsize=fontsize * 0.9)

hatches.pop(8)
for i, patch in enumerate(handles.get_patches()):
    patch.set_hatch(hatches[i])
    patch.set_edgecolor('black')


for ax in (ax2, ax3, ax1, ax4, ax5):
    ax.axvspan(xmin=9.5, xmax=14.5, alpha=0.2, color='gray', zorder=-100)
    ax.set_xlim(-0.6, 13.6)

fig.align_labels()
plt.savefig(path, bbox_inches='tight')

In [None]:
sum(minimal_runtimes)