# Imports

In [None]:
# standard library imports
# /

# related third party imports
import structlog

# local application/library specific imports
from tools.configurator import (
    get_configs_out,
    get_config_ids,
)
from tools.analyzer import (
    print_table_from_dict,
    print_df_from_dict,
    get_results_dict,
    merge_all_results,
    create_config_id_print,
    get_config_df,
    check_overlap,
)
# from tools.plotter import plot_level_correctness


logger = structlog.get_logger(__name__)

In [None]:
##### INPUTS #####
EXP_NAME = (
    # context
    # "roleplay_miscon_test_kt_20250820-204306"

    # no context
    # "roleplay_miscon_test_kt_nocontext_20250821-075331"

    # merged
    # "roleplay_miscon_test_kt_merged_20250821"

    # CFE
    "roleplay_cupacfe_val_20250919-112843"
)
SPLIT = "val" if "val" in EXP_NAME else "test"
EXCLUDE_METRICS = [
    "val_acc_true_pred",
    "val_f1_true_pred",
]
LEGEND_EXACT = True
PROBLEM_TYPE = "roleplay"

In [None]:
METRIC2LEGEND_DICT = {
    f"{SPLIT}_rmse": f"{SPLIT} RMSE",
    f"{SPLIT}_mae": f"{SPLIT} MAE",
    f"{SPLIT}_llm_correctness": "llm correctness",
    f"{SPLIT}_monotonicity": f"{SPLIT} monotonicity",
    f"{SPLIT}_prop_invalid": f"{SPLIT} prop invalid",
    f"{SPLIT}_distractor_alignment": f"{SPLIT} distr alignment"
}

In [None]:
configs = get_configs_out(EXP_NAME)
config_ids = get_config_ids(configs, problem_type=PROBLEM_TYPE)
config_dict = {config_id: cfg for config_id, cfg in zip(config_ids, configs)}

CONFIG2LEGEND_DICT = {
    config_id: create_config_id_print(config_id) for config_id in config_ids
}
legend_kwargs = {
    "config2legend": CONFIG2LEGEND_DICT,
    "legend_exact": LEGEND_EXACT,
    "metric2legend": METRIC2LEGEND_DICT,
}

In [None]:
# merge results for all configs
run_id_dict = merge_all_results(EXP_NAME, config_ids)

# Val/Test set performance
## Complete table

In [None]:
results_dict = get_results_dict(
    exp_name=EXP_NAME,
    config_ids=config_ids,
    run_id=None,
)
# # NOTE: print paper-like table with this code
# print_table_from_dict(
#     eval_dict=results_dict,
#     exp_name=EXP_NAME,
#     exclude_metrics=EXCLUDE_METRICS,
#     decimals=3,
#     **legend_kwargs,
# )

In [None]:
# NOTE: print dataframe
df = print_df_from_dict(
    eval_dict=results_dict,
    exp_name=EXP_NAME,
    exclude_metrics=EXCLUDE_METRICS,
    **legend_kwargs,
    # save=True,
    # save_kwargs={"fname": os.path.join("output", EXP_NAME, "results.csv")},
)

df_config = get_config_df(config_dict)

# mean
df_mean = df.xs('mean', axis=1, level=1, drop_level=True)
df_results = df_mean.merge(df_config, how="left", on="config_id")
df_results = df_results.reindex(
    columns=(
        list(df_config.columns)
        + list([a for a in df_mean.columns if a not in df_config.columns])
    )
)
df_results

In [None]:
# standard error
df_stderr = df.xs("stderr", axis=1, level=1, drop_level=True)
df_stderr = df_stderr.merge(df_config, how="left", on="config_id")
df_stderr = df_stderr.reindex(
    columns=(
        list(df_config.columns)
        + list([a for a in df_stderr.columns if a not in df_config.columns])
    )
)
# df_stderr

## Contextual models

In [None]:
if "kt" in EXP_NAME:
    agg_dict = {
        f"{SPLIT} RMSE": "mean",
        f"{SPLIT} MAE": "mean",
        f"{SPLIT} monotonicity": "mean",
        "llm correctness": "mean",
    }
else:
    agg_dict = {
        f"{SPLIT} RMSE": "mean",
        f"{SPLIT} MAE": "mean",
        f"{SPLIT} monotonicity": "mean",
        # f"{SPLIT} distr alignment": "mean",
        "llm correctness": "mean",
    }

