In [14]:
%matplotlib inline
from typing import Union, Dict, Any, List, Optional

import os
import cv2
import imageio
import pathlib
import numpy as np
from PIL import Image
import pandas as pd

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import gridspec
serif = False
if serif:
    dir_postfix = ""
    plt.rcParams["font.family"] = "serif"
else:
    dir_postfix = "sans"
    plt.rcParams["font.family"] = "Liberation Sans"
plt.rcParams["font.size"] = 10
plt.rcParams["pdf.fonttype"] = 42
plt.rcParams["ps.fonttype"] = 42


import sentinel.bc.datasets.utils as data_utils
from sentinel.bc.ood_detection import metric_utils
from sentinel.bc.ood_detection import action_utils
from sentinel.bc.ood_detection.models import utils as model_utils

from results import *
quantile = 0.95

# Video Rendering Utils

In [15]:
CROP_DOMAIN_PARAMS = {
    "push_chair": {
        "center_y_ratio": 0.5,
        "center_x_ratio": 0.55,
        "side_ratio": 0.5,
    },
}

BRIGHTEN_DOMAIN_PARAMS = {
    "push_chair": {
        "alpha": 1.2, 
        "beta": 50
    },
}

def crop_domain_image(image: np.ndarray, domain: str) -> np.ndarray:
    """Crop image according to domain."""
    crop_kwargs = CROP_DOMAIN_PARAMS[domain]
    return model_utils.crop_to_square(image, **crop_kwargs)

def adjust_brightness(image: np.ndarray, domain: str) -> np.ndarray:
    """Adjust the brightness and contrast of the image using alpha and beta."""
    brighten_kwargs = BRIGHTEN_DOMAIN_PARAMS[domain]
    return cv2.convertScaleAbs(image, **brighten_kwargs)

def apply_clahe(image: np.ndarray, clip_limit: float = 2.0, tile_grid_size: tuple = (8, 8)) -> np.ndarray:
    """
    Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) to enhance contrast.
    """
    # Convert image to LAB color space (good for contrast enhancement).
    lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
    l_channel, a_channel, b_channel = cv2.split(lab_image)
    
    # Apply CLAHE to the L (lightness) channel.
    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
    cl_l_channel = clahe.apply(l_channel)
    
    # Merge the channels back and convert to BGR color space.
    lab_image = cv2.merge((cl_l_channel, a_channel, b_channel))
    enhanced_image = cv2.cvtColor(lab_image, cv2.COLOR_LAB2BGR)
    
    return enhanced_image

def process_image(image: np.ndarray, domain: str) -> np.ndarray:
    """Crop, adjust brightness, and enhance contrast of the image."""
    if domain in CROP_DOMAIN_PARAMS:
        image = crop_domain_image(image, domain)

    if domain in BRIGHTEN_DOMAIN_PARAMS:
        image = adjust_brightness(image, domain)
        image = apply_clahe(image, clip_limit=2.0, tile_grid_size=(8, 8))
    
    return image

## Single Method Score Function

In [16]:
def render_episode_gif(
    data: Dict[str, Any],
    split: str,
    episode: int,
    exp_key: str,
    method: str,
    title: str,
    method_color: str,
    save_dir: Union[str, pathlib.Path],
    text_mode: str = "lower",
    fps: float = 5,
    domain: Optional[str] = None,
) -> None:
    # Extract data.
    test_frame = data[split][exp_key]["test_frame"]
    demo_frame = data[split][exp_key]["demo_frame"]
    episode_frame = data_utils.get_episode(test_frame, episode, use_index=False)
    thresh = np.quantile(data_utils.aggr_episode_key_data(demo_frame, f"{exp_key}_cum_score"), quantile)    
    scores: np.ndarray = episode_frame[f"{exp_key}_cum_score"].values
    images = episode_frame["rgb"].values
    
    # Normalize scores.
    ymin = 0
    ymax = 3 * thresh
    scores = (scores - ymin) / (ymax - ymin)
    thresh = (thresh - ymin) / (ymax - ymin)

    # Plot settings.
    thresh_color = "red"
    xticks = np.linspace(0, len(scores) - 1, 6)
    xticklabels = [f"{x/(len(scores) - 1):.1f}" for x in xticks]
    
    height = 5
    if len(title) == 0:
        height -= 0.5
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, height))
    gif_images = []
    for i in range(len(scores)):
        ax1.clear()
        ax2.clear()

        failure_detected = scores[i] > thresh

        img_text_time = f"Rollout %: {i / len(scores):.2f}"
        if failure_detected:
            text_color = "red"
            img_text_detector = "Detector: Policy Failure"
        else:
            text_color = "green"
            img_text_detector = "Detector: Policy Ok"

        # Plot image.
        image = images[i]
        if domain is not None:
            image = process_image(image, domain)
        ax1.imshow(image)
        ax1.set_xticks([])
        ax1.set_yticks([])
        for spine in ax1.spines.values():
            spine.set_visible(True)
            spine.set_linewidth(4)
            spine.set_edgecolor(text_color)
        
        # Adding text annotation on the image
        h, w = image.shape[:2]
        
        if text_mode == "lower":
            x1, y1 = w * 0.04, h * 0.82
            x2, y2 = w * 0.04, h * 0.9
        elif text_mode == "lower_real":
            x1, y1 = w * 0.04, h * 0.75
            x2, y2 = w * 0.04, h * 0.875
        elif text_mode == "upper":
            x2, y2 = w * 0.04, h * 0.04
            x1, y1 = w * 0.04, h * 0.13
        elif text_mode == "upper_real":
            x2, y2 = w * 0.04, h * 0.05
            x1, y1 = w * 0.04, h * 0.13
        else:
            raise ValueError(f"Text mode {text_mode} not supported.")

        # Annotate image.
        ax1.text(
            x1, y1,
            img_text_time,
            color="black", 
            fontsize=15, 
            fontweight='bold', 
            ha='left', 
            va='top'
        )
        ax1.text(
            x2, y2,
            img_text_detector,
            color=text_color, 
            fontsize=17, 
            fontweight='bold', 
            ha='left', 
            va='top'
        )
        
        # Plot data.
        ax2.plot(scores[:i + 1], label="Cumulative Error", color=method_color, linewidth=4)
        ax2.axhline(y=thresh, color=thresh_color, linestyle='--', linewidth=5, label="Detection Threshold")

        if failure_detected:
            ax2.set_facecolor((1, 0, 0, 0.1))

        # Title and axes.
        # if len(title) > 0:
        #     ax2.set_title(f"{title}", fontsize=20)
        ax2.set_xlabel("Normalized Trajectory Time (%)", fontsize=18)
        ax2.set_xlim([0, len(scores)-1])
        ax2.set_xticks(xticks)
        ax2.set_xticklabels(xticklabels, fontsize=14)
        
        ax2.set_ylabel("Normalized Score ($\\eta_t$)", fontsize=18)
        ax2.set_ylim([0, 1])
        ax2.tick_params(axis="y", labelsize=14)

        # Cosmetic.
        ax2.spines['top'].set_visible(False)
        ax2.spines['right'].set_visible(False)
        ax2.spines['left'].set_linewidth(3)
        ax2.spines['bottom'].set_linewidth(3)
        ax2.legend(loc='upper left', fancybox=True, framealpha=0.7, fontsize=16, edgecolor='gray')

        # Shared title.
        if len(title) > 0:
            fig.suptitle(title, fontsize=28)

        # Adjust spacing between subplots.
        plt.tight_layout()
        plt.subplots_adjust(wspace=0.25)

        # Save the frame temporarily to create the GIF.
        temp_save_path = f'temp_image_{i}.png'
        plt.savefig(temp_save_path)
        img = Image.open(temp_save_path)
        gif_images.append(np.array(img))

        # Remove the temporary file
        os.remove(temp_save_path)

    # Save the GIF.
    pause_duration = 2.0
    gif_images += [gif_images[-1]] * int(fps * pause_duration)

    save_dir = pathlib.Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    imageio.mimsave(save_dir / f"{split}-{episode}-{method}.gif", gif_images, fps=fps, loop=0)
    plt.close()

