# Setup

## User Config

In [None]:
RESULTS_DIR = r"../server_results/learning_results_10_accuracy2"
SAVE_FIGURES_TO = "thesis_charts/GnuServer"  # or None to just show figures
STYLE = "thesis"  # "thesis": styled like my thesis (times new roman, etc); "latex": default latex look

USES_ORACLE = False

if USES_ORACLE:
        ALGORITHM_NAMES = {
        "APMSL('ONLY_RW')": 'APMSL (ic, rw)',
        "APMSL('RW', orpg=True)": 'APMSL (ic, rw, replay, repro)',
    
        "GSM('PURGE')": 'GSM (purge, rw)',
        "GSM('NO_PURGE')": 'GSM (no-purge, rw)',
    }
else:
    ALGORITHM_NAMES = {
        "APMSL('ONLY_RW')": 'APMSL (ic, rw, term-impr)',
        "APMSL('ONLY_RW', 'GTT2')": 'APMSL (ic, rw, term-thresh2)',
        "APMSL('RW', orpg=True)": 'APMSL (ic, rw, replay, repro, term-impr)',
        "APMSL('RW', 'GTT2', orpg=True)": 'APMSL (ic, rw, replay, repro, term-thresh2)',
    
        "GSM('PURGE')": 'GSM (purge, rw, term-bisim)',
        "GSM('NO_PURGE')": 'GSM (no-purge, rw, term-bisim)',
    }

BASE_ALGORITHM_HATCHES = {
    "APMSL": "xxx",
    "GSM": "...",
}

BASE_ALGORITHM_COLORMAP = {
    "APMSL": "plasma",
    "GSM": "viridis",
}

In [None]:
## Imports & Misc Setup
import os
import sys
from pathlib import Path

sys.path.append(f"..")
sys.path.append(r"../../pmsat-inference")
sys.path = [r"../../AALpy"] + sys.path

import pandas as pd
import numpy as np

from evaluation.utils import print_results_info, print_results_info_per_alg, TracedMooreSUL
import evaluation.charts as charts
import evaluation.charts_pandas as charts_pd
from IPython.display import display, Markdown, Latex

if SAVE_FIGURES_TO:
    os.makedirs(SAVE_FIGURES_TO, exist_ok=True)
    
PRECISION_KEY = "Precision_v2"
RECALL_KEY = "Recall_v2"
FSCORE_KEY = "F-Score_v2"

ACCURACY_KEY = "Accuracy"
ACCURACY_NAME = "Accuracy"


## Loading

In [None]:
results, results_df = charts.load_results(RESULTS_DIR, remove_traces_used_to_learn=True, is_server_results=True, as_pandas=True)
print(f"Loaded {len(results)} results!")

## Cleaning

In [None]:
def remove_result_if(result):
    return False

def postprocess_result(result):
    if result["algorithm_name"] in ALGORITHM_NAMES:
        result["algorithm_name"] = ALGORITHM_NAMES[result["algorithm_name"]]
    else:
        print(f'Warning: "{result["algorithm_name"]}" not found in algorithm_names!')

    return result

def filter_results(results):
    if isinstance(results, pd.DataFrame):
        mask = results.apply(lambda r: not remove_result_if(r), axis=1)
        filtered_df = results[mask].copy()
        return filtered_df.apply(postprocess_result, axis=1)
    else:
        return [postprocess_result(r) for r in results if not remove_result_if(r)]

results = filter_results(results)
results_df = filter_results(results_df)

## Plot Config (automatic)

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use("default")  # initialize to default - seaborn overrides!

