# Data curation notebook: Result visualization and re-training config generation 

### Notebook Overview
At a high-level, this notebook contains code for (1) loading and visualizing various types of data curation results (e.g., curated dataset quality) and (2) generating configuration files from these results, so that new policies can be trained on datasets curated by different methods. This code is only relevant *after* (base) policies have been trained and data curation methods have been applied to score individual demonstrations within datasets. I.e., this notebook will show you how to load those demonstration scores (along with other data) from disk, perform some basic operations on them (e.g., combine scores from different methods), and generate simple plots visualizing curation metrics of interest. 

The notebook is organized into four sections:
1. **Sec. 1:** Contains all utility functions for loading, computing, and plotting results, along with other handy utilities. 
2. **Sec. 2:** Provides code samples for visualizing curated dataset quality for two curation tasks defined in the paper (filtering training demos and selecting holdout demos).
3. **Sec. 3:** Provides code samples for generating curation configuration files for re-training policies with curated datasets (for both the filtering and selection tasks). 
4. **Sec. 4:** Provides code samples for visualizing the performance of policies trained with curated datasets (for both the filtering and selection tasks).

We encourage users to modify the code as necessary to suit your needs!

In [None]:
# TODO: If cell below throws an error, try running this cell multiple times first.

from typing import Optional, List, Tuple, Dict, Any, Union, Callable

%matplotlib inline
from matplotlib import pyplot as plt
serif = True
if serif:
    plt.rcParams["font.family"] = "serif"
else:
    plt.rcParams["font.family"] = "Liberation Sans"
plt.rcParams["font.size"] = 10
plt.rcParams["pdf.fonttype"] = 42
plt.rcParams["ps.fonttype"] = 42
import seaborn as sns
import matplotlib.image as mpimg
from matplotlib.lines import Line2D
from matplotlib.gridspec import GridSpec
from matplotlib.patches import Patch
import matplotlib.lines as mlines
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib.colors as mcolors

import os
import h5py
import yaml
import hydra
import pickle
import pathlib
import omegaconf
import numpy as np
from copy import deepcopy
from functools import partial
from collections import defaultdict

current_dir = pathlib.Path.cwd()
if "notebooks" in str(current_dir):
    os.chdir(current_dir.parent)

import torch
from torch import nn

from diffusion_policy.dataset.episode_dataset import BatchEpisodeDataset
from diffusion_policy.dataset.pusht_dataset import PushTLowdimDataset
from diffusion_policy.dataset.pusht_image_dataset import PushTImageDataset
from diffusion_policy.dataset.robomimic_replay_lowdim_dataset import RobomimicReplayLowdimDataset
from diffusion_policy.dataset.robomimic_replay_image_dataset import RobomimicReplayImageDataset
from diffusion_policy.dataset.real_franka_image_dataset import RealFrankaImageDataset
from diffusion_policy.policy.diffusion_unet_lowdim_policy import DiffusionUnetLowdimPolicy

# Utility functions.
from diffusion_policy.common.trak_util import (
    DemoDatasetType,
    get_dataset_metadata,
    get_best_checkpoint,
    get_index_checkpoint,
    get_policy_from_checkpoint,
)
from diffusion_policy.common.results_util import (
    DEMO_RESULT_KEYS,
    ROLLOUT_RESULT_KEYS,
    CURATION_KEY_NAME_FN,
    get_offline_state_diversity_exp_key,
    get_online_state_similarity_exp_key,
    get_online_trak_influence_exp_key,
    get_last_n_log_keys,
)

if "notebooks" in str(current_dir):
    os.chdir(current_dir.parent)

In [None]:
# TODO: If this cell throws an error, try running the cell above multiple times first.

# Set up directories.
device = torch.device("cpu")
output_dir = current_dir / "data" / "outputs"
config_dir = current_dir / "configs"
train_dir = output_dir / "train"
eval_dir = output_dir / "eval_save_episodes"
real_train_dir = output_dir / "train_real"
real_eval_dir = output_dir / "eval_save_episodes_real"
result_dir = current_dir / "notebooks" / "outputs" / "official"

# Curation configuration directory.
SAVE_CURATION_CONFIGS = True
curation_config_dir = config_dir / "curation"
if not curation_config_dir.exists():
    curation_config_dir.mkdir()
    (curation_config_dir / "low_dim").mkdir()
    (curation_config_dir / "image").mkdir()

# Debug mode.
DEBUG = False
dbprint = print if DEBUG else lambda x: ()

## Sec 1: Run utilities for use in subsequent cells

### Sec 1.1: Loading utilities
**Description:** Code related to loading demonstration scores from disk, along with other necessary data (e.g., datasets, checkpoints, ground-truth demonstration labels) for visualization and data curation.

In [None]:
def load_pickle(path: Union[str, pathlib.Path]) -> Optional[Dict[str, Any]]:
    """Load pickle data."""
    try:
        with open(path, "rb") as f:
            data = pickle.load(f)
    except FileNotFoundError:
        data = None
    return data


def get_online_trak_influence_quality_exp_key(
    metric: str = "net",
    num_rollouts: str = "all",
    method_prefix: bool = True,
    aggr_fn: str = "sum_of_sum"
) -> List[str]:
    """Get influence quality (CUPID-Quality) experiment keys."""
    inf_key = get_online_trak_influence_exp_key(
        aggr_fn=aggr_fn,
        metric=metric,
        num_rollouts=num_rollouts,
        method_prefix=method_prefix,
    )
    minimax_key = get_online_trak_influence_exp_key(
        aggr_fn="min_of_max",
        metric=metric,
        num_rollouts=num_rollouts,
        method_prefix=method_prefix,
    )
    maximin_key = get_online_trak_influence_exp_key(
        aggr_fn="max_of_min",
        metric=metric,
        num_rollouts=num_rollouts,
        method_prefix=method_prefix,
    )
    
    return [inf_key, minimax_key, maximin_key]


def get_load_result_kwargs(
    task: str, 
    policy: str,
    seed: int,
    train_date: str = "25.03.03",
    eval_date: str = "25.03.03",
    result_date: str = "25.03.03",
    train_ckpt: str = "latest"
) -> Dict[str, Any]:
    """Return kwarg dictionaries for result loading functions."""
    return {
        "train_exp_kwargs": {
            "task": task,
            "policy": policy,
            "train_date": train_date,
            "exp_date": train_date,
            "train_ckpt": train_ckpt,
            "train_seed": seed,
        },
        "eval_exp_kwargs": {
            "task": task,
            "policy": policy,
            "eval_date": eval_date,
            "train_date": train_date,
            "train_seed": seed,
            "train_ckpt": train_ckpt,
        },
        "result_exp_kwargs": {
            "result_date": result_date,
            "result_seed": 0
        },
    }


def get_eval_exp_path(
    task: str = "pusht",
    policy: str = "diffusion_unet_lowdim",
    eval_date: str = "25.03.03",
    train_date: str = "25.03.03",
    train_seed: int = 0,
    train_ckpt: str = "latest",
    real_exp: bool = False,
    **kwargs,
) -> pathlib.Path:
    """Return path to evaluation experiment directory."""
    root_dir = real_eval_dir if real_exp else eval_dir
    train_exp_name = f"{train_date}_train_{policy}_{task}_{train_seed}"
    return root_dir / eval_date / train_exp_name / train_ckpt


def get_train_exp_path(
    task: str = "pusht",
    policy: str = "diffusion_unet_lowdim",
    train_date: str = "25.03.03",
    exp_date: str = "25.03.03",
    train_seed: int = 0,
    real_exp: bool = False,
    curate_dataset: bool = False,
    curation_method: Optional[str] = None,
    filter_ratio: Optional[float] = None,
    select_ratio: Optional[float] = None,
    **kwargs,
) -> pathlib.Path:
    """Return path to training experiment directory."""
    root_dir = real_train_dir if real_exp else train_dir
    train_exp_name = f"{exp_date}_train_{policy}_{task}_{train_seed}"
    if curate_dataset:
        assert (
            (curation_method is not None) and
            (filter_ratio is not None and 0.0 <= filter_ratio <= 1.0) and
            (select_ratio is not None and 0.0 <= select_ratio <= 1.0)
        ), "Curation arguments must be set together"
        train_exp_name = f"{train_exp_name}-curation_{curation_method}-filter_{filter_ratio:.2f}-select_{select_ratio:.2f}"
    return root_dir / train_date / train_exp_name


