In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from retrieval_exploration.common import util

# Use the same styling for all plots & figures in the paper
sns.set_theme(context="paper", style="ticks", palette="tab10", font_scale=1.375)

# Display all rows/columns in DataFrame
pd.set_option("display.max_columns", None)
pd.set_option("display.max_rows", None)

In [None]:
data_dir = "../output/results"
# Make sure the directory exists and contains the expected subdirectories
!ls $data_dir

In [None]:
training_dfs = []
pretty_ds_names = {
    "multinews": "Multi-News",
    "wcep": "WCEP-10",
    "multixscience": "Multi-XScience",
    "ms2": "MS^2",
    "cochrane": "Cochrane",
}

for subdir in Path(data_dir).iterdir():
    if subdir.name in ["multinews", "multixscience"]:
        include_models = ["primera"]
    elif subdir.name == "wcep":
        include_models = ["lsg-bart-base"]
    else:
        include_models = ["led-base"]

    # Here, we collect the data for "checkpoint 0". For the gold evaluation, this is simply the baseline model results.
    # For the retrieved evaluation, this is the results of whatever retriever/top-k strategy we trained on, which will
    # was a dense retriever with mean top-k strategy in the paper.
    for model_dir in subdir.iterdir():
        if model_dir.name not in include_models:
            continue
        gold_checkpoint_0 = model_dir / util._TRAINING_DIR / "gold" / "checkpoint-0"
        retrieved_checkpoint_0 = model_dir / util._TRAINING_DIR / "retrieved" / "checkpoint-0"
        gold_checkpoint_0.mkdir(parents=True, exist_ok=True)
        retrieved_checkpoint_0.mkdir(parents=True, exist_ok=True)
        gold_baseline = model_dir / util._BASELINE_DIR / util._RESULTS_FILENAME
        retrieved_baseline = model_dir / util._RETRIEVAL_DIR / "dense" / "mean" / util._RESULTS_FILENAME
        !cp $gold_baseline $gold_checkpoint_0
        !cp $retrieved_baseline $retrieved_checkpoint_0

    # Some datasets have blind test splits, and so we evaluate on the validation set
    # HuggingFace assigns a different prefix to the keys in the output json, so set that here
    metric_key_prefix = "eval" if subdir.name in {"ms2", "cochrane"} else "predict"

    # The metrics we want to record results for
    metric_columns = [
        f"{metric_key_prefix}_rouge_avg_fmeasure",
        f"{metric_key_prefix}_bertscore_f1",
    ]

    # Load the results as dataframes
    _, training_df = util.load_results_dicts(
        data_dir=subdir,
        include_models=include_models,
        metric_columns=metric_columns,
        metric_key_prefix=metric_key_prefix,
        # Only retain data that pertains to the training experiments
        load_perturbation_results=False,
        load_retrieval_results=False,
        load_training_results=True,
    )

    # Add the name of the dataset and model
    training_df["dataset_name"] = [pretty_ds_names[subdir.name]] * len(training_df)
    training_df.model_name_or_path = [include_models[0]] * len(training_df)
    # Remove the metric key prefix to plot everything with the same label
    training_df["rouge_avg_fmeasure_delta"] = training_df[f"{metric_key_prefix}_rouge_avg_fmeasure_delta"]
    training_df["bertscore_f1_delta"] = training_df[f"{metric_key_prefix}_bertscore_f1_delta"]
    # Deduce the number of steps from the checkpoint name
    training_df["frac_additional_steps"] = training_df.checkpoint / training_df.checkpoint.max() * 3
    # Specify the evaluation type
    training_df["evaluation"] = [
        "retrieved" if not retriever else "gold"
        for retriever in training_df[f"{metric_key_prefix}_retriever"].isnull()
    ]

    training_dfs.append(training_df[training_df.evaluation == "retrieved"])

In [None]:
df = pd.concat(training_dfs, ignore_index=True)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(4.2, 6.4))

for i, metric in enumerate(["rouge_avg_fmeasure_delta", "bertscore_f1_delta"]):

    g = sns.lineplot(
        data=df,
        x="frac_additional_steps",
        y=metric,
        style="dataset_name",
        hue="dataset_name",
        hue_order=pretty_ds_names.values(),
        ci=None,
        markers=True,
        ax=axes[i],
    )
    g.set(
        xlabel="Fraction of additional training steps" if i == 1 else None,
        ylabel="Δ ROUGE-Avg F1" if metric == "rouge_avg_fmeasure_delta" else "Δ BERTScore F1",
    )

# Additional per-axis styling
for ax in axes.flatten():
    # Add yaxis grid
    ax.yaxis.grid(True)
    ax.tick_params(left=False)


axes[1].legend_ = None
# handles, labels = axes[0].get_legend_handles_labels()
# axes[0].legend(handles=handles[:], labels=labels[:])
# axes[0].legend(loc="lower center")

axes[0].legend(
    loc="lower right",
    bbox_to_anchor=(1.0, 1.0),
    ncol=3,
    frameon=False,
    handletextpad=0.325,
    columnspacing=0.325,
    fontsize="small",
)

# Additional global styling
sns.despine(left=True)

AVG_ROUGE_IMPROVEMENT=0.49
# Optionally, shade the ROUGE plots to better contextualize the results
y_min, y_max = axes[0].get_ylim()
axes[0].fill_between(
    list(range(4)), y_max, -AVG_ROUGE_IMPROVEMENT, alpha=0.0875, facecolor="green"
)
axes[0].fill_between(list(range(4)), -AVG_ROUGE_IMPROVEMENT, y_min, alpha=0.04, facecolor="red")


plt.subplots_adjust(top=0.90, hspace=0.1)
plt.savefig(f"../output/plots/training.svg", facecolor="white", bbox_inches="tight")