# color map creation
CM_OFFSET = 2
ALL_ALGORITHMS_IN_RESULTS = sorted(results_df["algorithm_name"].unique().tolist())
ALGORITHM_COLORS = {}
ALGORITHM_HATCHES = {}
for base_alg, colormap in BASE_ALGORITHM_COLORMAP.items():
    alg_versions = [a for a in ALL_ALGORITHMS_IN_RESULTS if a.startswith(base_alg)]
    colormap = plt.get_cmap(BASE_ALGORITHM_COLORMAP[base_alg], lut=len(alg_versions) + CM_OFFSET)
    for i, alg_version in enumerate(alg_versions):
        color = colormap(i + (CM_OFFSET // 2))
        ALGORITHM_COLORS[alg_version] = color
        ALGORITHM_HATCHES[alg_version] = BASE_ALGORITHM_HATCHES.get(base_alg, '')

if STYLE == "latex":
    mpl.rcParams.update({
        "text.usetex": False,  # True if LaTeX is installed
        "font.family": "serif",
        "font.serif": ["CMU Serif"],  # download from https://ctan.org/pkg/cm-unicode
        "axes.labelsize": 12,
        "font.size": 12,
        "legend.fontsize": 10,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "axes.titlesize": 14
    })

elif STYLE == "thesis":
    mpl.rcParams.update({
        # Fonts
        "font.family": "serif",
        "font.serif": ["Times New Roman", "Times", "DejaVu Serif"],
        "font.size": 11,                # Match thesis body text
        "axes.titlesize": 13,           # Section-style figure titles
        "axes.labelsize": 11,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,

        # Axes & Lines
        "axes.linewidth": 1.0,
        "lines.linewidth": 1.5,
        "lines.markersize": 6,

        # Legend
        "legend.fontsize": 10,
        "legend.title_fontsize": 10,
        "legend.handlelength": 1.8,
        "legend.handleheight": 0.8,
        "legend.borderaxespad": 0.8,
        "legend.borderpad": 0.5,
        "legend.labelspacing": 0.4,
        "legend.handletextpad": 0.5,
        "legend.columnspacing": 1.2,
        "legend.fancybox": False,

        # Figure
        # "figure.dpi": 300,
        "savefig.dpi": 300,
        # "figure.figsize": (9, 6),  # Good for 2-column layout
        "figure.constrained_layout.use": True,

        # Ticks
        "xtick.major.size": 4,
        "xtick.major.width": 0.8,
        "ytick.major.size": 4,
        "ytick.major.width": 0.8,

        # PDF output
        "pdf.fonttype": 42,  # Ensures text remains text in PDFs
    })

# Charts

## Correctness

In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import matplotlib.patches as mpatches

def bar_chart_pd(
    ax,
    df: pd.DataFrame,
    key: str,
    agg_method: str,
    group_by: list[str],
    title: str,
    xlabel: str,
    ylabel: str,
    y_as_percentage: bool = False,
    x_as_percentage: bool = False,
    show_num_results_on_bar: bool = False,
    legend: bool = True,
    bar_label_fontsize: int = 10,
    bar_label_decimal_digits: int = 2,
    simple: bool = False,
) -> tuple[list, list]:
    """Draws a bar chart on the given `ax`, and returns legend handles and labels."""

    pivot_df = df.groupby(by=group_by)[key].agg(agg_method).unstack().astype(float)
    if y_as_percentage:
        pivot_df *= 100

    pivot_df.plot(kind="bar", ax=ax, width=0.92, xlabel=xlabel, ylabel=ylabel, rot=0, legend=False, title=title)
    if simple:
        # don't do anything except plotting - no visual adjustments
        return [], []

    num_groups = len(pivot_df.index)
    margin = 0.92 / 2  # half bar width
    padding = 0.1  # extra space on sides
    ax.set_xlim(-margin - padding, num_groups - 1 + margin + padding)

    def percent_label(value, *args, **kwargs):
        suffix = kwargs.get('suffix', '%')
        return (f"{value:.0f}" if value % 1 == 0 else f"{value:.1f}") + suffix

    # percentage formatting
    if y_as_percentage:
        for bars in ax.containers:
            ax.bar_label(bars, labels=[percent_label(v) for v in bars.datavalues], fontsize=bar_label_fontsize)
        ax.yaxis.set_major_formatter(mtick.FuncFormatter(percent_label))
    else:
        for bars in ax.containers:
            ax.bar_label(bars, fmt=f"%.{bar_label_decimal_digits}f", fontsize=bar_label_fontsize)

    if x_as_percentage:
        ax.set_xticklabels([percent_label(x) for x in pivot_df.index])
    else:
        ax.set_xticklabels([x for x in pivot_df.index])
        
    # number of results on bar
    if show_num_results_on_bar:
        assert len(group_by) == 2, f"Only works with exactly two group-by entries!"
        grouped_df = df.groupby(by=group_by)
        result_counts = {
            key: len(group)
            for key, group in grouped_df
        }
        for bars in ax.containers:
            col_key = bars.get_label()
            count_labels = []
            for i, row_key in enumerate(pivot_df.index):
                row_key_tuple = (row_key, ) if not isinstance(row_key, tuple) else row_key
                group_key = row_key_tuple + (col_key,)
                count = result_counts.get(group_key, 0)
                count_labels.append(f"n={count}")
                # label = f"n={count}"
                # 
                # bar = bars[i]
                # x = bar.get_x() + bar.get_width() / 2
                # y = bar.get_y()
                # 
                # ax.text(
                #     x, y-0.05,
                #     label,
                #     ha='center', va='top',
                #     fontsize=bar_label_fontsize,
                #     color='dimgray',
                #     clip_on=False
                # )
                # 
                
            
            ax.bar_label(bars, labels=count_labels, fontsize=bar_label_fontsize, label_type='center', 
                         backgroundcolor='white', bbox=dict(alpha=0.75, 
                                                            color="white", 
                                                            boxstyle="square, pad=0.05",
                                                            capstyle='round'), 
                         clip_on=False)

    handles = []
    labels = []

    # hatches and colors
    for bars in ax.containers:
        bar_label = bars.get_label()
        if bar_label in ALL_ALGORITHMS_IN_RESULTS:
            algorithm = bar_label
            hatch = ALGORITHM_HATCHES[algorithm]
            color = ALGORITHM_COLORS[algorithm]
        
            for bar in bars:
                if hatch:
                    bar.set_facecolor("none")
                    bar.set_hatch(hatch)
                    bar.set_edgecolor(color)
                    bar.set_linewidth(1.2)
                    patch_kwargs = dict(facecolor="none", edgecolor=color, hatch=hatch + (hatch[0] * 2))
                else:
                    bar.set_facecolor(color)
                    patch_kwargs = dict(facecolor=color)
                
            if bar_label not in labels:
                patch = mpatches.Patch(label=bar_label, **patch_kwargs)
                handles.append(patch)
                labels.append(algorithm)

    if legend:
        ax.legend(handles=handles, labels=labels, title="Algorithm", loc="best")

    return handles, labels

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(9, 8), sharex=True)
handles, labels = bar_chart_pd(
    ax1, results_df, "learned_correctly", agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
    xlabel='Glitch percentage', ylabel="Percentage of bisimilar results", y_as_percentage=True, x_as_percentage=True,
    title="Bisimilarity", legend=True,
)

bar_chart_pd(
    ax2, results_df, ACCURACY_KEY, agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
    xlabel='Glitch percentage', ylabel=f"Mean {ACCURACY_NAME} over all results", x_as_percentage=True,
    y_as_percentage=False, title=f"{ACCURACY_NAME}",
    legend=False,
)

ax1.set_ylim(0, 105)
ax2.set_ylim(0, 1.05)

if SAVE_FIGURES_TO:
    plt.savefig(Path(SAVE_FIGURES_TO)/'correctness.png')
else:
    plt.show()

In [None]:
# fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(9, 8), sharex=True)
# handles, labels = bar_chart_pd(
#     ax1, results_df, PRECISION_KEY, agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
#     xlabel='Glitch percentage', ylabel="Precision", y_as_percentage=False, x_as_percentage=True,
#     title="Precision", legend=True,
# )
# 
# bar_chart_pd(
#     ax2, results_df, RECALL_KEY, agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
#     xlabel='Glitch percentage', ylabel="Recall", x_as_percentage=True,
#     y_as_percentage=False, title="Recall",
#     legend=False,
# )
# 
# ax1.set_ylim(0, 1.05)
# ax2.set_ylim(0, 1.05)
# 
# if SAVE_FIGURES_TO:
#     plt.savefig(Path(SAVE_FIGURES_TO)/'precision_and_recall.png')
# else:
#     plt.show()

In [None]:
results_df[
    (results_df["algorithm_name"] == "GSM (purge, rw)") &
    (results_df["glitch_percent"] == 1.0)
]["learned_correctly"].value_counts()

In [None]:
for r, not_ic_result in results_df[results_df["learned_model_input_complete"] == False].iterrows():
    learning_info = dict(not_ic_result["detailed_learning_info"])
    for r, i in learning_info.items():
        print(f"{r}: {i.get('num_additional_traces_preprocessing_input_completeness', None)}")

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4), sharey=True)
handles, labels = bar_chart_pd(
    ax1, results_df, "timed_out", agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
    xlabel='Glitch percentage', ylabel="Percentage of timed-out results", y_as_percentage=True, x_as_percentage=True,
    title="", legend=False, bar_label_fontsize=6
)

