# Imports

In [None]:
# standard library imports
# /

# related third party imports
import structlog

# local application/library specific imports
from tools.configurator import (
    get_configs_out,
    get_config_ids,
)
from tools.analyzer import (
    print_table_from_dict,
    print_df_from_dict,
    get_results_dict,
    merge_all_results,
    create_config_id_print,
    get_config_df,
    get_llm_student_preds,
)
from tools.plotter import (
    plot_llm_student_confusion,
    plot_kt_confusion,
    plot_level_correctness,
)


logger = structlog.get_logger(__name__)

In [None]:
##### INPUTS #####
EXP_NAME = (
    # "replication_gemini_20250628-091441"
    # "replication_sonnet_20250629-084225"
    # "replication_ollama_gemma_20250628-142438"
    # "replication_ollama_qwen_20250627-164511"
    # "replication_o3_o4_20250628-143855"
    # "replication_cloud_others_20250624-184518"
    # "replication_ollama_20250624-210701"
    # "replication_cloud_20250624-155854"
    # "replication_ollama_20250624-133329"
    # "replication_misconceptions_20250806-170742"
    # "replication_misconceptions_tryout_20250807-152718"
    # "replication_misconceptions_20250808-171245"
    # "replication_misconceptions_20250811-224100"
    # "replication_snippets_20250812-090413"
    "replication_misconceptions_20250812-111219"
)
EXCLUDE_METRICS = [
    "val_acc_true_student",
    "val_acc_true_pred",
    "val_f1_true_student",
    "val_f1_true_pred",
    "val_f1_student_pred",
    "val_f1_kt",
]
LEGEND_EXACT = True
PROBLEM_TYPE = "replicate"

In [None]:
METRIC2LEGEND_DICT = {
    "val_acc_student_pred": "val acc",
    "val_acc_true_student": "val acc student -> true",
    "val_acc_true_pred": "val acc LLM -> true",
    "val_f1_student_pred": "val f1 LLM -> student",
    "val_f1_true_student": "val f1 student -> true",
    "val_f1_true_pred": "val f1 LLM -> true",
    "val_acc_kt": "val acc (KT)",
    "val_f1_kt": "val f1 LLM -> student (KT)",
    "val_prop_invalid": "val prop invalid",
}

In [None]:
configs = get_configs_out(EXP_NAME)
config_ids = get_config_ids(configs, problem_type=PROBLEM_TYPE)
config_dict = {config_id: cfg for config_id, cfg in zip(config_ids, configs)}

CONFIG2LEGEND_DICT = {
    config_id: create_config_id_print(config_id) for config_id in config_ids
}
legend_kwargs = {
    "config2legend": CONFIG2LEGEND_DICT,
    "legend_exact": LEGEND_EXACT,
    "metric2legend": METRIC2LEGEND_DICT,
}

In [None]:
# merge results for all configs
run_id_dict = merge_all_results(EXP_NAME, config_ids)

# Val/Test set performance

In [None]:
results_dict = get_results_dict(
    exp_name=EXP_NAME,
    config_ids=config_ids,
    run_id=None,
)
# # NOTE: print paper-like table with this code
# print_table_from_dict(
#     eval_dict=results_dict,
#     exp_name=EXP_NAME,
#     exclude_metrics=EXCLUDE_METRICS,
#     decimals=3,
#     **legend_kwargs,
# )

In [None]:
# NOTE: print dataframe
df = print_df_from_dict(
    eval_dict=results_dict,
    exp_name=EXP_NAME,
    exclude_metrics=EXCLUDE_METRICS,
    **legend_kwargs,
    # save=True,
    # save_kwargs={"fname": os.path.join("output", EXP_NAME, "results.csv")},
)

df_config = get_config_df(config_dict)

# mean
df_mean = df.xs("mean", axis=1, level=1, drop_level=True)
df_results = df_mean.merge(df_config, how="left", on="config_id")
df_results = df_results.reindex(
    columns=(
        list(df_config.columns)
        + list([a for a in df_mean.columns if a not in df_config.columns])
    )
)
df_results

In [None]:
# standard error
df_stderr = df.xs("stderr", axis=1, level=1, drop_level=True)
df_stderr = df_stderr.merge(df_config, how="left", on="config_id")
df_stderr = df_stderr.reindex(
    columns=(
        list(df_config.columns)
        + list([a for a in df_stderr.columns if a not in df_config.columns])
    )
)
df_stderr

In [None]:
# inspect average performance per config value
FEATURE = "num_examples"
df_results.groupby(FEATURE).agg({"val acc": "mean", "val acc (KT)": "mean", "val_monotonicity": "mean"}).round(3)

In [None]:
# inspect average performance per config value
FEATURE = "model"
df_results.groupby(FEATURE).agg({"val acc": "mean", "val acc (KT)": "mean", "val_monotonicity": "mean"}).round(3)

In [None]:
# inspect average performance per config value
FEATURE = "temp"
df_results.groupby(FEATURE).agg({"val acc": "mean", "val acc (KT)": "mean", "val_monotonicity": "mean"}).round(3)

In [None]:
# inspect average performance per config value
FEATURE = "prompt"
df_results.groupby(FEATURE).agg({"val acc": "mean", "val acc (KT)": "mean", "val_monotonicity": "mean"}).round(3)

In [None]:
# inspect average performance per config value
FEATURE = "example_selec"
df_results.groupby(FEATURE).agg({"val acc": "mean", "val acc (KT)": "mean", "val_monotonicity": "mean"}).round(3)

# Confusion matrices
## LLM vs student performance

In [None]:
# # single config
# config_id = "qwen3:8b~T_0.0~SO_teacher~SP_replicate_teacher_avocado~EF_quotes~ES_studentid_random3"  # TODO
# if config_id not in config_ids:
#     logger.error(f"Config ID not available", config_id=config_id)
# else:
#     preds_dict = get_llm_student_preds(
#         exp_name=EXP_NAME,
#         config_id=config_id,
#         run_id=1,
#         split="val",
#     )
#     plot_llm_student_confusion(
#         preds_dict,
#         config_id=config_id,
#         normalize="all",
#     )


In [None]:
# all configs
for config_id in config_ids:
    preds_dict = get_llm_student_preds(
        exp_name=EXP_NAME,
        config_id=config_id,
        run_id=1,
        split="val",
        problem_type=PROBLEM_TYPE,
    )
    plot_llm_student_confusion(
        preds_dict,
        config_id=config_id,
        normalize="all",
    )

## KT

In [None]:
# all configs
for config_id in config_ids:
    preds_dict = get_llm_student_preds(
        exp_name=EXP_NAME,
        config_id=config_id,
        run_id=1,
        split="val",
        problem_type=PROBLEM_TYPE,
    )
    plot_kt_confusion(
        preds_dict,
        config_id=config_id,
        normalize="all",
    )

# Student levels

In [None]:
# all configs
for config_id in config_ids:
    preds_dict = get_llm_student_preds(
        exp_name=EXP_NAME,
        config_id=config_id,
        run_id=1,
        split="val",
        problem_type=PROBLEM_TYPE,
    )
    plot_level_correctness(
        preds_dict,
        problem_type=PROBLEM_TYPE,
        config_id=config_id,
    )