# Imports

In [None]:
# standard library imports
import os

# related third party imports
import structlog

# local application/library specific imports
from tools.configurator import (
    get_configs_out,
    get_config_ids,
)
from tools.analyzer import (
    print_table_from_dict,
    print_df_from_dict,
    get_results_dict,
    merge_all_results,
    create_config_id_print,
    get_config_df,
    get_llm_student_preds,
)
from tools.plotter import (
    plot_llm_student_confusion,
    plot_kt_confusion,
    plot_level_correctness,
    activate_latex,
    deactivate_latex,
    plot_llm_correctness,
)


logger = structlog.get_logger(__name__)

In [None]:
##### INPUTS #####
EXP_NAME = (
    # "replication_misconceptions_20250806-170742"
    # "replication_misconceptions_tryout_20250807-152718"
    # "replication_misconceptions_20250808-171245"
    # "replication_misconceptions_20250811-224100"
    # "replication_snippets_20250812-090413"
    # "replication_misconceptions_20250812-111219"
    # "replication_misconceptions_20250812-210224"
    # "replication_miscon_valsmall_20250813-155505"
    # "replication_miscon_valsmall_20250816-133549"
    "replication_miscon_test_nocontext_20250817-083834"
)
EXCLUDE_METRICS = []
LEGEND_EXACT = True
PROBLEM_TYPE = "replicate"
SANS_SERIF = True
PRINT_PAPER = False

In [None]:
# TODO: update after renaming metrics!
# METRIC2LEGEND_DICT = {
#     "val_acc_student_pred": "val acc",
#     "val_acc_true_student": "student correctness",
#     "val_acc_true_pred": "llm correctness",
#     "val_f1_student_pred": "val f1",
#     "val_f1_true_student": "val f1 student -> true",
#     "val_f1_true_pred": "val f1 LLM -> true",
#     "val_acc_kt": "val acc (KT)",
#     "val_f1_kt": "val f1 (KT)",
#     "val_prop_invalid": "val prop invalid",
# }

METRIC2LEGEND_DICT = {
    "val_acc": "val acc",
    "val_student_correctness": "student correctness",
    "val_llm_correctness": "llm correctness",
    "val_acc_kt": "val acc (KT)",
    "val_f1_kt": "val f1 (KT)",
    "val_prop_invalid": "val prop invalid",
    "val_f1_micro": "val f1 (micro)",
    "val_f1_macro": "val f1 (macro)",
    "val_f1_weighted": "val f1 (weighted)"
}

In [None]:
configs = get_configs_out(EXP_NAME)
config_ids = get_config_ids(configs, problem_type=PROBLEM_TYPE)
config_dict = {config_id: cfg for config_id, cfg in zip(config_ids, configs)}

CONFIG2LEGEND_DICT = {
    config_id: create_config_id_print(config_id) for config_id in config_ids
}
legend_kwargs = {
    "config2legend": CONFIG2LEGEND_DICT,
    "legend_exact": LEGEND_EXACT,
    "metric2legend": METRIC2LEGEND_DICT,
}

In [None]:
# merge results for all configs
run_id_dict = merge_all_results(EXP_NAME, config_ids)

# Val/Test set performance

In [None]:
results_dict = get_results_dict(
    exp_name=EXP_NAME,
    config_ids=config_ids,
    run_id=None,
)
# # NOTE: print paper-like table with this code
# print_table_from_dict(
#     eval_dict=results_dict,
#     exp_name=EXP_NAME,
#     exclude_metrics=EXCLUDE_METRICS,
#     decimals=3,
#     **legend_kwargs,
# )

In [None]:
# NOTE: print dataframe
df = print_df_from_dict(
    eval_dict=results_dict,
    exp_name=EXP_NAME,
    exclude_metrics=EXCLUDE_METRICS,
    **legend_kwargs,
    # save=True,
    # save_kwargs={"fname": os.path.join("output", EXP_NAME, "results.csv")},
)

df_config = get_config_df(config_dict)