## Single Method Only Image

In [17]:
def render_episode_gif_image(
    data: Dict[str, Any],
    split: str,
    episode: int,
    exp_key: str,
    method: str,
    method_color: str,
    save_dir: Union[str, pathlib.Path],
    text_mode: str = "lower",
    fps: float = 5,
    domain: Optional[str] = None,
) -> None:
    # Extract data.
    test_frame = data[split][exp_key]["test_frame"]
    demo_frame = data[split][exp_key]["demo_frame"]
    episode_frame = data_utils.get_episode(test_frame, episode, use_index=False)
    thresh = np.quantile(data_utils.aggr_episode_key_data(demo_frame, f"{exp_key}_cum_score"), quantile)    
    scores: np.ndarray = episode_frame[f"{exp_key}_cum_score"].values
    images = episode_frame["rgb"].values
    
    # Normalize scores.
    ymin = 0
    ymax = 3 * thresh
    scores = (scores - ymin) / (ymax - ymin)
    thresh = (thresh - ymin) / (ymax - ymin)

    # Plot settings.
    thresh_color = "red"
    xticks = np.linspace(0, len(scores) - 1, 6)
    xticklabels = [f"{x/(len(scores) - 1):.1f}" for x in xticks]
    
    _, ax = plt.subplots(1, 1, figsize=(5, 5))
    gif_images = []
    for i in range(len(scores)):
        ax.clear()

        failure_detected = scores[i] > thresh

        img_text_time = f"Rollout %: {i / len(scores):.2f}"
        if failure_detected:
            text_color = "red"
            img_text_detector = "Policy Failure"
        else:
            text_color = "green"
            img_text_detector = "Policy Ok"

        # Plot image.
        image = images[i]
        if domain is not None:
            image = process_image(image, domain)
        ax.imshow(image)
        ax.set_xticks([])
        ax.set_yticks([])
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_linewidth(8)
            spine.set_edgecolor(text_color)
        
        # Adding text annotation on the image
        h, w = image.shape[:2]
        
        if text_mode == "lower":
            x1, y1 = w * 0.04, h * 0.82
            x2, y2 = w * 0.04, h * 0.9
        elif text_mode == "lower_real":
            x1, y1 = w * 0.04, h * 0.75
            x2, y2 = w * 0.04, h * 0.875
        elif text_mode == "upper":
            x2, y2 = w * 0.04, h * 0.04
            x1, y1 = w * 0.04, h * 0.13
        elif text_mode == "upper_real":
            x2, y2 = w * 0.04, h * 0.05
            x1, y1 = w * 0.04, h * 0.16
        else:
            raise ValueError(f"Text mode {text_mode} not supported.")

        # Annotate image.
        ax.text(
            x1, y1,
            img_text_time,
            color="black", 
            fontsize=26, 
            fontweight='bold', 
            ha='left', 
            va='top'
        )
        ax.text(
            x2, y2,
            img_text_detector,
            color=text_color, 
            fontsize=32, 
            fontweight='bold', 
            ha='left', 
            va='top'
        )
        
        # Adjust spacing between subplots.
        plt.tight_layout()
        plt.subplots_adjust(wspace=0.25)

        # Save the frame temporarily to create the GIF.
        temp_save_path = f'temp_image_{i}.png'
        plt.savefig(temp_save_path)
        img = Image.open(temp_save_path)
        gif_images.append(np.array(img))

        # Remove the temporary file
        os.remove(temp_save_path)

    # Save the GIF.
    pause_duration = 2.0
    gif_images += [gif_images[-1]] * int(fps * pause_duration)

    save_dir = pathlib.Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    imageio.mimsave(save_dir / f"{split}-{episode}-{method}-image.gif", gif_images, fps=fps, loop=0)
    plt.close()

## STAC Score Function with Actions

