In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import glob
import plotly.express as px

In [None]:
# --- Configuration ---
# DATA_FOLDER = "./whole_array/sweep_runs/run4/results"
DATA_FOLDER = "./whole_array/sweep_runs/run5_24_06_25/results"

INPUT_COLS = ["M", "K", "N"]
OUTPUT_COL = "It1"  # Execution time
hardware = "strix"

In [None]:
def parse_filename(filename):
    """
    Parses a filename to extract hardware ID and hyperparameters.

    Args:
        filename (str): The base name of the file (e.g., "32x32x32_4x8x4_bf16_out_8_col_peano.csv")

    Returns:
        dict: A dictionary with 'hardware_id' and other hyperparameter keys.
              Returns None if parsing fails.
    """
    parts = filename.replace(".csv", "").split("_")
    params = {}
    if len(parts) == 7:
        params["mkn"] = parts[0]
        params["rst"] = parts[1]
        params["out"] = parts[2]
        params["cols"] = parts[4]
        params["compiler"] = parts[6]
    else:
        print(
            f"Failed to parse filename '{filename}': Expected 7 parts, got {len(parts)}"
        )
        return None

    # param_items = sorted([f"{k}-{v}" for k, v in params.items()])
    param_items = [f"{v}" for _, v in params.items()]
    params["hyperparameter_set"] = (
        "_".join(param_items)
        if param_items
        else params.get("config_name", "unknown_config")
    )
    params["hardware_id"] = "strix"

    return params

In [None]:
# --- Load and Combine Data ---
all_data = []
csv_files = glob.glob(os.path.join(DATA_FOLDER, "*.csv"))

if not csv_files:
    print(f"No CSV files found in '{DATA_FOLDER}'. Please check the path.")
else:
    print(f"Found {len(csv_files)} CSV files. Loading...")
    for f_path in csv_files:
        try:
            df_temp = pd.read_csv(f_path)
            base_filename = os.path.basename(f_path)

            # Extract parameters from filename
            file_params = parse_filename(base_filename)
            if file_params:
                for key, value in file_params.items():
                    df_temp[key] = value
                all_data.append(df_temp)
            else:
                print(
                    f"Warning: Could not parse parameters from filename: {base_filename}"
                )
        except Exception as e:
            print(f"Error loading or processing file {f_path}: {e}")

    if not all_data:
        print("No data loaded. Exiting.")
        master_df = pd.DataFrame()
    else:
        master_df = pd.concat(all_data, ignore_index=True)
        print("\n--- Combined Data Head ---")
        print(master_df.head())
        print("\n--- Combined Data Info ---")
        master_df.info()

        # Ensure output is numeric
        master_df[OUTPUT_COL] = pd.to_numeric(master_df[OUTPUT_COL], errors="coerce")
        master_df_nans = master_df.copy()
        master_df.dropna(subset=[OUTPUT_COL], inplace=True)

The following cell contains various plotting functions

In [None]:
# Plotting functions
def hyperparameter_comparison_for_sampled_column(
    df: pd.DataFrame, y_col: str, hardware: str, sampled_column: str = "input_shape_id"
):
    if not df[sampled_column].empty:
        sample_input_shape = df[sampled_column].sample(1).iloc[0]
        print(f"\n--- Performance for Input Shape: {sample_input_shape} ---")

        df_specific_shape = df[df[sampled_column] == sample_input_shape]

        if not df_specific_shape.empty:
            plt.figure(figsize=(12, 6))
            sns.barplot(data=df_specific_shape, x="hyperparameter_set", y=y_col)
            plt.title(
                f"{y_col} for {sampled_column}: {sample_input_shape} (Hardware: {hardware})"
            )
            plt.xlabel("Hyperparameter Set")
            plt.ylabel(f"{y_col}")
            plt.xticks(rotation=45, ha="right")
            plt.tight_layout()
            plt.show()
        else:
            print(f"No data for sample input shape {sample_input_shape} to plot.")


def distribution_analysis_box_plot(df: pd.DataFrame, y_col: str, hardware: str):
    plt.figure(figsize=(14, 7))
    sns.boxplot(data=df, x="hyperparameter_set", y=y_col, whis=100)
    plt.title(f"Overall {y_col} Distribution (Hardware: {hardware})")
    plt.xlabel("Hyperparameter Set")
    plt.ylabel(f"{y_col}")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.show()


