# Open-domain MDS Experiments

This notebook organizes the analysis of the open-domain MDS experiments.

__Note__: if you are running this notebook in colab, uncomment and run the following cell to install the project and its dependencies



In [None]:
# %pip install "git+https://github.com/allenai/open-mds.git"

Run the following cells to import the required packages and load some helper functions

In [None]:
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import pandas as pd
import scipy.stats as stats
import seaborn as sns
from datasets import load_dataset

from open_mds.common import util


# Threshold under which to reject the null hypothesis
THRESHOLD = 0.01

# Controls the max number of studies to consider for MS2 and Cochrane.
# Following https://aclanthology.org/2021.emnlp-main.594/, take the first 25 articles.
MAX_INCLUDED_STUDIES = 25

# Use the same styling for all plots & figures in the paper
sns.set_theme(context="paper", style="ticks", palette="tab10", font_scale=1.8)

# Display all rows/columns in DataFrame
pd.set_option("display.max_columns", None)
pd.set_option("display.max_rows", None)

## Evaluate Summarization in the Open-domain Setting

Here we load the results from the document retrieval experiements to produce a table comparing baseline summarization performance to performance when the input document set is retrieved.

Point the variable `data_dir` to the location of a directory that contains the results of running the [`run_summarization.py`](../scripts/run_summarization.py) script for one or more models

In [None]:
data_dir = "../output/results/"
# Make sure the directory exists and contains the expected subdirectories
!ls $data_dir

In [None]:
results = {
    "dataset": [],
    "model": [],
    "retriever": [],
    "top_k_strategy": [],
    "metric": [],
    "baseline": [],
    "retrieval": [],
    "difference": [],
    "significant": [],
}

for subdir in Path(data_dir).iterdir():
    # Some datasets have blind test splits, and so we evaluate on the validation set
    # HuggingFace assigns a different prefix to the keys in the output json, so set that here
    metric_key_prefix = "eval" if subdir.name in {"ms2", "cochrane"} else "predict"

    # The metrics we want to record results for
    metric_columns = [
        f"{metric_key_prefix}_rouge_avg_fmeasure",
        f"{metric_key_prefix}_bertscore_f1",
    ]

    # Load the results as dataframes
    baseline_df, retrieval_df = util.load_results_dicts(
        data_dir=subdir,
        metric_columns=metric_columns,
        metric_key_prefix=metric_key_prefix,
        # Only retain data that pertains to the retrieval experiments
        load_perturbation_results=False,
        load_training_results=False,
    )

    for model_name_or_path in retrieval_df.model_name_or_path.unique():
        for retriever in ["sparse"]:
            for top_k_strat in ["max", "mean", "oracle"]:
                for metric in metric_columns:
                    # Isolate the results from one experiment
                    experiment_df = retrieval_df[retrieval_df.model_name_or_path == model_name_or_path][
                        retrieval_df[f"{metric_key_prefix}_retriever"] == retriever
                    ][retrieval_df[f"{metric_key_prefix}_top_k_strategy"] == top_k_strat]

                    baseline_scores = baseline_df[baseline_df.model_name_or_path == model_name_or_path][metric]
                    retrieval_scores = experiment_df[metric]
                    retrieval_scores_delta = experiment_df[f"{metric}_delta"]

                    # Sanity check that we are comparing the same number of samples
                    assert len(baseline_scores) == len(retrieval_scores) == len(retrieval_scores_delta)

                    # Report any significant differences
                    _, pvalue = stats.ttest_rel(baseline_scores, retrieval_scores)

                    # Collect the results we are interested in
                    metric_key = metric.removeprefix(f"{metric_key_prefix}_")
                    results["dataset"].append(subdir.name)
                    results["model"].append(model_name_or_path)
                    results["retriever"].append(retriever)
                    results["top_k_strategy"].append(top_k_strat)
                    results["metric"].append(metric_key)
                    results["baseline"].append(round(baseline_scores.mean(), 2))
                    results["retrieval"].append(round(retrieval_scores.mean(), 2))
                    results["difference"].append(round(retrieval_scores_delta.mean(), 2))
                    results["significant"].append(pvalue < THRESHOLD)

results_df = pd.DataFrame(results)

In [None]:
results_df.head()

You may wish to subset the results dataframe by dataset, retriever, and top-k strategy

In [None]:
results_df[results_df.dataset == "wcep"][results_df.retriever == "sparse"]

## Counting Retrieval Errors

Here we tally the errors made by each retriever for each dataset and top-k strategy and plot the results

In [None]:
def _get_error_stats(
    *, ground_truth_inputs: List[str], retrieved_inputs: List[str], doc_sep_token: Optional[str] = None
) -> Tuple[Dict[str, int], Dict[str, List[int]]]:
    """Given a list of retrieved documents IDs and ground truth documents IDs, return a dictionary that
    contains the counts of each type of error.
    """
    error_stats = {"addition": 0, "deletion": 0, "replacement": 0}
    retrieved_docs_ids = []

    for ground_truth_docs, retrieved_docs in zip(ground_truth_inputs, retrieved_inputs):
        # Get the individual documents
        if doc_sep_token:
            ground_truth_docs = util.split_docs(ground_truth_docs, doc_sep_token=doc_sep_token)
            retrieved_docs = util.split_docs(retrieved_docs, doc_sep_token=doc_sep_token)

        # Shouldn't be necessary, but strip whitespace and lowercase the strings for most robust equality checks
        ground_truth_docs = [util.sanitize_text(doc, lowercase=True) for doc in ground_truth_docs]
        retrieved_docs = [util.sanitize_text(doc, lowercase=True) for doc in retrieved_docs]

        retrieved_docs_ids.append([doc for doc in retrieved_docs])

        # Count up the number of additions and deletions
        additions: int = sum(True for doc in retrieved_docs if doc not in ground_truth_docs)
        deletions: int = sum(True for doc in ground_truth_docs if doc not in retrieved_docs)
        replacements = 0

        # Count all cases of 1 addition + 1 deletion as a single replacement error.
        if additions and deletions:
            replacements = min(additions, deletions)
            additions -= replacements
            deletions -= replacements

        error_stats["addition"] += additions
        error_stats["deletion"] += deletions
        error_stats["replacement"] += replacements

    return error_stats, retrieved_docs_ids

