In [None]:
# standard library imports
import os
import re

# related third party imports
import numpy as np
import matplotlib.pyplot as plt

# local application/library specific imports
from tools.configurator import (
    get_configs_out,
    get_config_ids,
)
from tools.analyzer import (
    get_labelling_progress,
    Dict2Class,
    get_train_logs,
    find_label_map,
    get_single_pred_label,
    merge_all_results,
)
from tools.plotter import (
    plot_level_acquisitions,
    plot_level_performance,
    plot_metric_vs_size,
    plot_active_gain,
    plot_history,
    plot_violinplot_racepp,
    plot_pred_parity,
)
from data_loader.build import build_hf_dataset

## Inputs

In [None]:
###### INPUTS ######
exp_name = "race_pp_merged"
metric = "test_discrete_rmse"
SANS_SERIF = True
PRINT_PAPER = False

In [None]:
def activate_latex(sans_serif: bool = False):
    """Activate latex for matplotlib."""
    if sans_serif:
        plt.rcParams.update(
            {
                "text.usetex": True,
                "font.family": "Helvetica",
                "text.latex.preamble": r"\usepackage[cm]{sfmath}",
            }
        )
    else:
        plt.rcParams.update(
            {"text.usetex": True, "font.family": "Computer Modern Roman"}
        )


def deactivate_latex():
    """Deactivate latex for matplotlib."""
    plt.rcParams.update(
        {"text.usetex": False, "font.family": "DejaVu Sans", "text.latex.preamble": ""}
    )

In [None]:
METRIC2LEGEND_DICT = {
    "test_discrete_rmse": "Discrete RMSE",
    "test_rmse": "RMSE",
}

CONFIG2LEGEND_DICT = {
    "distilbert-base-uncased-regr-full_data": "Baseline - Supervised",
    "random-regr-full_data": "Baseline - Random",
    "majority-regr-full_data": "Baseline - Majority",
    "distilbert-base-uncased-regr-random-STD-N96-Q100-I500-S5000-BFalse": "AL - Uniform",
    "distilbert-base-uncased-regr-powervariance-MCD-N96-Q100-I500-S5000-BFalse": "AL - PowerVariance",
    "distilbert-base-uncased-regr-variance-MCD-N96-Q100-I500-S5000-BFalse": "AL - Variance",
}

In [None]:
configs = get_configs_out(exp_name)
config_ids = get_config_ids(configs)
print(config_ids)

config_dict = {config_id: cfg for config_id, cfg in zip(config_ids, configs)}
# e.g. do `config_dict["powerbald-MC_dropout"]`

In [None]:
# merge results for all configs
merge_all_results(exp_name, config_ids)

In [None]:
# # automatic
# baselines = ["random-regr-full_data", "majority-regr-full_data"]
# CONFIG_IDS_TO_PLOT = [x for x in config_ids if x not in baselines]
# CONFIG_IDS_AL = [x for x in config_ids if "full_data" not in x]

# manual
CONFIG_IDS_AL = [
    "distilbert-base-uncased-regr-random-STD-N96-Q100-I500-S5000-BFalse",
    "distilbert-base-uncased-regr-variance-MCD-N96-Q100-I500-S5000-BFalse",
    "distilbert-base-uncased-regr-powervariance-MCD-N96-Q100-I500-S5000-BFalse",
]
CONFIG_IDS_TO_PLOT = CONFIG_IDS_AL + ["distilbert-base-uncased-regr-full_data"]

## Metric vs dataset size

In [None]:
# plot_metric_vs_size(
#     exp_name=exp_name,
#     metric=metric,
#     config_ids=config_ids,
#     run_id=None,
#     config2legend=CONFIG2LEGEND_DICT,
#     metric2legend=METRIC2LEGEND_DICT,
#     stderr=False,  # True,
#     x_axis="percent",
# )

In [None]:
plot_metric_vs_size(
    exp_name=exp_name,
    metric=metric,
    config_ids=CONFIG_IDS_TO_PLOT,
    run_id=None,
    config2legend=CONFIG2LEGEND_DICT,
    metric2legend=METRIC2LEGEND_DICT,
    stderr=False,  # True,
    x_axis="percent",
)

