# Imports

In [None]:
# standard library imports
import os

# related third party imports
import structlog

# local application/library specific imports
from tools.configurator import (
    get_configs_out,
    get_config_ids,
)
from tools.analyzer import (
    print_table_from_dict,
    print_df_from_dict,
    get_results_dict,
    merge_all_results,
    create_config_id_print,
    get_config_df,
    get_llm_student_preds,
)
from tools.plotter import (
    plot_llm_student_confusion,
    plot_kt_confusion,
    plot_level_correctness,
    activate_latex,
    deactivate_latex,
    plot_llm_correctness_by_model,
)


logger = structlog.get_logger(__name__)

In [None]:
##### INPUTS #####
EXP_NAME = (
    # "replication_miscon_valsmall_kt_20250818-153513"
    # "replication_miscon_vallarge_kt_20250819-115545"
    "replication_miscon_test_kt_20250819-192630"
)
SPLIT = "val"
EXCLUDE_METRICS = []
LEGEND_EXACT = True
PROBLEM_TYPE = "replicate"
SANS_SERIF = True
PRINT_PAPER = False

In [None]:
METRIC2LEGEND_DICT = {
    f"{SPLIT}_acc": f"{SPLIT} acc",
    f"{SPLIT}_bal_acc": f"{SPLIT} bal acc",
    f"{SPLIT}_student_correctness": f"student correctness",
    f"{SPLIT}_llm_correctness": f"llm correctness",
    f"{SPLIT}_acc_kt": f"{SPLIT} acc (KT)",
    f"{SPLIT}_bal_acc_kt": f"{SPLIT} bal acc (KT)",
    f"{SPLIT}_f1_kt": f"{SPLIT} f1 (KT)",
    f"{SPLIT}_prop_invalid": f"{SPLIT} prop invalid",
    f"{SPLIT}_f1_micro": f"{SPLIT} f1 (micro)",
    f"{SPLIT}_f1_macro": f"{SPLIT} f1 (macro)",
    f"{SPLIT}_f1_weighted": f"{SPLIT} f1 (weighted)"
}

In [None]:
configs = get_configs_out(EXP_NAME)
config_ids = get_config_ids(configs, problem_type=PROBLEM_TYPE)
config_dict = {config_id: cfg for config_id, cfg in zip(config_ids, configs)}

CONFIG2LEGEND_DICT = {
    config_id: create_config_id_print(config_id) for config_id in config_ids
}
legend_kwargs = {
    "config2legend": CONFIG2LEGEND_DICT,
    "legend_exact": LEGEND_EXACT,
    "metric2legend": METRIC2LEGEND_DICT,
}

In [None]:
# merge results for all configs
run_id_dict = merge_all_results(EXP_NAME, config_ids)

# Val/Test set performance

In [None]:
results_dict = get_results_dict(
    exp_name=EXP_NAME,
    config_ids=config_ids,
    run_id=None,
)
# # NOTE: print paper-like table with this code
# print_table_from_dict(
#     eval_dict=results_dict,
#     exp_name=EXP_NAME,
#     exclude_metrics=EXCLUDE_METRICS,
#     decimals=3,
#     **legend_kwargs,
# )

In [None]:
# NOTE: print dataframe
df = print_df_from_dict(
    eval_dict=results_dict,
    exp_name=EXP_NAME,
    exclude_metrics=EXCLUDE_METRICS,
    **legend_kwargs,
    # save=True,
    # save_kwargs={"fname": os.path.join("output", EXP_NAME, "results.csv")},
)

df_config = get_config_df(config_dict)

# mean
df_mean = df.xs("mean", axis=1, level=1, drop_level=True)
df_results = df_mean.merge(df_config, how="left", on="config_id")
df_results = df_results.reindex(
    columns=(
        list(df_config.columns)
        + list([a for a in df_mean.columns if a not in df_config.columns])
    )
)
df_results

In [None]:
# standard error
df_stderr = df.xs("stderr", axis=1, level=1, drop_level=True)
df_stderr = df_stderr.merge(df_config, how="left", on="config_id")
df_stderr = df_stderr.reindex(
    columns=(
        list(df_config.columns)
        + list([a for a in df_stderr.columns if a not in df_config.columns])
    )
)
df_stderr

## Contextual models