Collect the error statistics. Note that this can take several minutes and will cache several GBs worth of datasets to `~/.cache/huggingface/datasets`

In [None]:
results = {"dataset": [], "retriever": [], "strategy": [], "error_count": [], "error_type": []}
# Also collect IDs of the retrieved docs to compare sparse and dense retrievers
doc_ids = {}

for dataset_name in ["multinews", "wcep", "multixscience", "ms2", "cochrane"]:
    for retriever in ["sparse", "dense"]:
        for strategy in ["max", "mean", "oracle"]:

            if dataset_name not in doc_ids:
                doc_ids[dataset_name] = {"sparse": {}, "dense": {}}

            # In some cases we replaced the validation set with retrieved results
            split = "validation" if dataset_name in {"ms2", "cochrane"} else "test"
            retrieved_dataset = load_dataset(f"allenai/{dataset_name}_{retriever}_{strategy}")[split]

            if dataset_name in {"multinews", "wcep"}:
                doc_sep_token = (
                    util.DOC_SEP_TOKENS["multi_news"]
                    if dataset_name == "multinews"
                    else util.DOC_SEP_TOKENS["ccdv/WCEP-10"]
                )
                ground_truth_dataset = load_dataset(
                    "multi_news" if dataset_name == "multinews" else "ccdv/WCEP-10", split=split
                )
                error_stats, retrieved_docs_ids = _get_error_stats(
                    ground_truth_inputs=ground_truth_dataset["document"],
                    retrieved_inputs=retrieved_dataset["document"],
                    doc_sep_token=doc_sep_token,
                )
            elif dataset_name == "multixscience":
                ground_truth_dataset = load_dataset("multi_x_science_sum", split=split)
                error_stats, retrieved_docs_ids = _get_error_stats(
                    ground_truth_inputs=[example["abstract"] for example in ground_truth_dataset["ref_abstract"]],
                    retrieved_inputs=[example["abstract"] for example in retrieved_dataset["ref_abstract"]],
                )
            elif dataset_name in {"ms2", "cochrane"}:
                ground_truth_dataset = load_dataset("allenai/mslr2022", name=dataset_name, split=split)
                # Following https://aclanthology.org/2021.emnlp-main.594/, take the first 25 articles.
                error_stats, retrieved_docs_ids = _get_error_stats(
                    ground_truth_inputs=[example["pmid"][:_MAX_INCLUDED_STUDIES] for example in ground_truth_dataset],
                    retrieved_inputs=[example["pmid"][:_MAX_INCLUDED_STUDIES] for example in retrieved_dataset],
                )
            else:
                raise ValueError(f"Unrecognized dataset_name: {dataset_name}.")

            doc_ids[dataset_name][retriever][strategy] = retrieved_docs_ids

            # Give each dataset a nicely formatted name for plotting
            if dataset_name == "multinews":
                nice_dataset_name = "Multi-News"
            elif dataset_name == "wcep":
                nice_dataset_name = "WCEP-10"
            elif dataset_name == "multixscience":
                nice_dataset_name = "Multi-XScience"
            elif dataset_name == "ms2":
                nice_dataset_name = "MS2"
            elif dataset_name == "cochrane":
                nice_dataset_name = "Cochrane"

            # Collect the error stats for each dataset in a way amendable to plotting
            results["dataset"].extend([nice_dataset_name] * len(error_stats))
            results["retriever"].extend([retriever] * len(error_stats))
            results["strategy"].extend([strategy] * len(error_stats))
            results["error_count"].extend(
                [error_stats["addition"], error_stats["deletion"], error_stats["replacement"]]
            )
            results["error_type"].extend(["addition", "deletion", "replacement"])

results_df = pd.DataFrame(results)

In [None]:
df = results_df.copy()
# Display error counts in the 100s
df.error_count = df.error_count / 100

# Setup the grid
g = sns.FacetGrid(
    df,
    row="retriever",
    col="dataset",
    sharex=True,
    sharey=False,
    row_order=["sparse", "dense"],
    col_order=["Multi-News", "WCEP-10", "Multi-XScience", "MS2", "Cochrane"],
    margin_titles=True,
)

# Plot the barplots
_ = g.map_dataframe(
    sns.barplot,
    x="strategy",
    y="error_count",
    hue="error_type",
    order=["max", "mean", "oracle"],
    palette="tab10",
)

# Setup a legend
_ = g.add_legend(loc="lower center", bbox_to_anchor=(0.4, 1.0), frameon=False, ncol=5, columnspacing=0.8)
# Setup global axis titles
_ = g.set_axis_labels("", "")
_ = g.fig.supylabel("Absolute Error Count (100s)", x=0.025, horizontalalignment="left", verticalalignment="center")
# Change the default subplot title format, see: https://wckdouglas.github.io/2016/12/seaborn_annoying_title
_ = g.set_titles(row_template=r"{row_name}", col_template=r"{col_name}")
# Rotate the x-axis labels
_ = [plt.setp(ax.get_xticklabels(), rotation=45) for ax in g.axes.flat]

# Save the figure
Path("../output/plots").mkdir(parents=True, exist_ok=True)
plt.savefig(f"../output/plots/retrieval_errors.svg", facecolor="white", bbox_inches="tight")