In [18]:
def render_episode_action_gif(
    data: Dict[str, Any],
    data_frame: pd.DataFrame,
    split: str,
    episode: int,
    exp_key: str,
    method: str,
    title: str,
    method_color: str,
    save_dir: Union[str, pathlib.Path],
    text_mode: str = "lower",
    fps: float = 5,
    # 3D rendering params.
    num_robots: int = 2,
    exec_horizon: int = 4,
    action_dim: int = 7,
    robot_freq: float = 3,
    pred_horizon: Optional[int] = None,
    sample_size: Optional[int] = None,
    num_3d_plots: Optional[int] = None,
    scale_mode: str = "default",
    domain: Optional[str] = None,
) -> None:
    
    # Extract data.
    test_frame = data[split][exp_key]["test_frame"]
    demo_frame = data[split][exp_key]["demo_frame"]
    episode_frame = data_utils.get_episode(test_frame, episode, use_index=False)
    thresh = np.quantile(data_utils.aggr_episode_key_data(demo_frame, f"{exp_key}_cum_score"), quantile)    
    scores: np.ndarray = episode_frame[f"{exp_key}_cum_score"].values
    images = episode_frame["rgb"].values
    
    # Normalize scores.
    ymin = 0
    ymax = 3 * thresh
    scores = (scores - ymin) / (ymax - ymin)
    thresh = (thresh - ymin) / (ymax - ymin)

    # Plot settings: ax2.
    thresh_color = "red"
    xticks = np.linspace(0, len(scores) - 1, 5)
    xticklabels = [f"{x/(len(scores) - 1):.1f}" for x in xticks]

    # Extract action data.
    executed_actions = data_frame["executed_action"].values
    sampled_actions = data_frame["sampled_actions"].values

    def extract_trajectories() -> List[List[np.ndarray]]:
        executed_trajectories = []
        sampled_trajectories = []

        for i in range(len(sampled_actions)):
            action = executed_actions[i][:exec_horizon] # [exec_horizon, action_dim]
            action_batch = sampled_actions[i] # [sample_size, pred_horizon, action_dim]
            if pred_horizon is not None:
                action_batch = action_batch[:, :pred_horizon]

            # Extract linear velocities.
            action = action_utils.filter_actions(
                action=action, 
                num_robots=num_robots,
                action_dim=action_dim,
                ignore_gripper=True,
                ignore_rotation=True,
            ) # [exec_horizon, action_dim, 3 * num_robots]
            action_batch = action_utils.filter_actions(
                action=action_batch, 
                num_robots=num_robots,
                action_dim=action_dim,
                ignore_gripper=True,
                ignore_rotation=True,
            ) # [sample_size, pred_horizon, 3 * num_robots]

            # Subsample actions.
            if sample_size is not None:
                action_batch = action_utils.subsample_actions(
                    action=action_batch, 
                    sample_size=sample_size
                ) # [sample_size, pred_horizon, 3 * num_robots]

            # Compute trajectories.
            trajectory = action_utils.velocity_to_trajectory(
                action=action,
                time=1/robot_freq,
                num_robots=num_robots,
                return_list=True,
            ) # List[[exec_horizon, 3]], len(List) = num_robots
            trajectory_batch = action_utils.velocity_to_trajectory(
                action=action_batch, 
                time=1/robot_freq,
                num_robots=num_robots,
                return_list=True,
            ) # List[[sample_size, pred_horizon, 3]], len(List) = num_robots
            
            executed_trajectories.append(trajectory)
            sampled_trajectories.append(trajectory_batch)

        executed_trajectories = np.array(executed_trajectories) # [T, num_robots, exec_horizon, 3]
        sampled_trajectories = np.array(sampled_trajectories) # [T, num_robots, sample_size, pred_horizon, 3]
        assert executed_trajectories.ndim == 4 and sampled_trajectories.ndim == 5

        # Normalize trajectories between [-1, 1].
        for i in range(num_robots):
            for j in range(3):
                # TODO: Decide on whether to normalize per axis or across axes.
                e = executed_trajectories[:, i, ..., j]
                s = sampled_trajectories[:, i, ..., j]
                e[:] = (e - s.min()) / (s.max() - s.min()) * 2 - 1
                s[:] = (s - s.min()) / (s.max() - s.min()) * 2 - 1

        return executed_trajectories, sampled_trajectories
    
    executed_trajectories, sampled_trajectories = extract_trajectories()

    # Plot settings: ax3.
    curr_color = "blue"
    prev_color = "orange"
    exec_color = "green"

    if scale_mode == "default":
        ax3d_ticks = np.linspace(-1, 1, 5)
        ax3d_lims = [-1, 1]
    elif scale_mode == "zoom":
        ax3d_ticks = np.linspace(-0.6, 0.6, 5)
        ax3d_lims = [-0.6, 0.6]
    elif scale_mode == "heavy_zoom":
        ax3d_ticks = np.linspace(-0.4, 0.4, 5)
        ax3d_lims = [-0.4, 0.4]
    else:
        raise ValueError

    # Create figure.
    fig = plt.figure()
    num_3d_plots = num_3d_plots if num_3d_plots is not None else num_robots
    ncols = 2 + num_3d_plots
    width = 6.5 * ncols
    fig.set_figwidth(width)
    fig.set_figheight(5.5)
    spec = gridspec.GridSpec(nrows=1, ncols=ncols, width_ratios=[1, 1, 1.3])
    spec.tight_layout(fig)
    ax1 = fig.add_subplot(spec[0])
    ax2 = fig.add_subplot(spec[1])
    axes_3d: List[plt.Axes] = [fig.add_subplot(spec[2 + i], projection='3d') for i in range(num_3d_plots)]
    
    gif_images = []
    for i in range(len(scores)):
        ax1.clear()
        ax2.clear()
        for ax in axes_3d:
            ax.clear()

        failure_detected = scores[i] > thresh

        img_text_time = f"Rollout %: {i / len(scores):.2f}"
        if failure_detected:
            text_color = "red"
            img_text_detector = "Detector: Policy Failure"
        else:
            text_color = "green"
            img_text_detector = "Detector: Policy Ok"

        # Plot image.
        image = images[i]
        if domain is not None:
            image = process_image(image, domain)
        ax1.imshow(image)
        ax1.set_xticks([])
        ax1.set_yticks([])
        for spine in ax1.spines.values():
            spine.set_visible(True)
            spine.set_linewidth(4)
            spine.set_edgecolor(text_color)
        
        # Adding text annotation on the image
        h, w = image.shape[:2]
        
        if text_mode == "lower":
            x1, y1 = w * 0.04, h * 0.82
            x2, y2 = w * 0.04, h * 0.9
        elif text_mode == "upper":
            x2, y2 = w * 0.04, h * 0.04
            x1, y1 = w * 0.04, h * 0.13
        else:
            raise ValueError(f"Text mode {text_mode} not supported.")

        # Annotate image.
        ax1.text(
            x1, y1,
            img_text_time,
            color="black", 
            fontsize=15, 
            fontweight='bold', 
            ha='left', 
            va='top'
        )
        ax1.text(
            x2, y2,
            img_text_detector,
            color=text_color, 
            fontsize=17, 
            fontweight='bold', 
            ha='left', 
            va='top'
        )
        
        # Plot data.
        ax2.plot(scores[:i + 1], label="Cumulative Error", color=method_color, linewidth=4)
        ax2.axhline(y=thresh, color=thresh_color, linestyle='--', linewidth=5, label="Detection Threshold")

        if failure_detected:
            ax2.set_facecolor((1, 0, 0, 0.1))

        # Title and axes.
        ax2.set_xlabel("Normalized Trajectory Time (%)", fontsize=18)
        ax2.set_xlim([0, len(scores)-1])
        ax2.set_xticks(xticks)
        ax2.set_xticklabels(xticklabels, fontsize=14)
        
        ax2.set_ylabel("Normalized Score ($\\eta_t$)", fontsize=18)
        ax2.set_ylim([0, 1])
        ax2.tick_params(axis="y", labelsize=14)

        # Cosmetic.
        ax2.spines['top'].set_visible(False)
        ax2.spines['right'].set_visible(False)
        ax2.spines['left'].set_linewidth(3)
        ax2.spines['bottom'].set_linewidth(3)
        ax2.legend(loc='upper left', fancybox=True, framealpha=0.7, fontsize=16, edgecolor='gray')

        for j in range(len(axes_3d)):
            exec_traj = executed_trajectories[i][j].copy() # [exec_horizon, 3]
            prev_trajs = sampled_trajectories[i][j].copy() # [sample_size, pred_horizon, 3]
            curr_trajs = sampled_trajectories[i+1][j].copy() # [sample_size, pred_horizon, 3]
            
            # Express in local coordinates.
            prev_trajs -= prev_trajs[:, :1, :]
            curr_trajs -= curr_trajs[:, :1, :]
            exec_traj -= exec_traj[:1, :]
            
            # Plot trajectories.
            for k, traj in enumerate(prev_trajs):                    
                # traj += np.array([0.0, 0.0, 0.0])
                label = "Previous Predictions" if k == 0 else None
                axes_3d[j].plot(traj[:, 0], traj[:, 1], traj[:, 2], color=prev_color, label=label)
            
            for k, traj in enumerate(curr_trajs):
                traj += exec_traj[-1]
                label = "Current Predictions" if k == 0 else None
                axes_3d[j].plot(traj[:, 0], traj[:, 1], traj[:, 2], color=curr_color, label=label)
            
            # Plot executed trajectory with dotted line and thicker line width
            axes_3d[j].plot(exec_traj[:, 0], exec_traj[:, 1], exec_traj[:, 2], color=exec_color, linestyle=':', linewidth=3, label="Executed Actions")
            axes_3d[j].scatter(exec_traj[[0, -1], 0], exec_traj[[0, -1], 1], exec_traj[[0, -1], 2], color=exec_color)

            # Titles and axes.
            # axes_3d[j].set_title(f'Policy Predictions: Manipulator {j + 1}', fontsize=18)
            axes_3d[j].set_xlabel('Normalized Traj. X', fontsize=12)
            axes_3d[j].set_ylabel('Normalized Traj. Y', fontsize=12)
            axes_3d[j].set_zlabel('Normalized Traj. Z', fontsize=12)

            # Cosmetic.
            axes_3d[j].set_xlim(ax3d_lims)
            axes_3d[j].set_ylim(ax3d_lims)
            axes_3d[j].set_zlim(ax3d_lims) 
            axes_3d[j].set_xticks(ax3d_ticks)
            axes_3d[j].set_yticks(ax3d_ticks)
            axes_3d[j].set_zticks(ax3d_ticks)
            axes_3d[j].legend(loc='upper left', bbox_to_anchor=(-0.25, 1.0), fancybox=True, framealpha=0.7, fontsize=16, edgecolor='gray', ncol=2)
            # axes_3d[j].legend(loc='upper left', fancybox=True, framealpha=0.7, fontsize=16, edgecolor='gray', ncol=2)

        # Shared title.
        if len(title) > 0:
            fig.suptitle(title, fontsize=28)

        # Adjust spacing between subplots.
        plt.tight_layout()
        plt.subplots_adjust(wspace=0.3)

        # Save the frame temporarily to create the GIF.
        temp_save_path = f'temp_image_{i}.png'
        plt.savefig(temp_save_path)
        img = Image.open(temp_save_path)
        gif_images.append(np.array(img))

        # Remove the temporary file
        os.remove(temp_save_path)

    # Save the GIF.
    pause_duration = 2.0
    gif_images += [gif_images[-1]] * int(fps * pause_duration)

    save_dir = pathlib.Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    imageio.mimsave(save_dir / f"{split}-{episode}-{method}-3d.gif", gif_images, fps=fps, loop=0)
    plt.close()