def get_train_checkpoint_paths(
    task: str = "pusht",
    policy: str = "diffusion_unet_lowdim",
    train_date: str = "25.03.03",
    exp_date: str = "25.03.03",
    real_exp: bool = False,
    filter_latest: bool = False,
    train_seed: Optional[int] = None,
) -> List[pathlib.Path]:
    """Return paths to training checkpoints."""
    # Get experiment directory.
    root_dir = real_train_dir if real_exp else train_dir
    
    # If train seed is not provided, checkpoint paths for all existing training seeds are returned.
    if train_seed is None:
        exp_root_dir = root_dir / train_date
        exp_name_prefix = f"{exp_date}_train_{policy}_{task}"
        exp_dirs = sorted(
            [x for x in exp_root_dir.iterdir() if exp_name_prefix in str(x)]
        )

    # If train seed is provided, checkpoint paths for the specified training seed are returned.
    else:
        exp_dirs = [
            get_train_exp_path(
                task=task,
                policy=policy,
                train_date=train_date,
                exp_date=exp_date,
                train_seed=train_seed,
                real_exp=real_exp,
            )
        ]

    # Get checkpoint paths.
    ckpt_paths: List[pathlib.Path] = []
    for exp_dir in exp_dirs:
        ckpt_dir = exp_dir / "checkpoints"
        exp_ckpt_paths = sorted(list(ckpt_dir.iterdir()))
        ckpt_paths.extend(exp_ckpt_paths)

    if filter_latest:
        ckpt_paths = [p for p in ckpt_paths if "latest" not in p.stem]

    return ckpt_paths


def get_policy_and_config(
    train_ckpt: Union[str, int], 
    **kwargs
) -> Union[nn.Module, Tuple[nn.Module, omegaconf.DictConfig]]:
    """Load policy and config from checkpoint."""
    checkpoint = None
    checkpoints = get_train_checkpoint_paths(**kwargs)
    if isinstance(train_ckpt, int):
        checkpoint = get_index_checkpoint(checkpoints, int(train_ckpt))
    elif isinstance(train_ckpt, str):
        if train_ckpt == "best":
            checkpoint = get_best_checkpoint(checkpoints)
        else:
            for ckpt_path in checkpoints:
                checkpoint = ckpt_path if train_ckpt in ckpt_path.stem else None

    if checkpoint is None:
        raise ValueError(f"Checkpoint type {train_ckpt} is not supported.")
    
    print(f"Loading checkpoint {checkpoint}")
    return get_policy_from_checkpoint(checkpoint, return_cfg=True, device=device)


def get_demo_quality_labels(cfg: omegaconf.DictConfig, dataset: DemoDatasetType) -> Optional[np.ndarray]:
    """Return an array of ground-truth demonstration labels. The length of the array is equal to the number of "training" demonstrations in the dataset."""
    if cfg.task.dataset_type != "mh":
        return None
    
    # Get demo indices and quality labels depending on dataset type.
    if isinstance(dataset, (RobomimicReplayLowdimDataset, RobomimicReplayImageDataset)):
        decode_fn = lambda x: np.array([int(name.decode().split("_")[-1]) for name in x])
        with h5py.File(dataset._dataset_path) as file:
            if any(x in cfg.task_name for x in ["lift", "can", "square"]):
                # Note: Lift, Can, Square have 3 quality tiers.
                demo_quality_vals = [1.0, 2.0, 3.0]
                demo_quality_sets = ["worse", "okay", "better"]
                demo_quality_idxs = [decode_fn(file["mask"][s][:]) for s in demo_quality_sets]
            elif "transport" in cfg.task_name:
                # ---------------------------------------------------------------------------------------------------------------------------
                # Note: Transport has 6 demo subsets, with two demonstrators per subset. Therefore, there is not 'correct' way to define demo quality.
                # That said, the chosen quality labels only affect visualization, not the actual performance of the policy after curated re-training. 
                # Thus, we use a custom definition of demo quality, which is a weighted average of the two demonstrators. We note that other 
                # definitions are equally valid, and we encourage the user to experiment with them to get a better understanding of the behavior
                # of different curation methods, and how they rank-order demonstrations differently.
                # ---------------------------------------------------------------------------------------------------------------------------

                # demo_quality_vals = [1.0, 1.5, 2.0, 2.5, 3.0, 1.5] # Mean of demonstrator quality.
                # demo_quality_vals = [1.0, 2.0, 2.0, 3.0, 3.0, 3.0] # Max of demonstrator quality.
                # demo_quality_vals = [1.0, 1.0, 2.0, 2.0, 3.0, 1.0] # Min of demonstrator quality.
                demo_quality_vals = [1.0, (1.0 * 1/3) + (2.0 * 2/3), 2.0, (2.0 * 1/3) + (3.0 * 2/3), 3.0, (1.0 * 1/3) + (3.0 * 2/3)] # Custom of demonstrator quality.
                demo_quality_sets = ["worse", "worse_okay", "okay", "okay_better", "better", "worse_better"]
                demo_quality_idxs = [decode_fn(file["mask"][s][:]) for s in demo_quality_sets]
            else:
                raise ValueError(f"Task {cfg.task_name} is not supported.")
    else:
        return None

    # Extract quality labels for training demonstrations.
    quality_labels = np.zeros(len(dataset.train_mask), dtype=float)
    for demo_idxs, quality_val in zip(demo_quality_idxs, demo_quality_vals):
        quality_labels[demo_idxs] = quality_val
    quality_labels = quality_labels[dataset.train_mask]
    
    assert np.all(quality_labels > 0), "All demos should have a quality label greater than zero."
    return quality_labels


def get_demonstration_scores_result_data(
    eval_exp_path: pathlib.Path,
    result_date: str = "25.03.03",
    result_seed: int = 0,
) -> Dict[str, Dict[str, np.ndarray]]:
    """Return scores computed over training and holdout demonstrations."""
    result_dir = eval_exp_path / f"{result_date}_demonstration_scores-seed={result_seed}"
    
    result_data = defaultdict(dict)
    for key in DEMO_RESULT_KEYS:

        data = load_pickle(result_dir / f"{key}.pkl")
        if data is None:
            continue

        for split in ["train", "holdout"]:
            if split in data:
                split_result = data[split]
                if isinstance(split_result, np.ndarray):
                    result_data[split][key] = split_result
                elif isinstance(split_result, dict):
                    for subkey, value in split_result.items():
                        if isinstance(value, np.ndarray):
                            result_data[split][f"{key}-{subkey}"] = value
        
        dbprint(f"Loaded {key} results.")

    return result_data


def get_deminf_result_data(
    task: str,
    train_date: str,
    train_seed: int,
    policy: str,
    **kwargs: Any,
) -> Optional[Dict[str, np.ndarray]]:
    """Return DemInf scores computed over training and holdout demonstrations. 
    Note: DemInf scores are computed using their official implementation at https://github.com/jhejna/demonstration-information."""
    state = "lowdim" if "lowdim" in policy else "image"
    result_file = output_dir / "deminf" / train_date / f"{task}_{state}_seed{train_seed}.npz"
    try: 
        result_data = dict(np.load(str(result_file)))
    except FileNotFoundError:
        return None
    
    dbprint(f"Loaded offline_deminf results.")
    return result_data


def get_rollout_scores_result_data(
    eval_exp_path: pathlib.Path,
    result_date: str = "25.03.03",
    result_seed: int = 0,
) -> Dict[str, np.ndarray]:
    """Return scores computed over test rollouts."""
    result_dir = eval_exp_path / f"{result_date}_rollout_scores-seed={result_seed}"
    
    result_data = {}
    for key in ROLLOUT_RESULT_KEYS:

        data = load_pickle(result_dir / f"{key}.pkl")
        if data is None:
            continue
        
        assert isinstance(data["test"], np.ndarray)
        result_data[key] = data["test"]
        dbprint(f"Loaded {key} results.")

    return result_data


