# Perturbation Experiments

This notebook organizes the analysis of the simulated document retrieval errors experiments.

__Note__: if you are running this notebook in colab, uncomment and run the following cell to install the project and its dependencies

In [None]:
# %pip install "git+https://github.com/allenai/open-mds.git"

Run the following cells to import the required packages and load some helper functions

In [None]:
from pathlib import Path
from typing import List

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from open_mds.common import util

# Use the same styling for all plots & figures in the paper
sns.set_theme(context="paper", style="ticks", palette="tab10", font_scale=1.375)

# Display all columns in DataFrame
pd.set_option("display.max_columns", None)

# This is the average improvement in summarization performance in *ACL conferences. We will use it to shade our
# figures to better contextualize the results of our experiments.
# See: https://aclanthology.org/2022.naacl-main.442/
AVG_ROUGE_IMPROVEMENT = 0.49

# Colors from our palette that we need to access directly (e.g. for shading the figure)
GREEN = "#2ca02c"
RED = "#d62728"

In [None]:
def style_plot(
    axes,
    strategies: List[str],
    bins: List[float],
    xticklabels: List[str],
    include_shading: bool = False,
    rescale_axis: bool = False,
) -> None:
    """Style the resulting plots for the perturbation experiments."""
    for i, strategy in enumerate(strategies):
        axes[0][i].set_title(strategy)

        # Set axis labels
        axes[-1][0].set_xlabel(r"% of input documents perturbed")
        axes[0][0].set_ylabel("Δ ROUGE-Avg F1")
        axes[1][0].set_ylabel("Δ BERTScore F1")

    # Additional per-axis styling
    for ax in axes.flatten():
        # Add yaxis grid
        ax.yaxis.grid(True)
        # Remove xticklabels
        ax.set_xticklabels([])
        ax.tick_params(left=False)

    # Place one legend at the top of the plot
    for ax in axes.flatten():
        ax.legend_ = None

    axes[0][1].legend(
        loc="lower right",
        bbox_to_anchor=(1.03, 1.15),
        ncol=5,
        frameon=False,
        handletextpad=0.325,
        columnspacing=0.325,
        fontsize="small",
    )

    # Optionally, shade the ROUGE plots to better contextualize the results
    if include_shading:
        y_min = min([ax.get_ylim()[0] for ax in axes[0]])
        y_max = max([ax.get_ylim()[1] for ax in axes[0]])
        for ax in axes[0]:
            ax.fill_between(list(range(len(bins)))[:-1], y_max, -AVG_ROUGE_IMPROVEMENT, alpha=0.1, facecolor=GREEN)
            ax.fill_between(list(range(len(bins)))[:-1], -AVG_ROUGE_IMPROVEMENT, y_min, alpha=0.06, facecolor=RED)

    # Optionally, rescale the y-axis to better discern between the various perturbations
    if rescale_axis:
        for ax in axes.flatten():
            ax.set_yscale("symlog")

    # Add bottom left xticklabels
    axes[-1][0].set_xticklabels(xticklabels, rotation=45, ha="right")

    # Additional global styling
    sns.despine(left=True)

Point the variable `data_dir` to the location of a directory that contains the results of running the [`run_summarization.py`](../scripts/run_summarization.py) script for one or more models

In [None]:
data_dir = "../output/results"
# Make sure the directory exists and contains the expected subdirectories
!ls $data_dir

Lastly, you may also set a couple flags to control modifications to the plots that improve visualization

In [None]:
# Shade differences greater than the average yearly improvement in summarization performance (ROUGE only)
include_shading = True
# Rescale the y-axis with symlog scaling to better discern between the various perturbations
rescale_axis = True

In [None]:
# The two document select strategies we will compare
strategies = ["random", "oracle"]

# We will bin the data into equal-width bins, which are defined here
bins = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
xticklabels = ["0-20", "20-40", "40-60", "60-80", "80-100"]

for subdir in Path(data_dir).iterdir():
    # Some datasets have blind test splits, and so we evaluate on the validation set
    # HuggingFace assigns a different prefix to the keys in the output json, so set that here
    metric_key_prefix = "eval" if subdir.name in {"ms2", "cochrane"} else "predict"

    # The metrics we want to plot the delta for
    metric_columns = [
        f"{metric_key_prefix}_rouge_avg_fmeasure",
        f"{metric_key_prefix}_bertscore_f1",
    ]

    # Load the results as dataframes
    baseline_df, perturbed_df = util.load_results_dicts(
        data_dir=subdir,
        metric_columns=metric_columns,
        metric_key_prefix=metric_key_prefix,
        # Only retain data that pertains to the perturbation experiments
        load_retrieval_results=False,
        load_training_results=False,
    )

    for model_name_or_path in perturbed_df.model_name_or_path.unique():
        df = perturbed_df[perturbed_df.model_name_or_path == model_name_or_path]

        # Don't include sorting, as it is displayed in it's own seperate figure
        df = df[df[f"{metric_key_prefix}_perturbation"] != "sorting"]

        # Bin the data into equal-sized bins
        df.frac_docs_perturbed = pd.cut(df.frac_docs_perturbed, bins=bins, labels=xticklabels)

        # Set up the figure axes
        fig, axes = plt.subplots(nrows=2, ncols=2, sharey="row")
        # plt.subplots_adjust(top=0.90, wspace=0.0175, hspace=0.075)

        # Plot the data
        for i, metric in enumerate(metric_columns):
            for j, strategy in enumerate(strategies):
                g = sns.lineplot(
                    data=df[df[f"{metric_key_prefix}_selection_strategy"] == strategy],
                    x="frac_docs_perturbed",
                    y=f"{metric}_delta",
                    hue=f"{metric_key_prefix}_perturbation",
                    hue_order=["addition", "duplication", "deletion", "replacement", "backtranslation"],
                    style=f"{metric_key_prefix}_perturbation",
                    estimator="mean",
                    errorbar=("ci", 68),
                    n_boot=1000,
                    markers=True,
                    ax=axes[i][j],
                )
                g.set(xlabel="", ylabel="", title="")

        style_plot(
            axes,
            strategies=strategies,
            bins=bins,
            xticklabels=xticklabels,
            include_shading=include_shading,
            rescale_axis=rescale_axis,
        )

        # Save the figure
        filename = model_name_or_path.replace("/", "_").replace("-", "_")
        if include_shading:
            filename += "_shaded"
        if rescale_axis:
            filename += "_rescaled"

        plt.subplots_adjust(top=0.90, wspace=0.0175, hspace=0.075)
        Path("../output/plots").mkdir(parents=True, exist_ok=True)
        plt.savefig(f"../output/plots/{filename}.svg", facecolor="white", bbox_inches="tight")