## Multiple Method Score Function

In [19]:
def render_episode_gif_multi(
    data: Dict[str, Any],
    split: str,
    episode: int,
    exp_keys: List[str],
    title: str,
    methods: List[str],
    titles: List[str],
    method_colors: List[str],
    save_dir: Union[str, pathlib.Path],
    fps: float = 5,
    domain: Optional[str] = None,
) -> None:
    # Extract data.
    thresh = []
    scores = []
    for i, exp_key in enumerate(exp_keys):
        test_frame = data[split][exp_key]["test_frame"]
        demo_frame = data[split][exp_key]["demo_frame"]
        episode_frame = data_utils.get_episode(test_frame, episode, use_index=False)
        thresh.append(np.quantile(data_utils.aggr_episode_key_data(demo_frame, f"{exp_key}_cum_score"), quantile))
        scores.append(episode_frame[f"{exp_key}_cum_score"].values)
    
        # Normalize scores.
        ymin = 0
        ymax = 3 * thresh[i]
        scores[i] = (scores[i] - ymin) / (ymax - ymin)
        thresh[i] = (thresh[i] - ymin) / (ymax - ymin)

    images = episode_frame["rgb"].values
    length = min(len(s) for s in scores)
    nplots = len(exp_keys)

    # Plot settings.
    thresh_color = "red"
    xticks = np.linspace(0, length - 1, 6)
    xticklabels = [f"{x/(length - 1):.1f}" for x in xticks]
    
    height = 5
    if len(title) == 0:
        height -= 0.5
    fig, axes = plt.subplots(1, 1 + nplots, figsize=(4.75 * (1 + nplots), height))
    # fig, axes = plt.subplots(1, 1 + nplots, figsize=(4.75 * (1 + nplots), 4.5))
    gif_images = []
    for i in range(length):
        for ax in axes:
            ax.clear()

        # Plot image.
        image = images[i]
        if domain is not None:
            image = process_image(image, domain)
        axes[0].imshow(image)
        axes[0].set_xticks([])
        axes[0].set_yticks([])
        for spine in axes[0].spines.values():
            spine.set_visible(True)
            spine.set_linewidth(2)
        
        # Plot data.
        for j in range(len(exp_keys)):
            axes[j+1].plot(scores[j][:i + 1], label="Cumulative Error", color=method_colors[j], linewidth=4)
            axes[j+1].axhline(y=thresh[j], color=thresh_color, linestyle='--', linewidth=5, label="Detection Threshold")
            if scores[j][i] > thresh[j]: # Failure detected.
                axes[j+1].set_facecolor((1, 0, 0, 0.1))

            # Title and axes.
            if len(titles[j]) > 0:
                axes[j+1].set_title(f"{titles[j]}", fontsize=20)
            axes[j+1].set_xlabel("Normalized Trajectory Time (%)", fontsize=18)
            axes[j+1].set_xlim([0, length-1])
            axes[j+1].set_xticks(xticks)
            axes[j+1].set_xticklabels(xticklabels, fontsize=14)
            
            axes[j+1].set_ylabel("Normalized Score ($\\eta_t$)", fontsize=18)
            axes[j+1].set_ylim([0, 1])
            axes[j+1].tick_params(axis="y", labelsize=14)

            # Cosmetic.
            axes[j+1].spines['top'].set_visible(False)
            axes[j+1].spines['right'].set_visible(False)
            axes[j+1].spines['left'].set_linewidth(3)
            axes[j+1].spines['bottom'].set_linewidth(3)
            axes[j+1].legend(loc='upper left', fancybox=True, framealpha=0.7, fontsize=16, edgecolor='gray')

        # Shared title.
        if len(title) > 0:
            fig.suptitle(title, fontsize=32)

        # Adjust spacing between subplots.
        plt.tight_layout()
        plt.subplots_adjust(wspace=0.3)

        # Save the frame temporarily to create the GIF.
        temp_save_path = f'temp_image_{i}.png'
        plt.savefig(temp_save_path)
        img = Image.open(temp_save_path)
        gif_images.append(np.array(img))

        # Remove the temporary file
        os.remove(temp_save_path)

    # Save the GIF.
    pause_duration = 2.0
    gif_images += [gif_images[-1]] * int(fps * pause_duration)

    save_dir = pathlib.Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    imageio.mimsave(save_dir / f"{split}-{episode}-{'-'.join(methods)}.gif", gif_images, fps=fps, loop=0)
    plt.close()

# STAC Detection Results

## Real World: Push Chair Domain Gifs

In [20]:
kl_exp_key = "pred_horizon_16_sample_size_256_error_fn_kde_kl_all_rev_eig"
var_exp_key = "pred_horizon_16_sample_size_256_action_space_all"
exp_colors = {
    "pred_horizon_16_sample_size_256_error_fn_kde_kl_all_rev_eig": "orange",
    "pred_horizon_16_sample_size_256_action_space_all": "gray",
}

In [21]:
push_chair_metrics_gif = compile_metrics(
    domain="0914_push_chair_4",
    splits=["test"],
    exp_keys=[kl_exp_key, var_exp_key],
    return_test_data=True,
    return_test_frame=True,
    return_demo_frame=True,
)
parent_dir = CWD / ".." / f"gifs_{dir_postfix}" / "push_chair"

### Test Split: Success

In [22]:
split = "test"
episodes = [5, 7, 9, 12]
fps = 2
# title = "Out-of-Distribution Diffusion Policy Success"
title = ""
save_dir = parent_dir / "ood_success"

for episode in episodes:
    render_episode_gif(
        data=push_chair_metrics_gif,
        split=split,
        episode=episode,
        exp_key=kl_exp_key,
        method="rev-kl",
        title=title,
        method_color=exp_colors[kl_exp_key],
        save_dir=save_dir,
        text_mode="upper_real",
        fps=fps,
        domain="push_chair",
    )

In [23]:
split = "test"
episodes = [5, 7, 9, 12]
fps = 2
title = "Out-of-Distribution Diffusion Policy Success"
save_dir = parent_dir / "ood_success"

for episode in episodes:   
    episode_frame = load_episode_frame(
        domain="0914_push_chair_4",
        split=split,
        episode=episode,
    )

    render_episode_action_gif(
        data=push_chair_metrics_gif,
        data_frame=episode_frame,
        split=split,
        episode=episode,
        exp_key=kl_exp_key,
        method="rev-kl",
        title=title,
        method_color=exp_colors[kl_exp_key],
        save_dir=save_dir,
        text_mode="lower",
        fps=fps,
        # 3D rendering params.
        num_robots=1,
        exec_horizon=4,
        action_dim=7,
        robot_freq=3,
        sample_size=128,
        pred_horizon=8,
        domain="push_chair",
    )