def load_result_data(
    train_exp_kwargs: Dict[str, Any],
    eval_exp_kwargs: Dict[str, Any],
    result_exp_kwargs: Dict[str, Any],
    return_policy: bool = False,
    return_datasets: bool = False,
    real_exp: bool = False,
) -> Dict[str, Any]:
    """Load all necessary data for experiment analysis."""
    # Evaluation directory.
    eval_exp_path = get_eval_exp_path(**eval_exp_kwargs, real_exp=real_exp)

    # Load policy.
    policy, cfg = get_policy_and_config(**train_exp_kwargs, real_exp=real_exp)

    # Load training set and metadata.
    train_set: DemoDatasetType = hydra.utils.instantiate(cfg.task.dataset)
    train_set_metadata = get_dataset_metadata(cfg, train_set)
    train_set_metadata["demo_mask"] = train_set.train_mask
    train_set_metadata["quality_labels"] = get_demo_quality_labels(cfg, train_set)
    train_idxs = np.where(train_set.train_mask)[0]

    # Load holdout set and metadata.
    holdout_set = train_set.get_holdout_dataset()
    holdout_set_metadata = get_dataset_metadata(cfg, holdout_set)
    holdout_set_metadata["demo_mask"] = holdout_set.train_mask
    holdout_set_metadata["quality_labels"] = get_demo_quality_labels(cfg, holdout_set)

    # Load test set and metadata.
    try: 
        test_set = BatchEpisodeDataset(
            batch_size=1,
            dataset_path=eval_exp_path / "episodes",
            exec_horizon=1,
            sample_history=0,
        )
        test_set_metadata = get_dataset_metadata(cfg, test_set)
    except FileNotFoundError:
        test_set = None
        test_set_metadata = None
    
    # Load demonstration score results.
    result_data = get_demonstration_scores_result_data(eval_exp_path, **result_exp_kwargs)
    
    # Load DemInf demonstration score results.
    deminf_result_data = get_deminf_result_data(**eval_exp_kwargs)
    if deminf_result_data is not None:
        assert np.all(deminf_result_data["idxs"] == train_idxs), "DemInf experiments do not align."
        result_data["train"]["offline_deminf"] = deminf_result_data["scores"]

    # Load rollout results data.
    result_data["test"] = get_rollout_scores_result_data(eval_exp_path, **result_exp_kwargs)

    # Store return data.
    return_data = {
        "train_data": {
            "dataset": train_set if return_datasets else None,
            "metadata": train_set_metadata,
            "scores": result_data["train"],
        },
        "holdout_data": {
            "dataset": holdout_set if return_datasets else None,
            "metadata": holdout_set_metadata,
            "scores": result_data["holdout"],
        },
        "test_data": {
            "dataset": test_set if return_datasets else None,
            "metadata": test_set_metadata,
            "scores": result_data["test"]
        },
        "metadata": {
            "cfg": cfg,
            "policy": policy if return_policy else None,
            "eval_exp_path": eval_exp_path,
        },
    }

    return return_data

### Sec 1.2: Score results utilities
**Description:** Code related to performing operations on loaded demonstration scores (e.g., combining scores across methods) and computing metrics of interest, such as curated dataset quality. 

In [None]:
def sum_of_normalized_scores(
    scores: np.ndarray,
    weights: Optional[np.ndarray] = None,
) -> np.ndarray:
    """Return sum of normalized scores."""
    # Weights for weighted average.
    if weights is None:
        weights = np.ones(len(scores)) / len(scores)
    assert len(weights) == len(scores)

    # Remove invalid scores (e.g., all zeros)
    mask = scores.sum(axis=1) != 0
    scores = scores[mask]
    weights = weights[mask] / weights[mask].sum()

    # Normalize scores between [0, 1], return weighted average.
    def norm(scores: np.ndarray) -> np.ndarray:
        return (scores - scores.min()) / (scores.max() - scores.min())
    
    return np.array([norm(s) * weights[i] for i, s in enumerate(scores)]).sum(axis=0)


def compile_demo_quality_scores(
    result_data: Dict[str, Any],
    exp_keys: List[Union[str, List[str]]],
    exp_labels: List[str],
    exp_signs: Optional[List[Union[int, List[int]]]] = None,
    exp_weights: Optional[List[Union[float, List[float]]]] = None,
    split: str = "train",
) -> Tuple[np.ndarray, List[Union[str, List[str]]], List[str]]:
    """Compile demonstration quality scores."""
    # Assume correct sign of scores if none provided.
    if exp_signs is None:
        exp_signs = [1 for _ in range(len(exp_keys))]
    if exp_weights is None:
        exp_weights = [1.0 for _ in range(len(exp_keys))]
        
    demo_quality_scores = []
    exp_mask = np.zeros(len(exp_keys), dtype=bool)
    for i, (key, sign, weight) in enumerate(zip(exp_keys, exp_signs, exp_weights)):
        # Combining multiple score methods.
        if isinstance(key, list):            
            scores = []
            if not isinstance(sign, list):
                sign = [sign] * len(key)
            if not isinstance(weight, list):
                weight = [float(weight) / len(key)] * len(key)
            for subkey, subsign in zip(key, sign):
                try:
                    scores.append(subsign * result_data[f"{split}_data"]["scores"][subkey])
                except KeyError:
                    raise ValueError(f"Cannot handle missing key for combining scores.")
            scores = sum_of_normalized_scores(np.vstack(scores), weights=np.array(weight))
        
        # Single score method.
        else:
            try:
                scores = sign * result_data[f"{split}_data"]["scores"][key]
            except KeyError:
                continue
        
        assert isinstance(scores, np.ndarray)

        # Check if all scores are identical.
        if np.all(scores == scores[0]):
            print(f"Identical scores found for {key}.")
            continue

        # Store scores.
        exp_mask[i] = True
        demo_quality_scores.append(scores)

    exp_keys = [x for i, x in enumerate(exp_keys) if exp_mask[i]]
    exp_labels = [x for i, x in enumerate(exp_labels) if exp_mask[i]]
    demo_quality_scores = np.array(demo_quality_scores)
    if demo_quality_scores.ndim == 1:
        demo_quality_scores = demo_quality_scores[None, :]

    return demo_quality_scores, exp_keys, exp_labels


def compute_filtered_mean_quality_scores(
    result_data: Dict[str, Any],
    exp_keys: List[Union[str, List[str]]],
    exp_labels: List[str],
    exp_signs: Optional[List[Union[int, List[int]]]] = None,
    exp_weights: Optional[List[Union[float, List[float]]]] = None,
    num_keep: int = 16,
    normalize: bool = True,
    split: str = "train",
) -> Dict[str, np.ndarray]:
    """Compute average dataset quality as a function of the number of training demonstrations filtered."""
    # Compute demo quality scores.
    demo_quality_scores, exp_keys, exp_labels = compile_demo_quality_scores(
        result_data=result_data,
        exp_keys=exp_keys,
        exp_labels=exp_labels,
        exp_signs=exp_signs,
        exp_weights=exp_weights,
        split=split,
    )

    # Sort demonstrations from highest to lowest predicted quality.
    quality_labels = result_data[f"{split}_data"]["metadata"]["quality_labels"]
    sorted_quality_labels = quality_labels[quality_labels.argsort()][::-1]
    sorted_demo_quality_scores_idx = demo_quality_scores.argsort(axis=-1)[:, ::-1]
    sorted_quality_preds = [quality_labels[x] for x in sorted_demo_quality_scores_idx]

    # Compute average data quality of the filtered set.
    mean_quality_fn = lambda x: (np.cumsum(x) / np.arange(1, len(x) + 1))[::-1][:-num_keep]
    mean_quality_labels = mean_quality_fn(sorted_quality_labels)
    mean_quality_preds = [mean_quality_fn(x) for x in sorted_quality_preds]
    random_baseline_pred = np.ones_like(mean_quality_labels) * quality_labels.mean()

    # Return result, along with random and oracle.
    result = {k: v for k, v in zip(exp_labels, mean_quality_preds)}
    result["Oracle"] = mean_quality_labels
    result["Random"] = random_baseline_pred

    # Optionally normalize scores.
    if normalize:
        normalize_fn = lambda x: ((x / quality_labels.mean()) - 1.0) * 100
        result = {k: normalize_fn(v) for k, v in result.items()}

    return result


def compute_selected_mean_quality_scores(
    result_data: Dict[str, Any],
    exp_keys: List[Union[str, List[str]]],
    exp_labels: List[str],
    exp_signs: Optional[List[Union[int, List[int]]]] = None,
    exp_weights: Optional[List[Union[float, List[float]]]] = None,
    num_keep: int = 20,
    normalize: bool = False,
    split: str = "train",
) -> Dict[str, np.ndarray]:
    """Compute average dataset quality as a function of the number of holdout demonstrations selected."""
    # Compute demo quality scores.
    demo_quality_scores, exp_keys, exp_labels = compile_demo_quality_scores(
        result_data=result_data,
        exp_keys=exp_keys,
        exp_labels=exp_labels,
        exp_signs=exp_signs,
        exp_weights=exp_weights,
        split=split,
    )

    # Sort demonstrations from highest to lowest predicted quality.
    quality_labels = result_data[f"{split}_data"]["metadata"]["quality_labels"]
    sorted_quality_labels = quality_labels[quality_labels.argsort()][::-1]
    sorted_demo_quality_scores_idx = demo_quality_scores.argsort(axis=-1)[:, ::-1]
    sorted_quality_preds = [quality_labels[x] for x in sorted_demo_quality_scores_idx]

    # Compute average data quality of the selected set.
    mean_quality_fn = lambda x: (np.cumsum(x) / np.arange(1, len(x) + 1))[num_keep:]
    mean_quality_labels = mean_quality_fn(sorted_quality_labels)
    mean_quality_preds = [mean_quality_fn(x) for x in sorted_quality_preds]
    random_baseline_pred = np.ones_like(mean_quality_labels) * quality_labels.mean()

    # Return result, along with random and oracle.
    result = {k: v for k, v in zip(exp_labels, mean_quality_preds)}
    result["Oracle"] = mean_quality_labels
    result["Random"] = random_baseline_pred

    # Optionally normalize scores.
    if normalize:
        normalize_fn = lambda x: ((x - mean_quality_labels.min()) / (mean_quality_labels.max() - mean_quality_labels.min())) * 100
        result = {k: normalize_fn(v) for k, v in result.items()}

    return result