In [None]:
df_results_context = df_results[df_results["num_examples"] > 0]

In [None]:
# inspect average performance per config value
FEATURE = "context"
df_results_context.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
# inspect average performance per config value
FEATURE = "num_examples"
df_results_context.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
# inspect average performance per config value
FEATURE = "model"
df_results_context.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
# inspect average performance per config value
FEATURE = "temp"
df_results_context.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
# inspect average performance per config value
FEATURE = "prompt"
df_results_context.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
# inspect average performance per config value
FEATURE = "example_selec"
df_results_context.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
# inspect average performance per config value
FEATURE = "struc_output"
df_results_context.groupby(FEATURE).agg(agg_dict).round(3)

In [None]:
# Get the full rows for these best-performing configs
metric = f"{SPLIT} RMSE"

best_indices = df_results_context.groupby(["model"])[metric].idxmin()  # NOTE: min because RMSE
best_configs = df_results_context.loc[best_indices]
best_configs

## Non-contextual models

In [None]:
df_results_nocontext = df_results[df_results["num_examples"] == 0]

In [None]:
df_results_nocontext

In [None]:
if not df_results_nocontext.empty:
    # inspect average performance per config value
    FEATURE = "context"
    display(df_results_nocontext.groupby(FEATURE).agg(agg_dict).round(3))

In [None]:
if not df_results_nocontext.empty:
    # inspect average performance per config value
    FEATURE = "num_examples"
    display(df_results_nocontext.groupby(FEATURE).agg(agg_dict).round(3))

In [None]:
if not df_results_nocontext.empty:
    # inspect average performance per config value
    FEATURE = "model"
    display(df_results_nocontext.groupby(FEATURE).agg(agg_dict).round(3))

In [None]:
if not df_results_nocontext.empty:
    # inspect average performance per config value
    FEATURE = "temp"
    display(df_results_nocontext.groupby(FEATURE).agg(agg_dict).round(3))

In [None]:
if not df_results_nocontext.empty:
    # inspect average performance per config value
    FEATURE = "prompt"
    display(df_results_nocontext.groupby(FEATURE).agg(agg_dict).round(3))

In [None]:
if not df_results_nocontext.empty:
    # inspect average performance per config value
    FEATURE = "example_selec"
    display(df_results_nocontext.groupby(FEATURE).agg(agg_dict).round(3))

In [None]:
if not df_results_nocontext.empty:
    # inspect average performance per config value
    FEATURE = "struc_output"
    display(df_results_nocontext.groupby(FEATURE).agg(agg_dict).round(3))

In [None]:
# Get the full rows for these best-performing configs
metric = f"{SPLIT} RMSE"

best_indices = df_results_nocontext.groupby(["model"])[metric].idxmin()  # NOTE: min because RMSE
best_configs = df_results_nocontext.loc[best_indices]
best_configs

## Contextual & Non-contextual models

In [None]:
df_metric = df[[f"{SPLIT} monotonicity", f"{SPLIT} RMSE"]].droplevel(0, axis=1)
df_metric.columns = ['mean_monotonicity', 'stderr_monotonicity', "mean_rmse", "stderr_rmse"]
df_metric = df_metric.merge(df_config, how="left", on="config_id")
df_metric = df_metric.reindex(
    columns=(
        list(df_config.columns)
        + list([a for a in df_metric.columns if a not in df_config.columns])
    )
)
df_metric["performance monotonicity"] = df_metric.apply(lambda x: f"{x['mean_monotonicity']:.3f} \gray{{$\pm$ {x['stderr_monotonicity']:.3f}}}", axis=1)
df_metric["performance rmse"] = df_metric.apply(lambda x: f"{x['mean_rmse']:.3f} \gray{{$\pm$ {x['stderr_rmse']:.3f}}}", axis=1)
# extract model family and size
df_metric["family"] = df_metric["model"].str.extract(r"^(.*?):")[0]
df_metric["size"] = (
        df_metric["model"].str.extract(r":(\d+\.?\d*)b$")[0].astype(float).round(1)
    )
# create new column to group on having context or not
context_map = {True: "Context", False: "No context"}
df_metric["context"] = df_metric["prompt"].str.contains("_context").map(context_map)

df_metric


In [None]:
df_clean = df_metric.sort_values(by=["family", "size", "context"], ascending=True)[["family", "size", "context", "performance monotonicity", "performance rmse"]]
df_clean["size"] = df_clean["size"].astype(str) + " B"