bar_chart_pd(
    ax2, results_df, "timed_out", agg_method="mean", group_by=["original_automaton_size", "algorithm_name"],
    xlabel='Number of states in the ground truth model', ylabel="Percentage of timed-out results", y_as_percentage=True,
    title="", legend=False, x_as_percentage=False,bar_label_fontsize=6
)

fig.suptitle("Time-Outs", y=1.25)

legend_columns = len(ALL_ALGORITHMS_IN_RESULTS) / 2
if legend_columns <= 2:
    if len(ALL_ALGORITHMS_IN_RESULTS) <= 5:
        legend_columns = len(ALL_ALGORITHMS_IN_RESULTS)
        
fig.legend(
    handles=handles,
    labels=labels,
    title="Algorithm",
    loc="upper center",
    bbox_to_anchor=(0.5, 1.2),
    ncol=legend_columns,
)

ax1.set_ylim(0, 105)
ax2.set_ylim(0, 105)

ax1.set_title("Per glitch percentage", fontsize=12)
ax2.set_title("Per number of states", fontsize=12)

if SAVE_FIGURES_TO:
    plt.savefig(Path(SAVE_FIGURES_TO)/'timeouts.png', bbox_inches='tight')
else:
    plt.show()

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(9, 7), sharex=True)

timed_out_results = results_df[results_df['timed_out']]

