In [None]:
import os
import seaborn as sns
import matplotlib.pyplot as plt
from classes.paths import LocalPaths
from classes.classes import ColorManager
from evaluation.plots.utils import load_wandb_runs
from dotenv import load_dotenv
from evaluation.plots.evaluation_metrics import SelectedRuntime
from evaluation.plots.utils import get_model_results
import pandas as pd
from classes.workloads import EvalWorkloads
from classes.classes import MODEL_CONFIGS
from classes.classes import ModelName
import matplotlib.patches as mpatches

load_dotenv()
sns.set_theme(style="whitegrid", font_scale=1.8)
%load_ext autoreload
%autoreload 2

In [None]:
training_results = load_wandb_runs(wandb_user= os.environ["WANDB_USER"], 
                                   wandb_project=os.environ["WANDB_PROJECT"], 
                                   result_dir=LocalPaths().data / "training"/ "results.csv", 
                                   model_confs=MODEL_CONFIGS)

In [None]:
runtime_df = pd.DataFrame()
minimal_runtimes = []
for workload in EvalWorkloads.FullJoinOrder.imdb:
    results = get_model_results(workload, MODEL_CONFIGS)
    for model in MODEL_CONFIGS:
        model_results = results[results["model"] == model.name.DISPLAY_NAME]
        selected_runtime = SelectedRuntime().evaluate_metric(preds=model_results["prediction"], labels=model_results['label'])
        runtime_df.loc[model.name.DISPLAY_NAME, workload.get_workload_name()] = selected_runtime
    minimal_runtimes.append(model_results["label"].min())
runtime_df = runtime_df.sum(axis=1).reset_index()

In [None]:
path = LocalPaths().data / "plots" / "motivating_plot.pdf"

fontsize = 15
fig, (q50, runtimes) = plt.subplots(1, 2, figsize=(6, 2.2))

# Filter only for database IMDB
training_results = training_results[training_results.database == "imdb"]
sort_map = {model.name.DISPLAY_NAME: i for i, model in enumerate(MODEL_CONFIGS)}
training_results = training_results.sort_values(by=['model'], key=lambda x: x.map(sort_map))

blank_row = pd.DataFrame({'model': ['test'], 'val_median_q_error_50' : 0.9, 'display_name': ['test']})
training_res = pd.concat([training_results[0:4], blank_row, training_results[4:]])
color_palette = ColorManager.COLOR_PALETTE
color_palette["test"] = "white"

blank_row = pd.DataFrame({'index': ['test'], 0: 0})
runtime_df = pd.concat([runtime_df[0:2], blank_row, runtime_df[2:]])

# ---------------------------- Plot Q50 ---------------------------- #
q50 = sns.barplot(x="database",
                  y="val_median_q_error_50",
                  hue="display_name",
                  data=training_res,
                  capsize=.0,
                  ax=q50,
                  width=1.0,
                  palette=color_palette,
                  errorbar=None,
                  edgecolor='black')

q50.set_ylim(1, 4)
q50.set_xlim(-0.53, 0.53)
q50.set_yscale("log")
q50.set_ylabel("Relative Error", fontsize=fontsize)
q50.set_xlabel("", fontsize=fontsize)
q50.set_yticks([1, 2, 3, 4, 5], labels=[1, 2, 3, 4, 5], minor=False, fontsize=fontsize * 0.8)
q50.set_xticklabels([])
q50.grid(False, axis='x', which='both')
q50.grid(True, axis='y', which='both', linewidth=0.3)
q50.set_title("Cost Estimation", fontsize=fontsize)
q50.get_legend().remove()

sns.barplot(data=runtime_df,
            x='index', 
            palette=ColorManager.COLOR_PALETTE,
            hue='index',
            y=0, 
            ax=runtimes,
            width=1,
            edgecolor='black')
runtimes.set_xlabel('')
runtimes.set_title('Query Optimization', fontsize=fontsize)
runtimes.set_xticklabels([])
runtimes.set_ylabel("Total Runtime (s)", fontsize=fontsize)
runtimes.set_yticklabels(runtimes.get_yticklabels(), fontsize=fontsize * 0.8)
runtimes.annotate("LCMs", xy=(0.55, -0.12), fontsize=fontsize, xycoords='axes fraction')
q50.annotate("LCMs", xy=(0.55, -0.12), fontsize=fontsize, xycoords='axes fraction')

runtimes.annotate("PG", xy=(0.05, -0.12), fontsize=fontsize, xycoords='axes fraction')
q50.annotate("PG", xy=(0.05, -0.12), fontsize=fontsize, xycoords='axes fraction')
runtimes.set_xlim(-1, 10)

for ax in [q50, runtimes]:
    # Highlighting options
    bar_rects = ax.patches  # Get all the bar rectangles created by seaborn
    bar_rects[0].set_edgecolor('black')
    bar_rects[0].set_linewidth(2)
    bar_rects[0].set_zorder(2)

    bar_rects[1].set_edgecolor('black')
    bar_rects[1].set_linewidth(2)
    bar_rects[1].set_zorder(2)
    for rect in bar_rects:
        rect.set_alpha(1)
    bar_rects[0].set_alpha(1.0)

legend_patches = [mpatches.Patch(color=config.color(), label=config.name.DISPLAY_NAME) for config in MODEL_CONFIGS]
legend_patches[0] = mpatches.Patch(label="Trad. (PG10)", color=ColorManager.COLOR_PALETTE[ModelName.POSTGRES.DISPLAY_NAME])
legend_patches[1] = mpatches.Patch(label="Trad. (PG16)", color=ColorManager.COLOR_PALETTE[ModelName.POSTGRES_V16.DISPLAY_NAME])

for p in legend_patches[1:]:
    p.set_edgecolor('black')
    p.set_alpha(1)

legend_patches[0].set_edgecolor('black')
legend_patches[1].set_edgecolor('black')
legend_patches[0].set_linewidth(2.5)
legend_patches[1].set_linewidth(2.5)
hatches = ['//', '\\\\', None,  '+++', '..', '--', 'xx', 'o', '///']

handles = runtimes.legend(handles=legend_patches, 
            loc='center right', 
            bbox_to_anchor=(-1.9, 0.5), 
            edgecolor='white',  
            labelspacing=0.1,
            fontsize=fontsize)

# add hatches to handles
for i, patch in enumerate(handles.get_patches()):
    patch.set_hatch(hatches[i])

q50.annotate(
    "A",
    xy=(-0.2, 1.08), 
    xycoords='axes fraction', 
    fontsize=fontsize, 
    ha='center', 
    va='center', 
    bbox=dict(boxstyle='circle,pad=0.2', edgecolor='black', facecolor='white'))

runtimes.annotate(
    "B",
    xy=(-0.25, 1.07), 
    xycoords='axes fraction', 
    fontsize=fontsize, 
    ha='center', 
    va='center', 
    bbox=dict(boxstyle='circle,pad=0.2', edgecolor='black', facecolor='white'))

for bar, hatch in zip(q50.patches, hatches):
    bar.set_hatch(hatch)

hatches.insert(2, None)
# Add hatches
for bar, hatch in zip(runtimes.patches, hatches):
    bar.set_hatch(hatch)
    
plt.subplots_adjust(wspace=0.6)
fig.align_labels()
plt.savefig(path, bbox_inches='tight')