def classic_line_plot(df: pd.DataFrame, x_col: str, y_col: str, hardware: str):
    if x_col in df.columns:
        df_sorted = df.sort_values(by=x_col, inplace=False)
        plt.figure(figsize=(12, 7))
        # This could also be replaced by a scatter plot if you want to show individual points
        sns.lineplot(
            data=df_sorted,
            x=x_col,
            y=y_col,
            hue="hyperparameter_set",
            alpha=0.7,
            # s=50,
        )
        plt.title(f"{x_col} vs. {y_col} (Hardware: {hardware})")
        plt.xlabel(x_col)
        plt.ylabel(y_col)
        plt.legend(
            title="Hyperparameter Set", bbox_to_anchor=(1.05, 1), loc="upper left"
        )
        plt.tight_layout()
        plt.show()


def winner_on_grouped_column(
    df: pd.DataFrame,
    x_col: str,
    y_col: str,
    hardware: str,
    grouped_column: str = "input_shape_id",
):
    # Identify "Best" Hyperparameter Set for Each Input Shape
    print("\n--- Identifying Best Hyperparameter Set per Input Shape ---")
    # Find the row with the minimum output time for each grouped_column
    best_performers = df.loc[df.groupby(grouped_column)[y_col].idxmin()]

    print(best_performers[[grouped_column, "hyperparameter_set", y_col]].head())

    # Which hyperparameter set wins most often?
    if not best_performers.empty:
        print("\n--- Count of 'Wins' per Hyperparameter Set ---")
        win_counts = best_performers["hyperparameter_set"].value_counts()

        plt.figure(figsize=(10, 6))
        win_counts.plot(kind="bar")
        plt.title(
            f"Number of Times Each Hyperparameter Set Was Fastest (Hardware: {hardware})"
        )
        plt.xlabel("Hyperparameter Set")
        plt.ylabel("Number of Input Shapes Won")
        plt.xticks(rotation=45, ha="right")
        plt.tight_layout()
        plt.show()

        print(win_counts)
    else:
        print("No best performers found to summarize wins.")


def average_performance_rank(
    df: pd.DataFrame, y_col: str, hardware: str, grouped_column: str = "input_shape_id"
):
    df["rank"] = df.groupby("input_shape_id")[y_col].rank(method="min")
    avg_rank = df.groupby("hyperparameter_set")["rank"].mean().sort_values()

    plt.figure(figsize=(10, 6))
    avg_rank.plot(kind="bar")
    plt.title(f"Average Performance Rank by Hyperparameter Set (Hardware: {hardware})")
    plt.xlabel("Hyperparameter Set")
    plt.ylabel("Average Rank (Lower is Better)")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.show()
    print(avg_rank)


def scatter_plot_3d(
    df: pd.DataFrame,
    x_col: str,
    y_col: str,
    z_col: str,
    hardware: str,
):
    try:
        fig_3d = px.scatter_3d(
            df,
            x=x_col,
            y=y_col,
            z=z_col,
            color="hyperparameter_set",
            # symbol='hyperparameter_set', # Can also use symbol
            opacity=0.7,
            title=f"3D Performance: {x_col} & {y_col} vs. {z_col} (Hardware: {hardware})",
        )
        # To make points smaller if there are many
        # fig_3d.update_traces(marker=dict(size=3))
        fig_3d.show()
    except Exception as e:
        print(f"Could not generate 3D plot: {e}")