handles, labels = bar_chart_pd(
    ax1, timed_out_results, "learned_correctly", agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
    xlabel='Glitch percentage', ylabel="Percentage of bisimilar timed-out results", y_as_percentage=True, x_as_percentage=True,
    title="Bisimilarity of timed-out results", legend=True, show_num_results_on_bar=True,
)

bar_chart_pd(
    ax2, timed_out_results, ACCURACY_KEY, agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
    xlabel='Glitch percentage', ylabel="Mean Accuracy over timed-out results", x_as_percentage=True,
    y_as_percentage=False, title="Accuracy of timed-out results",
    legend=False, show_num_results_on_bar=True,
)

ax1.set_ylim(0, 105)
ax2.set_ylim(0, 1.05)

if SAVE_FIGURES_TO:
    plt.savefig(Path(SAVE_FIGURES_TO)/'timed_out_correctness.png')
else:
    plt.show()

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(9, 7), sharex=True)

not_timed_out_results = results_df[results_df['timed_out'] == False]

handles, labels = bar_chart_pd(
    ax1, not_timed_out_results, "learned_correctly", agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
    xlabel='Glitch percentage', ylabel="Percentage of bisimilar non-timed-out results", y_as_percentage=True, x_as_percentage=True,
    title="Bisimilarity of non-timed-out results", legend=True, show_num_results_on_bar=True,
)

bar_chart_pd(
    ax2, not_timed_out_results, ACCURACY_KEY, agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
    xlabel='Glitch percentage', ylabel="Mean Accuracy over non-timed-out results", x_as_percentage=True,
    y_as_percentage=False, title="Accuracy of non-timed-out results", show_num_results_on_bar=True,
    legend=False,
)

ax1.set_ylim(0, 105)
ax2.set_ylim(0, 1.05)

print(labels)

if SAVE_FIGURES_TO:
    plt.savefig(Path(SAVE_FIGURES_TO)/'not_timed_out_correctness.png')
else:
    plt.show()

In [None]:
for alg_name in not_timed_out_results["algorithm_name"].unique():
    print(alg_name)
    for gp in [1,5,10]:
        print(f"{gp}%")
        print(len(not_timed_out_results[not_timed_out_results["algorithm_name"] == alg_name][not_timed_out_results["glitch_percent"] == gp]))

In [None]:
# fig, ax1 = plt.subplots(1, 1, figsize=(9, 5), sharex=True)
# 
# not_bisim_out_results = results_df[results_df['learned_correctly'] == False]
# 
# bar_chart_pd(
#     ax1, not_bisim_out_results, ACCURACY_KEY, agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
#     xlabel='Glitch percentage', ylabel="Mean Accuracy over non-bisimilar results", x_as_percentage=True,
#     y_as_percentage=False, title="Accuracy of non-bisimilar results",
#     legend=True,
# )
# 
# ax1.set_ylim(0, 1.05)
# 
# if SAVE_FIGURES_TO:
#     plt.savefig(Path(SAVE_FIGURES_TO)/'not_bisim_correctness.png')
# else:
#     plt.show()

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(9, 8), sharex=True)
handles, labels = bar_chart_pd(
    ax1, results_df, "learned_correctly", agg_method="mean", group_by=["original_automaton_size", "algorithm_name"],
    xlabel='Number of states in the ground truth model', ylabel="Percentage of bisimilar results", y_as_percentage=True, x_as_percentage=False,
    title="Bisimilarity", legend=True,
)