In [None]:
df_results_context = df_results[df_results["num_examples"] > 0]
if "kt" in EXP_NAME:
    agg_dict = {f"{SPLIT} bal acc (KT)": "mean", f"{SPLIT} f1 (KT)": "mean", "llm correctness": "mean"}
else:
    agg_dict = {f"{SPLIT} f1 (macro)": "mean", f"{SPLIT} f1 (KT)": "mean", "llm correctness": "mean"}

In [None]:
# inspect average performance per config value
FEATURE = "num_examples"
# TODO: only look at contextual models
df_results_context.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
# inspect average performance per config value
FEATURE = "model"
df_results_context.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
# inspect average performance per config value
FEATURE = "temp"
df_results_context.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
# inspect average performance per config value
FEATURE = "prompt"
df_results_context.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
# inspect average performance per config value
FEATURE = "prompt"
df_results_context.groupby(["model", FEATURE], as_index=False).agg(agg_dict).round(3)

In [None]:
# inspect average performance per config value
FEATURE = "example_selec"
df_results_context.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
# get best-performing config per model, per prompt persona

# Get the index of best-performing config per model-prompt pair
if "kt" in EXP_NAME:
    metric = f"{SPLIT} bal acc (KT)"
else:
    metric = f"{SPLIT} bal acc"
best_indices = df_results_context.groupby(["model"])[metric].idxmax()

# Get the full rows for these best-performing configs
best_configs = df_results_context.loc[best_indices]
best_configs

In [None]:
# best 3 configs per LLM
(
    df_results_context.groupby(["model"])
    .apply(lambda x: x.nlargest(3, metric))
    .reset_index(drop=True)
)

## Non-contextual models

In [None]:
df_results_nocontext = df_results[df_results["num_examples"] == 0]
agg_dict = {f"{SPLIT} f1 (macro)": "mean", f"{SPLIT} f1 (KT)": "mean", "llm correctness": "mean"}

In [None]:
if not df_results_nocontext.empty:
    # inspect average performance per config value
    FEATURE = "model"
    df_results_nocontext.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
if not df_results_nocontext.empty:
    # inspect average performance per config value
    FEATURE = "temp"
    df_results_nocontext.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
if not df_results_nocontext.empty:
    # inspect average performance per config value
    FEATURE = "prompt"
    df_results_nocontext.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
if not df_results_nocontext.empty:
    # inspect average performance per config value
    FEATURE = "example_selec"
    df_results_nocontext.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
if not df_results_nocontext.empty:
    # get best-performing config per model, per prompt persona

    # Get the index of best-performing config per model-prompt pair
    best_indices = df_results_nocontext.groupby(["model", "prompt"])[f"{SPLIT} f1 (macro)"].idxmax()

    # Get the full rows for these best-performing configs
    best_configs = df_results_nocontext.loc[best_indices]
    best_configs

# Confusion matrices
## LLM vs student performance

In [None]:
# # single config
# config_id = "qwen3:8b~T_0.0~SO_teacher~SP_replicate_teacher_avocado~EF_quotes~ES_studentid_random3"  # TODO
# if config_id not in config_ids:
#     logger.error(f"Config ID not available", config_id=config_id)
# else:
#     preds_dict = get_llm_student_preds(
#         exp_name=EXP_NAME,
#         config_id=config_id,
#         run_id=1,
#         split="val",
#     )
#     plot_llm_student_confusion(
#         preds_dict,
#         config_id=config_id,
#         normalize="all",
#     )


In [None]:
# # all configs
# for config_id in config_ids:
#     preds_dict = get_llm_student_preds(
#         exp_name=EXP_NAME,
#         config_id=config_id,
#         run_id=1,
#         split="val",
#         problem_type=PROBLEM_TYPE,
#     )
#     plot_llm_student_confusion(
#         preds_dict,
#         config_id=config_id,
#         normalize="all",
#     )

## KT

In [None]:
# # all configs
# for config_id in config_ids:
#     preds_dict = get_llm_student_preds(
#         exp_name=EXP_NAME,
#         config_id=config_id,
#         run_id=1,
#         split="val",
#         problem_type=PROBLEM_TYPE,
#     )
#     plot_kt_confusion(
#         preds_dict,
#         config_id=config_id,
#         normalize="all",
#     )

# Student levels