def get_scores_filter_idxs(
    result_data: Dict[str, Any],
    split: str,
    scores: np.ndarray,
) -> List[int]:
    """Return train mask indices from scores."""
    assert split in ["train", "holdout"]
    num_eps = result_data[f"{split}_data"]["metadata"]["num_eps"]
    train_mask = result_data[f"{split}_data"]["metadata"]["demo_mask"]
    valid_idxs = np.where(train_mask == True)[0]
    assert num_eps == len(valid_idxs) == len(scores)
    
    pred_idxs = valid_idxs[scores.argsort()].tolist()
    return pred_idxs if split == "train" else pred_idxs[::-1]


def get_oracle_filter_idxs(
    result_data: Dict[str, Any],
    split: str,
    seed: int,
) -> List[int]:
    """Return oracle quality train mask indices."""
    assert split in ["train", "holdout"]
    num_eps = result_data[f"{split}_data"]["metadata"]["num_eps"]
    train_mask = result_data[f"{split}_data"]["metadata"]["demo_mask"]
    valid_idxs = np.where(train_mask == True)[0]
    quality_labels = result_data[f"{split}_data"]["metadata"]["quality_labels"]
    assert num_eps == len(valid_idxs)

    # Randomly shuffle train mask indices before sorting by quality labels.
    shuffle_idxs = np.arange(num_eps)
    rng = np.random.default_rng(seed=seed)
    rng.shuffle(shuffle_idxs)
    valid_idxs: np.ndarray = valid_idxs[shuffle_idxs]
    quality_labels: np.ndarray = quality_labels[shuffle_idxs]

    # Sort by oracle.
    pred_idxs = valid_idxs[quality_labels.argsort()].tolist()
    return pred_idxs if split == "train" else pred_idxs[::-1]


def get_random_filter_idxs(
    result_data: Dict[str, Any],
    split: str,
    seed: int,
) -> List[int]:
    """Return oracle quality train mask indices."""
    assert split in ["train", "holdout"]
    num_eps = result_data[f"{split}_data"]["metadata"]["num_eps"]
    train_mask = result_data[f"{split}_data"]["metadata"]["demo_mask"]
    valid_idxs = np.where(train_mask == True)[0]
    assert num_eps == len(valid_idxs)

    # Randomly shuffle train mask indices.
    rng = np.random.default_rng(seed=seed)
    rng.shuffle(valid_idxs)
    return valid_idxs.tolist()

### Sec 1.3: Compile results utilities
**Description:** Code related to compiling metrics of interest (e.g., curated dataset quality) across tasks, methods, and seeds for easy visualization/plotting.

In [None]:
METRIC_FNS = (
    compute_filtered_mean_quality_scores,
    compute_selected_mean_quality_scores,
)


def compile_metric_across_tasks_seeds(
    split: str,
    tasks: List[str],
    seeds: List[int],
    policy: str,
    train_date: str,
    eval_date: str,
    result_date: str,
    metric_fn: Callable[[Any], Dict[str, np.ndarray]],
    real_exp: bool = False,
) -> Dict[str, Dict[str, np.ndarray]]:
    """Return dictionary of compiled results across tasks, methods, and seeds."""
    compiled_result = defaultdict(dict)
    for task in tasks:
        task_results = []
        for seed in seeds:
            load_kwargs = get_load_result_kwargs(
                task=task, 
                policy=policy, 
                seed=seed,
                train_date=train_date,
                eval_date=eval_date,
                result_date=result_date,
            )
            result_data = load_result_data(**load_kwargs, real_exp=real_exp)
            assert metric_fn in METRIC_FNS or (isinstance(metric_fn, partial) and metric_fn.func in METRIC_FNS)
            task_results.append(metric_fn(result_data, split=split))

        # Aggregate results by key before stacking.
        grouped_task_results = defaultdict(list)
        for seed_result in task_results:
            for k, v in seed_result.items():
                # Note: Demo-SCORE does not work on Lift MH, because the policy does not exhibit failures.  
                if "lift" in task and "Demo-SCORE" in k:
                    continue
                grouped_task_results[k].append(v)

        # Stack results across seeds.
        compiled_result[task] = {k: np.vstack(v) for k, v in grouped_task_results.items()}
 
    return compiled_result


def save_ranked_demos_to_config(
    split: str,
    tasks: List[str],
    seeds: List[int],
    policy: str,
    train_date: str,
    eval_date: str,
    result_date: str,
    compile_fn: partial,
    real_exp: bool = False,
) -> None:
    """Save yaml configuration file rank-ordering demonstrations based on predicted quality/value.
    This configuration file is used to re-train the policy with a curated dataset of demonstrations."""
    def dict_to_defaultdict(d: Dict[str, Any]) -> defaultdict:
        """Recursively convert a dictionary to a defaultdict(dict)."""
        return defaultdict(dict, **d) if isinstance(d, dict) else d
        
    state = "low_dim" if "lowdim" in policy else "image"
    for task in tasks:
        # Load curation config.
        task_curation_config_dir = curation_config_dir / state / task
        if not task_curation_config_dir.exists():
            task_curation_config_dir.mkdir(parents=True)

        config_file = task_curation_config_dir / f"{split}_config.yaml"
        if config_file.exists():
            with open(config_file, mode="+r") as f:
                curation_config = dict_to_defaultdict(yaml.safe_load(f))
        else:
            curation_config = defaultdict(dict)

        # Iterate over seeds.
        for seed in seeds:
            load_kwargs = get_load_result_kwargs(
                task=task, 
                policy=policy, 
                seed=seed,
                train_date=train_date,
                eval_date=eval_date,
                result_date=result_date,
            )
            result_data = load_result_data(**load_kwargs, real_exp=real_exp)
            
            assert isinstance(compile_fn, partial) and compile_fn.func == compile_demo_quality_scores
            demo_quality_scores, _, exp_curation_keys = compile_fn(result_data, split=split)
            
            # Iterate over methods.
            if "_mh" in task:
                curation_config["oracle"][seed] = get_oracle_filter_idxs(result_data, split, seed=seed)
            curation_config["random"][seed] = get_random_filter_idxs(result_data, split, seed=seed)
            for scores, curation_key in zip(demo_quality_scores, exp_curation_keys):
                # Note: Demo-SCORE does not work on Lift MH, because the policy does not exhibit failures. 
                if "lift" in task and "demoscore" in curation_key:
                    continue
                curation_config[curation_key][seed] = get_scores_filter_idxs(result_data, split, scores)
        
        # Save curation config.
        with open(config_file, mode="+w") as f:
            yaml.safe_dump(dict(curation_config), f)


def compile_last_n_across_tasks_seeds(
    tasks: List[str],
    seeds: List[int],
    policy: str,
    curate_dataset: bool = False,
    exp_curation_keys: Optional[List[str]] = None,
    exp_curation_labels: Optional[List[str]] = None,
    filter_ratios: Optional[List[float]] = None,
    select_ratios: Optional[List[float]] = None,
    get_last_n_kwargs: Optional[Dict[str, Any]] = None,
    curation_train_date: str = "25.03.05",
    reference_train_dates: Optional[Dict[str, str]] = None,
    real_exp: bool = False,
) -> Dict[str, Dict[str, np.ndarray]]:
    """Compile last N of a specified training metric across tasks, methods, and seeds."""
    GET_LAST_N_KWARGS = {
        "diffusion_unet_lowdim_lift_mh": {
            "n": 10,
            "required_epochs": 1000,
        },
        "diffusion_unet_lowdim_square_mh": {
            "n": 10,
            "required_epochs": 1000,
        },
        "diffusion_unet_lowdim_transport_mh": {
            "n": 10,
            "required_epochs": 1000,
        },
    }

    compiled_result = defaultdict(dict)
    for task in tasks:
        task_results = []

        if get_last_n_kwargs is None:
            get_last_n_kwargs = GET_LAST_N_KWARGS[f"{policy}_{task}"]

        for seed in seeds:
            seed_results = defaultdict(list)

            # Store default results.
            if isinstance(reference_train_dates, dict):
                for ref_exp_label, ref_train_date in reference_train_dates.items():
                    train_exp_kwargs = get_load_result_kwargs(task, policy, seed, train_date=ref_train_date)["train_exp_kwargs"]
                    train_exp_path = get_train_exp_path(**train_exp_kwargs, real_exp=real_exp)
                    last_n_metrics = get_last_n_log_keys(train_exp_path, **get_last_n_kwargs)
                    seed_results[ref_exp_label].append(last_n_metrics.mean())
            
            # Store curation results.
            if curate_dataset:
                assert (
                    (exp_curation_keys is not None) and
                    (exp_curation_labels is not None) and
                    (filter_ratios is not None) and
                    (select_ratios is not None)
                ), "Curation arguments must be set together"

                for curation_method, exp_label in zip(exp_curation_keys, exp_curation_labels):
                    # Note: Demo-SCORE does not work on Lift MH, because the policy does not exhibit failures. 
                    if "lift" in task and "demoscore" in curation_method:
                        continue

                    for filter_ratio, select_ratio in zip(filter_ratios, select_ratios):
                        train_exp_kwargs = get_load_result_kwargs(task, policy, seed, train_date=curation_train_date)["train_exp_kwargs"]
                        train_exp_path = get_train_exp_path(
                            **train_exp_kwargs, 
                            curate_dataset=True,
                            curation_method=curation_method,
                            filter_ratio=filter_ratio,
                            select_ratio=select_ratio,
                            real_exp=real_exp,
                        )
                        last_n_metrics = get_last_n_log_keys(train_exp_path, **get_last_n_kwargs)
                        seed_results[exp_label].append(last_n_metrics.mean())
            
            task_results.append(seed_results)

        # Aggregate results by key before stacking.
        grouped_task_results = defaultdict(list)
        for seed_results in task_results:
            for k, v in seed_results.items():
                grouped_task_results[k].append(v)

        # Stack results across seeds.
        compiled_result[task] = {k: np.vstack(v) for k, v in grouped_task_results.items()}
 
    return compiled_result