def compare_compilers_for_identical_configs(
    df: pd.DataFrame,
    y_col: str,
    hardware: str,
    compiler_col: str,
    base_config_col: str = "hyperparameter_set",
    shape_id_col: str = "input_shape_id",
):
    """
    Compares performance of different compilers when the rest of the hyperparameter set is identical,
    for each input shape.
    """
    unique_compilers = df[compiler_col].unique()
    if len(unique_compilers) < 2:
        print(f"Not enough compilers to compare (found: {unique_compilers}). Skipping.")
        return

    # Pivot table to get compilers side-by-side for each base_config and input_shape
    # We need to handle cases where a compiler might be missing for a specific config/shape
    df_pivot = df.pivot_table(
        index=[base_config_col, shape_id_col],
        columns=compiler_col,
        values=y_col,
        aggfunc="mean",
    )  # Use mean if multiple entries for same compiler/config/shape

    # Drop rows where not all compilers have data (or at least two for comparison)
    df_pivot.dropna(
        thresh=2, inplace=True
    )  # Keep rows with at least 2 non-NaN compiler values

    if df_pivot.empty:
        print(
            "No input shapes found where multiple compilers ran with the same base hyperparameters."
        )
        return

    # Calculate relative performance or plot side-by-side
    # For simplicity, let's pick two compilers if more are present, or plot all if just two
    compilers_to_plot = df_pivot.columns.tolist()

    if len(compilers_to_plot) >= 2:
        # Option 1: Scatter plot comparing two compilers directly
        # (More complex if >2 compilers, would need multiple plots or different viz)
        # if len(compilers_to_plot) == 2:
        #     c1, c2 = compilers_to_plot[0], compilers_to_plot[1]
        #     plt.figure(figsize=(8, 8))
        #     sns.scatterplot(
        #         data=df_pivot,
        #         x=c1,
        #         y=c2,
        #         hue=df_pivot.index.get_level_values(base_config_col),
        #         s=50,
        #         alpha=0.7,
        #     )
        #     min_val = min(df_pivot[c1].min(), df_pivot[c2].min())
        #     max_val = max(df_pivot[c1].max(), df_pivot[c2].max())
        #     plt.plot(
        #         [min_val, max_val],
        #         [min_val, max_val],
        #         "k--",
        #         lw=1,
        #         label="y=x (Equal Performance)",
        #     )
        #     plt.xlabel(f"Execution Time ({c1})")
        #     plt.ylabel(f"Execution Time ({c2})")
        #     plt.title(
        #         f"Compiler Performance: {c1} vs {c2} (Hardware: {hardware})\n(Identical Base Configs & Input Shapes)"
        #     )
        #     plt.legend(
        #         title="Base Hyperparameter Set",
        #         bbox_to_anchor=(1.05, 1),
        #         loc="upper left",
        #     )
        #     plt.grid(True)
        #     plt.tight_layout()
            # plt.show()

        # Option 2: Bar plot of average speedup/slowdown
        # Or, for each base_config, show bar plots for a few input_shapes
        # Let's do a summary: For each base_config, which compiler is better on average?

        # Calculate mean performance for each compiler across shapes for each base_config
        mean_perf_by_base_config = df_pivot.groupby(level=base_config_col).mean()

        if not mean_perf_by_base_config.empty:
            mean_perf_by_base_config.plot(kind="bar", figsize=(12, 7))
            plt.title(
                f"Average {y_col} by {compiler_col} for each Base Config (Hardware: {hardware})"
            )
            plt.xlabel("Base Hyperparameter Set")
            plt.ylabel(f"Average {y_col} (lower is better)")
            plt.xticks(rotation=45, ha="right")
            plt.legend(title=compiler_col)
            plt.tight_layout()
            plt.show()

            # Percentage difference (example for 2 compilers)
            if len(compilers_to_plot) == 2:
                c1, c2 = compilers_to_plot[0], compilers_to_plot[1]
                # Positive means c2 is slower than c1 by X%
                df_pivot[f"{c2}_vs_{c1}_diff_pct"] = (
                    (df_pivot[c2] - df_pivot[c1]) / df_pivot[c1]
                ) * 100

                plt.figure(figsize=(10, 6))
                sns.histplot(df_pivot[f"{c2}_vs_{c1}_diff_pct"].dropna(), kde=True)
                plt.title(
                    f"Distribution of Performance Difference: {c2} vs {c1}\n(Hardware: {hardware}, Identical Base Configs)"
                )
                plt.xlabel(f"Percentage Difference in {y_col} ({c2} relative to {c1})")
                plt.ylabel("Frequency")
                plt.axvline(0, color="k", linestyle="--")
                plt.tight_layout()
                plt.show()
                
                print(f"A positive value indicates that {c2} is slower than {c1} by that percentage: Formula used: (({c2} - {c1}) / {c1}) * 100")
                
                print(f"\nSummary of {c2} vs {c1} performance difference (%):")
                print(df_pivot[f"{c2}_vs_{c1}_diff_pct"].describe())

                # Print information on outlier shapes
                print("Configurations where the performance difference is highest:")
                print(df_pivot[f"{c2}_vs_{c1}_diff_pct"].idxmax())
                print("Configurations where the performance difference is lowest:")
                print(df_pivot[f"{c2}_vs_{c1}_diff_pct"].idxmin())

    else:
        print("Could not find enough compiler data on pivot table for comparison.")