In [None]:
if PRINT_PAPER:
    activate_latex(sans_serif=SANS_SERIF)
    ########
    savefig_kwargs = {
        "fname": os.path.join("output", exp_name, "figures", f"{metric}_vs_size.pdf")
    }
    plot_metric_vs_size(
        exp_name=exp_name,
        metric=metric,
        config_ids=CONFIG_IDS_TO_PLOT,
        run_id=None,
        config2legend=CONFIG2LEGEND_DICT,
        metric2legend=METRIC2LEGEND_DICT,
        stderr=False,  # True,
        x_axis="percent",
        save=True,
        savefig_kwargs=savefig_kwargs,
    )
    ########
    savefig_kwargs = {
        "fname": os.path.join("output", exp_name, "figures", f"{metric}_vs_size_stderr.pdf")
    }
    plot_metric_vs_size(
        exp_name=exp_name,
        metric=metric,
        config_ids=CONFIG_IDS_TO_PLOT,
        run_id=None,
        config2legend=CONFIG2LEGEND_DICT,
        metric2legend=METRIC2LEGEND_DICT,
        stderr=True,
        x_axis="percent",
        save=True,
        savefig_kwargs=savefig_kwargs,
    )
    ########
    deactivate_latex()

## Learning convergence

In [None]:
###### INPUTS ######
config_id = "distilbert-base-uncased-regr-powervariance-MCD-N96-Q100-I500-S5000-BFalse"
run_id = 1
ds_size = 500

In [None]:
train_log, lines, eval_results = get_train_logs(
    exp_name=exp_name, config_id=config_id, run_id=run_id, ds_size=ds_size
)
plot_history(lines, "eval_rmse")
plot_history(lines, "eval_discrete_rmse")

## Labeling progress

In [None]:
###### INPUTS ######
# config_id = USE FROM ABOVE
run_id = None

In [None]:
# labeling progress
labelling_dict = get_labelling_progress(exp_name, config_ids)

# NOTE: no random process in dataset building so seed does not matter
datasets_runs = {}
for run_key, run_value in labelling_dict[config_id].items():
    run_n = int(re.search(r"run_(\d+)", run_key).group(1))
    datasets_runs[run_key] = build_hf_dataset(
        Dict2Class(config_dict[config_id]["LOADER"]),
        config_dict[config_id]["MODEL"]["NUM_LABELS"],
        config_dict[config_id]["SEED"] + run_n,
    )

Distribution of all labeled samples 

In [None]:
# NOTE: label map is same for all dataset seeds
label_map = find_label_map(datasets_runs["run_1"]["train"])
plot_level_acquisitions(
    labelling_dict=labelling_dict,
    datasets=datasets_runs,
    label_map=label_map,
    config_dict=config_dict,
    config_id=config_id,
    exp_name=exp_name,
    run_id=run_id,
    x_axis="percent",
    only_acquisition=False,
)

In [None]:
# plot for all AL configs (over all runs)
if PRINT_PAPER:
    activate_latex(sans_serif=SANS_SERIF)

    for config_id_tmp in CONFIG_IDS_AL:
        savefig_kwargs = {
            "fname": os.path.join(
                "output", exp_name, "figures", f"labeling_progress_{config_id_tmp}.pdf"
            )
        }
        plot_level_acquisitions(
            labelling_dict=labelling_dict,
            datasets=datasets_runs,
            label_map=label_map,
            config_dict=config_dict,
            config_id=config_id_tmp,
            exp_name=exp_name,
            run_id=None,
            x_axis="percent",
            only_acquisition=False,
            save=True,
            savefig_kwargs=savefig_kwargs,
        )
    ########
    deactivate_latex()

Distribution of samples acquired per step

In [None]:
plot_level_acquisitions(
    labelling_dict=labelling_dict,
    datasets=datasets_runs,
    label_map=label_map,
    config_dict=config_dict,
    config_id=config_id,
    exp_name=exp_name,
    run_id=run_id,
    x_axis="percent",
    only_acquisition=True,
)

## Metrics per difficulty level

In [None]:
###### INPUTS ######
# config_id = USE FROM ABOVE
diff_level = "0"