# mean
df_mean = df.xs("mean", axis=1, level=1, drop_level=True)
df_results = df_mean.merge(df_config, how="left", on="config_id")
df_results = df_results.reindex(
    columns=(
        list(df_config.columns)
        + list([a for a in df_mean.columns if a not in df_config.columns])
    )
)
df_results

In [None]:
# standard error
df_stderr = df.xs("stderr", axis=1, level=1, drop_level=True)
df_stderr = df_stderr.merge(df_config, how="left", on="config_id")
df_stderr = df_stderr.reindex(
    columns=(
        list(df_config.columns)
        + list([a for a in df_stderr.columns if a not in df_config.columns])
    )
)
df_stderr

In [None]:
df_results_context = df_results[df_results["num_examples"] > 0]
agg_dict = {"val f1 (macro)": "mean", "val f1 (KT)": "mean", "llm correctness": "mean"}

In [None]:
# inspect average performance per config value
FEATURE = "num_examples"
# TODO: only look at contextual models
df_results_context.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
# inspect average performance per config value
FEATURE = "model"
df_results_context.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
# inspect average performance per config value
FEATURE = "temp"
df_results_context.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
# inspect average performance per config value
FEATURE = "prompt"
df_results_context.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
# inspect average performance per config value
FEATURE = "example_selec"
df_results_context.groupby(FEATURE).agg(agg_dict).round(3)

# Confusion matrices
## LLM vs student performance

In [None]:
# # single config
# config_id = "qwen3:8b~T_0.0~SO_teacher~SP_replicate_teacher_avocado~EF_quotes~ES_studentid_random3"  # TODO
# if config_id not in config_ids:
#     logger.error(f"Config ID not available", config_id=config_id)
# else:
#     preds_dict = get_llm_student_preds(
#         exp_name=EXP_NAME,
#         config_id=config_id,
#         run_id=1,
#         split="val",
#     )
#     plot_llm_student_confusion(
#         preds_dict,
#         config_id=config_id,
#         normalize="all",
#     )


In [None]:
# # all configs
# for config_id in config_ids:
#     preds_dict = get_llm_student_preds(
#         exp_name=EXP_NAME,
#         config_id=config_id,
#         run_id=1,
#         split="val",
#         problem_type=PROBLEM_TYPE,
#     )
#     plot_llm_student_confusion(
#         preds_dict,
#         config_id=config_id,
#         normalize="all",
#     )

## KT

In [None]:
# # all configs
# for config_id in config_ids:
#     preds_dict = get_llm_student_preds(
#         exp_name=EXP_NAME,
#         config_id=config_id,
#         run_id=1,
#         split="val",
#         problem_type=PROBLEM_TYPE,
#     )
#     plot_kt_confusion(
#         preds_dict,
#         config_id=config_id,
#         normalize="all",
#     )

# Student levels

In [None]:
# # all configs
# for config_id in config_ids:
#     preds_dict = get_llm_student_preds(
#         exp_name=EXP_NAME,
#         config_id=config_id,
#         run_id=1,
#         split="val",
#         problem_type=PROBLEM_TYPE,
#     )
#     plot_level_correctness(
#         preds_dict,
#         problem_type=PROBLEM_TYPE,
#         config_id=config_id,
#     )

## LLM vs student correctness

In [None]:
model_family="qwen3"
plot_llm_correctness(
    df_results=df_results,
    model_family=model_family,
)

In [None]:
model_family="llama"
plot_llm_correctness(
    df_results=df_results,
    model_family=model_family,
)

In [None]:
LLM_FAMILIES = ["qwen3", "llama"]

if PRINT_PAPER:
    activate_latex(sans_serif=SANS_SERIF)
    ########
    # plot for every config_id
    for llm_family in LLM_FAMILIES:

        fname = os.path.join(
            "output", EXP_NAME, "figures", f"llm_correctness_{llm_family}.pdf"
        )
        plot_llm_correctness(
            df_results=df_results,
            model_family=llm_family,
            savename=fname,
        )
    ########
    deactivate_latex()