def plot_execution_time_vs_MxKxN(
    df: pd.DataFrame,
    y_col: str,
    hardware: str,
    MxKxN_col_name: str,
    hyperparam_set_col: str = "hyperparameter_set",
):
    """
    Plots how average execution time increases with matrix size for each hyperparameter set.
    """
    if MxKxN_col_name not in df.columns:
        print(f"Error: Column '{MxKxN_col_name}' not found in DataFrame.")
        return

    # Calculate average execution time for each hyperparameter_set and MxKxN
    # This handles multiple entries for the same matrix size and hyperparameter set
    df_agg = (
        df.groupby([hyperparam_set_col, MxKxN_col_name])[y_col]
        .mean()
        .reset_index()
    )

    # Sort by matrix size for cleaner plotting
    df_agg = df_agg.sort_values(by=MxKxN_col_name)

    if df_agg.empty:
        print(f"No aggregated data to plot for time vs {MxKxN_col_name}.")
        return

    plt.figure(figsize=(14, 8))
    # Using lineplot to show trends
    sns.lineplot(
        data=df_agg,
        x=MxKxN_col_name,
        y=y_col,
        hue=hyperparam_set_col,
        marker="o",
        errorbar=None,
    )
    # If you didn't pre-aggregate with .mean(), seaborn's lineplot would show confidence intervals by default.
    # sns.lineplot(data=df, x=MxKxN_col_name, y=y_col, hue=hyperparam_set_col, marker='o')

    plt.title(
        f"Average {y_col} vs. {MxKxN_col_name} by Hyperparameter Set (Hardware: {hardware})"
    )
    plt.xlabel(f"{MxKxN_col_name}")
    plt.ylabel(f"Average {y_col}")
    plt.legend(title="Hyperparameter Set", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.xscale("log")  # Often matrix sizes grow exponentially, log scale can be useful
    plt.yscale("log")  # Execution times also often grow polynomially/exponentially
    plt.grid(True, which="both", ls="-", alpha=0.5)
    plt.tight_layout()
    plt.show()


def plot_nan_occurrences_by_input(
    df: pd.DataFrame,
    output_col: str,
    input_col: str,
    hardware_id: str,
    top_n: int = 20,
):
    """
    Identifies and plots input shapes that most frequently have NaN values
    in the specified output column.

    Args:
        df (pd.DataFrame): The DataFrame to analyze.
        output_col (str): The name of the column to check for NaNs.
        input_col (str): The column to analyze NaNs results based on.
        hardware_id (str): Identifier for the current hardware (for title).
        top_n (int): Number of top input shapes with NaNs to display.
    """
    if output_col not in df.columns:
        print(f"Error: Output column '{output_col}' not found in DataFrame.")
        return
    if input_col not in df.columns:
        print(f"Error: Input col '{input_col}' not found in DataFrame.")
        return

    # Filter rows where the output column is NaN
    nan_df = df[df[output_col].isna()]

    if nan_df.empty:
        print(f"No NaN values found in '{output_col}' for hardware '{hardware_id}'.")
        return

    # Count NaN occurrences for each input_shape_id
    nan_counts_by_shape = nan_df[input_col].value_counts().nlargest(top_n)

    if nan_counts_by_shape.empty:
        print(
            f"Found NaN values, but could not aggregate by '{input_col}'."
        )  # Should not happen if nan_df is not empty
        return

    print(
        f"Top {min(top_n, len(nan_counts_by_shape))} input shapes with the most NaN values in '{output_col}':"
    )
    print(nan_counts_by_shape)

    plt.figure(
        figsize=(12, max(6, len(nan_counts_by_shape) * 0.4))
    )  # Adjust height based on number of bars
    sns.barplot(x=nan_counts_by_shape.values, y=nan_counts_by_shape.index, orient="h")
    plt.title(
        f"Top {min(top_n, len(nan_counts_by_shape))} Input Shapes with Most NaN in '{output_col}'\n(Hardware: {hardware_id})"
    )
    plt.xlabel(f"Number of NaN Occurrences in '{output_col}'")
    plt.ylabel(f"{input_col}")
    plt.tight_layout()
    plt.show()

In [None]:
if not master_df.empty:
    # --- Basic Exploration & Preparation ---
    print("\n--- Unique Hyperparameter Sets ---")
    print(master_df["hyperparameter_set"].unique())

    # Create a unique identifier for each input shape for easier grouping
    master_df["input_shape_id"] = master_df[INPUT_COLS].apply(
        lambda row: "_".join(row.astype(str)), axis=1
    )

    # Rename it1 column
    master_df.rename(columns={OUTPUT_COL: "execution_time"}, inplace=True)

    # Add a new column for gflops
    macs = 2.0 * master_df["M"] * master_df["K"] * master_df["N"]
    gflops = macs / (master_df["execution_time"] * 1000)
    master_df["gflops"] = gflops

    # Add a new column for normalized execution time (by gflops)
    master_df["mac_normalized_execution_time"] = master_df["execution_time"] / macs

    # Add a new column for matrix size
    master_df["MxKxN"] = master_df["M"] * master_df["K"] * master_df["N"]

    print("\n--- Edited Data Head ---")
    print(master_df.head())

    # --- Visualizations ---
    sns.set_theme(style="whitegrid")

    # Filter the dataframe for some specific variable (for example, compiler)
    df_analysis = master_df[master_df["hardware_id"] == "strix"].copy()
    df_analysis = master_df.copy()

    if df_analysis.empty:
        print(f"No data to analyze for current dataframe.")
    else:
        # Compare performance for a *specific* input shape (Bar Plot), possible to select a given shape
        # hyperparameter_comparison_for_sampled_column(
        #     df_analysis[df_analysis["compiler"] == "peano"], "execution_time", hardware, sampled_column="input_shape_id"
        # )

        # Overall performance distribution (Box Plot)
        # print("\n--- Overall Performance Distribution by Hyperparameter Set ---")
        # distribution_analysis_box_plot(df_analysis, "execution_time", hardware)

        # print("\n--- Overall Mac Normalized Above Matrix Size of 10e9 Performance Distribution by Hyperparameter Set ---")
        # distribution_analysis_box_plot(df_analysis[df_analysis["MxKxN"] > 10e9], "mac_normalized_execution_time", hardware)
        
        # print("\n--- Overall Mac Normalized Execution Time Distribution by Hyperparameter Set ---")
        # distribution_analysis_box_plot(df_analysis, "gflops", hardware)

        # Identify best hyperparameter set for each input shape and count number of wins
        # winner_on_grouped_column(
        #     df_analysis,
        #     "input_shape_id",
        #     "execution_time",
        #     hardware,
        #     grouped_column="input_shape_id",
        # )

        # Average Performance Rank (More advanced)
        # For each input_shape, rank hyperparameter sets by their execution time.
        # Then average these ranks for each hyperparameter set. Lower average rank is better.
        print("\n--- Average Performance Rank (Lower is Better) ---")

        average_performance_rank(
            df_analysis, "execution_time", hardware, grouped_column="input_shape_id"
        )

        temp = df_analysis[df_analysis["compiler"] == "chess"]
        average_performance_rank(
            temp[temp["rst"] != "8x8x8"], "execution_time", hardware, grouped_column="input_shape_id"
        )

        average_performance_rank(
            temp[temp["rst"] == "8x8x8"], "execution_time", hardware, grouped_column="input_shape_id"
        )

        # Scatter plots: Input vs. Output, colored by hyperparameter set
        # print(f"\n--- Matrix Size vs Execution Time by Hyperparameter Set ---")
        # classic_line_plot(df_analysis, "MxKxN", "execution_time", hardware)

        # 3D Scatter Plot (Plotly Express)
        # print("\n--- 3D Scatter Plot: Inputs vs. Output (Plotly) ---")
        # scatter_plot_3d(
        #     df_analysis,
        #     "MxKxN",
        #     "compiler",
        #     "execution_time",
        #     hardware,
        # )

        print(
            f"\n--- Gflops vs. matrix size (Hardware: {hardware}) ---"
        )
        plot_execution_time_vs_MxKxN(
            df_analysis,
            "gflops",
            hardware,
            MxKxN_col_name="MxKxN",
            hyperparam_set_col="hyperparameter_set",
        )

        # print(
        #     f"\n--- Gflops With Emulation vs. matrix size and only Peano (Hardware: {hardware}) ---"
        # )
        # temp = df_analysis[df_analysis["rst"] == "8x8x8"]
        # plot_execution_time_vs_MxKxN(
        #     temp[temp["compiler"] == "peano"],
        #     "gflops",
        #     hardware,
        #     MxKxN_col_name="MxKxN",
        #     hyperparam_set_col="hyperparameter_set",
        # )

        # print(f"Max value")
        # print(temp[temp["compiler"] == "peano"]["gflops"].max())

        # print(
        #     f"\n--- Gflops Without Emulation vs. matrix size and only on Peano (Hardware: {hardware}) ---"
        # )
        # temp = df_analysis[df_analysis["rst"] != "8x8x8"]
        # plot_execution_time_vs_MxKxN(
        #     temp[temp["compiler"] == "peano"],
        #     "gflops",
        #     hardware,
        #     MxKxN_col_name="MxKxN",
        #     hyperparam_set_col="hyperparameter_set",
        # )

        print(
            f"\n--- Gflops vs. matrix size and only on Chess and 4 cols (Hardware: {hardware}) ---"
        )
        temp = df_analysis[df_analysis["cols"] == "4"]
        plot_execution_time_vs_MxKxN(
            temp[temp["compiler"] == "chess"],
            "gflops",
            hardware,
            MxKxN_col_name="MxKxN",
            hyperparam_set_col="hyperparameter_set",
        )

        print(
            f"\n--- Gflops vs. matrix size and only on Chess and 8 cols (Hardware: {hardware}) ---"
        )
        temp = df_analysis[df_analysis["cols"] == "8"]
        plot_execution_time_vs_MxKxN(
            temp[temp["compiler"] == "chess"],
            "gflops",
            hardware,
            MxKxN_col_name="MxKxN",
            hyperparam_set_col="hyperparameter_set",
        )

        # # Create a hyperparameter set without the compiler for this comparison
        # df_analysis["hyperparameter_set_compiler_comparison"] = df_analysis.apply(
        #     lambda row: "_".join(
        #         [
        #             row["mkn"],
        #             row["rst"],
        #             row["out"],
        #             row["cols"],
        #         ]
        #     ),
        #     axis=1,
        # )

        # print(
        #     f"\n--- Compiler Comparison for Identical Base Configurations (Hardware: {hardware}) ---"
        # )

        # compare_compilers_for_identical_configs(
        #     df_analysis,
        #     "execution_time",
        #     hardware,
        #     compiler_col="compiler",
        #     base_config_col="hyperparameter_set_compiler_comparison",
        #     shape_id_col="input_shape_id",
        # )

        # print(
        #     f"\n--- Compiler Comparison for Identical Base Configurations (Only emulation) (Hardware: {hardware}) ---"
        # )

        # compare_compilers_for_identical_configs(
        #     df_analysis[df_analysis["rst"] == "8x8x8"],
        #     "execution_time",
        #     hardware,
        #     compiler_col="compiler",
        #     base_config_col="hyperparameter_set_compiler_comparison",
        #     shape_id_col="input_shape_id",
        # )

        # print(
        #     f"\n--- Compiler Comparison for Identical Base Configurations Without Emulation (Hardware: {hardware}) ---"
        # )

        # compare_compilers_for_identical_configs(
        #     df_analysis[df_analysis["rst"] != "8x8x8"],
        #     "execution_time",
        #     hardware,
        #     compiler_col="compiler",
        #     base_config_col="hyperparameter_set_compiler_comparison",
        #     shape_id_col="input_shape_id",
        # )

        # print(f"\n--- Execution Time vs. matrix size (Hardware: {hardware}) ---")
        # plot_execution_time_vs_MxKxN(
        #     df_analysis,
        #     "execution_time",
        #     hardware,
        #     MxKxN_col_name="MxKxN",
        #     hyperparam_set_col="hyperparameter_set",
        # )

        # print(f"\n--- NaN Occurrences by Input Shape (Hardware: {hardware}) ---")
        # master_df_nans["input_shape_id"] = master_df_nans[INPUT_COLS].apply(
        #     lambda row: "_".join(row.astype(str)), axis=1
        # )
        # plot_nan_occurrences_by_input(
        #     master_df_nans,
        #     output_col="It1",
        #     input_col="input_shape_id",
        #     hardware_id=hardware,
        #     top_n=20,
        # )

        # print(f"\n--- NaN Occurrences by Input Cols (Hardware: {hardware}) ---")
        # plot_nan_occurrences_by_input(
        #     master_df_nans,
        #     output_col="It1",
        #     input_col="cols",
        #     hardware_id=hardware,
        #     top_n=20,
        # )

        # print(f"\n--- NaN Occurrences by Input Cols (only 8 cols) (Hardware: {hardware}) ---")
        # plot_nan_occurrences_by_input(
        #     master_df_nans[master_df_nans["cols"] == "8"],
        #     output_col="It1",
        #     input_col="input_shape_id",
        #     hardware_id=hardware,
        #     top_n=20,
        # )

        

else:
    print("Master DataFrame is empty. No analysis performed.")

print("\n--- Visualization end ---")

In [None]:
def plot_stacked_bar(
    df, 
    x_col, 
    y_cols, 
    title=None, 
    xlabel=None, 
    ylabel=None, 
    figsize=(12, 8)
):
    """
    Groups data by the x_col, computes the average of y_cols, and 
    displays the result as a stacked bar plot with simplified x-axis labels.

    Args:
        df (pd.DataFrame): The input DataFrame.
        x_col (str): The column name to group by and use for the x-axis.
        y_cols (list of str): A list of column names to be averaged and stacked.
        title (str, optional): The title of the plot. Defaults to a generated title.
        xlabel (str, optional): The label for the x-axis. Defaults to "Problem Size".
        ylabel (str, optional): The label for the y-axis. Defaults to "Average Value".
        figsize (tuple, optional): The size of the figure. Defaults to (12, 8).
    """
    # 1. Group by the x-column and calculate the mean of the y-columns.
    # .reset_index() converts the grouped output back into a DataFrame.
    agg_df = df.groupby(x_col)[y_cols].mean().reset_index()
    
    # Optional: Sort values to have a logical progression on the x-axis
    # This assumes the size can be extracted and converted to an integer for sorting.
    try:
        agg_df['sort_key'] = agg_df[x_col]
        agg_df = agg_df.sort_values('sort_key').drop('sort_key', axis=1)
    except (ValueError, IndexError):
        # If the label format is not as expected, sort alphabetically
        agg_df = agg_df.sort_values(x_col)

    # 2. Set the x-column as the index for plotting.
    plot_df = agg_df.set_index(x_col)

    # 3. Create the stacked bar plot.
    ax = plot_df[y_cols].plot(
        kind='bar',
        stacked=True,
        figsize=figsize,
        width=0.8,
        edgecolor='black'
    )

    # --- Formatting and Customization ---
    
    # 4. Set title and labels. Default labels are now more descriptive.
    plt.title(title or f"Average {', '.join(y_cols)} by {x_col}", fontsize=16, pad=20)
    plt.xlabel(xlabel or "Problem Size", fontsize=12)
    plt.ylabel(ylabel or "Average Value", fontsize=12)

    # 6. Rotate x-axis labels for better readability.
    plt.xticks(rotation=30, ha='right')

    # 7. Add a grid for easier value reading.
    plt.grid(axis='y', linestyle='--', alpha=0.7)

    # 8. Improve legend placement.
    plt.legend(title="Metrics", bbox_to_anchor=(1.05, 1), loc='upper left')

    # 9. Ensure everything fits nicely.
    plt.tight_layout()

    # 10. Display the plot.
    plt.show()

print("\n--- Shuffle time estimation ---")
df_shuffle = df_analysis[
    ((df_analysis["compiler"] == "chess") & (df_analysis["hyperparameter_set"] == "64x64x64_8x8x8_bfp16_8_chess"))
].dropna(subset="shuffle_time").copy()
print(df_shuffle)
# plot_execution_time_vs_MxKxN(df_shuffle, "shuffle_time", hardware, MxKxN_col_name="MxKxN", hyperparam_set_col="hyperparameter_set")
plot_stacked_bar(
    df_shuffle,
    "MxKxN",
    ["execution_time", "shuffle_time"],
    title="Total execution time including CPU shuffle time",
    xlabel="MxKxN",
    ylabel="Average Time (us)",
)