# Training Experiments

This notebook organizes the analysis of the experiments which fine-tune existing summarizers in the open-domain setting.

Run the following cells to import the required packages and load some helper functions

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib.ticker import MultipleLocator

from open_mds.common import util

# Use the same styling for all plots & figures in the paper
sns.set_theme(context="paper", style="ticks", palette="tab10", font_scale=1.375)

# Display all rows/columns in DataFrame
pd.set_option("display.max_columns", None)

# This is the average improvement in summarization performance in *ACL conferences. We will use it to shade our
# figures to better contextualize the results of our experiments.
# See: https://aclanthology.org/2022.naacl-main.442/
AVG_ROUGE_IMPROVEMENT = 0.49

# Colors from our palette that we need to access directly (e.g. for shading the figure)
GREEN = "#2ca02c"
RED = "#d62728"

In [None]:
def style_plot(
    axes,
    include_shading: bool = False,
) -> None:
    """Style the resulting plots for the training experiments."""

    # Set axis labels
    axes[0][0].set_title("retrieved")
    axes[0][1].set_title("ground-truth")
    axes[1][0].set_xlabel("additional training epochs")
    axes[0][0].set_ylabel("Δ ROUGE-Avg F1")
    axes[1][0].set_ylabel("Δ BERTScore F1")

    # Additional per-axis styling
    for ax in axes.flatten():
        # Add yaxis grid
        ax.yaxis.grid(True, zorder=-1)
        ax.tick_params(left=False)

        # Add a horizontal line at 0. It will be behind the data but in front of the grid lines.
        ax.axhline(y=0, color="#7f7f7f", linestyle="--", linewidth=1.5, zorder=1)

    # Remove the x tick labels from bottom right plot
    axes[1][1].tick_params(labelbottom=False)

    # Place one legend at the top of the plot
    for ax in axes.flatten():
        ax.legend_ = None
        # Minor ticks every 0.5 epochs
        ax.xaxis.set_minor_locator(MultipleLocator(0.5))

    axes[0][1].legend(
        loc="lower right",
        bbox_to_anchor=(1.03, 1.15),
        ncol=5,
        frameon=False,
        handletextpad=0.325,
        columnspacing=0.325,
        fontsize="small",
    )

    if include_shading:
        # Optionally, shade the ROUGE plots to better contextualize the results
        y_min = min([ax.get_ylim()[0] for ax in axes[0]])
        y_max = max([ax.get_ylim()[1] for ax in axes[0]])
        for ax in axes[0]:
            ax.fill_between(list(range(4)), y_max, -AVG_ROUGE_IMPROVEMENT, alpha=0.1, facecolor=GREEN)
            ax.fill_between(list(range(4)), -AVG_ROUGE_IMPROVEMENT, y_min, alpha=0.06, facecolor=RED)

    # Additional global styling
    sns.despine(left=True)

Point the variable `data_dir` to the location of a directory that contains the results of running the [`run_summarization.py`](../scripts/run_summarization.py) script for one or more models

In [None]:
data_dir = "../output/results"
# Make sure the directory exists and contains the expected subdirectories
!ls $data_dir

Lastly, you may also set a couple flags to control modifications to the plots that improve visualization

In [None]:
# Shade differences greater than the average yearly improvement in summarization performance (ROUGE only)
include_shading = True

In [None]:
training_dfs = []
pretty_ds_names = {
    "multinews": "Multi-News",
    "wcep": "WCEP-10",
    "multixscience": "Multi-XScience",
    "ms2": "MS^2",
    "cochrane": "Cochrane",
}

# Only load the best model for each dataset
for subdir in Path(data_dir).iterdir():
    if subdir.name in ["multinews", "multixscience"]:
        include_models = ["primera"]
    elif subdir.name == "wcep":
        include_models = ["lsg-bart-base"]
    else:
        include_models = ["led-base"]

    # Here, we collect the data for "checkpoint 0". For the gold evaluation, this is simply the baseline model results.
    # For the retrieved evaluation, this is the results of whatever retriever/top-k strategy we trained on, which will
    # was a dense retriever with mean top-k strategy in the paper.
    for model_dir in subdir.iterdir():
        if model_dir.name not in include_models:
            continue
        gold_checkpoint_0 = model_dir / util._TRAINING_DIR / "gold" / "checkpoint-0"
        retrieved_checkpoint_0 = model_dir / util._TRAINING_DIR / "retrieved" / "checkpoint-0"
        gold_checkpoint_0.mkdir(parents=True, exist_ok=True)
        retrieved_checkpoint_0.mkdir(parents=True, exist_ok=True)
        gold_baseline = model_dir / util._BASELINE_DIR / util._RESULTS_FILENAME
        retrieved_baseline = model_dir / util._RETRIEVAL_DIR / "dense" / "mean" / util._RESULTS_FILENAME
        !cp $gold_baseline $gold_checkpoint_0
        !cp $retrieved_baseline $retrieved_checkpoint_0

    # Some datasets have blind test splits, and so we evaluate on the validation set
    # HuggingFace assigns a different prefix to the keys in the output json, so set that here
    metric_key_prefix = "eval" if subdir.name in {"ms2", "cochrane"} else "predict"

    # The metrics we want to record results for
    metric_columns = [
        f"{metric_key_prefix}_rouge_avg_fmeasure",
        f"{metric_key_prefix}_bertscore_f1",
    ]

    # Load the results as dataframes
    _, training_df = util.load_results_dicts(
        data_dir=subdir,
        include_models=include_models,
        metric_columns=metric_columns,
        metric_key_prefix=metric_key_prefix,
        # Only retain data that pertains to the training experiments
        load_perturbation_results=False,
        load_retrieval_results=False,
        load_training_results=True,
    )

    # Add the name of the dataset and model
    training_df["dataset_name"] = [pretty_ds_names[subdir.name]] * len(training_df)
    training_df.model_name_or_path = [include_models[0]] * len(training_df)
    # Remove the metric key prefix to plot everything with the same label
    training_df["rouge_avg_fmeasure_delta"] = training_df[f"{metric_key_prefix}_rouge_avg_fmeasure_delta"]
    training_df["bertscore_f1_delta"] = training_df[f"{metric_key_prefix}_bertscore_f1_delta"]
    # Deduce the number of steps from the checkpoint name
    training_df["frac_additional_steps"] = training_df.checkpoint / training_df.checkpoint.max() * 3
    # Specify the evaluation type
    training_df["evaluation"] = [
        "retrieved" if not retriever else "gold"
        for retriever in training_df[f"{metric_key_prefix}_retriever"].isnull()
    ]

    training_dfs.append(training_df)

df = pd.concat(training_dfs, ignore_index=True)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey="row")

for i, metric in enumerate(["rouge_avg_fmeasure_delta", "bertscore_f1_delta"]):
    for j, evaluation in enumerate(["retrieved", "gold"]):

        g = sns.lineplot(
            data=df[df.evaluation == evaluation],
            x="frac_additional_steps",
            y=metric,
            style="dataset_name",
            hue="dataset_name",
            hue_order=pretty_ds_names.values(),
            errorbar=None,
            markers=True,
            ax=axes[i][j],
        )
        g.set(xlabel="", ylabel="", title="")


style_plot(axes, include_shading=include_shading)

# Save the figure
filename = "training"
if include_shading:
    filename += "_shaded"

plt.subplots_adjust(top=0.90, wspace=0.0175, hspace=0.075)
Path("../output/plots").mkdir(parents=True, exist_ok=True)
plt.savefig(f"../output/plots/{filename}.svg", facecolor="white", bbox_inches="tight")