### Sec 1.4: Plotting utilities
**Description:** Code related to plotting, visualization, macros for colors and plot specs.

In [None]:
NUM_KEEP_FILTERED = 16
NUM_KEEP_SELECTED = 40

COLOR_SETTINGS = {
    # Reference methods.
    "Random": "#737373",                    # Tuned.
    "Oracle": "#252525",                    # Tuned.    
    "All Demos": "#66c2a4",                 # Tuned.
    "Base Policy": "#A5856B",               # Tuned.
    
    # Offline methods.
    "Policy Loss": "#2171b5",               # Outdated.
    "Policy Uncertainty": "#2171b5",        # Outdated.
    "Action Diversity": "#6a51a3",          # Outdated.
    "State Diversity": "#238b45",           # Outdated.
    "DemInf": "#8073ac",                    # Tuned.

    # Online methods.
    "Success Similarity": "#41b6c4",        # Tuned.
    "Demo-SCORE": "#045a8d",                # Tuned.
    
    # Our methods.
    "CUPID": "#ff8c00",                     # Tuned.
    "CUPID-Quality": "#ff5349",             # Tuned.    
}


def task_to_subtitle(task: str) -> str:
    """Return plot subtitle from task string."""
    parts = task.split("_")
    subtitle = None
    if len(parts) == 1:
        # PushT
        if task == "pusht":
            subtitle = "PushT"
    elif len(parts) == 2:
        # RoboMimic.
        if any(x in task for x in ["lift", "square", "transport"]):
            subtitle = f"{parts[0].title()} {parts[1].upper()}"
    elif len(parts) == 3:
        # RoboMimic.
        if "tool_hang" in task:
            subtitle = "ToolHang PH"
        # Hardware standard.
        elif "figure8" in task:
            subtitle = "Figure-8"
        elif "tuckbox" in task:
            subtitle = "TuckBox"
        elif "bookshelf" in task:
            subtitle = "Bookshelf"
    elif len(parts) == 4:
        # Hardware iterations.
        if "figure8" in task:
            subtitle = "Figure-8"
        elif "tuckbox" in task:
            subtitle = "TuckBox"
        elif "bookshelf" in task:
            subtitle = "Bookshelf"

    assert subtitle is not None
    return subtitle


def get_selected_xticks_xticklabels(num_samples: int, num_selected: int) -> Tuple[np.ndarray, List[str]]:
    """Return xticks and xticklabels."""
    xtick_spacing = 10 if num_samples <= 100 else 20
    upper_xtick_mod = int(num_samples / 10)
    upper_xtick_mod = upper_xtick_mod if upper_xtick_mod % 2 == 0 else upper_xtick_mod - 1
    upper_xtick = upper_xtick_mod * 10
    xticks = np.linspace(0, upper_xtick, int(upper_xtick / xtick_spacing) + 1)
    xtick_labels = [str(int(xtick) + num_selected) for xtick in xticks]
    return xticks, xtick_labels


def render_mean_filtered_quality_plot(
    results: Dict[str, Dict[str, np.ndarray]],
    tasks: List[str],
    colors: Optional[Dict[str, Any]] = None,
    linestyles: Optional[Dict[str, Any]] = None,
    label_order: Optional[List[str]] = None,
) -> None:
    """Render average dataset quality as a function of the number of training demonstrations filtered."""
    # Figure setup.
    height, width = 4, 5
    fig, axes = plt.subplots(1, len(tasks), figsize=(width * len(tasks), height), dpi=300)
    axes = [axes] if len(tasks) == 1 else axes

    # Plot results.
    handles_labels = {}
    for i, (ax, (task, methods)) in enumerate(zip(axes, results.items())):
        for method, values in methods.items():
            if method in ["Oracle", "Random"]:
                x = np.arange(values.shape[1])
                y = values[0]
                line, = ax.plot(x, y, label=method, color=COLOR_SETTINGS[method], linestyle="--", linewidth=3)
                handles_labels[method] = line

        for method, values in methods.items():
            if method not in ["Oracle", "Random"]:
                # Custom color and linestyle.
                color = colors[method] if colors is not None else COLOR_SETTINGS[method]
                linestyle = linestyles[method] if linestyles is not None else "-"

                x = np.arange(values.shape[1])
                y = np.mean(values, axis=0)
                y_err = np.std(values, axis=0) / np.sqrt(values.shape[0])
                line, = ax.plot(x, y, label=method, color=color, linestyle=linestyle, linewidth=2)
                ax.fill_between(x, y - y_err, y + y_err, color=color, alpha=0.2)
                handles_labels[method] = line

        # Set title and axis labels.
        ax.set_title(task_to_subtitle(task), fontsize=16)
        ax.set_xlabel("Number of Demonstrations Filtered", fontsize=13.5)
        if i == 0:
            ax.set_ylabel("(%) Increase in Data Quality", fontsize=13.5)

        # Set axis lims.
        ax.set_xlim(0, methods["Oracle"].shape[1])
        ax.set_ylim(-10, int(methods["Oracle"].max()) + 1)

        # Polish spines.
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["bottom"].set_linewidth(2)
        ax.spines["left"].set_linewidth(2)

    # Construct legend.
    if label_order is not None and isinstance(label_order, list):
        handles_labels = {k: handles_labels[k] for k in ["Oracle", "Random"] + label_order if k in handles_labels}

    ncol = len(handles_labels) // 2
    ncol = ncol if len(handles_labels) % 2 == 0 else ncol + 1
    fig.legend(
        handles_labels.values(), 
        handles_labels.keys(),
        loc="upper center", 
        fontsize=10, 
        ncol=ncol,
        bbox_to_anchor=(0.5, 0.0), 
        frameon=False
    )
    plt.tight_layout()
    plt.show()