bar_chart_pd(
    ax2, results_df, ACCURACY_KEY, agg_method="mean", group_by=["original_automaton_size", "algorithm_name"],
    xlabel='Number of states in the ground truth model', ylabel="Mean Accuracy over all results", x_as_percentage=False,
    y_as_percentage=False, title="Accuracy",
    legend=False,
)

ax1.set_ylim(0, 105)
ax2.set_ylim(0, 1.05)

if SAVE_FIGURES_TO:
    plt.savefig(Path(SAVE_FIGURES_TO)/'correctness_over_states.png')
else:
    plt.show()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

ax = sns.boxplot(
    data=results_df,
    x="glitch_percent",
    y=ACCURACY_KEY,
    hue="algorithm_name",
    # fill=False,  # this colors the whiskers, but not the boxes itself - maybe i can use this and apply hatch myself?
    palette=ALGORITHM_COLORS,
    hue_order=ALL_ALGORITHMS_IN_RESULTS,
    flierprops=dict(marker='o', markersize=3, alpha=0.3),
)
plt.title("Accuracy per Run")

ax.set_ylabel("Accuracy")
ax.set_xlabel("Glitch percentage")

def percent_label(label):
    return f"{int(float(label))}%"

ax.set_xticklabels([percent_label(t.get_text()) for t in ax.get_xticklabels()])

l = ax.legend(title="Algorithm")

# set hatches (& colors) # CAUTION: this might not work with other matplotlib / seaborn versions!
patches = [p for p in ax.patches if type(p) == mpl.patches.PathPatch]
num_groups = len(results_df["glitch_percent"].unique())
for alg_idx, alg in enumerate(ALL_ALGORITHMS_IN_RESULTS):
    hatch = ALGORITHM_HATCHES[alg]
    color = ALGORITHM_COLORS[alg]
    for patch in patches[(alg_idx * num_groups):((alg_idx + 1) * num_groups)]:
        if hatch:
            # patch.set_fill(True)
            patch.set_facecolor("none")
            patch.set_hatch(hatch)
            patch.set_edgecolor(color)
            patch.set_linewidth(1.2)

# fix legend for hatches
for lp, hatch in zip(l.get_patches(), ALGORITHM_HATCHES.values()):
    lp.set_hatch((hatch + hatch[0] * 2))
    fc = lp.get_facecolor()
    lp.set_edgecolor(fc)
    lp.set_facecolor('none')

if SAVE_FIGURES_TO:
    plt.savefig(Path(SAVE_FIGURES_TO)/'accuracy_boxplot.png')
else:
    plt.show()

In [None]:
quartiles = (
    results_df.groupby(['algorithm_name', 'glitch_percent'])[ACCURACY_KEY]
    .quantile([0.25, 0.5, 0.75])
    .unstack(level=2)
    .rename(columns={0.25: 'Q1', 0.5: 'Median', 0.75: 'Q3'})
    .reset_index()
)

quartiles

In [None]:
def compute_boxplot_stats(group):
    scores = group[ACCURACY_KEY]
    q1 = scores.quantile(0.25)
    q3 = scores.quantile(0.75)
    iqr = q3 - q1
    lower_whisker = scores[scores >= q1 - 1.5 * iqr].min()
    upper_whisker = scores[scores <= q3 + 1.5 * iqr].max()
    outliers = scores[(scores < lower_whisker) | (scores > upper_whisker)].tolist()

    return pd.Series({
        'Q1': q1,
        'Q3': q3,
        'Median': scores.median(),
        'IQR': iqr,
        'Lower Whisker': lower_whisker,
        'Upper Whisker': upper_whisker,
        "Number of Outliers": len(outliers),
        'Outliers': outliers
    })

# Group and apply the function
boxplot_stats = (
    results_df.groupby(['algorithm_name', 'glitch_percent'])
      .apply(compute_boxplot_stats)
      .reset_index()
)
boxplot_stats