In [None]:
# # all configs
# for config_id in config_ids:
#     preds_dict = get_llm_student_preds(
#         exp_name=EXP_NAME,
#         config_id=config_id,
#         run_id=1,
#         split="val",
#         problem_type=PROBLEM_TYPE,
#     )
#     plot_level_correctness(
#         preds_dict,
#         problem_type=PROBLEM_TYPE,
#         config_id=config_id,
#     )

# Distributions
## LLM vs student correctness
### Contextual

In [None]:
model_family="qwen3"
plot_llm_correctness_by_model(
    df_results=df_results_context,
    model_family=model_family,
    exclude_models=["qwen3:1.7b"]
)

In [None]:
model_family="llama"
plot_llm_correctness_by_model(
    df_results=df_results_context,
    model_family=model_family,
)

In [None]:
LLM_FAMILIES = ["qwen3", "llama"]
EXCLUDE_MODELS = {"qwen3": ["qwen3:1.7b"], "llama": []}

if PRINT_PAPER:
    activate_latex(sans_serif=SANS_SERIF)
    ########
    # plot for every config_id
    for llm_family in LLM_FAMILIES:

        fname = os.path.join(
            "output", EXP_NAME, "figures", f"llm_correctness_{SPLIT}_{llm_family}.pdf"
        )
        plot_llm_correctness_by_model(
            df_results=df_results_context,
            model_family=llm_family,
            exclude_models=EXCLUDE_MODELS[llm_family],
            savename=fname,
        )
    ########
    deactivate_latex()

### No context

In [None]:
model_family="qwen3"
plot_llm_correctness_by_model(
    df_results=df_results_nocontext,
    model_family=model_family,
)

In [None]:
model_family="llama"
plot_llm_correctness_by_model(
    df_results=df_results_nocontext,
    model_family=model_family,
)

## Per config value

In [None]:
from typing import Optional
from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns
from tools.utils import ensure_dir


In [None]:
def plot_llm_correctness(
    df_results: pd.DataFrame,
    savename: Optional[str] = None,
) -> None:
    """Plot LLM correctness over all configurations

    Parameters
    ----------
    df_results : pd.DataFrame
        DataFrame containing the results
    model_family : str
        Name of the model family
    savename : Optional[str], optional
        Name to save plot to, by default None
    """
    df = df_results[["model", "llm correctness", "student correctness"]]

    student_correctness_scalar = df["student correctness"].values[0]

    _, ax = plt.subplots()

    # density
    sns.kdeplot(data=df, x="llm correctness", fill=True, ax=ax, palette="tab10")

    ax.axvline(student_correctness_scalar, color="black", linestyle="--")

    ax.set_xlabel("LLM answer correctness")
    ax.set_ylabel("Density")
    ax.grid(True, linestyle="--")
    # ax.legend_.set_title(None)
    # get ticks in sans-serif if sans-serif is used
    ax.xaxis.get_major_formatter()._usetex = False
    ax.yaxis.get_major_formatter()._usetex = False

    if savename is not None:
        plt.tight_layout()
        ensure_dir(os.path.dirname(savename))
        plt.savefig(savename)
    plt.show()


In [None]:
def plot_correctness_by_config(
    df_results: pd.DataFrame,
    config_value: str,
    savename: Optional[str] = None,
    hyperparamkey2legend: Optional[dict] = None,
    hyperparamvalue2legend: Optional[dict] = None,
) -> None:
    """Plot LLM correctness over all configurations

    Parameters
    ----------
    df_results : pd.DataFrame
        DataFrame containing the results
    model_family : str
        Name of the model family
    savename : Optional[str], optional
        Name to save plot to, by default None
    """
    df = df_results.copy()
    df["prompt_persona"] = df_results["prompt"].str.extract(r"^([^_]+)")
    df = df[["model", config_value, "llm correctness", "student correctness"]]
    # use map to replace values in the config column with their corresponding legend labels
    if config_value in hyperparamvalue2legend:
        df[config_value] = df[config_value].map(hyperparamvalue2legend[config_value])

    student_correctness_scalar = df["student correctness"].values[0]

    _, ax = plt.subplots()

    # density
    sns.kdeplot(data=df, x="llm correctness", hue=config_value, fill=True, ax=ax, palette="tab10")

    ax.axvline(student_correctness_scalar, color="black", linestyle="--")

    ax.set_xlabel("LLM answer correctness")
    ax.set_ylabel("Density")
    ax.grid(True, linestyle="--")
    # update legend title
    if hyperparamkey2legend is not None:
        hyperparam_print = hyperparamkey2legend.get(config_value, config_value)
    else:
        hyperparam_print = config_value
    ax.legend_.set_title(hyperparam_print)
    # get ticks in sans-serif if sans-serif is used
    ax.xaxis.get_major_formatter()._usetex = False
    ax.yaxis.get_major_formatter()._usetex = False

    if savename is not None:
        plt.tight_layout()
        ensure_dir(os.path.dirname(savename))
        plt.savefig(savename)
    plt.show()