#### Test Split: Success Multi

In [24]:
split = "test"
episodes = [5, 7, 9, 12]
fps = 2
title = "Out-of-Distribution Success: Push Chair Diffusion Policy"
save_dir = parent_dir / "ood_success"

for episode in episodes:
    render_episode_gif_multi(
        data=push_chair_metrics_gif,
        split=split,
        episode=episode,
        exp_keys=[kl_exp_key, var_exp_key],
        title=title,
        methods=["rev-kl", "var"],
        titles=["STAC Rev. KL (Ours)", "Diffusion Output Variance"],
        method_colors=[exp_colors[e] for e in [kl_exp_key, var_exp_key]],
        save_dir=save_dir,
        fps=fps,
        domain="push_chair"
    )

### Test Split: Failure

In [25]:
split = "test"
episodes = [0, 1, 2, 10, 11, 13, 14, 17, 18, 19]
fps = 2
title = "Out-of-Distribution Diffusion Policy Failure"
save_dir = parent_dir / "ood_failure"

for episode in episodes:
    render_episode_gif(
        data=push_chair_metrics_gif,
        split=split,
        episode=episode,
        exp_key=kl_exp_key,
        method="rev-kl",
        title=title,
        method_color=exp_colors[kl_exp_key],
        save_dir=save_dir,
        text_mode="upper_real",
        fps=fps,
        domain="push_chair",
    )

In [26]:
split = "test"
episodes = [0, 1, 2, 10, 11, 13, 14, 17, 18, 19]
fps = 2
title = "Out-of-Distribution Diffusion Policy Failure"
save_dir = parent_dir / "ood_failure"

for episode in episodes:
    episode_frame = load_episode_frame(
        domain="0914_push_chair_4",
        split=split,
        episode=episode,
    )

    render_episode_action_gif(
        data=push_chair_metrics_gif,
        data_frame=episode_frame,
        split=split,
        episode=episode,
        exp_key=kl_exp_key,
        method="rev-kl",
        title=title,
        method_color=exp_colors[kl_exp_key],
        save_dir=save_dir,
        text_mode="lower",
        fps=fps,
        # 3D rendering params.
        num_robots=1,
        exec_horizon=4,
        action_dim=7,
        robot_freq=3,
        sample_size=128,
        pred_horizon=8,
        domain="push_chair",
    )

#### Test Split: Failure Multi

In [27]:
split = "test"
episodes = [0, 1, 2, 10, 11, 13, 14, 17, 18, 19]
fps = 2
title = "Out-of-Distribution Failure: Push Chair Diffusion Policy"
save_dir = parent_dir / "ood_failure"

for episode in episodes:
    render_episode_gif_multi(
        data=push_chair_metrics_gif,
        split=split,
        episode=episode,
        exp_keys=[kl_exp_key, var_exp_key],
        title=title,
        methods=["rev-kl", "var"],
        titles=["STAC Rev. KL (Ours)", "Diffusion Output Variance"],
        method_colors=[exp_colors[e] for e in [kl_exp_key, var_exp_key]],
        save_dir=save_dir,
        fps=fps,
        domain="push_chair",
    )

#### Test Split: Failure Image Only

In [28]:
split = "test"
episodes = [2]
fps = 2
save_dir = parent_dir / "ood_failure"

for episode in episodes:
    render_episode_gif_image(
        data=push_chair_metrics_gif,
        split=split,
        episode=episode,
        exp_key=kl_exp_key,
        method="rev-kl",
        method_color=exp_colors[kl_exp_key],
        save_dir=save_dir,
        text_mode="upper_real",
        fps=fps,
        domain="push_chair",
    )

## PushT Domain Gifs

In [29]:
mmd_exp_key = "pred_horizon_16_sample_size_256_error_fn_mmd_rbf_all_median"
var_exp_key = "pred_horizon_16_sample_size_256_action_space_all"
rec_exp_key = "loss_fn_action_rec_all_sample_size_4"
exp_colors = {
    "pred_horizon_16_sample_size_256_error_fn_mmd_rbf_all_median": "orange",
    "pred_horizon_16_sample_size_256_action_space_all": "gray",
    "loss_fn_action_rec_all_sample_size_4": "#6a51a3",
}

In [30]:
pusht_metrics_gif = compile_metrics(
    domain="0525_pusht_8",
    splits=["na", "hh"],
    exp_keys=[mmd_exp_key, var_exp_key, rec_exp_key],
    return_test_data=True,
    return_test_frame=True,
    return_demo_frame=True,
)
parent_dir = CWD / ".." / f"gifs_{dir_postfix}" / "pusht"

### NA Split: Success

In [31]:
split = "na"
episode = 49
fps = 3
title = "In-Distribution Diffusion Policy Success"
save_dir = parent_dir / "id_success"

render_episode_gif(
    data=pusht_metrics_gif,
    split=split,
    episode=episode,
    exp_key=mmd_exp_key,
    method="mmd",
    title=title,
    method_color=exp_colors[mmd_exp_key],
    save_dir=save_dir,
    text_mode="upper",
    fps=fps,
)

render_episode_gif(
    data=pusht_metrics_gif,
    split=split,
    episode=episode,
    exp_key=var_exp_key,
    method="var",
    title=title,
    method_color=exp_colors[var_exp_key],
    save_dir=save_dir,
    text_mode="upper",
    fps=fps,
)

render_episode_gif(
    data=pusht_metrics_gif,
    split=split,
    episode=episode,
    exp_key=rec_exp_key,
    method="rec",
    title=title,
    method_color=exp_colors[rec_exp_key],
    save_dir=save_dir,
    text_mode="upper",
    fps=fps,
)

#### NA Split: Success Multi

In [32]:
split = "na"
episode = 49
fps = 3
title = "In-Distribution Success: PushT Diffusion Policy"
save_dir = parent_dir / "id_success"

render_episode_gif_multi(
    data=pusht_metrics_gif,
    split=split,
    episode=episode,
    exp_keys=[mmd_exp_key, var_exp_key, rec_exp_key],
    title=title,
    methods=["mmd", "var", "rec"],
    titles=["STAC MMD (Ours)", "Diffusion Output Variance", "Diffusion Reconstruction"],
    method_colors=[exp_colors[e] for e in [mmd_exp_key, var_exp_key, rec_exp_key]],
    save_dir=save_dir,
    fps=fps,
)

### HH Split: Success

In [33]:
split = "hh"
episode = 43
fps = 3
title = "Out-of-Distribution Diffusion Policy Success"
save_dir = parent_dir / "ood_success"

render_episode_gif(
    data=pusht_metrics_gif,
    split=split,
    episode=episode,
    exp_key=mmd_exp_key,
    method="mmd",
    title=title,
    method_color=exp_colors[mmd_exp_key],
    save_dir=save_dir,
    text_mode="upper",
    fps=fps,
)

render_episode_gif(
    data=pusht_metrics_gif,
    split=split,
    episode=episode,
    exp_key=var_exp_key,
    method="var",
    title=title,
    method_color=exp_colors[var_exp_key],
    save_dir=save_dir,
    text_mode="upper",
    fps=fps,
)

render_episode_gif(
    data=pusht_metrics_gif,
    split=split,
    episode=episode,
    exp_key=rec_exp_key,
    method="rec",
    title=title,
    method_color=exp_colors[rec_exp_key],
    save_dir=save_dir,
    text_mode="upper",
    fps=fps,
)

#### HH Split: Success Multi