def render_mean_selected_quality_plot(
    results: Dict[str, Dict[str, np.ndarray]],
    tasks: List[str],
    colors: Optional[Dict[str, Any]] = None,
    linestyles: Optional[Dict[str, Any]] = None,
    label_order: Optional[List[str]] = None,
    auto_ylims: bool = False,
    num_keep_selected: int = NUM_KEEP_SELECTED
) -> None:
    """Render average dataset quality as a function of the number of holdout demonstrations selected."""
    # Figure setup.
    height, width = 4, 5
    fig, axes = plt.subplots(1, len(tasks), figsize=(width * len(tasks), height), dpi=300)
    axes = [axes] if len(tasks) == 1 else axes

    # Plot results.
    handles_labels = {}
    for i, (ax, (task, methods)) in enumerate(zip(axes, results.items())):
        for method, values in methods.items():
            if method in ["Oracle", "Random"]:
                x = np.arange(values.shape[1])
                y = values[0]
                line, = ax.plot(x, y, label=method, color=COLOR_SETTINGS[method], linestyle="--", linewidth=3)
                handles_labels[method] = line

        for method, values in methods.items():
            if method not in ["Oracle", "Random"]:
                # Custom color and linestyle.
                color = colors[method] if colors is not None else COLOR_SETTINGS[method]
                linestyle = linestyles[method] if linestyles is not None else "-"

                x = np.arange(values.shape[1])
                y = np.mean(values, axis=0)
                y_err = np.std(values, axis=0) / np.sqrt(values.shape[0])
                line, = ax.plot(x, y, label=method, color=color, linestyle=linestyle, linewidth=2)
                ax.fill_between(x, y - y_err, y + y_err, color=color, alpha=0.2)
                handles_labels[method] = line

        # Set title and axis labels.
        ax.set_title(task_to_subtitle(task), fontsize=16)
        ax.set_xlabel("Number of Demonstrations Selected", fontsize=13.5)
        if i == 0:
            ax.set_ylabel("Average Selected Quality", fontsize=13.5)
                
        # Set axis lims.
        ax.set_xlim(0, methods["Oracle"].shape[1])
        if not auto_ylims:
            if "transport" in task:
                ax.set_ylim(2.025, 3.02)
            else:
                ax.set_ylim(1.9, 3.02)
        
        # Set axis ticks.
        xticks, xticklabels = get_selected_xticks_xticklabels(methods["Oracle"].shape[1], num_keep_selected)
        ax.set_xticks(xticks)
        ax.set_xticklabels(xticklabels)

        # Polish spines.
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["bottom"].set_linewidth(2)
        ax.spines["left"].set_linewidth(2)

    # Construct legend.
    if label_order is not None and isinstance(label_order, list):
        handles_labels = {k: handles_labels[k] for k in ["Oracle", "Random"] + label_order if k in handles_labels}

    ncol = len(handles_labels) // 2
    ncol = ncol if len(handles_labels) % 2 == 0 else ncol + 1
    fig.legend(
        handles_labels.values(), 
        handles_labels.keys(),
        loc="upper center", 
        fontsize=10, 
        ncol=ncol,
        bbox_to_anchor=(0.5, 0.0), 
        frameon=False
    )
    plt.tight_layout()
    plt.show()


def render_curation_retraining_plot(
    results: Dict[str, Dict[str, np.ndarray]],
    tasks: List[str],
    curation_ratios: List[float],
    xlabel: str = "Fraction of Dataset Filtered",
    colors: Optional[Dict[str, Any]] = None,
    label_order: Optional[List[str]] = None,
) -> None:
    """Render policy success rate as a function of the number of demonstrations filtered or selected."""
    # Figure setup.
    height, width = 4, 5
    fig, axes = plt.subplots(1, len(tasks), figsize=(width * len(tasks), height), dpi=300)
    axes = [axes] if len(tasks) == 1 else axes

    # Axis lim params.
    xmargin = 0.025
    ymargin = 0.050
    xmin = max(curation_ratios[0] - xmargin, 0)
    xmax = min(curation_ratios[-1] + xmargin, 1)
    hxmin = max(curation_ratios[0] - xmargin * 0.5, 0)
    hxmax = min(curation_ratios[-1] + xmargin * 0.5, 1)

    # Plot results.
    handles_labels = {}
    for i, (ax, (task, methods)) in enumerate(zip(axes, results.items())):
        
        # Track for ylims.
        ymin = float("inf")
        ymax = float("-inf")
        
        for method, values in methods.items():
            if method == "All Demos":
                assert values.squeeze().ndim == 1
                y = values.mean()
                handles_labels[method] = ax.axhline(y=y, color=COLOR_SETTINGS[method], linestyle="--", linewidth=5)
                
                # Update ymin and ymax.
                ymin = min(ymin, y)
                ymax = max(ymax, y)
        
        for method, values in methods.items():
            if method == "Base Policy":
                assert values.squeeze().ndim == 1
                y = values.mean()
                handles_labels[method] = ax.axhline(y=y, color=COLOR_SETTINGS[method], linestyle="--", linewidth=5)
                
                # Update ymin and ymax.
                ymin = min(ymin, y)
                ymax = max(ymax, y)
        
        for method, values in methods.items():
            if method in ["Oracle", "Random"]:
                assert values.squeeze().ndim == 2
                x = np.array(curation_ratios)
                y = values.mean(axis=0)
                line, = ax.plot(x, y, color=COLOR_SETTINGS[method], linestyle="--", linewidth=3, marker="o", markersize=8)
                handles_labels[method] = line

                # Update ymin and ymax.
                ymin = min(ymin, y.min())
                ymax = max(ymax, y.max())

        for method, values in methods.items():
            if method not in ["All Demos", "Base Policy", "Oracle", "Random"]:
                # Custom color and linestyle.
                color = colors[method] if colors is not None else COLOR_SETTINGS[method]
                
                assert values.squeeze().ndim == 2
                x = np.array(curation_ratios)
                y = values.mean(axis=0)
                y_err = values.std(axis=0) / np.sqrt(values.shape[0])
                line, = ax.plot(x, y, color=color, linestyle="-", linewidth=3, marker="o", markersize=10, markeredgecolor='white', markeredgewidth=2)
                ax.fill_between(x, y - y_err, y + y_err, color=color, alpha=0.25)
                handles_labels[method] = line

                # Update ymin and ymax.
                ymin = min(ymin, y.min())
                ymax = max(ymax, y.max())

        # Set title and axis labels.
        ax.set_title(task_to_subtitle(task), fontsize=16)
        ax.set_xlabel(xlabel, fontsize=13.5)
        if i == 0:
            ax.set_ylabel("Success Rate", fontsize=13.5)
        
        # Set axis lims.
        ax.set_xlim(xmin, xmax)
        ax.set_ylim(max(ymin - ymargin, -0.01), min(ymax + ymargin, 1.01))

        # Set axis ticks.
        ax.set_xticks(curation_ratios)
        ax.set_xticklabels([f"{r:.2f}" for r in curation_ratios])

        # Polish spines.
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["bottom"].set_linewidth(2)
        ax.spines["left"].set_linewidth(2)
    
    # Set ygrid.
    for ax in axes:
        ax.hlines(y=ax.get_yticks(), xmin=hxmin, xmax=hxmax, colors="gray", linestyles="-", linewidth=0.5, alpha=0.5, zorder=0)

    # Construct legend.
    if label_order is not None and isinstance(label_order, list):
        _handles_labels = {k: handles_labels[k] for k in label_order if k in handles_labels}
        if "Base Policy" in handles_labels:
            _handles_labels["Base Policy"] = handles_labels["Base Policy"]
        if "All Demos" in handles_labels:
            _handles_labels["All Demos"] = handles_labels["All Demos"]
        handles_labels = _handles_labels

    ncol = len(handles_labels) // 2
    ncol = ncol if len(handles_labels) % 2 == 0 else ncol + 1
    fig.legend(
        handles_labels.values(), 
        handles_labels.keys(),
        loc="upper center", 
        fontsize=10, 
        ncol=ncol,
        bbox_to_anchor=(0.5, 0.0),
        frameon=False
    )
    plt.tight_layout()
    plt.show()

## Sec 2: Visualize data quality results

### Sec 2.1: RoboMimic demo filtering (Task 1: Filter-k)

In [None]:
# TODO: Adjust to match your experiment dates.
eval_date="<enter_policy_eval_date>"
train_date="<enter_policy_train_date>"
result_date="default"

# TODO: Adjust to your intended policy state.
state = "lowdim" 
# state = "image"

# TODO: Adjust to your intended tasks and seeds.
tasks = ["lift_mh"]  # tasks = ["lift_mh", "square_mh", "transport_mh"]
seeds = [0, 1, 2]
policy = f"diffusion_unet_{state}"

# TODO: Adjust to your intended methods.
exp_metadata = {

    ######################## Custom baselines. ########################
    # "Policy Loss": {
    #     "key": "offline_policy_loss",
    #     "curation_key": "policy_loss",
    #     "sign": -1,
    #     "weight": 1, 
    # },

    # "Action Diversity": {
    #     "key": "offline_action_diversity",
    #     "curation_key": "action_diversity",
    #     "sign": 1,
    #     "weight": 1, 
    # },

    # "State Diversity": {
    #     "key": get_offline_state_diversity_exp_key(
    #         embedding_name="policy",
    #         score_fn="mahal",
    #         method_prefix=True,
    #     ),
    #     "curation_key": "state_diversity",
    #     "sign": 1,
    #     "weight": 1, 
    # },

    ######################## Key baselines. ########################
    # "Demo-SCORE": {
    #     "key": "online_demo_score",
    #     "curation_key": "demoscore",
    #     "sign": 1,
    #     "weight": 1, 
    # },
    
    # "Success Similarity": {
    #     "key": get_online_state_similarity_exp_key(
    #         embedding_name="policy",
    #         score_fn="l2",
    #         aggr_fn="mean_of_mean_success",
    #         metric="net",
    #         num_rollouts="all",
    #         method_prefix=True,
    #     ),
    #     "curation_key": "state_similarity",
    #     "sign": 1,
    #     "weight": 1, 
    # },

    # "DemInf": {
    #     "key": "offline_deminf",
    #     "curation_key": "deminf",
    #     "sign": 1,
    #     "weight": 1, 
    # },

    ######################## Our methods. ########################
    "CUPID": {
        "key": get_online_trak_influence_exp_key(
            aggr_fn="sum_of_sum",
            metric="net",
            num_rollouts="all",
            method_prefix=True,
        ),
        "curation_key": "influence_sum_official",
        "sign": 1,
        "weight": 1, 
    },

    "CUPID-Quality": {
        "key": get_online_trak_influence_quality_exp_key(
            aggr_fn="sum_of_sum",
            metric="net",
            num_rollouts="all",
        ),
        "curation_key": "influence_quality_official",
        "sign": 1,
        "weight": [0.50, 0.25, 0.25], 
    },
}
exp_labels = list(exp_metadata.keys())
exp_keys = [exp_metadata[k]["key"] for k in exp_labels]
exp_curation_keys = [exp_metadata[k]["curation_key"] for k in exp_labels]
exp_signs = [exp_metadata[k]["sign"] for k in exp_labels]
exp_weights = [exp_metadata[k]["weight"] for k in exp_labels]
assert len(exp_labels) == len(exp_keys) == len(exp_curation_keys) == len(exp_signs) == len(exp_weights)