HYPERPARAMKEY2LEGEND_DICT = {
    "temp": "Temperature",
    "prompt": "Prompt",
    "example_selec": "Example selection",
    "num_examples": "Number of examples",
    "prompt_persona": "Prompt persona",
}
HYPERPARAMVALUE2LEGEND_DICT = {
    "prompt_persona": {"teacher": "Teacher", "student": "Student"},
    "example_selec": {"miscon_studentid_random": "Random", "miscon_studentid_kc_exact": "Knowledge Concept"}
}


In [None]:
CONFIG_HUE = "prompt_persona"
plot_correctness_by_config(
    df_results=df_results_context,
    config_value=CONFIG_HUE,
    hyperparamkey2legend=HYPERPARAMKEY2LEGEND_DICT,
    hyperparamvalue2legend=HYPERPARAMVALUE2LEGEND_DICT,
)

In [None]:
CONFIG_HUE = "temp"
plot_correctness_by_config(
    df_results=df_results_context,
    config_value=CONFIG_HUE,
    hyperparamkey2legend=HYPERPARAMKEY2LEGEND_DICT,
    hyperparamvalue2legend=HYPERPARAMVALUE2LEGEND_DICT,
)

In [None]:
CONFIG_HUE = "num_examples"
plot_correctness_by_config(
    df_results=df_results_context,
    config_value=CONFIG_HUE,
    hyperparamkey2legend=HYPERPARAMKEY2LEGEND_DICT,
    hyperparamvalue2legend=HYPERPARAMVALUE2LEGEND_DICT,
)

In [None]:
CONFIG_HUE = "example_selec"
plot_correctness_by_config(
    df_results=df_results_context,
    config_value=CONFIG_HUE,
    hyperparamkey2legend=HYPERPARAMKEY2LEGEND_DICT,
    hyperparamvalue2legend=HYPERPARAMVALUE2LEGEND_DICT,
)

In [None]:
plot_llm_correctness(
    df_results=df_results_context,
)

In [None]:
CONFIG_HUES = ["prompt_persona", "temp", "example_selec", "num_examples"]
HYPERPARAMKEY2LEGEND_DICT = {
    "temp": "Temperature",
    "prompt": "Prompt",
    "example_selec": "Example selection",
    "num_examples": "Number of examples",
    "prompt_persona": "Prompt persona",
}
HYPERPARAMVALUE2LEGEND_DICT = {
    "prompt_persona": {"teacher": "Teacher", "student": "Student"},
    "example_selec": {
        "miscon_studentid_random": "Random",
        "miscon_studentid_kc_exact": "Knowledge Concept",
    },
}

if PRINT_PAPER:
    activate_latex(sans_serif=SANS_SERIF)
    ########
    # plot over all configs
    fname = os.path.join(
            "output", EXP_NAME, "figures", f"llm_correctness_{SPLIT}_all.pdf"
        )
    plot_llm_correctness(
        df_results=df_results_context,
        savename=fname,
    )

    # plot for every config hue
    for config_hue in CONFIG_HUES:

        fname = os.path.join(
            "output", EXP_NAME, "figures", f"llm_correctness_{SPLIT}_{config_hue}.pdf"
        )
        plot_correctness_by_config(
            df_results=df_results_context,
            config_value=config_hue,
            hyperparamkey2legend=HYPERPARAMKEY2LEGEND_DICT,
            hyperparamvalue2legend=HYPERPARAMVALUE2LEGEND_DICT,
            savename=fname,
        )
    ########
    deactivate_latex()