df_clean

In [None]:
print(df_clean.to_latex(index=False))

Check overlap in in results (to know what should be boldface in the table)

In [None]:
# Apply to your DataFrame for monotonicity
# df_sorted = df_metric.sort_values(by=["family", "size", "context"], ascending=True)
df_sorted = df_metric.sort_values(by=["mean_monotonicity"], ascending=False)
df_with_overlap = check_overlap(df_sorted, 'mean_monotonicity', 'stderr_monotonicity')

# Show relevant columns
df_with_overlap[['family', 'size', 'context', 'mean_monotonicity', 'stderr_monotonicity', 
                'lower_bound', 'upper_bound', 'overlap_with_prev', 'overlap_with_next']]

In [None]:
# Apply to your DataFrame for monotonicity
# df_sorted = df_metric.sort_values(by=["family", "size", "context"], ascending=True)
df_sorted = df_metric.sort_values(by=["mean_rmse"], ascending=True)
df_with_overlap = check_overlap(df_sorted, 'mean_rmse', 'stderr_rmse')

# Show relevant columns
df_with_overlap[['family', 'size', 'context', 'mean_rmse', 'stderr_rmse', 
                'lower_bound', 'upper_bound', 'overlap_with_prev', 'overlap_with_next']]

## Additional: LLM question answering

In [None]:
# for config_id in config_ids:
#     logger.info(f"Plotting student level performance", config_id=config_id)
#     plot_student_level_performance(
#         exp_name=EXP_NAME,
#         config_id=config_id,
#         metric="val_accuracy",
#         **legend_kwargs,
#         save=False,
#     )

# Student levels

In [None]:
from typing import Optional
import os
import pandas as pd
import matplotlib.pyplot as plt
from tools.utils import ensure_dir


def _plot_level_correctness_roleplay(
    df_results: pd.DataFrame,
    config_id: str = None,
    save: bool = False,
    savefig_kwargs: Optional[dict] = None,
):
    """Plot LLM correctness per level.

    Parameters
    ----------
    df_results : pd.DataFrame
        DataFrame containing the results to plot.
    config_id : str, optional
        Configuration ID, by default None
    save : bool, optional
        Whether to save the plot, by default False
    savefig_kwargs : Optional[dict], optional
        Dictionary with save arguments, by default None
    """
    llm_group_correctness = (
        df_results[df_results["config_id"] == config_ids[0]]
        .filter(regex=(".*_llm_group_correctness"))
        .iloc[0, 0]
    )
    df_llm = pd.DataFrame({
        "student_level_group": range(1, len(llm_group_correctness) + 1),
        "llm_correct": llm_group_correctness
    }).set_index("student_level_group")
    print(df_llm)

    _, ax = plt.subplots()
    df_llm.plot(kind="line", ax=ax, label="LLM")  # FIXME: label is not shown
    ax.set(
        xlabel="Student levels",
        ylabel="MCQ correctness",
    )
    ax.set_ylim(-0.05, 1.05)
    ax.set_title((None if save else config_id), fontsize=9)
    ax.legend(loc="upper left", fontsize=9)
    ax.grid(True, linestyle="--", alpha=0.7)
    if save:
        plt.tight_layout()
        ensure_dir(os.path.dirname(savefig_kwargs["fname"]))
        plt.savefig(**savefig_kwargs)
    plt.show()

In [None]:
# # all configs
# for config_id in config_ids:
#     _plot_level_correctness_roleplay(
#         df_results=df_results,
#         config_id=config_id,
#     )

## Distractor alignment

In [None]:
# TODO: obtain answer proportion for all questions in valsmall set


In [None]:
# import pandas as pd

# df_interactions = pd.read_csv("../data/silver/dbe_kt22_interactions.csv")
# df_q_val = pd.read_csv("../data/gold/dbe_kt22_questions_validation.csv")

In [None]:
# df_i_val = df_interactions[df_interactions["question_id"].isin(df_q_val["question_id"])]

In [None]:
# # Get proportions as a DataFrame with options as columns
# prop_df = (df_i_val.groupby(["question_id", "student_level_group"])["student_option_id"]
#            .value_counts(normalize=True)
#            .unstack(fill_value=0.0)
#            .reset_index())
# # convert 4 columns to dict
# prop_df["dict"] = prop_df.set_index(['question_id', 'student_level_group']).to_dict('index').values()
# prop_df = prop_df.drop(columns=[1, 2, 3, 4])
# prop_df