Metric for 1 difficulty level

In [None]:
plot_level_performance(
    experiment=exp_name,
    metric=metric,
    config_ids=[config_id],
    config_dict=config_dict,
    label_map=label_map,
    diff_level=diff_level,
    run_id=None,
    x_axis="percent",
    config2legend=CONFIG2LEGEND_DICT,
    metric2legend=METRIC2LEGEND_DICT,
)

In [None]:
# plot for all AL configs
if PRINT_PAPER:
    activate_latex(sans_serif=SANS_SERIF)

    for label_str in label_map.values():
        savefig_kwargs = {
            "fname": os.path.join(
                "output", exp_name, "figures", f"difficulty_level_{label_str}.pdf"
            )
        }
        plot_level_performance(
            experiment=exp_name,
            metric=metric,
            config_ids=CONFIG_IDS_AL,
            config_dict=config_dict,
            label_map=label_map,
            diff_level=label_str,
            run_id=None,
            x_axis="percent",
            config2legend=CONFIG2LEGEND_DICT,
            metric2legend=METRIC2LEGEND_DICT,
            save=True,
            savefig_kwargs=savefig_kwargs,
        )

    ########
    deactivate_latex()

Metric for all difficulty levels

In [None]:
plot_level_performance(
    experiment=exp_name,
    metric=metric,
    config_ids=[config_id],
    config_dict=config_dict,
    label_map=label_map,
    diff_level=None,
    run_id=None,
    x_axis="percent",
    config2legend=CONFIG2LEGEND_DICT,
    metric2legend=METRIC2LEGEND_DICT,
)

## Predictive parity

In [None]:
plot_pred_parity(
    experiment=exp_name,
    metric=metric,
    config_ids=CONFIG_IDS_AL,  # [config_id],
    config_dict=config_dict,
    label_map=label_map,
    run_id=run_id,
    x_axis="percent",
    stderr=False,
    config2legend=CONFIG2LEGEND_DICT,
    metric2legend=METRIC2LEGEND_DICT,
)

## Violin plots of difficulty

In [None]:
###### INPUTS ######
# config_id = USE FROM ABOVE
run_id = 3
ds_size = 10000

In [None]:
test_pred_label = get_single_pred_label(exp_name, config_id, run_id, ds_size)
plot_violinplot_racepp(test_pred_label, label_map)

In [None]:
# plot for all AL configs
if PRINT_PAPER:
    activate_latex(sans_serif=SANS_SERIF)

    for config_id_tmp in CONFIG_IDS_AL:
        savefig_kwargs = {
            "fname": os.path.join(
                "output", exp_name, "figures", f"violin_plot_{config_id_tmp}.pdf"
            )
        }
        test_pred_label = get_single_pred_label(exp_name, config_id_tmp, run_id, ds_size)
        plot_violinplot_racepp(test_pred_label, label_map, save=True, savefig_kwargs=savefig_kwargs)
    ########
    deactivate_latex()

## Active gain

In [None]:
###### INPUTS ######
baseline = "distilbert-base-uncased-regr-random-STD-N96-Q100-I500-S5000-BFalse"

In [None]:
plot_active_gain(
    exp_name=exp_name,
    metric=metric,
    baseline=baseline,
    config_ids=CONFIG_IDS_AL,
    run_id=None,
    config2legend=CONFIG2LEGEND_DICT,
    metric2legend=METRIC2LEGEND_DICT,
    x_axis="percent",
)

In [None]:
if PRINT_PAPER:
    activate_latex(sans_serif=SANS_SERIF)
    ########
    savefig_kwargs = {
        "fname": os.path.join(
            "output", exp_name, "figures", f"{metric}_active_gain.pdf"
        )
    }
    plot_active_gain(
        exp_name=exp_name,
        metric=metric,
        baseline=baseline,
        config_ids=CONFIG_IDS_AL,
        run_id=None,
        config2legend=CONFIG2LEGEND_DICT,
        metric2legend=METRIC2LEGEND_DICT,
        x_axis="percent",
        save=True,
        savefig_kwargs=savefig_kwargs,
    )
    ########
    deactivate_latex()