## Efficiency

In [None]:
# fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(9, 8), sharex=True)
# handles, labels = bar_chart_pd(
#     ax1, results_df, "steps_learning", agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
#     xlabel='Glitch percentage', ylabel="Mean number of steps per run", x_as_percentage=True,
#     title="Mean number of steps", legend=True, bar_label_decimal_digits=1,
# )
# 
# bar_chart_pd(
#     ax2, results_df, "queries_learning", agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
#     xlabel='Glitch percentage', ylabel="Mean number of queries per run",
#     title="Mean number of queries", legend=False, bar_label_decimal_digits=1, x_as_percentage=True,
# )

fig, ax1 = plt.subplots(1, 1, figsize=(9, 5), sharex=True)
handles, labels = bar_chart_pd(
    ax1, results_df, "steps_learning", agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
    xlabel='Glitch percentage', ylabel="Mean number of steps per run", x_as_percentage=True,
    title="Mean number of steps", legend=True, bar_label_decimal_digits=1,
)

if SAVE_FIGURES_TO:
    plt.savefig(Path(SAVE_FIGURES_TO)/'steps.png')
else:
    plt.show()

In [None]:
fig, ax1 = plt.subplots(1, 1, figsize=(9, 5), sharex=True)
handles, labels = bar_chart_pd(
    ax1, not_timed_out_results, "steps_learning", agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
    xlabel='Glitch percentage', ylabel="Mean number of steps per run in non-timed-out results", x_as_percentage=True,
    title="Mean number of steps (without timeouts)", legend=True, bar_label_decimal_digits=1, show_num_results_on_bar=True,
)

if SAVE_FIGURES_TO:
    plt.savefig(Path(SAVE_FIGURES_TO)/'steps_without_timeouts.png')
else:
    plt.show()

In [None]:
results_df["InformationPerStep"] = results_df[ACCURACY_KEY] / results_df["steps_learning"]
results_df["InformationPerQuery"] = results_df[ACCURACY_KEY] / results_df["queries_learning"]

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(9, 8), sharex=True)
handles, labels = bar_chart_pd(
    ax1, results_df, "InformationPerStep", agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
    xlabel='Glitch percentage', ylabel="Information per step", x_as_percentage=True,
    title="Information per step", legend=True, bar_label_decimal_digits=5,
)

bar_chart_pd(
    ax2, results_df, "InformationPerQuery", agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
    xlabel='Glitch percentage', ylabel="Information per query",
    title="Information per query", legend=False, bar_label_decimal_digits=5, x_as_percentage=True,
)

if False:
    plt.savefig(Path(SAVE_FIGURES_TO)/'info_per_steps_and_queries.png')
else:
    plt.show()

In [None]:
def relplot(df, columns: list[str]):
    ROWS = 2
    MANUAL_LEGEND = True
    GLITCH_PERCENTAGE_NAME = "Glitch percentage"
    SHOW_GRID = False
    
    df[GLITCH_PERCENTAGE_NAME] = df["glitch_percent"].apply(lambda x: f"{int(x)}%")  # so we can simply use the column title for relplot, and its already formatted
    col_order = columns
    
    kwargs = dict()
    if MANUAL_LEGEND:
        kwargs["facet_kws"] = dict(legend_out=False)
    
    g = sns.relplot(
        df, 
        x="steps_learning", y=ACCURACY_KEY, 
        col="algorithm_name", col_wrap=ROWS, col_order=col_order,
        hue=GLITCH_PERCENTAGE_NAME, style=GLITCH_PERCENTAGE_NAME, alpha=0.5, palette=["green", "blue", "red"],
        height=4, aspect=1.2, **kwargs,
    )
    g.figure.suptitle("Steps vs. Accuracy", y=1.04)
    g.set_titles(col_template="{col_name}")
    g.set_axis_labels("Steps", "Accuracy")
    
    if MANUAL_LEGEND:
        # sns.move_legend(
        #     g, loc="upper center", 
        #     title=None,
        #     bbox_to_anchor=(0.5, 1.03),  # Centered above the entire figure
        #     bbox_transform=g.figure.transFigure,
        #     frameon=True,
        #     ncol=results_df[GLITCH_PERCENTAGE_NAME].nunique(),
        # )#, edgecolor='0.8', fancybox=True)
        
        # remove the existing legend
        g._legend.remove()  
    
        handles, labels = g.axes[0].get_legend_handles_labels()
        
        # Insert title into first entry as a dummy handle
        from matplotlib.lines import Line2D
        title_handle = Line2D([], [], linestyle='None')  # No marker, empty handle
        handles.insert(0, title_handle)
        labels.insert(0, f"{GLITCH_PERCENTAGE_NAME}:")
        
        # Add custom legend
        g.figure.legend(
            handles,
            labels,
            loc="upper center",
            bbox_to_anchor=(0.5, 1.03),
            bbox_transform=g.figure.transFigure,
            ncol=len(labels),  # All on one row, including the title
            frameon=True,
            handletextpad=0.5,
            # columnspacing=1.2,
        )
        
    if SHOW_GRID:
        for ax in g.axes.flatten():
            ax.grid(True, which='major', linestyle='--', linewidth=0.5, alpha=0.7)
            
    return g