In [None]:
# preds_dict = get_llm_student_preds(
#         exp_name=EXP_NAME,
#         config_id="qwen3:8b~T_0.0~SO_student_bool~L_5~SP_student_chocolate_level_nocontext~SS_proficiency_5_str~EFQ_quotes~EFI_quotes~ES_miscon_studentlevel_random0",
#         run_id=1,
#         split="val",
#         problem_type=PROBLEM_TYPE,
#     )

In [None]:
# import numpy as np


# def alignment_score_single(
#     y_true: int, y_llm: int, dict_props: dict
# ) -> float:
#     """Calculate the alignment score for a single question.

#     Parameters
#     ----------
#     y_true : int
#         The true answer.
#     y_llm : int
#         The LLM predicted answer.
#     dict_props : dict
#         A dictionary mapping answer options to student proportions.

#     Returns
#     -------
#     float
#         The alignment score.
#     """  
#     llm_answer_incorrect = y_true != y_llm
#     if llm_answer_incorrect:
#         dict_tmp = dict_props.copy()
#         prop_answer_llm = dict_tmp[y_llm]
#         # remove correct idx from dict
#         dict_tmp.pop(y_true, None)
#         idx_most_popular_distractor = max(dict_tmp, key=dict_tmp.get)
#         prop_most_popular_distractor = dict_tmp[idx_most_popular_distractor]
#         # calculate score
#         try:
#             score = prop_answer_llm / prop_most_popular_distractor
#         except ZeroDivisionError:
#             score = 0.0
#     else:
#         score = np.nan
#     return score

# alignment_score_single(
#     y_true=3,
#     y_llm=4,
#     dict_props={1: 0.07692307692307693, 2: 0.2692307692307692, 3: 0.6538461538461539, 4: 0.0}
# )

In [None]:
# from operator import itemgetter
# import numpy as np
# from numpy.typing import NDArray
# import pandas as pd


# def eval_distractor_alignment(
#     y_true_array: NDArray,
#     y_llm_array: NDArray,
#     student_level_group_array: NDArray,
#     question_id_array: NDArray,
#     student_scale_map: dict,
#     prop_df: pd.DataFrame,
# ) -> float:
#     """Evaluate the alignment of distractor answers.

#     Parameters
#     ----------
#     y_true_array : NDArray
#         The true answers.
#     y_llm_array : NDArray
#         The LLM predicted answers.
#     student_level_group_array : NDArray
#         The student level groups.
#     question_id_array : NDArray
#         The question IDs.
#     student_scale_map : dict
#         A mapping from student IDs to their scale.
#     prop_df : pd.DataFrame
#         A DataFrame containing the student proportions of each answer option.

#     Returns
#     -------
#     float
#         The mean alignment score.
#     """
#     dict_inverse = {v: int(k) for k, v in student_scale_map.items()}
#     student_level_group_array_int = np.array(
#         itemgetter(*student_level_group_array)(dict_inverse)
#     )
#     print(student_level_group_array_int)

#     scores = []
#     for y_true, y_llm, student_level_group, question_id in zip(
#         y_true_array, y_llm_array, student_level_group_array_int, question_id_array
#     ):
#         dict_tmp = (
#             prop_df[
#                 (prop_df["question_id"] == question_id)
#                 & (prop_df["student_level_group"] == student_level_group)
#             ]["dict"]
#             .item()
#             .copy()
#         )
#         score = alignment_score_single(y_true=y_true, y_llm=y_llm, dict_props=dict_tmp)
#         scores.append(score)

#     # compute mean and ignore NaNs
#     mean_score = np.nanmean(scores) if scores else 0.0
#     return mean_score


# eval_distractor_alignment(
#     y_true_array=preds_dict["y_true"],
#     y_llm_array=preds_dict["y_pred"],
#     student_level_group_array=preds_dict["student_level_group"],
#     question_id_array=preds_dict["question_ids"],
#     student_scale_map=preds_dict["student_scale_map"],
#     prop_df=prop_df,
# )

In [None]:
# df_prop = pd.read_csv("../data/platinum/dbe_kt22_proportions_val.csv")
# df_prop["dict"] = df_prop["dict"].apply(eval)