# Plot data quality result.
metric_fn = partial(
    compute_filtered_mean_quality_scores,
    exp_keys=exp_keys,
    exp_labels=exp_labels,
    exp_signs=exp_signs,
    exp_weights=exp_weights,
    num_keep=NUM_KEEP_FILTERED,
)
filter_results = compile_metric_across_tasks_seeds(
    split="train",
    tasks=tasks,
    seeds=seeds,
    policy=policy,
    train_date=train_date,
    eval_date=eval_date,
    result_date=result_date,
    metric_fn=metric_fn,
)
render_mean_filtered_quality_plot(
    results=filter_results,
    tasks=tasks,
    label_order=["Oracle", "Random", "Policy Loss", "State Diversity", "Action Diversity", "Demo-SCORE", "Success Similarity", "DemInf", "CUPID", "CUPID-Quality"],
)

### Sec 2.2: RoboMimic demo selection (Task 2: Select-k)

In [None]:
# TODO: Adjust to match your experiment dates.
eval_date="<enter_policy_eval_date>"
train_date="<enter_policy_train_date>"
result_date="default"

# TODO: Adjust to your intended policy state.
state = "lowdim" 
# state = "image"

# TODO: Adjust to your intended tasks and seeds.
tasks = ["lift_mh"]  # tasks = ["lift_mh", "square_mh", "transport_mh"]
seeds = [0, 1, 2]
policy = f"diffusion_unet_{state}"

# TODO: Adjust to your intended methods.
exp_metadata = {
    
    ######################## Custom baselines. ########################
    # "Policy Uncertainty": {
    #     "key": "offline_policy_loss",
    #     "curation_key": "policy_uncertainty",
    #     "sign": 1,
    #     "weight": 1, 
    # },

    # "Action Diversity": {
    #     "key": "offline_action_diversity",
    #     "curation_key": "action_diversity",
    #     "sign": 1,
    #     "weight": 1, 
    # },

    # "State Diversity": {
    #     "key": get_offline_state_diversity_exp_key(
    #         embedding_name="policy",
    #         score_fn="mahal",
    #         method_prefix=True,
    #     ),
    #     "curation_key": "state_diversity",
    #     "sign": 1,
    #     "weight": 1, 
    # },

    ######################## Key baselines. ########################
    # "Demo-SCORE": {
    #     "key": "online_demo_score",
    #     "curation_key": "demoscore",
    #     "sign": 1,
    #     "weight": 1, 
    # },
    
    # "Success Similarity": {
    #     "key": get_online_state_similarity_exp_key(
    #         embedding_name="policy",
    #         score_fn="l2",
    #         aggr_fn="mean_of_mean_success",
    #         metric="net",
    #         num_rollouts="all",
    #         method_prefix=True,
    #     ),
    #     "curation_key": "state_similarity",
    #     "sign": 1,
    #     "weight": 1, 
    # },

    ######################## Our methods. ########################
    "CUPID": {
        "key": get_online_trak_influence_exp_key(
            aggr_fn="sum_of_sum",
            metric="net",
            num_rollouts="all",
            method_prefix=True,
        ),
        "curation_key": "influence_sum_official",
        "sign": 1,
        "weight": 1, 
    },

    "CUPID-Quality": {
        "key": get_online_trak_influence_quality_exp_key(
            aggr_fn="sum_of_sum",
            metric="net",
            num_rollouts="all",
        ),
        "curation_key": "influence_quality_official",
        "sign": 1,
        "weight": [0.50, 0.25, 0.25], 
    },
}
exp_labels = list(exp_metadata.keys())
exp_keys = [exp_metadata[k]["key"] for k in exp_labels]
exp_curation_keys = [exp_metadata[k]["curation_key"] for k in exp_labels]
exp_signs = [exp_metadata[k]["sign"] for k in exp_labels]
exp_weights = [exp_metadata[k]["weight"] for k in exp_labels]
assert len(exp_keys) == len(exp_labels) == len(exp_curation_keys) == len(exp_signs) == len(exp_weights)

# Plot data quality result.
metric_fn = partial(
    compute_selected_mean_quality_scores,
    exp_keys=exp_keys,
    exp_labels=exp_labels,
    exp_signs=exp_signs,
    exp_weights=exp_weights,
    num_keep=NUM_KEEP_SELECTED,
)
select_results = compile_metric_across_tasks_seeds(
    split="holdout",
    tasks=tasks,
    seeds=seeds,
    policy=policy,
    train_date=train_date,
    eval_date=eval_date,
    result_date=result_date,
    metric_fn=metric_fn,
)
render_mean_selected_quality_plot(
    results=select_results,
    tasks=tasks,
    label_order=["Oracle", "Random", "Policy Loss", "State Diversity", "Action Diversity", "Demo-SCORE", "Success Similarity", "DemInf", "CUPID", "CUPID-Quality"],
)

## Sec 3: Generate config files for curated re-training

### Sec 3.1: RoboMimic demo filtering (Task 1: Filter-k)

In [None]:
# TODO: Adjust to match your experiment dates.
eval_date="<enter_policy_eval_date>"
train_date="<enter_policy_train_date>"
result_date="default"

# TODO: Adjust to your intended policy state.
state = "lowdim" 
# state = "image"

# TODO: Adjust to your intended tasks and seeds.
tasks = ["lift_mh"]  # tasks = ["lift_mh", "square_mh", "transport_mh"]
seeds = [0, 1, 2]
policy = f"diffusion_unet_{state}"

# TODO: Adjust to your intended methods.
exp_metadata = {

    ######################## Custom baselines. ########################
    # "Policy Loss": {
    #     "key": "offline_policy_loss",
    #     "curation_key": "policy_loss",
    #     "sign": -1,
    #     "weight": 1, 
    # },

    # "Action Diversity": {
    #     "key": "offline_action_diversity",
    #     "curation_key": "action_diversity",
    #     "sign": 1,
    #     "weight": 1, 
    # },

    # "State Diversity": {
    #     "key": get_offline_state_diversity_exp_key(
    #         embedding_name="policy",
    #         score_fn="mahal",
    #         method_prefix=True,
    #     ),
    #     "curation_key": "state_diversity",
    #     "sign": 1,
    #     "weight": 1, 
    # },

    ######################## Key baselines. ########################
    # "Demo-SCORE": {
    #     "key": "online_demo_score",
    #     "curation_key": "demoscore",
    #     "sign": 1,
    #     "weight": 1, 
    # },
    
    # "Success Similarity": {
    #     "key": get_online_state_similarity_exp_key(
    #         embedding_name="policy",
    #         score_fn="l2",
    #         aggr_fn="mean_of_mean_success",
    #         metric="net",
    #         num_rollouts="all",
    #         method_prefix=True,
    #     ),
    #     "curation_key": "state_similarity",
    #     "sign": 1,
    #     "weight": 1, 
    # },

    # "DemInf": {
    #     "key": "offline_deminf",
    #     "curation_key": "deminf",
    #     "sign": 1,
    #     "weight": 1, 
    # },

    ######################## Our methods. ########################
    "CUPID": {
        "key": get_online_trak_influence_exp_key(
            aggr_fn="sum_of_sum",
            metric="net",
            num_rollouts="all",
            method_prefix=True,
        ),
        "curation_key": "influence_sum_official",
        "sign": 1,
        "weight": 1, 
    },

    "CUPID-Quality": {
        "key": get_online_trak_influence_quality_exp_key(
            aggr_fn="sum_of_sum",
            metric="net",
            num_rollouts="all",
        ),
        "curation_key": "influence_quality_official",
        "sign": 1,
        "weight": [0.50, 0.25, 0.25], 
    },
}
exp_labels = list(exp_metadata.keys())
exp_keys = [exp_metadata[k]["key"] for k in exp_labels]
exp_curation_keys = [exp_metadata[k]["curation_key"] for k in exp_labels]
exp_signs = [exp_metadata[k]["sign"] for k in exp_labels]
exp_weights = [exp_metadata[k]["weight"] for k in exp_labels]
assert len(exp_labels) == len(exp_keys) == len(exp_curation_keys) == len(exp_signs) == len(exp_weights)