g = relplot(results_df, columns=ALL_ALGORITHMS_IN_RESULTS)
if SAVE_FIGURES_TO:
    g.savefig(Path(SAVE_FIGURES_TO)/'steps_vs_accuracy.png')
    
g2 = relplot(results_df, columns=[a for a in ALL_ALGORITHMS_IN_RESULTS if a.startswith("APMSL")])
if SAVE_FIGURES_TO:
    g2.savefig(Path(SAVE_FIGURES_TO)/'steps_vs_accuracy_apmsl.png')

In [None]:
import seaborn as sns

fig, ax = plt.subplots(figsize=(9, 6))
sns.scatterplot(
    data=results_df,
    x="steps_learning",
    y=ACCURACY_KEY,
    hue="algorithm_name",
    style="glitch_percent",
    #size="queries_learning",
    palette=ALGORITHM_COLORS,
    ax=ax,
)
ax.set_title("Accuracy vs. Steps")
# ax.set_xlabel("Steps")
# ax.set_ylabel("Accuracy")
# plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.legend()
plt.tight_layout()

In [None]:
results_df[["algorithm_name", "glitch_percent", "original_automaton"]].drop_duplicates().values  # unique combos

group_cols = ["algorithm_name", "glitch_percent", "original_automaton"]
target_metric = ACCURACY_KEY

divergence = results_df.groupby(group_cols)[target_metric].std().reset_index()
divergence = divergence.rename(columns={target_metric: "within_group_std"})

stats_per_unique_run_combo = results_df.groupby(group_cols)[target_metric].agg(
    mean="mean", std="std", min="min", max="max",
    range=lambda x: x.max() - x.min()
).reset_index()

In [None]:
unique_models_df = (
    results_df[["original_automaton", "original_automaton_size", "original_automaton_num_outputs"]]
    .drop_duplicates()
    .sort_values(by=["original_automaton_size", "original_automaton_num_outputs"])
    .reset_index(drop=True)
)
unique_models_df["model_index"] = np.arange(len(unique_models_df))
model_to_int = dict(zip(unique_models_df["original_automaton"], unique_models_df["model_index"]))
results_df["model_index"] = results_df["original_automaton"].map(model_to_int)

fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(15, 12), sharex=True, sharey=True)
axes = axes.flatten()

sizes = unique_models_df["original_automaton_size"].values
change_indices = np.where(np.diff(sizes) != 0)[0] + 0.5

for i, algorithm in enumerate(ALL_ALGORITHMS_IN_RESULTS):
    ax = axes[i]
    subset = results_df[results_df["algorithm_name"] == algorithm]

    sns.scatterplot(
        data=subset,
        x="model_index",
        y=ACCURACY_KEY,
        hue="glitch_percent",
        ax=ax,
        palette="plasma",
        alpha=0.6,
        legend=(i == 0)  # only show legend once
    )

    ax.set_ylim(0, 1.05)
    ax.set_xlabel("")  # common x label set below
    ax.set_ylabel("Accuracy")
    ax.set_title(algorithm)

    # Draw vertical lines for size changes
    for pos in change_indices:
        ax.axvline(x=pos, color="gray", linestyle="--", alpha=0.7)

    # improve ticks for only bottom plots
    # if i < len(ALL_ALGORITHMS_IN_RESUlTS) - 2:
    #     ax.set_xticklabels([])

