## Imports

In [None]:
# standard library imports
import os
import re

# related third party imports
from pprint import pprint
import numpy as np
import matplotlib.pyplot as plt
import structlog

# local application/library specific imports
from tools.configurator import (
    get_configs_out,
    get_config_ids,
)
from tools.constants import OutputType, EXCLUDE_METRICS
from tools.analyzer import (
    print_table_from_dict,
    get_train_logs,
    get_label_map,
    get_single_pred_label,
    merge_all_results,
    compute_avg_confusion_matrix,
    compute_avg_rps_per_level,
    get_results_dict,
    reorder_config_ids,
    compute_rank_inconsistencies,
    get_logit_params_history,
)
from tools.plotter import (
    plot_history,
    plot_violinplot,
    activate_latex,
    deactivate_latex,
    plot_confusion_matrix,
    plot_rank_inconsistencies,
    plot_cutpoints_history,
    plot_bias_history,
)

logger = structlog.get_logger(__name__)

In [None]:
##### INPUTS #####
EXP_NAME = (
    # RACE++
    # "race_pp_bert_bal_rps_20250517"
    "race_pp_bert_logit_20250522"

    # ARC
    # "arc_bert_bal_rps_20250518"
    # "arc_bert_logit_20250521"
)
CONFIG_ID = "bert_ordinal_logit_SL512_BALFalse_LR0.000029_WD0.048_FRFalse_ESTrue"  # TODO: select from config_ids
SANS_SERIF = True
PRINT_PAPER = False  # False  #
LEGEND_EXACT = False

In [None]:
CONFIG2LEGEND_DICT = {
    "random_regression": "Random baseline",
    "majority_regression": "Majority baseline",
    "bert_regression": "BERT - Regression",
    "bert_classification": "BERT - Classification",
    "bert_ordinal_or_nn": "BERT - Ordinal OR-NN",
    "bert_ordinal_coral": "BERT - Ordinal CORAL",
    "bert_ordinal_corn": "BERT - Ordinal CORN",
    "bert_ordinal_logit": "BERT - Ordinal Logit",
}

legend_kwargs = {"config2legend": CONFIG2LEGEND_DICT, "legend_exact": LEGEND_EXACT}

In [None]:
configs = get_configs_out(EXP_NAME)
config_ids = get_config_ids(configs)
config_dict = {config_id: cfg for config_id, cfg in zip(config_ids, configs)}

# reorder config_ids according to CONFIG2LEGEND_DICT keys
config_ids = reorder_config_ids(config_ids, CONFIG2LEGEND_DICT)
pprint(config_ids)

In [None]:
# merge results for all configs
run_id_dict = merge_all_results(EXP_NAME, config_ids)

## Test set performance
### Aggregated over levels

In [None]:
results_dict = get_results_dict(
    exp_name=EXP_NAME,
    config_ids=config_ids,
    run_id=None,
)
print_table_from_dict(
    eval_dict=results_dict,
    exp_name=EXP_NAME,
    exclude_metrics=EXCLUDE_METRICS,
    decimals=3,
    **legend_kwargs,
)

### RPS per level

In [None]:
rps_agg = compute_avg_rps_per_level(
    exp_name=EXP_NAME,
    config_ids=config_ids,
    run_id=None,
    config_dict=config_dict,
)
print_table_from_dict(
    eval_dict=rps_agg,
    exp_name=EXP_NAME,
    exclude_metrics=[],
    **legend_kwargs,
)

## Learning convergence

In [None]:
if CONFIG_ID in config_ids:
    for run_id in run_id_dict[CONFIG_ID]:
        logger.info(f"Plotting history", run_id=run_id)
        train_log, lines, eval_results = get_train_logs(
            exp_name=EXP_NAME, config_id=CONFIG_ID, run_id=run_id
        )
        plot_history(lines, metric="eval_bal_rps")  # NOTE: metrics

In [None]:
if OutputType.ORD_LOGIT in CONFIG_ID:
    for run_id in run_id_dict[CONFIG_ID]:
        params_history = get_logit_params_history(
            exp_name=EXP_NAME, config_id=CONFIG_ID, run_id=run_id
        )
        plot_cutpoints_history(params_history)
        plot_bias_history(params_history)

## Confusion matrix

In [None]:
###### INPUTS ######
# CONFIG_ID = USE FROM ABOVE
RUN_ID = None

In [None]:
# plot for every config_id
for config_id in config_ids:
    label_map = get_label_map(EXP_NAME, config_id)
    avg_conf_matrix = compute_avg_confusion_matrix(
        exp_name=EXP_NAME,
        config_id=config_id,
        run_id=RUN_ID,
        int2label=label_map,
        config_dict=config_dict,
        normalize="true",
    )
    plot_confusion_matrix(
        conf_matrix=avg_conf_matrix,
        int2label=label_map,
        config_id=config_id,
        **legend_kwargs,
    )

In [None]:
if PRINT_PAPER:
    activate_latex(sans_serif=SANS_SERIF)
    ########
    # plot for every config_id
    for config_id in config_ids:
        label_map = get_label_map(EXP_NAME, config_id)
        dataset_name = config_dict[CONFIG_ID]["LOADER"]["NAME"]
        avg_conf_matrix = compute_avg_confusion_matrix(
            exp_name=EXP_NAME,
            config_id=config_id,
            run_id=RUN_ID,
            int2label=label_map,
            config_dict=config_dict,
            normalize="true",
        )
        savefig_kwargs = {
            "fname": os.path.join(
                "output", EXP_NAME, "figures", f"confusion_{dataset_name}_{config_id}.pdf"
            )
        }
        plot_confusion_matrix(
            conf_matrix=avg_conf_matrix,
            int2label=label_map,
            config_id=config_id,
            **legend_kwargs,
            save=True,
            savefig_kwargs=savefig_kwargs,
        )
    ########
    deactivate_latex()

## Rank inconsistencies

In [None]:
ord_types = [OutputType.ORD_OR_NN, OutputType.ORD_CORAL, OutputType.ORD_CORN]
for config_id in config_ids:
    if any(ord_type in config_id for ord_type in ord_types):
        sum_count, count_per_obs = compute_rank_inconsistencies(
            exp_name=EXP_NAME,
            config_id=config_id,
            config_dict=config_dict,
            run_id=RUN_ID,
        )
        print(f"Total # inconsistencies in test set: {sum_count:.0f}")
        print(f"Average # inconsistencies per test observation: {np.mean(count_per_obs):.4f}")
        print(f"# test observations with inconsistencies: {np.sum(count_per_obs > 0)}")
        plot_rank_inconsistencies(count_per_obs, num_classes=config_dict[config_id]["MODEL"]["NUM_LABELS"])

In [None]:
if PRINT_PAPER:
    for config_id in config_ids:
        if OutputType.ORD_OR_NN in config_id:
            activate_latex(sans_serif=SANS_SERIF)
            ########
            sum_count, count_per_obs = compute_rank_inconsistencies(
                exp_name=EXP_NAME,
                config_id=config_id,
                config_dict=config_dict,
                run_id=RUN_ID,
            )
            savename = os.path.join(
                "output", EXP_NAME, "figures", f"rank_inconsistencies_{config_id}.pdf"
            )
            plot_rank_inconsistencies(
                count_per_obs,
                num_classes=config_dict[config_id]["MODEL"]["NUM_LABELS"],
                savename=savename,
            )
            ########
            deactivate_latex()