In [34]:
split = "hh"
episode = 43
fps = 3
title = "Out-of-Distribution Success: PushT Diffusion Policy"
save_dir = parent_dir / "ood_success"

render_episode_gif_multi(
    data=pusht_metrics_gif,
    split=split,
    episode=episode,
    exp_keys=[mmd_exp_key, var_exp_key, rec_exp_key],
    title=title,
    methods=["mmd", "var", "rec"],
    titles=["STAC MMD (Ours)", "Diffusion Output Variance", "Diffusion Reconstruction"],
    method_colors=[exp_colors[e] for e in [mmd_exp_key, var_exp_key, rec_exp_key]],
    save_dir=save_dir,
    fps=fps,
)

### HH Split: Failure

In [35]:
split = "hh"
episode = 4
fps = 3
# title = "Out-of-Distribution Diffusion Policy Failure"
title = ""
save_dir = parent_dir / "ood_failure"

render_episode_gif(
    data=pusht_metrics_gif,
    split=split,
    episode=episode,
    exp_key=mmd_exp_key,
    method="mmd",
    title=title,
    method_color=exp_colors[mmd_exp_key],
    save_dir=save_dir,
    text_mode="upper",
    fps=fps,
)

render_episode_gif(
    data=pusht_metrics_gif,
    split=split,
    episode=episode,
    exp_key=var_exp_key,
    method="var",
    title=title,
    method_color=exp_colors[var_exp_key],
    save_dir=save_dir,
    text_mode="upper",
    fps=fps,
)

render_episode_gif(
    data=pusht_metrics_gif,
    split=split,
    episode=episode,
    exp_key=rec_exp_key,
    method="rec",
    title=title,
    method_color=exp_colors[rec_exp_key],
    save_dir=save_dir,
    text_mode="upper",
    fps=fps,
)

#### HH Split: Failure Multi

In [36]:
split = "hh"
episode = 4
fps = 3
title = "Out-of-Distribution Failure: PushT Diffusion Policy"
save_dir = parent_dir / "ood_failure"

render_episode_gif_multi(
    data=pusht_metrics_gif,
    split=split,
    episode=episode,
    exp_keys=[mmd_exp_key, var_exp_key, rec_exp_key],
    title=title,
    methods=["mmd", "var", "rec"],
    titles=["STAC MMD (Ours)", "Diffusion Output Variance", "Diffusion Reconstruction"],
    method_colors=[exp_colors[e] for e in [mmd_exp_key, var_exp_key, rec_exp_key]],
    save_dir=save_dir,
    fps=fps,
)

## Close Box Domain Gifs

In [37]:
mmd_exp_key = "pred_horizon_16_sample_size_32_error_fn_mmd_rbf_all"
var_exp_key = "pred_horizon_16_sample_size_32_action_space_all"
emb_exp_key = "embedding_clip_feat_score_fn_mahal"
exp_colors = {
    "pred_horizon_16_sample_size_32_error_fn_mmd_rbf_all": "orange",
    "pred_horizon_16_sample_size_32_action_space_all": "gray",
    "embedding_clip_feat_score_fn_mahal": "#41b6c4",
}

In [38]:
close_metrics_gif = compile_metrics(
    domain="0527_close_4",
    splits=["na", "hh"],
    exp_keys=[mmd_exp_key, var_exp_key, emb_exp_key],
    return_test_data=True,
    return_test_frame=True,
    return_demo_frame=True,
)
parent_dir = CWD / ".." / f"gifs_{dir_postfix}" / "close"

### HH Split: Success

In [39]:
split = "hh"
episode = 21
fps = 5
title = "Out-of-Distribution Diffusion Policy Success"
save_dir = parent_dir / "ood_success"