# Common x-label and adjust layout  # TODO very manual...
fig.text(0.5, -0.04, "Model Index (sorted by number of states)", ha='center', fontsize=12)
fig.text(0.04, 0.5, "Accuracy", va='center', rotation='vertical', fontsize=12)

if SAVE_FIGURES_TO:
    plt.savefig(Path(SAVE_FIGURES_TO)/'accuracy_per_model_scatterplot.png')
else:
    plt.show()


In [None]:
durations = [0, 1, 2, 30]
fig, axes = plt.subplots(len(durations), 1, figsize=(9, 3 * len(durations)), sharex=True)

for i, sec_per_step in enumerate(durations):
    col_name = f"FictionalRuntime_{sec_per_step}Sec"
    results_df[col_name] = results_df["total_time"] + (results_df["steps_learning"] * sec_per_step)

    bar_chart_pd(
        axes[i], results_df, col_name, agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
        xlabel='Glitch percentage', ylabel="Mean fictional runtime (sec)",
        x_as_percentage=True, legend=(i == 0), bar_label_decimal_digits=0, bar_label_fontsize=9,
        title=f"Mean Fictional Runtime ({sec_per_step} sec/step)"
    )

if SAVE_FIGURES_TO:
    plt.savefig(Path(SAVE_FIGURES_TO)/'fictional_runtime.png')
else:
    plt.show()

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

step_durations = np.arange(0, 31)  # from 0 to 30 seconds

plot_data = []
for sec_per_step in step_durations:
    temp_df = results_df.copy()
    temp_df["fictional_runtime"] = temp_df["total_time"] + temp_df["steps_learning"] * sec_per_step

    grouped = (
        temp_df
        .groupby(["glitch_percent", "algorithm_name"])["fictional_runtime"]
        .mean()
        .reset_index()
    )
    grouped["sec_per_step"] = sec_per_step
    plot_data.append(grouped)

combined_df = pd.concat(plot_data)

# unique_glitches = sorted(combined_df["glitch_percent"].unique())
unique_glitches = [1, 5]
n = len(unique_glitches)

fig, axes = plt.subplots(n, 1, figsize=(10, 4 * n), sharex=True)
    
for ax, glitch in zip(axes, unique_glitches):
    subset = combined_df[combined_df["glitch_percent"] == glitch]
    
    for algorithm_name, group in subset.groupby("algorithm_name"):
        ax.plot(group["sec_per_step"], group["fictional_runtime"], label=algorithm_name, color=ALGORITHM_COLORS[algorithm_name])

    ax.set_title(f"Fictional runtime vs. per-step duration ({glitch:.0f}% glitches)")
    ax.set_ylabel("Fictional runtime (sec)")
    ax.grid(True)
    ax.legend()

axes[-1].set_xlabel("Assumed time per step (sec)")

if SAVE_FIGURES_TO:
    plt.savefig(Path(SAVE_FIGURES_TO)/'fictional_runtime_linecharts_by_glitch.png')
else:
    plt.show()

In [None]:
fig, ax1 = plt.subplots(1, 1, figsize=(9, 5), sharex=True)
handles, labels = bar_chart_pd(
    ax1, results_df, "learning_rounds", agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
    xlabel='Glitch percentage', ylabel="Mean number of learning rounds per run", x_as_percentage=True,
    title="Mean number of learning rounds", legend=True, bar_label_decimal_digits=1,
)

if False:
    plt.savefig(Path(SAVE_FIGURES_TO)/'learning_rounds.png')
else:
    plt.show()

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(9, 8), sharex=True)
handles, labels = bar_chart_pd(
    ax1, results_df, "steps_eq_oracle", agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
    xlabel='Glitch percentage', ylabel="Mean number of steps in the EQ oracle per round", x_as_percentage=True,
    title="Mean number of EQ oracle steps", legend=True, bar_label_decimal_digits=1,
)

bar_chart_pd(
    ax2, results_df, "queries_eq_oracle", agg_method="mean", group_by=["glitch_percent", "algorithm_name"],
    xlabel='Glitch percentage', ylabel="Mean number of queries in the EQ oracle per round", x_as_percentage=True,
    title="Mean number of EQ oracle queries", legend=True, bar_label_decimal_digits=1,
)

if False:
    plt.savefig(Path(SAVE_FIGURES_TO)/'learning_rounds.png')
else:
    plt.show()