# Common functions
This notebook defines common imports and functions used in other notebooks.

In [1]:
import os

import pandas as pd
import matplotlib.pyplot as plt

In [2]:
def plot_dataset(benchmark_name: str, data_suites = None, join_names = None):
    dataset = pd.read_csv(f"data/benchmark-{benchmark_name}.csv", on_bad_lines="warn")
    dataset = dataset.query('result == "success"')

    for column_name in ["elapsed_time", "rows_per_cluster"]:
        dataset[column_name] = pd.to_numeric(dataset[column_name], errors="coerce").astype("Int64")

    dataset.sort_values("rows_per_cluster", inplace=True)

    if data_suites is None:
        data_suites = dataset['data_suite'].unique()

    if join_names is None:
        join_names = dataset['join_name'].unique()
        
    plot_figure, plot_axes = plt.subplots(len(data_suites), 2, figsize=(16, 4 * len(data_suites)), sharex = "col", sharey = "all")

    plt.ticklabel_format(useOffset=False)

    shared_ax = None    
    
    for data_suite_index in range(len(data_suites)):
        data_suite = data_suites[data_suite_index]

        plot_left  = plot_axes[data_suite_index, 0]
        plot_right = plot_axes[data_suite_index, 1]

        plot_left.set_title(f'{data_suite} - lin')
        plot_right.set_title(f'{data_suite} - log')

        if data_suite_index > 0:
            shared_axl = plot_axes[data_suite_index - 1, 0]
            shared_axl = plot_axes[data_suite_index - 1, 1]
        
        for join_name in join_names:
            df_to_plot = dataset.query(f'join_name == "{join_name}" and data_suite == "{data_suite}"')
            plot_left.plot(df_to_plot['rows_per_cluster'], df_to_plot['elapsed_time'], label=join_name)
            plot_right.plot(df_to_plot['rows_per_cluster'], df_to_plot['elapsed_time'], label=join_name)
            
            plot_right.semilogx()
            plot_left.legend()

    return dataset