render_episode_gif(
    data=close_metrics_gif,
    split=split,
    episode=episode,
    exp_key=mmd_exp_key,
    method="mmd",
    title=title,
    method_color=exp_colors[mmd_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
)

render_episode_gif(
    data=close_metrics_gif,
    split=split,
    episode=episode,
    exp_key=var_exp_key,
    method="var",
    title=title,
    method_color=exp_colors[var_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
)

render_episode_gif(
    data=close_metrics_gif,
    split=split,
    episode=episode,
    exp_key=emb_exp_key,
    method="emb",
    title=title,
    method_color=exp_colors[emb_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
)

In [40]:
split = "hh"
episode = 21
fps = 2
title = "Out-of-Distribution Diffusion Policy Success"
save_dir = parent_dir / "ood_success"

episode_frame = load_episode_frame(
    domain="0527_close_4",
    split=split,
    episode=episode,
)

render_episode_action_gif(
    data=close_metrics_gif,
    data_frame=episode_frame,
    split=split,
    episode=episode,
    exp_key=mmd_exp_key,
    method="mmd",
    title=title,
    method_color=exp_colors[mmd_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
    # 3D rendering params.
    num_robots=2,
    exec_horizon=4,
    action_dim=7,
    robot_freq=3,
    sample_size=32,
    pred_horizon=12,
    num_3d_plots=1,
    scale_mode="zoom"
)

#### HH Split: Success Multi

In [41]:
split = "hh"
episode = 21
fps = 5
title = "Out-of-Distribution Success: Close Box Diffusion Policy"
save_dir = parent_dir / "ood_success"

render_episode_gif_multi(
    data=close_metrics_gif,
    split=split,
    episode=episode,
    exp_keys=[mmd_exp_key, var_exp_key, emb_exp_key],
    title=title,
    methods=["mmd", "var", "emb"],
    titles=["STAC MMD (Ours)", "Diffusion Output Variance", "CLIP Embedding Similarity"],
    method_colors=[exp_colors[e] for e in [mmd_exp_key, var_exp_key, emb_exp_key]],
    save_dir=save_dir,
    fps=fps,
)

### HH Split: Jitter Failure

In [42]:
split = "hh"
episode = 33
fps = 5
title = "Out-of-Distribution Diffusion Policy Failure"
save_dir = parent_dir / "ood_failure_jitter"

render_episode_gif(
    data=close_metrics_gif,
    split=split,
    episode=episode,
    exp_key=mmd_exp_key,
    method="mmd",
    title=title,
    method_color=exp_colors[mmd_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
)

render_episode_gif(
    data=close_metrics_gif,
    split=split,
    episode=episode,
    exp_key=var_exp_key,
    method="var",
    title=title,
    method_color=exp_colors[var_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
)

render_episode_gif(
    data=close_metrics_gif,
    split=split,
    episode=episode,
    exp_key=emb_exp_key,
    method="emb",
    title=title,
    method_color=exp_colors[emb_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
)

In [43]:
split = "hh"
episode = 33
fps = 2
title = "Out-of-Distribution Diffusion Policy Failure"
save_dir = parent_dir / "ood_failure_jitter"

episode_frame = load_episode_frame(
    domain="0527_close_4",
    split=split,
    episode=episode,
)

render_episode_action_gif(
    data=close_metrics_gif,
    data_frame=episode_frame,
    split=split,
    episode=episode,
    exp_key=mmd_exp_key,
    method="mmd",
    title=title,
    method_color=exp_colors[mmd_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
    # 3D rendering params.
    num_robots=2,
    exec_horizon=4,
    action_dim=7,
    robot_freq=3,
    sample_size=32,
    pred_horizon=12,
    num_3d_plots=1,
    scale_mode="zoom"
)

#### HH Split: Jitter Failure Multi

In [44]:
split = "hh"
episode = 33
fps = 5
title = "Out-of-Distribution Failure: Close Box Diffusion Policy"
save_dir = parent_dir / "ood_failure_jitter"

render_episode_gif_multi(
    data=close_metrics_gif,
    split=split,
    episode=episode,
    exp_keys=[mmd_exp_key, var_exp_key, emb_exp_key],
    title=title,
    methods=["mmd", "var", "emb"],
    titles=["STAC MMD (Ours)", "Diffusion Output Variance", "CLIP Embedding Similarity"],
    method_colors=[exp_colors[e] for e in [mmd_exp_key, var_exp_key, emb_exp_key]],
    save_dir=save_dir,
    fps=fps,
)

### HH Split: Collision Failure

In [45]:
split = "hh"
episode = 24
fps = 5
title = "Out-of-Distribution Diffusion Policy Failure"
save_dir = parent_dir / "ood_failure_collision"

render_episode_gif(
    data=close_metrics_gif,
    split=split,
    episode=episode,
    exp_key=mmd_exp_key,
    method="mmd",
    title=title,
    method_color=exp_colors[mmd_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
)

render_episode_gif(
    data=close_metrics_gif,
    split=split,
    episode=episode,
    exp_key=var_exp_key,
    method="var",
    title=title,
    method_color=exp_colors[var_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
)

render_episode_gif(
    data=close_metrics_gif,
    split=split,
    episode=episode,
    exp_key=emb_exp_key,
    method="emb",
    title=title,
    method_color=exp_colors[emb_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
)

In [46]:
split = "hh"
episode = 24
fps = 2
title = "Out-of-Distribution Diffusion Policy Failure"
save_dir = parent_dir / "ood_failure_collision"

episode_frame = load_episode_frame(
    domain="0527_close_4",
    split=split,
    episode=episode,
)

render_episode_action_gif(
    data=close_metrics_gif,
    data_frame=episode_frame,
    split=split,
    episode=episode,
    exp_key=mmd_exp_key,
    method="mmd",
    title=title,
    method_color=exp_colors[mmd_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
    # 3D rendering params.
    num_robots=2,
    exec_horizon=4,
    action_dim=7,
    robot_freq=3,
    sample_size=32,
    pred_horizon=12,
    num_3d_plots=1,
    scale_mode="zoom"
)

### HH Split: Erratic Failure

In [47]:
split = "hh"
episode = 2
fps = 5
title = "Out-of-Distribution Diffusion Policy Failure"
save_dir = parent_dir / "ood_failure_erratic"

render_episode_gif(
    data=close_metrics_gif,
    split=split,
    episode=episode,
    exp_key=mmd_exp_key,
    method="mmd",
    title=title,
    method_color=exp_colors[mmd_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
)

render_episode_gif(
    data=close_metrics_gif,
    split=split,
    episode=episode,
    exp_key=var_exp_key,
    method="var",
    title=title,
    method_color=exp_colors[var_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
)

render_episode_gif(
    data=close_metrics_gif,
    split=split,
    episode=episode,
    exp_key=emb_exp_key,
    method="emb",
    title=title,
    method_color=exp_colors[emb_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
)

In [48]:
split = "hh"
episode = 2
fps = 2
title = "Out-of-Distribution Diffusion Policy Failure"
save_dir = parent_dir / "ood_failure_erratic"

episode_frame = load_episode_frame(
    domain="0527_close_4",
    split=split,
    episode=episode,
)

render_episode_action_gif(
    data=close_metrics_gif,
    data_frame=episode_frame,
    split=split,
    episode=episode,
    exp_key=mmd_exp_key,
    method="mmd",
    title=title,
    method_color=exp_colors[mmd_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
    # 3D rendering params.
    num_robots=2,
    exec_horizon=4,
    action_dim=7,
    robot_freq=3,
    sample_size=32,
    pred_horizon=12,
    num_3d_plots=1,
    scale_mode="zoom"
)

#### HH Split: Erratic Failure Multi

In [49]:
split = "hh"
episode = 2
fps = 5
title = "Out-of-Distribution Failure: Close Box Diffusion Policy"
save_dir = parent_dir / "ood_failure_erratic"

render_episode_gif_multi(
    data=close_metrics_gif,
    split=split,
    episode=episode,
    exp_keys=[mmd_exp_key, var_exp_key, emb_exp_key],
    title=title,
    methods=["mmd", "var", "emb"],
    titles=["STAC MMD (Ours)", "Diffusion Output Variance", "CLIP Embedding Similarity"],
    method_colors=[exp_colors[e] for e in [mmd_exp_key, var_exp_key, emb_exp_key]],
    save_dir=save_dir,
    fps=fps,
)

## Smooth Domain Gifs

### Cover Domain: Updated

In [50]:
mmd_exp_key = "pred_horizon_16_sample_size_32_error_fn_mmd_rbf_all"
var_exp_key = "pred_horizon_16_sample_size_32_action_space_all"
exp_colors = {
    "pred_horizon_16_sample_size_32_error_fn_mmd_rbf_all": "orange",
    "pred_horizon_16_sample_size_32_action_space_all": "gray",
}

In [51]:
cover_metrics_gif = compile_metrics(
    domain="0914_cover_4",
    splits=["na", "ss"],
    exp_keys=[mmd_exp_key, var_exp_key],
    return_test_data=True,
    return_test_frame=True,
    return_demo_frame=True,
)
parent_dir = CWD / ".." / f"gifs_{dir_postfix}" / "cover"

### Cover NA: Policy Sucesss

In [52]:
split = "na"
episodes = [10, 29]
fps = 3
title = "In-Distribution Success: Cover Object"
save_dir = parent_dir / "id_success"

for episode in episodes:
    render_episode_gif(
        data=cover_metrics_gif,
        split=split,
        episode=episode,
        exp_key=mmd_exp_key,
        method="mmd",
        title=title,
        method_color=exp_colors[mmd_exp_key],
        save_dir=save_dir,
        text_mode="upper",
        fps=fps,
    )

### Cover SS: Policy Failure

In [53]:
split = "ss"
episodes = [15, 4, 43]
fps = 3
# title = "Out-of-Distribution Failure: Cover Object"
# title = "Task Progression Failure: Cover Object"
title = ""
save_dir = parent_dir / "ood_failure"

for episode in episodes:
    render_episode_gif(
        data=cover_metrics_gif,
        split=split,
        episode=episode,
        exp_key=mmd_exp_key,
        method="mmd",
        title=title,
        method_color=exp_colors[mmd_exp_key],
        save_dir=save_dir,
        text_mode="upper",
        fps=fps,
    )

### Cover Domain: Archive

In [54]:
mmd_exp_key = "pred_horizon_16_sample_size_32_error_fn_mmd_rbf_all"
var_exp_key = "pred_horizon_16_sample_size_32_action_space_all"
emb_exp_key = "embedding_clip_feat_score_fn_mahal"
exp_colors = {
    "pred_horizon_16_sample_size_32_error_fn_mmd_rbf_all": "orange",
    "pred_horizon_16_sample_size_32_action_space_all": "gray",
    "embedding_clip_feat_score_fn_mahal": "#41b6c4",
}

In [55]:
cover_metrics_gif = compile_metrics(
    domain="0527_cover_4_abl",
    splits=["ss"],
    exp_keys=[mmd_exp_key, var_exp_key, emb_exp_key],
    return_test_data=True,
    return_test_frame=True,
    return_demo_frame=True,
)
parent_dir = CWD / ".." / f"gifs_{dir_postfix}" / "cover_archive"

#### Cover SS: Policy Success

In [56]:
split = "ss"
episode = 13
fps = 5
title = "In-Distribution Diffusion Policy Success"
save_dir = parent_dir / "id_success"

render_episode_gif(
    data=cover_metrics_gif,
    split=split,
    episode=episode,
    exp_key=mmd_exp_key,
    method="mmd",
    title=title,
    method_color=exp_colors[mmd_exp_key],
    save_dir=save_dir,
    text_mode="upper",
    fps=fps,
)

render_episode_gif(
    data=cover_metrics_gif,
    split=split,
    episode=episode,
    exp_key=var_exp_key,
    method="var",
    title=title,
    method_color=exp_colors[var_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
)

render_episode_gif(
    data=cover_metrics_gif,
    split=split,
    episode=episode,
    exp_key=emb_exp_key,
    method="emb",
    title=title,
    method_color=exp_colors[emb_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
)

In [57]:
split = "ss"
episode = 13
fps = 2
title = "In-Distribution Diffusion Policy Success"
save_dir = parent_dir / "id_success"

episode_frame = load_episode_frame(
    domain="0527_cover_4_abl",
    split=split,
    episode=episode,
)

render_episode_action_gif(
    data=cover_metrics_gif,
    data_frame=episode_frame,
    split=split,
    episode=episode,
    exp_key=mmd_exp_key,
    method="mmd",
    title=title,
    method_color=exp_colors[mmd_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
    # 3D rendering params.
    num_robots=2,
    exec_horizon=4,
    action_dim=7,
    robot_freq=3,
    sample_size=32,
    pred_horizon=16,
    num_3d_plots=1,
    scale_mode="default"
)

#### Cover SS: Policy Failure (Not Detected)

In [58]:
split = "ss"
episode = 48
fps = 5
title = "Out-of-Distribution Diffusion Policy Failure"
save_dir = parent_dir / "ood_failure"

render_episode_gif(
    data=cover_metrics_gif,
    split=split,
    episode=episode,
    exp_key=mmd_exp_key,
    method="mmd",
    title=title,
    method_color=exp_colors[mmd_exp_key],
    save_dir=save_dir,
    text_mode="upper",
    fps=fps,
)

render_episode_gif(
    data=cover_metrics_gif,
    split=split,
    episode=episode,
    exp_key=var_exp_key,
    method="var",
    title=title,
    method_color=exp_colors[var_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
)

render_episode_gif(
    data=cover_metrics_gif,
    split=split,
    episode=episode,
    exp_key=emb_exp_key,
    method="emb",
    title=title,
    method_color=exp_colors[emb_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
)

In [59]:
split = "ss"
episode = 48
fps = 2
title = "Out-of-Distribution Diffusion Policy Failure"
save_dir = parent_dir / "ood_failure"

episode_frame = load_episode_frame(
    domain="0527_cover_4_abl",
    split=split,
    episode=episode,
)

render_episode_action_gif(
    data=cover_metrics_gif,
    data_frame=episode_frame,
    split=split,
    episode=episode,
    exp_key=mmd_exp_key,
    method="mmd",
    title=title,
    method_color=exp_colors[mmd_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
    # 3D rendering params.
    num_robots=2,
    exec_horizon=4,
    action_dim=7,
    robot_freq=3,
    sample_size=32,
    pred_horizon=16,
    num_3d_plots=1,
    scale_mode="default"
)

#### Cover SS: Policy Failure (Not Detected) Multi

In [62]:
split = "ss"
episode = 48
fps = 5
title = "Out-of-Distribution Diffusion Policy Failure"
save_dir = parent_dir / "ood_failure"

render_episode_gif_multi(
    data=cover_metrics_gif,
    split=split,
    episode=episode,
    exp_keys=[mmd_exp_key, var_exp_key, emb_exp_key],
    title=title,
    methods=["mmd", "var", "emb"],
    titles=["STAC MMD (Ours)", "Diffusion Output Variance", "CLIP Embedding Similarity"],
    method_colors=[exp_colors[e] for e in [mmd_exp_key, var_exp_key, emb_exp_key]],
    save_dir=save_dir,
    fps=fps,
)

#### Undetected Failure Case: Cover Object

In [63]:
split = "ss"
episode = 36
fps = 2
title = "Out-of-Distribution Diffusion Policy Failure"
save_dir = parent_dir / "ood_failure"

episode_frame = load_episode_frame(
    domain="0527_cover_4_abl",
    split=split,
    episode=episode,
)

render_episode_action_gif(
    data=cover_metrics_gif,
    data_frame=episode_frame,
    split=split,
    episode=episode,
    exp_key=mmd_exp_key,
    method="mmd",
    title=title,
    method_color=exp_colors[mmd_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
    # 3D rendering params.
    num_robots=2,
    exec_horizon=4,
    action_dim=7,
    robot_freq=3,
    sample_size=32,
    pred_horizon=16,
    num_3d_plots=1,
    scale_mode="default"
)

### Close Domain

In [64]:
close_metrics_gif = compile_metrics(
    domain="0525_close_4_abl",
    splits=["ss"],
    exp_keys=[mmd_exp_key, var_exp_key, emb_exp_key],
    return_test_data=True,
    return_test_frame=True,
    return_demo_frame=True,
)
parent_dir = CWD / ".." / f"gifs_{dir_postfix}" / "close"

#### Close SS: Policy Failure

In [65]:
split = "ss"
episode = 42
fps = 5
title = "Out-of-Distribution Diffusion Policy Failure"
save_dir = parent_dir / "ood_failure_smooth"

render_episode_gif(
    data=close_metrics_gif,
    split=split,
    episode=episode,
    exp_key=mmd_exp_key,
    method="mmd",
    title=title,
    method_color=exp_colors[mmd_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
)

render_episode_gif(
    data=close_metrics_gif,
    split=split,
    episode=episode,
    exp_key=var_exp_key,
    method="var",
    title=title,
    method_color=exp_colors[var_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
)

render_episode_gif(
    data=close_metrics_gif,
    split=split,
    episode=episode,
    exp_key=emb_exp_key,
    method="emb",
    title=title,
    method_color=exp_colors[emb_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
)

In [66]:
split = "ss"
episode = 42
fps = 2
title = "Out-of-Distribution Diffusion Policy Failure"
save_dir = parent_dir / "ood_failure_smooth"

episode_frame = load_episode_frame(
    domain="0525_close_4_abl",
    split=split,
    episode=episode,
)

render_episode_action_gif(
    data=close_metrics_gif,
    data_frame=episode_frame,
    split=split,
    episode=episode,
    exp_key=mmd_exp_key,
    method="mmd",
    title=title,
    method_color=exp_colors[mmd_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
    # 3D rendering params.
    num_robots=2,
    exec_horizon=4,
    action_dim=7,
    robot_freq=3,
    sample_size=32,
    pred_horizon=12,
    num_3d_plots=1,
    scale_mode="zoom"
)

#### Undetected Failure Case: Close Box

In [67]:
split = "ss"
episode = 7
fps = 2
title = "Out-of-Distribution Diffusion Policy Failure"
save_dir = parent_dir / "ood_failure_smooth"

episode_frame = load_episode_frame(
    domain="0525_close_4_abl",
    split=split,
    episode=episode,
)

render_episode_action_gif(
    data=close_metrics_gif,
    data_frame=episode_frame,
    split=split,
    episode=episode,
    exp_key=mmd_exp_key,
    method="mmd",
    title=title,
    method_color=exp_colors[mmd_exp_key],
    save_dir=save_dir,
    text_mode="lower",
    fps=fps,
    # 3D rendering params.
    num_robots=2,
    exec_horizon=4,
    action_dim=7,
    robot_freq=3,
    sample_size=32,
    pred_horizon=12,
    num_3d_plots=1,
    scale_mode="zoom"
)