# Generate curation config.
compile_fn = partial(
    compile_demo_quality_scores,
    exp_keys=exp_keys,
    exp_labels=exp_curation_keys,
    exp_signs=exp_signs,
    exp_weights=exp_weights,
)
save_ranked_demos_to_config(
    split="train",
    tasks=tasks,
    seeds=seeds,
    policy=policy,
    train_date=train_date,
    eval_date=eval_date,
    result_date=result_date,
    compile_fn=compile_fn,
)

### Sec 3.2: RoboMimic demo selection (Task 2: Select-k)

In [None]:
# TODO: Adjust to match your experiment dates.
eval_date="<enter_policy_eval_date>"
train_date="<enter_policy_train_date>"
result_date="default"

# TODO: Adjust to your intended policy state.
state = "lowdim" 
# state = "image"

# TODO: Adjust to your intended tasks and seeds.
tasks = ["lift_mh"]  # tasks = ["lift_mh", "square_mh", "transport_mh"]
seeds = [0, 1, 2]
policy = f"diffusion_unet_{state}"

# TODO: Adjust to your intended methods.
exp_metadata = {
    
    ######################## Custom baselines. ########################
    # "Policy Uncertainty": {
    #     "key": "offline_policy_loss",
    #     "curation_key": "policy_uncertainty",
    #     "sign": 1,
    #     "weight": 1, 
    # },

    # "Action Diversity": {
    #     "key": "offline_action_diversity",
    #     "curation_key": "action_diversity",
    #     "sign": 1,
    #     "weight": 1, 
    # },

    # "State Diversity": {
    #     "key": get_offline_state_diversity_exp_key(
    #         embedding_name="policy",
    #         score_fn="mahal",
    #         method_prefix=True,
    #     ),
    #     "curation_key": "state_diversity",
    #     "sign": 1,
    #     "weight": 1, 
    # },

    ######################## Key baselines. ########################
    # "Demo-SCORE": {
    #     "key": "online_demo_score",
    #     "curation_key": "demoscore",
    #     "sign": 1,
    #     "weight": 1, 
    # },
    
    # "Success Similarity": {
    #     "key": get_online_state_similarity_exp_key(
    #         embedding_name="policy",
    #         score_fn="l2",
    #         aggr_fn="mean_of_mean_success",
    #         metric="net",
    #         num_rollouts="all",
    #         method_prefix=True,
    #     ),
    #     "curation_key": "state_similarity",
    #     "sign": 1,
    #     "weight": 1, 
    # },

    ######################## Our methods. ########################
    "CUPID": {
        "key": get_online_trak_influence_exp_key(
            aggr_fn="sum_of_sum",
            metric="net",
            num_rollouts="all",
            method_prefix=True,
        ),
        "curation_key": "influence_sum_official",
        "sign": 1,
        "weight": 1, 
    },

    "CUPID-Quality": {
        "key": get_online_trak_influence_quality_exp_key(
            aggr_fn="sum_of_sum",
            metric="net",
            num_rollouts="all",
        ),
        "curation_key": "influence_quality_official",
        "sign": 1,
        "weight": [0.50, 0.25, 0.25], 
    },
}
exp_labels = list(exp_metadata.keys())
exp_keys = [exp_metadata[k]["key"] for k in exp_labels]
exp_curation_keys = [exp_metadata[k]["curation_key"] for k in exp_labels]
exp_signs = [exp_metadata[k]["sign"] for k in exp_labels]
exp_weights = [exp_metadata[k]["weight"] for k in exp_labels]
assert len(exp_keys) == len(exp_labels) == len(exp_curation_keys) == len(exp_signs) == len(exp_weights)

# Generate curation config.
compile_fn = partial(
    compile_demo_quality_scores,
    exp_keys=exp_keys,
    exp_labels=exp_curation_keys,
    exp_signs=exp_signs,
    exp_weights=exp_weights,
)

save_ranked_demos_to_config(
    split="holdout",
    tasks=tasks,
    seeds=seeds,
    policy=policy,
    train_date=train_date,
    eval_date=eval_date,
    result_date=result_date,
    compile_fn=compile_fn,
)

## Sec 4: Visualize policy performance after curated re-training

### Sec 4.1: RoboMimic demo filtering (Task 1: Filter-k)

In [None]:
# TODO: Adjust to match your experiment dates.
reference_train_dates = {"All Demos": "<enter_policy_train_date>"}  # Train date of original policy. 
curation_train_date = "<enter_policy_retrain_date>"                 # Train date of curated policy.

# TODO: Adjust to your intended policy state.
state = "lowdim" 
# state = "image"

# TODO: Adjust to your intended tasks and seeds.
tasks = ["lift_mh"]  # tasks = ["lift_mh", "square_mh", "transport_mh"]
seeds = [0, 1, 2]
policy = f"diffusion_unet_{state}"

# Plot policy performance result.
plot_curation_keys = [
    "oracle",
    "random",
    # "policy_loss",
    # "state_diversity",
    # "action_diversity",
    # "deminf",
    # "state_similarity",
    # "demoscore",
    "influence_sum_official",
    "influence_quality_official",
]
plot_curation_labels = [
    "Oracle",
    "Random",
    # "Policy Loss",
    # "State Diversity",
    # "Action Diversity",
    # "DemInf",
    # "Success Similarity",
    # "Demo-SCORE",
    "CUPID",
    "CUPID-Quality",
]
filter_ratios = [0.10, 0.25, 0.50, 0.75, 0.90]
select_ratios = [0.00, 0.00, 0.00, 0.00, 0.00]

results = compile_last_n_across_tasks_seeds(
    tasks=tasks,
    seeds=seeds,
    policy=policy,
    curate_dataset=True,
    exp_curation_keys=plot_curation_keys,
    exp_curation_labels=plot_curation_labels,
    filter_ratios=filter_ratios,
    select_ratios=select_ratios,
    curation_train_date=curation_train_date,
    reference_train_dates=reference_train_dates,
)
render_curation_retraining_plot(
    results=results,
    tasks=tasks,
    curation_ratios=filter_ratios,
    label_order=plot_curation_labels,
)

### Sec 4.2: RoboMimic demo selection (Task 2: Select-k)

In [None]:
# TODO: Adjust to match your experiment dates.
reference_train_dates = {       
    "Base Policy": "<enter_policy_train_date>",              # Train date of original policy. 
    "All Demos": "<optional_enter_full_policy_train_date>",  # Train date of policy trained with all 300 demos. 
}
curation_train_date = "<enter_policy_retrain_date>"          # Train date of curated policy.

# TODO: Adjust to your intended policy state.
state = "lowdim" 
# state = "image"

# TODO: Adjust to your intended tasks and seeds.
tasks = ["lift_mh"]  # tasks = ["lift_mh", "square_mh", "transport_mh"]
seeds = [0, 1, 2]
policy = f"diffusion_unet_{state}"

# Plot policy performance result.
plot_curation_keys = [
    "oracle",
    "random",
    # "policy_loss",
    # "state_diversity",
    # "action_diversity",
    # "state_similarity",
    # "demoscore",
    "influence_sum_official",
    "influence_quality_official",
]
plot_curation_labels = [
    "Oracle",
    "Random",
    # "Policy Loss",
    # "State Diversity",
    # "Action Diversity",
    # "Success Similarity",
    # "Demo-SCORE",
    "CUPID",
    "CUPID-Quality",
]
filter_ratios = [0.00, 0.00, 0.00, 0.00, 0.00]
select_ratios = [0.10, 0.25, 0.50, 0.75, 0.90]

results = compile_last_n_across_tasks_seeds(
    tasks=tasks,
    seeds=seeds,
    policy=policy,
    curate_dataset=True,
    exp_curation_keys=plot_curation_keys,
    exp_curation_labels=plot_curation_labels,
    filter_ratios=filter_ratios,
    select_ratios=select_ratios,
    curation_train_date=curation_train_date,
    reference_train_dates=reference_train_dates,
)
render_curation_retraining_plot(
    results=results,
    tasks=tasks,
    curation_ratios=select_ratios,
    xlabel="Fraction of Holdout Demos Selected",
    label_order=plot_curation_labels,
)