# Notebook used to plot all the figures of the paper
May contain dead or deprecated code.

## Installation

In [None]:
!pip install seaborn==0.13.2
!pip install git+https://github.com/killiansheriff/LovelyPlots@3cfd78fb4d5a3d8c9f89feb7742f18340ffd2fb7
!pip install brokenaxes==0.6.1

In [None]:
import warnings
import yaml
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.ticker as ticker
import matplotlib.patches as patches
import matplotlib.lines as lines
import seaborn as sns
import math
from scipy.stats import linregress, kendalltau, spearmanr, pearsonr
from brokenaxes import brokenaxes
from cycler import cycler
import lovelyplots

In [None]:
warnings.filterwarnings("ignore")

lovelyplots_installdir = os.path.dirname(lovelyplots.__file__)
plt.style.use(
    [
        f"file://{lovelyplots_installdir}/styles/ipynb.mplstyle",
        f"file://{lovelyplots_installdir}/styles/colors/colors10.mplstyle",
        f"file://{lovelyplots_installdir}/styles/utils/use_mathtext.mplstyle",
    ]
)

line_cycler = cycler(color=["#E69F00", "#56B4E9", "#009E73", "#0072B2", "#D55E00", "#CC79A7", "#F0E442"])

marker_cycler = (
    cycler(color=["#E69F00", "#56B4E9", "#009E73", "#0072B2", "#D55E00", "#CC79A7", "#F0E442"])
    + cycler(linestyle=["none", "none", "none", "none", "none", "none", "none"])
    + cycler(marker=["4", "2", "3", "1", "+", "x", "."])
)

# Atari <a name="atari"></a>

## Loading the runs from raw logs in a df and saving them combined .csv files <a name="load_atari"></a>
### Training
<div style="background-color: #F9F9F9; border-left: 5px solid #CC0000; padding: 10px; margin: 20px 0;">
    <strong> Skip to loading the combined raw logs .csv if you don't have the raw logs.
</div> 

In [None]:
# If the data is somewhere else create a symlink to it as your $PROJECT_DIR/outputs/rlconf
# ln -s <path-to-rlconf-containing-data> $PROJECT_DIR/outputs/rlconf

root_log_dir = f"../outputs/rlconf/solve/atari-ppo/baselines/2024-04-05_23-18-42-112955"

all_keys = set()
log_subdirs = ["logs/models", "logs/minibatch", "logs/eval", "logs/epoch", "logs/batch"]
n = len(os.listdir(root_log_dir))

for subdir in log_subdirs:
    subdir_path = os.path.join(root_log_dir, subdir)
    if not os.path.exists(subdir_path):
        print(f"Subdirectory {subdir} does not exist.")
        continue
    for file_name in os.listdir(subdir_path):
        file_path = os.path.join(subdir_path, file_name)
        if file_path.endswith(".tar"):
            loaded_data = torch.load(file_path, map_location=torch.device("cpu"))
            all_keys.update(loaded_data.keys())
all_keys = list(all_keys)
all_keys.sort()


def get_nested_config_value(config, nested_key):
    """
    Retrieve a value from a nested dictionary using a list of keys.
    :param config: The configuration dictionary.
    :param nested_key: A list of keys representing the path to the desired value.
    :return: The value if found, None otherwise.
    """
    for key in nested_key:
        if isinstance(config, dict) and key in config:
            config = config[key]
        else:
            return None
    return config


def config_matches_criteria(config, criteria):
    """
    Check if the configuration matches the given criteria, supporting nested keys.
    :param config: The configuration dictionary from the YAML file.
    :param criteria: The criteria dictionary to match against, with nested keys as tuples.
    :return: True if the config matches the criteria, False otherwise.
    """
    for keys, value in criteria.items():
        # Support for nested keys represented as tuples in criteria
        nested_keys = keys if isinstance(keys, tuple) else (keys,)
        config_value = get_nested_config_value(config, nested_keys)
        if config_value != value:
            return False
    return True


def flatten_config(config, parent_key="", sep="/"):
    items = []
    for k, v in config.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_config(v, new_key, sep=sep).items())
        elif isinstance(v, list):
            for i, item in enumerate(v):
                items.extend(flatten_config({f"{new_key}{sep}{i}": item}, "", sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)


def find_matching_configs(root_path, criteria):
    """
    Find subfolders where the config matches the given criteria within each subfolder of root_path.
    """
    matching_paths = []

    # Iterate over each subfolder in the given root_path
    for category in tqdm(os.listdir(root_path), total=len(os.listdir(root_path)), desc="Find matching path"):
        category_path = os.path.join(root_path, category)
        if not os.path.isdir(category_path):
            continue  # Skip if not a directory
        if category == "baselines":
            # Iterate over each run folder within the category
            for run_folder in os.listdir(category_path):
                config_path = os.path.join(category_path, run_folder, "config", "config_resolved.yaml")

                # Check if the config_resolved.yaml file exists
                if os.path.exists(config_path):
                    with open(config_path, "r") as file:
                        config = yaml.safe_load(file)

                    # Check if the config matches the given criteria
                    if config_matches_criteria(config, criteria):
                        matching_paths.append(os.path.join(category_path, run_folder))
        else:
            for category2 in os.listdir(category_path):
                category_path2 = os.path.join(category_path, category2)
                for run_folder in os.listdir(category_path2):
                    config_path = os.path.join(category_path2, run_folder, "config", "config_resolved.yaml")

                    # Check if the config_resolved.yaml file exists
                    if os.path.exists(config_path):
                        with open(config_path, "r") as file:
                            config = yaml.safe_load(file)

                        # Check if the config matches the given criteria
                        if config_matches_criteria(config, criteria):
                            matching_paths.append(os.path.join(category_path2, run_folder))

    return matching_paths


def compile_data_debug(matching_paths, metric_keys):
    data = []
    nan_counter = 0  # Counter for NaN occurrences for the specific key

    for path in tqdm(matching_paths, total=len(matching_paths), desc="Filling df"):
        # Load and flatten the config
        config_path = os.path.join(path, "config", "config_resolved.yaml")
        with open(config_path, "r") as file:
            config = yaml.safe_load(file)
        flat_config = flatten_config(config)
        flat_config = {f"config/{key}": value for key, value in flat_config.items()}  # Prefix with "config/"

        # Navigate to the logs/batch folder and read metrics
        batch_folder = os.path.join(path, "logs", "batch")
        if os.path.exists(batch_folder):
            for log_file in os.listdir(batch_folder):
                log_path = os.path.join(batch_folder, log_file)
                if log_path.endswith(".tar"):
                    loaded_data = torch.load(log_path, map_location=torch.device("cpu"))
                    # Initialize a record with the flattened config
                    record = flat_config.copy()

                    for metric_key in metric_keys:
                        metric_value = loaded_data.get(metric_key, None)

                        record[metric_key] = metric_value

                    data.append(record)

    # Create a DataFrame from the compiled data
    df = pd.DataFrame(data)
    return df


# Example usage
path_folder = f"../outputs/rlconf/solve/atari-ppo/"
# Define criteria with nested keys as tuples
criteria = {}
matching_paths = find_matching_configs(path_folder, criteria)

atari_df = compile_data_debug(matching_paths, all_keys)
print(atari_df.head())

atari_df.to_csv("../outputs/rlconf-plotting/combined-raw-logs/atari.csv")

In [None]:
# atari_df.to_csv("../outputs/rlconf-plotting/combined-raw-logs/atari.csv")

In [None]:
# At the time of submission to RLC
# Should expect 468 runs
# 180 baselines = 6 maps * 3 epochs * 2 lr schedules * 5 seeds
# 288 interventions = 4 interventions * 3 maps * 3 epochs * (2 lr schedules * 3 seeds + 1 lr schedule * 2 seeds)

### Plasticity
<div style="background-color: #F9F9F9; border-left: 5px solid #CC0000; padding: 10px; margin: 20px 0;">
    <strong> Skip to loading the combined raw logs .csv if you don't have the raw logs.
</div> 

In [None]:
root_log_dir = f"../outputs/rlconf/capacity/atari-ppo/all/2024-04-08_11-25-26-690228"


all_keys = set()
log_subdirs = ["logs/checkpoint", "logs/minibatch", "logs/model", "logs/epoch"]
n = len(os.listdir(root_log_dir))

for subdir in log_subdirs:
    subdir_path = os.path.join(root_log_dir, subdir)
    if not os.path.exists(subdir_path):
        print(f"Subdirectory {subdir} does not exist.")
        continue
    for file_name in os.listdir(subdir_path):
        file_path = os.path.join(subdir_path, file_name)
        if file_path.endswith(".tar"):
            loaded_data = torch.load(file_path, map_location=torch.device("cpu"))
            all_keys.update(loaded_data.keys())
all_keys = list(all_keys)
all_keys.sort()


def find_matching_configs(root_path, criteria):
    """
    Find subfolders where the config matches the given criteria within each subfolder of root_path.
    """
    matching_paths = []

    # Iterate over each subfolder in the given root_path
    for category in tqdm(os.listdir(root_path), total=len(os.listdir(root_path)), desc="Find matching path"):
        category_path = os.path.join(root_path, category)
        if not os.path.isdir(category_path):
            continue  # Skip if not a directory
        if category == "all":
            # Iterate over each run folder within the category
            for run_folder in os.listdir(category_path):
                config_path = os.path.join(category_path, run_folder, "config", "config_resolved.yaml")

                # Check if the config_resolved.yaml file exists
                if os.path.exists(config_path):
                    with open(config_path, "r") as file:
                        config = yaml.safe_load(file)

                    # Check if the config matches the given criteria
                    if config_matches_criteria(config, criteria):
                        matching_paths.append(os.path.join(category_path, run_folder))
        else:
            for category2 in os.listdir(category_path):
                category_path2 = os.path.join(category_path, category2)
                for run_folder in os.listdir(category_path2):
                    config_path = os.path.join(category_path2, run_folder, "config", "config_resolved.yaml")

                    # Check if the config_resolved.yaml file exists
                    if os.path.exists(config_path):
                        with open(config_path, "r") as file:
                            config = yaml.safe_load(file)

                        # Check if the config matches the given criteria
                        if config_matches_criteria(config, criteria):
                            matching_paths.append(os.path.join(category_path2, run_folder))

    return matching_paths


def compile_data_debug(matching_paths, metric_keys):
    data = []
    nan_counter = 0  # Counter for NaN occurrences for the specific key

    for path in tqdm(matching_paths, total=len(matching_paths), desc="Filling df"):
        # Load and flatten the config
        config_path = os.path.join(path, "config", "config_resolved.yaml")
        with open(config_path, "r") as file:
            config = yaml.safe_load(file)
        flat_config = flatten_config(config)
        flat_config = {f"config/{key}": value for key, value in flat_config.items()}  # Prefix with "config/"

        # Navigate to the logs/batch folder and read metrics
        batch_folder = os.path.join(path, "logs", "checkpoint")
        if os.path.exists(batch_folder):
            for log_file in os.listdir(batch_folder):
                log_path = os.path.join(batch_folder, log_file)
                if log_path.endswith(".tar"):
                    loaded_data = torch.load(log_path, map_location=torch.device("cpu"))
                    # Initialize a record with the flattened config
                    record = flat_config.copy()

                    for metric_key in metric_keys:
                        metric_value = loaded_data.get(metric_key, None)

                        record[metric_key] = metric_value

                    data.append(record)

    # Create a DataFrame from the compiled data
    df = pd.DataFrame(data)
    return df


# Example usage
path_folder = f"../outputs/rlconf/capacity/atari-ppo/"
# Define criteria with nested keys as tuples
criteria = {}
matching_paths = find_matching_configs(path_folder, criteria)

atari_df_capacity = compile_data_debug(matching_paths, all_keys)
print(atari_df_capacity.head())
atari_df_capacity.to_csv("../outputs/rlconf-plotting/combined-raw-logs/atari-capacity.csv")

In [None]:
# atari_df_capacity.to_csv("../outputs/rlconf-plotting/combined-raw-logs/atari-capacity.csv")

In [None]:
# At the time of submission to RLC
# Should expect 468 runs
# 180 baselines = 6 maps * 3 epochs * 2 lr schedules * 5 seeds
# 288 interventions = 4 interventions * 3 maps * 3 epochs * (2 lr schedules * 3 seeds + 1 lr schedule * 2 seeds)

<div style="background-color: #F9F9F9; border-left: 5px solid #CC0000; padding: 10px; margin: 20px 0;">
    <strong> Load the combined raw logs .csv files here. If you did not generate them you can find the instructions to download them in $PROJECT_ROOT/outputs/README.md. 
</div> 

In [None]:
atari_df = pd.read_csv("../outputs/rlconf-plotting/combined-raw-logs/atari.csv")
atari_df.head()

In [None]:
atari_df_capacity = pd.read_csv("../outputs/rlconf-plotting/combined-raw-logs/atari-capacity.csv")
atari_df.head()

## Plotting functions

In [None]:
def filter_df_by_criteria(df, criteria):
    """
    Filter a DataFrame based on matching criteria, supporting multiple possible values for each criterion,
    nested keys of varying depths, and substring matches for specific keys.

    Args:
    - df (pd.DataFrame): The DataFrame to filter.
    - criteria (dict): A dictionary where keys are column names or tuples representing nested keys in the DataFrame,
                       and values are the criteria that the column values must match. The value can be a single value,
                       a list of possible values, or a substring to match within the column values.

    Returns:
    - pd.DataFrame: A new DataFrame containing only the rows that match all the criteria.
    """
    filtered_df = df.copy()
    for keys, value in criteria.items():
        # Support for nested keys represented as tuples in criteria
        if not isinstance(keys, tuple):
            keys = (keys,)  # Convert single key to tuple for uniform processing
        # Construct the column name based on nested keys
        adjusted_column = "/".join(keys)

        # Special handling for substring matching in specific columns
        if adjusted_column in ["config/working_dir", "config/solve_dir"]:
            # Check if value is a list or a single value
            if isinstance(value, list):
                # When value is a list, filter rows containing any of the strings in the list
                filtered_df = filtered_df[
                    filtered_df[adjusted_column].apply(
                        lambda x: any(val in x for val in value) if pd.notna(x) else False
                    )
                ]
            else:
                # When value is a single string, use str.contains for substring match
                filtered_df = filtered_df[filtered_df[adjusted_column].str.contains(value, na=False)]
        else:
            # Filter based on whether the value is a list or a single value
            if isinstance(value, list):
                filtered_df = filtered_df[filtered_df[adjusted_column].isin(value)]
            else:
                filtered_df = filtered_df[filtered_df[adjusted_column] == value]
    return filtered_df


def create_mask_for_config(df, config_row, config_keys):
    mask = pd.Series(True, index=df.index)
    for key in config_keys:
        # Check if both are NaN or if they are equal
        both_na = pd.isna(df[key]) & pd.isna(config_row[key])
        values_equal = df[key] == config_row[key]
        mask &= both_na | values_equal
    return mask

In [None]:
# Function to plot in figure 1
def plot_shaded_metrics_side_by_side_plasticity_smooth(
    df, df2, group_by_cols, x_col, metrics_keys, metrics_keys2, log_key="", subplot_titles=None, title="", name_=""
):
    if name_.split("-")[-1] == "bis":
        color_cycler = cycler(color=["#D55E00", "#56B4E9", "#0072B2", "#009E73", "#E69F00"])
        plt.rc("axes", prop_cycle=color_cycler)
    else:
        plt.rc("axes", prop_cycle=line_cycler)
    if isinstance(group_by_cols, list) and len(group_by_cols) == 1:
        group_by = group_by_cols[0]
    else:
        group_by = group_by_cols

    grouped = df.groupby(group_by)
    num_metrics = len(metrics_keys) + 2
    alpha = 0.05
    fig, axs = plt.subplots(1, num_metrics, figsize=(40, 8), sharex=True)
    special_keys = [
        "batch/perf/avg_return_raw",
        "batch/end/SVD/approximate_rank_pca/features_policy_batch",
        "batch/start/feature_stats/norm_features_preactivation_policy_batch",
        "batch/end/SVD/approximate_rank_pca/features_value_batch",
    ]
    for metric_idx, metric_key in enumerate(metrics_keys):
        for name, group in grouped:
            if metric_key in special_keys:
                # Calculate EMA for special keys: smoothing
                ema_mean = group.groupby(x_col)[metric_key].mean().ewm(alpha=alpha, adjust=False).mean()
                ema_min = group.groupby(x_col)[metric_key].min().ewm(alpha=alpha, adjust=False).mean()
                ema_max = group.groupby(x_col)[metric_key].max().ewm(alpha=alpha, adjust=False).mean()

                if metric_idx == 3:
                    axs[metric_idx + 1].plot(
                        ema_mean.index, ema_mean, label=name if metric_idx == 0 else "", linewidth=4.0
                    )
                    axs[metric_idx + 1].fill_between(ema_mean.index, ema_min, ema_max, alpha=0.3)
                else:
                    axs[metric_idx].plot(ema_mean.index, ema_mean, label=name if metric_idx == 0 else "", linewidth=4.0)
                    axs[metric_idx].fill_between(ema_mean.index, ema_min, ema_max, alpha=0.3)
            else:
                # Here, adjust the groupby for mean, min, and max to directly use the column name
                mean = group.groupby(x_col)[metric_key].mean()
                min_val = group.groupby(x_col)[metric_key].min()
                max_val = group.groupby(x_col)[metric_key].max()
                if metric_idx == 3:
                    axs[metric_idx + 1].plot(mean.index, mean, label=name if metric_idx == 0 else "", linewidth=4.0)
                    axs[metric_idx + 1].fill_between(
                        mean.index, min_val, max_val, alpha=0.3
                    )  # Use min and max for shaded area
                else:
                    axs[metric_idx].plot(mean.index, mean, label=name if metric_idx == 0 else "", linewidth=4.0)
                    axs[metric_idx].fill_between(
                        mean.index, min_val, max_val, alpha=0.3
                    )  # Use min and max for shaded area
        if metric_idx == 3:
            axs[metric_idx + 1].set_xlabel("Environment steps", fontsize=25)
            axs[metric_idx + 1].set_ylabel(subplot_titles[metric_idx], fontsize=30)
            axs[metric_idx + 1].tick_params(axis="x", labelsize=30)
            axs[metric_idx + 1].tick_params(axis="y", labelsize=30)
            axs[metric_idx + 1].xaxis.get_offset_text().set_fontsize(25)
            axs[metric_idx + 1].yaxis.get_offset_text().set_fontsize(25)
        else:
            axs[metric_idx].set_xlabel("Environment steps", fontsize=25)
            axs[metric_idx].set_ylabel(subplot_titles[metric_idx], fontsize=30)
            axs[metric_idx].tick_params(axis="x", labelsize=30)
            axs[metric_idx].tick_params(axis="y", labelsize=30)
            axs[metric_idx].xaxis.get_offset_text().set_fontsize(25)
            axs[metric_idx].yaxis.get_offset_text().set_fontsize(25)

        if metric_key in log_key:
            axs[metric_idx].set_yscale("log")  # Apply log scale for the specific subplot

        axs[metric_idx].grid(linestyle="dotted")

    for metric_idx2, metric_key2 in enumerate(metrics_keys2):
        if metric_idx2 == 0:
            grouped = df2.groupby(group_by)
            for name, group in grouped:
                # Here, adjust the groupby for mean, min, and max to directly use the column name
                mean = group.groupby("capacity-counters/env_steps")[metric_key2].mean()
                min_val = group.groupby("capacity-counters/env_steps")[metric_key2].min()
                max_val = group.groupby("capacity-counters/env_steps")[metric_key2].max()

                axs[3].plot(mean.index, mean, label=name if metric_idx == 0 else "", linewidth=4.0)
                axs[3].fill_between(mean.index, min_val, max_val, alpha=0.3)  # Use min and max for shaded area
            axs[len(metrics_keys) - 1].set_xlabel("Environment steps", fontsize=25)
            axs[len(metrics_keys) - 1].set_ylabel(subplot_titles[len(metrics_keys)], fontsize=30)
            axs[len(metrics_keys) - 1].tick_params(axis="x", labelsize=30)
            axs[len(metrics_keys) - 1].tick_params(axis="y", labelsize=30)
            axs[len(metrics_keys) - 1].xaxis.get_offset_text().set_fontsize(25)
            axs[len(metrics_keys) - 1].yaxis.get_offset_text().set_fontsize(25)
            if metric_key2 in log_key:
                axs[len(metrics_keys) - 1].set_yscale("log")  # Apply log scale for the specific subplot
        else:
            grouped = df2.groupby(group_by)
            for name, group in grouped:
                # Here, adjust the groupby for mean, min, and max to directly use the column name
                mean = group.groupby("capacity-counters/env_steps")[metric_key2].mean()
                min_val = group.groupby("capacity-counters/env_steps")[metric_key2].min()
                max_val = group.groupby("capacity-counters/env_steps")[metric_key2].max()

                axs[len(metrics_keys) + 1].plot(mean.index, mean, label=name if metric_idx == 0 else "", linewidth=4.0)
                axs[len(metrics_keys) + 1].fill_between(
                    mean.index, min_val, max_val, alpha=0.3
                )  # Use min and max for shaded area
            axs[len(metrics_keys) + 1].set_xlabel("Environment steps", fontsize=25)
            axs[len(metrics_keys) + 1].set_ylabel(subplot_titles[len(metrics_keys) + 1], fontsize=30)
            axs[len(metrics_keys) + 1].tick_params(axis="x", labelsize=30)
            axs[len(metrics_keys) + 1].tick_params(axis="y", labelsize=30)
            axs[len(metrics_keys) + 1].xaxis.get_offset_text().set_fontsize(25)
            axs[len(metrics_keys) + 1].yaxis.get_offset_text().set_fontsize(25)
            if metric_key2 in log_key:
                axs[len(metrics_keys) + 1].set_yscale("log")  # Apply log scale for the specific subplot
        axs[len(metrics_keys) + metric_idx2].grid(linestyle="dotted")
    # fig.suptitle(title, fontsize=30)
    fig.tight_layout(pad=2.0)
    handles, labels = axs[0].get_legend_handles_labels()
    lab = []
    # Handle labels for fig 1 and additional figures
    for label in labels:
        if label.startswith("(") and label.endswith(")"):
            axs[0].text(
                0.02,
                0.90,
                f"{title}, \n{df['config/optim/num_epochs'].values[0]} epochs",
                fontsize=20,
                fontweight="bold",
                transform=axs[0].transAxes,
            )
            label = label[1:-1]
            parts = label.split(",")
            share = parts[0]
            trust = float(parts[1])
            reset = parts[2]
            if share == "True":
                lab.append(f"Share actor and critic features")
            else:
                if reset == " True":
                    lab.append("Reset Adam")
                else:
                    if trust == 1:
                        lab.append(f"Regularize last preactivation")
                    elif trust == 10:
                        lab.append(f"Regularize all preactivations")
                    else:
                        lab.append("No intervention")
        else:
            lab.append(label + " epochs")
            axs[0].text(0.02, 0.95, f"{title}", fontsize=20, fontweight="bold", transform=axs[0].transAxes)
    legend_properties = {"weight": "bold"}

    if len(lab) == 4:
        box_to_anchor = (0.25, 1.05)

    elif len(lab) == 3:
        box_to_anchor = (0.4, 1.05)
    else:
        box_to_anchor = (0.15, 1.05)
    fig.legend(
        handles,
        lab,
        loc="upper left",
        bbox_to_anchor=box_to_anchor,
        borderaxespad=0.0,
        fontsize=30,
        ncol=len(lab),
        frameon=False,
        handlelength=1,
        handletextpad=0.5,
        columnspacing=1,
    )
    plt.subplots_adjust(right=1, top=0.90)  # Adjust subplot params to make room for the legend
    plt.savefig(f"../outputs/rlconf-plotting/plots/Figure-1-{name_}.pdf", format="pdf", bbox_inches="tight")
    plt.show()


# Function to plot in figure 2
def plot_shaded_metrics_side_by_side_smooth(
    df, group_by_cols, x_col, metrics_keys, log_key="", subplot_titles=None, title="", name_=""
):
    plt.rc("axes", prop_cycle=line_cycler)
    if isinstance(group_by_cols, list) and len(group_by_cols) == 1:
        group_by = group_by_cols[0]
    else:
        group_by = group_by_cols

    grouped = df.groupby(group_by)
    num_metrics = len(metrics_keys)

    fig, axs = plt.subplots(1, num_metrics, figsize=(30, 10), sharex=True)
    alpha = 0.05
    for metric_idx, metric_key in enumerate(metrics_keys):
        for name, group in grouped:
            # Here, adjust the groupby for mean, min, and max to directly use the column name

            ema_mean = group.groupby(x_col)[metric_key].mean().ewm(alpha=alpha, adjust=False).mean()
            ema_min = group.groupby(x_col)[metric_key].min().ewm(alpha=alpha, adjust=False).mean()
            ema_max = group.groupby(x_col)[metric_key].max().ewm(alpha=alpha, adjust=False).mean()

            axs[metric_idx].plot(ema_mean.index, ema_mean, label=name if metric_idx == 0 else "", linewidth=3.0)
            axs[metric_idx].fill_between(ema_mean.index, ema_min, ema_max, alpha=0.3)

        axs[metric_idx].set_xlabel("Environment steps", fontsize=35)
        axs[metric_idx].set_ylabel(subplot_titles[metric_idx], fontsize=35)
        axs[metric_idx].tick_params(axis="x", labelsize=35)
        axs[metric_idx].tick_params(axis="y", labelsize=35)
        axs[metric_idx].xaxis.get_offset_text().set_fontsize(35)
        axs[metric_idx].yaxis.get_offset_text().set_fontsize(35)

        if metric_key == log_key:
            axs[metric_idx].set_yscale("log")  # Apply log scale for the specific subplot

        axs[metric_idx].grid(linestyle="dotted")
    # Assuming the first subplot's labels represent all potential labels across subplots
    handles, labels = axs[0].get_legend_handles_labels()
    labels = [label + " epochs" for label in labels]
    legend_properties = {"weight": "bold"}
    axs[0].text(0.02, 0.95, f"{title}", fontsize=20, fontweight="bold", transform=axs[0].transAxes)
    # fig.suptitle(title, fontsize=40)
    fig.legend(
        handles,
        labels,
        loc="upper left",
        bbox_to_anchor=(0.4, 1.02),
        borderaxespad=0.0,
        fontsize=35,
        ncol=3,
        frameon=False,
        handlelength=1,
        handletextpad=0.5,
        columnspacing=1,
    )
    plt.subplots_adjust(right=1)  # Adjust subplot params to make room for the legend
    plt.savefig(f"../outputs/rlconf-plotting/plots/Figure-2-{name_}.pdf", format="pdf", bbox_inches="tight")
    plt.show()

___

## Figure 1 <a name="load_atari"></a>

In [None]:
def figure1_atari(name):
    group_by_cols = ["config/optim/num_epochs"]

    metrics_keys = [
        "batch/perf/avg_return_raw",
        "batch/end/SVD/approximate_rank_pca/features_policy_batch",
        "batch/start/feature_stats/norm_features_preactivation_policy_batch",
        "batch/end/SVD/approximate_rank_pca/features_value_batch",
    ]
    criteria = {
        ("config", "env", "name"): [
            name,
        ],
        ("config", "optim", "anneal_linearly"): [False],
        ("config", "optim", "reset_state"): [False],
        ("config", "models", "share_features"): [False],
        ("config", "loss", "policy", "kwargs", "feature_trust_region_coef"): [0.0],
        ("config", "working_dir"): "baselines",
    }
    fig1_df = filter_df_by_criteria(atari_df, criteria)
    criteria = {
        ("config", "env", "name"): [
            name,
        ],
        ("config", "optim", "anneal_linearly"): [False],
        ("config", "optim", "reset_state"): [False],
        ("config", "models", "share_features"): [False],
        ("config", "loss", "policy", "kwargs", "feature_trust_region_coef"): [0.0],
        ("config", "solve_dir"): "baselines",
    }
    fig1_df_plasticity = filter_df_by_criteria(atari_df_capacity, criteria)
    f_name = f"1-{name}"
    f_name = f_name.replace("/", "-")
    plot_shaded_metrics_side_by_side_plasticity_smooth(
        fig1_df,
        fig1_df_plasticity,
        group_by_cols,
        "global_step",
        metrics_keys,
        [
            "capacity-checkpoint/last_capacity-epoch/last_capacity-minibatch/loss/policy",
            "capacity-checkpoint/last_capacity-epoch/last_capacity-minibatch/loss/value",
        ],
        "batch/start/feature_stats/norm_features_preactivation_policy_batch",
        [
            "Episode return",
            "Feature rank policy (PCA)",
            "Norm preactivation policy",
            "Feature rank critic (PCA)",
            "Plasticity loss policy",
            "Plasticity loss critic",
        ],
        title=name,
        name_=f_name,
    )
    return

In [None]:
figure1_atari("ALE/Phoenix-v5")
figure1_atari("ALE/Qbert-v5")
figure1_atari("ALE/NameThisGame-v5")
figure1_atari("ALE/Gravitar-v5")
figure1_atari("ALE/DoubleDunk-v5")
figure1_atari("ALE/BattleZone-v5")

In [None]:
def figure1_atari_bis(name, num_epochs):
    group_by_cols = [
        "config/models/share_features",
        "config/loss/policy/kwargs/feature_trust_region_coef",
        "config/optim/reset_state",
    ]

    metrics_keys = [
        "batch/perf/avg_return_raw",
        "batch/end/SVD/approximate_rank_pca/features_policy_batch",
        "batch/start/feature_stats/norm_features_preactivation_policy_batch",
        "batch/end/SVD/approximate_rank_pca/features_value_batch",
    ]
    criteria = {
        ("config", "env", "name"): [
            name,
        ],
        ("config", "optim", "num_epochs"): [num_epochs],
        ("config", "optim", "anneal_linearly"): [False],
        ("config", "algo"): ["ppo-clip"],
        ("config", "loss", "policy", "kwargs", "feature_trust_region_coef"): [0, 1, 10],
        ("config", "working_dir"): [
            "baseline",
            "experiment",
            "optimizer",
            "regularize",
            "regularize-all-layers",
            "shared-trunk",
        ],
    }
    fig1_df = filter_df_by_criteria(atari_df, criteria)
    criteria = {
        ("config", "env", "name"): [
            name,
        ],
        ("config", "optim", "num_epochs"): [num_epochs],
        ("config", "optim", "anneal_linearly"): [False],
        ("config", "algo"): ["ppo-clip"],
        ("config", "loss", "policy", "kwargs", "feature_trust_region_coef"): [0, 1, 10],
        ("config", "solve_dir"): [
            "baseline",
            "experiment",
            "optimizer",
            "regularize",
            "regularize-all-layers",
            "shared-trunk",
        ],
    }
    fig1_df_plasticity = filter_df_by_criteria(atari_df_capacity, criteria)
    f_name = f"1-{name}-{num_epochs}-epochs-bis"
    f_name = f_name.replace("/", "-")
    plot_shaded_metrics_side_by_side_plasticity_smooth(
        fig1_df,
        fig1_df_plasticity,
        group_by_cols,
        "global_step",
        metrics_keys,
        [
            "capacity-checkpoint/last_capacity-epoch/last_capacity-minibatch/loss/policy",
            "capacity-checkpoint/last_capacity-epoch/last_capacity-minibatch/loss/value",
        ],
        "batch/start/feature_stats/norm_features_preactivation_policy_batch",
        [
            "Episode return",
            "Feature rank policy (PCA)",
            "Norm preactivation policy",
            "Feature rank critic (PCA)",
            "Plasticity loss policy",
            "Plasticity loss critic",
        ],
        title=name,
        name_=f_name,
    )
    return

In [None]:
figure1_atari_bis("ALE/Phoenix-v5", 4)
figure1_atari_bis("ALE/Phoenix-v5", 6)
figure1_atari_bis("ALE/Phoenix-v5", 8)

figure1_atari_bis("ALE/Gravitar-v5", 4)
figure1_atari_bis("ALE/Gravitar-v5", 6)
figure1_atari_bis("ALE/Gravitar-v5", 8)

figure1_atari_bis("ALE/NameThisGame-v5", 4)
figure1_atari_bis("ALE/NameThisGame-v5", 6)
figure1_atari_bis("ALE/NameThisGame-v5", 8)

___

## Figure 2

In [None]:
def figure2_atari(name):
    metrics_keys2 = [
        "batch/first_epoch/first_minibatch/loss/entropy",
        "batch/start/action_diversity/policy_variance",
        "batch/start/dead_neurons/features_policy_batch",
    ]
    y_axis = ["Entropy", "Policy variance", "Dead neurons policy"]
    f_name = f"2-{name}"
    f_name = f_name.replace("/", "-")
    criteria = {
        ("config", "env", "name"): [
            name,
        ],
        ("config", "optim", "anneal_linearly"): [False],
        ("config", "optim", "reset_state"): [False],
        ("config", "models", "share_features"): [False],
        ("config", "loss", "policy", "kwargs", "feature_trust_region_coef"): [0.0],
        ("config", "working_dir"): "baselines",
    }
    fig2_df = filter_df_by_criteria(atari_df, criteria)
    group_by_cols = ["config/optim/num_epochs"]
    plot_shaded_metrics_side_by_side_smooth(
        fig2_df,
        group_by_cols,
        "global_step",
        metrics_keys2,
        "batch/start/feature_stats/norm_features_preactivation_policy_batch",
        y_axis,
        title=name,
        name_=f_name,
    )

In [None]:
figure2_atari("ALE/Phoenix-v5")
figure2_atari("ALE/NameThisGame-v5")
figure2_atari("ALE/Gravitar-v5")
figure2_atari("ALE/BattleZone-v5")
figure2_atari("ALE/DoubleDunk-v5")
figure2_atari("ALE/Qbert-v5")

___

## Figure 3

In [None]:
def plot_shaded_metrics_side_by_side_plasticity_fig3(
    df, df2, group_by_cols, x_col, metrics_keys, metrics_keys2, log_key="", subplot_titles=None, title="", name_=""
):
    plt.rc("axes", prop_cycle=line_cycler)
    if isinstance(group_by_cols, list) and len(group_by_cols) == 1:
        group_by = group_by_cols[0]
    else:
        group_by = group_by_cols

    grouped = df.groupby(group_by)
    num_metrics = len(metrics_keys) + 1

    fig, axs = plt.subplots(1, num_metrics, figsize=(40, 8), sharex=True)
    alpha = 0.05
    positions = {
        "batch/end/SVD/approximate_rank_pca/features_policy_batch": 0,
        "batch/perf/avg_return_raw": 2,
        "batch/diff/avg_prob_ratio_below_epsilon": 3,
        "batch/end/action_diversity/policy_variance": 4,
        "batch/last_epoch/last_minibatch/loss/loss_policy": 5,
        "capacity-checkpoint/last_capacity-epoch/last_capacity-minibatch/loss/policy": 1,
    }
    legend_info = {}  # Dictionary to track unique legends
    for metric_idx, metric_key in enumerate(metrics_keys):
        for name, group in grouped:
            if metric_key in [
                "batch/end/SVD/approximate_rank_pca/features_policy_batch",
                "batch/perf/avg_return_raw",
                "batch/diff/avg_prob_ratio_below_epsilon",
                "batch/end/action_diversity/policy_variance",
                "batch/last_epoch/last_minibatch/loss/loss_policy",
            ]:
                env_name = group["config/env/name"].iloc[0]  # Assuming uniform configuration within the group
                num_epochs = group["config/optim/num_epochs"].iloc[0]
                activation = group["config/models/activation"].iloc[0]
                label = f"{env_name}, {num_epochs} epochs"

                group = group.dropna(subset=[metric_key])
                if metric_key == "batch/last_epoch/last_minibatch/loss/loss_policy":
                    # - to have ppo-clip objective
                    ema_mean = -group.groupby(x_col)[metric_key].mean().ewm(alpha=alpha, adjust=False).mean()
                    ema_min = -group.groupby(x_col)[metric_key].min().ewm(alpha=alpha, adjust=False).mean()
                    ema_max = -group.groupby(x_col)[metric_key].max().ewm(alpha=alpha, adjust=False).mean()
                else:
                    ema_mean = group.groupby(x_col)[metric_key].mean().ewm(alpha=alpha, adjust=False).mean()
                    ema_min = group.groupby(x_col)[metric_key].min().ewm(alpha=alpha, adjust=False).mean()
                    ema_max = group.groupby(x_col)[metric_key].max().ewm(alpha=alpha, adjust=False).mean()
                (line,) = axs[positions[metric_key]].plot(ema_mean.index, ema_mean, label=label, linewidth=4.0)
                axs[positions[metric_key]].fill_between(ema_mean.index, ema_min, ema_max, alpha=0.3)
                axs[positions[metric_key]].set_xlabel("Environment steps", fontsize=25)
                axs[positions[metric_key]].set_ylabel(subplot_titles[positions[metric_key]], fontsize=25)
                axs[positions[metric_key]].tick_params(axis="x", labelsize=25)
                axs[positions[metric_key]].tick_params(axis="y", labelsize=25)
                axs[positions[metric_key]].xaxis.get_offset_text().set_fontsize(25)
                axs[positions[metric_key]].yaxis.get_offset_text().set_fontsize(25)
                if label not in legend_info:
                    legend_info[label] = line

        if metric_key == log_key:
            axs[positions[metric_key]].set_yscale("log")  # Apply log scale for the specific subplot

        axs[positions[metric_key]].grid(linestyle="dotted")

    for metric_idx2, metric_key2 in enumerate(metrics_keys2):
        if metric_idx2 == 0:
            grouped = df2.groupby(group_by)
            for name, group in grouped:
                env_name = group["config/env/name"].iloc[0]  # Assuming uniform configuration within the group
                num_epochs = group["config/optim/num_epochs"].iloc[0]
                activation = group["config/models/activation"].iloc[0]
                label = f"{env_name}, {num_epochs} epochs, {activation}"
                # Here, adjust the groupby for mean, min, and max to directly use the column name
                mean = group.groupby("capacity-counters/env_steps")[metric_key2].mean()
                min_val = group.groupby("capacity-counters/env_steps")[metric_key2].min()
                max_val = group.groupby("capacity-counters/env_steps")[metric_key2].max()

                (line,) = axs[positions[metric_key2]].plot(mean.index, mean, label=label, linewidth=4.0)
                axs[positions[metric_key2]].fill_between(
                    mean.index, min_val, max_val, alpha=0.3
                )  # Use min and max for shaded area
            axs[positions[metric_key2]].set_xlabel("Environment steps", fontsize=25)
            axs[positions[metric_key2]].set_ylabel(subplot_titles[positions[metric_key2]], fontsize=25)
            axs[positions[metric_key2]].tick_params(axis="x", labelsize=25)
            axs[positions[metric_key2]].tick_params(axis="y", labelsize=25)
            axs[positions[metric_key2]].xaxis.get_offset_text().set_fontsize(25)
            axs[positions[metric_key2]].yaxis.get_offset_text().set_fontsize(25)

        axs[positions[metric_key2]].grid(linestyle="dotted")
    fig.suptitle(title, fontsize=30)
    fig.tight_layout(pad=2.0)
    handles, labels = axs[0].get_legend_handles_labels()
    labels = [label + " epochs" for label in labels]
    color_legend = [
        mlines.Line2D([], [], color="#E69F00", linestyle="solid", markersize=20, label="4 epochs", linewidth=4),
        mlines.Line2D([], [], color="#56B4E9", linestyle="solid", markersize=20, label="6 epochs", linewidth=4),
        mlines.Line2D([], [], color="#009E73", linestyle="solid", markersize=20, label="8 epochs", linewidth=4),
    ]
    plt.legend(
        handles=list(legend_info.values()),
        labels=list(legend_info.keys()),
        loc="upper center",  # This anchors the center of the legend at the provided coordinate
        bbox_to_anchor=(-2.8, 1.2),  # Anchors the legend above the plot
        borderaxespad=0.0,
        fontsize=30,
        ncol=3,
        frameon=False,
        handlelength=1,
        handletextpad=0.5,
        columnspacing=1,
    )
    plt.savefig(f"../outputs/rlconf-plotting/plots/Figure-3-{name_}.pdf", format="pdf", bbox_inches="tight")
    plt.show()

In [None]:
def figure3_atari():
    criteria = {
        ("config", "working_dir"): [
            "baselines/2024-02-27_00-43-36-670838",
            "baselines/2024-02-27_00-43-36-908626",
            "baselines/2024-02-27_00-43-29-025809",
        ],
        ("config", "env", "name"): ["ALE/NameThisGame-v5"],
        ("config", "seed"): [7, 25],
    }

    one_game_5_df = filter_df_by_criteria(atari_df, criteria)

    criteria = {
        ("config", "solve_dir"): [
            "baselines/2024-02-27_00-43-36-670838",
            "baselines/2024-02-27_00-43-36-908626",
            "baselines/2024-02-27_00-43-29-025809",
        ],
        ("config", "env", "name"): ["ALE/NameThisGame-v5"],
    }

    one_game_5_df2 = filter_df_by_criteria(atari_df_capacity, criteria)

    group_by_cols = ["config/working_dir"]

    metrics_keys2 = [
        "batch/end/SVD/approximate_rank_pca/features_policy_batch",
        "batch/perf/avg_return_raw",
        "batch/diff/avg_prob_ratio_below_epsilon",
        "batch/end/action_diversity/policy_variance",
        "batch/last_epoch/last_minibatch/loss/loss_policy",
    ]

    plot_shaded_metrics_side_by_side_plasticity_fig3(
        one_game_5_df,
        one_game_5_df2,
        group_by_cols,
        "global_step",
        metrics_keys2,
        ["capacity-checkpoint/last_capacity-epoch/last_capacity-minibatch/loss/policy"],
        "batch/start/feature_stats/norm_features_preactivation_policy_batch",
        [
            "Rank policy (PCA)",
            "Plasticity loss policy",
            "Episode return",
            r"Avg of prob ratios < 1 - $\epsilon$",
            "Policy variance",
            "PPO-Clip objective",
        ],
        title="",
        name_="4",
    )
    return

___

In [None]:
figure3_atari()

## Figure 4

In [None]:
def plot_correlation_single_metric(ax, points1, points2, colors, shapes, text1, text2, log=0):
    # Handle log transformations
    if log == 1:
        points1 = np.log1p(points1)  # log1p ensures log(0 + 1) to avoid -inf
    if log == 2:
        points2 = np.log1p(points2)
    if log == 3:
        points1 = np.log1p(points1)
        points2 = np.log1p(points2)

    # Try to perform regression and plot it
    try:
        # Calculate linear regression and correlation metrics if data length permits
        if len(points1) > 1 and len(points2) > 1:
            slope, intercept, r_value, p_value, std_err = linregress(points1, points2)
            tau, tau_p_value = kendalltau(points1, points2)
            rho, rho_p_value = spearmanr(points1, points2)

            # Plot the regression line if regression was successful
            sns.regplot(
                x=points1,
                y=points2,
                ci=95,
                ax=ax,
                scatter=False,
                line_kws={"color": "#E69F00", "label": f"y={slope:.2f}x+{intercept:.2f}"},
            )
            ax.set_title(f"Kendall: {tau:.2f}, Spearman: {rho:.2f}", fontsize=20)
    except Exception as e:
        # If regression fails, only plot the points
        print(f"Regression or correlation failed: {str(e)}")

    # Create a scatter plot with customized markers in all cases
    for x, y, color, shape in zip(points1, points2, colors, shapes):
        ax.scatter(x, y, color=color, marker=shape)

    # Customize the subplot
    ax.set_xlabel(text1, fontsize=20)
    ax.set_ylabel(text2, fontsize=20)
    if text1 == "Feature preactivation norm":
        ax.set_xscale("log")
    ax.legend(loc="upper left", labels=[])


def determine_color_shape2(row):
    # Determine color based on 'config/optim/anneal_linearly'
    if row["config/optim/num_epochs"] == 4:
        color = "#E69F00"
    elif row["config/optim/num_epochs"] == 6:
        color = "#E69F00"
    else:
        color = "#E69F00"

    # Determine the directory key to use based on the presence of 'config/solve_dir'
    directory_key = "config/solve_dir" if "config/solve_dir" in row else "config/working_dir"

    # Determine shape based on a keyword in the selected directory
    directory_value = row[directory_key]
    if "shared-trunk" in directory_value:
        shape = "s"  # Circle
    elif "optimizer" in directory_value:
        shape = "p"
    elif "regularize-all-layers" in directory_value:
        shape = "P"
    elif "regularize" in directory_value:
        shape = "D"
    else:
        shape = "o"  # Square as default

    return color, shape


def plot_correlation(df, df2, metric_key1, metric_key2, metrics_key3, metric_keys, text1, texts2, log=0):
    names = df["config/env/name"].drop_duplicates().values
    i = 0
    for name in names:
        mask = df["config/env/name"] == name
        # Apply the mask to the DataFrame to get a filtered DataFrame
        df_filtered = df.loc[mask]
        mask2 = df2["config/env/name"] == name
        df_filtered2 = df2.loc[mask2]
        # Create a figure with subplots
        fig, axes = plt.subplots(nrows=1, ncols=len(metric_keys), figsize=(20, 6))
        i = 0
        for ax, x_metric in zip(axes, metric_keys):
            # Use the average_metrics_per_config_hns function to get data for each metric comparison
            points1, points2, colors, shapes = average_metrics_correlation(
                df_filtered, metric_key1, metric_key2, x_metric
            )

            plot_correlation_single_metric(ax, points2, points1, colors, shapes, texts2[i], text1, log)
            if i % 3 == 0:
                legend_properties = {"weight": "bold"}
                if name != "":
                    pass
            i += 1
    axes[0].text(0.02, 0.05, f"{name}", fontsize=17, fontweight="bold", transform=axes[0].transAxes)
    # add_legends_to_last_subplot2(axes[-1])
    plt.tight_layout(rect=[0, 0, 0.95, 1])
    check = names[0].split("/")
    if len(check) == 1:
        f_name = check[0]
    else:
        f_name = check[1]
    activation = df["config/models/activation"].drop_duplicates().values[0]
    plt.savefig(f"../outputs/rlconf-plotting/plots/correlation_by_maps_average_prob_ratio_{f_name}_{activation}.pdf")
    plt.show()
    return


def average_metrics_correlation(df, metric_key1, metric_key2, metric_key3):
    config_keys = [key for key in df.columns if key.startswith("config/")]
    unique_configs = df[config_keys].drop_duplicates()

    points_combined_metric = []
    points_metric3 = []
    colors = []
    shapes = []
    lowest_metric1_values = []
    lowest_metric3_values = []
    for _, config_row in unique_configs.iterrows():
        mask = create_mask_for_config(df, config_row, config_keys)
        filtered_df = df[mask].copy()
        color, shape = determine_color_shape2(config_row)

        valid_rows = filtered_df.dropna(subset=[metric_key1])
        valid_rows.sort_values(by="global_step", ascending=True, inplace=True)
        if unique_configs["config/env/name"].values[0].split("-")[-1] == "v5":
            # ALE.
            window_size = 1000000  # Global steps window size
            step_size = 1000000  # Step size for moving the window
        else:
            # MuJoCo
            window_size = 50000  # Global steps window size
            step_size = 50000  # Step size for moving the window
        window_averages = []
        window_positions = []
        second_metric = []
        for start_step in range(0, valid_rows["global_step"].max(), step_size):
            window_df = valid_rows[
                (valid_rows["global_step"] >= start_step) & (valid_rows["global_step"] < start_step + window_size)
            ]
            if not window_df.empty:
                avg_metric1 = window_df[metric_key1].mean()
                avg_metric3 = window_df[metric_key3].mean()
                window_averages.append(avg_metric1)
                second_metric.append(avg_metric3)
                window_positions.append(start_step)

        # Find the 20 lowest window averages for metric_key1
        if len(window_averages) > 20:
            indices_of_lowest = np.argsort(window_averages)[:20]
            lowest_metric1_values.extend([window_averages[index] for index in indices_of_lowest])
            lowest_metric3_values.extend([second_metric[index] for index in indices_of_lowest])
            colors.extend([color] * 20)
            shapes.extend([shape] * 20)
        else:
            lowest_metric1_values.extend(window_averages)
            lowest_metric3_values.extend(second_metric)
            colors.extend([color] * len(window_averages))
            shapes.extend([shape] * len(window_averages))
    return lowest_metric1_values, lowest_metric3_values, colors, shapes

In [None]:
def figure4_atari(name):
    criteria = {
        ("config", "env", "name"): [name],
        ("config", "working_dir"): ["baselines"],
        ("config", "optim", "anneal_linearly"): False,
        ("config", "optim", "num_epochs"): [4, 6, 8],
    }
    df_games = filter_df_by_criteria(atari_df, criteria)

    criteria = {
        ("config", "env", "name"): [name],
        ("config", "solve_dir"): ["baselines"],
        ("config", "optim", "anneal_linearly"): False,
        ("config", "optim", "num_epochs"): [4, 6, 8],
    }
    df_games_2 = filter_df_by_criteria(atari_df_capacity, criteria)

    plot_correlation(
        df_games,
        df_games_2,
        "batch/diff/avg_prob_ratio_below_epsilon",
        "",
        "",
        [
            "batch/start/dead_neurons/features_policy_batch",
            "batch/start/SVD/approximate_rank_pca/features_policy_batch",
            "batch/start/feature_stats/norm_features_preactivation_policy_batch",
        ],
        r"Avg of prob ratios < 1 - $\epsilon$",
        ["Dead neurons policy", "Feature rank policy (PCA)", "Feature preactivation norm"],
        log=0,
    )
    return

In [None]:
figure4_atari("ALE/Phoenix-v5")
figure4_atari("ALE/NameThisGame-v5")
figure4_atari("ALE/Qbert-v5")
figure4_atari("ALE/Gravitar-v5")
figure4_atari("ALE/BattleZone-v5")
figure4_atari("ALE/DoubleDunk-v5")

## Figure 5

In [None]:
def figure5(num, suffix):
    seed = 1
    torch.manual_seed(seed)
    N_layers = 1

    # Define the model
    class PolicyNetwork(nn.Module):
        def __init__(self):
            super(PolicyNetwork, self).__init__()
            self.linear = nn.Linear(N_layers, 2, bias=False)  # Single state input, two actions output

        def forward(self, state):
            logits = self.linear(state)
            return nn.functional.softmax(logits, dim=-1)

    # Initialize the model and optimizer
    model = PolicyNetwork()
    optimizer = optim.SGD(model.parameters(), lr=1.5)

    # Example state x and alpha
    feature_norm = 1
    x = feature_norm * torch.randn((1, N_layers), dtype=torch.float)
    y = feature_norm * torch.randn((1, N_layers), dtype=torch.float)

    copy_n = N_layers
    alpha = num  # -2 for interference, 3 for boost
    y[:, :copy_n] = alpha * x[:, :copy_n]

    # Fixed advantage and old policy probabilities for demonstration
    A = 1.0  # Advantage
    old_pi_a1_x = model(x)[0, 0].item()
    old_pi_a1_y = model(y)[0, 0].item()
    epsilon = 0.1

    color = ["#E69F00", "#56B4E9", "#009E73", "#0072B2", "#D55E00", "#CC79A7", "#F0E442"]

    def compute_loss(prob_a1, old_prob_a1):
        """assume A > 0 for simplicity"""
        ratio = prob_a1 / old_prob_a1
        clipped_ratio = torch.clamp(ratio, max=1 + epsilon)
        return A * torch.min(ratio, clipped_ratio)

    # Training loop with alternating updates
    ratio_x_history = [1]
    ratio_y_history = [1]

    steps = 20
    for step in range(steps):
        # Alternate between x and y
        state = x if step % 2 == 0 else y
        old_pi_a1 = old_pi_a1_x if step % 2 == 0 else old_pi_a1_y

        # Forward pass
        probs = model(state)
        prob_a1 = probs[0, 0]

        # Compute loss
        loss = -compute_loss(prob_a1, old_pi_a1)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Log probabilities for plotting
        with torch.no_grad():
            pi_x = model(x)[0, 0].item()
            pi_y = model(y)[0, 0].item()
            ratio_x = pi_x / old_pi_a1_x
            ratio_y = pi_y / old_pi_a1_y
            ratio_x_history.append(ratio_x)
            ratio_y_history.append(ratio_y)
            # Uncomment if you want to have the step by step printed.
            # print(f"Step {step}: pi_x = {pi_x:.2f}, pi_y = {pi_y:.2f}, ratio_x = {ratio_x:.2f}, ratio_y = {ratio_y:.2f}")
    # Plotting
    # Set the y-axis between 0 and 2
    # plt.ylim(0, 2)
    plt.plot(ratio_x_history, label=r"$\pi_\theta(a_1| x) / \pi_\text{old}(a_1|x)$", color="#E69F00", linewidth=4)
    plt.plot(ratio_y_history, label=r"$\pi_\theta(a_1 | y) / \pi_\text{old}(a_1 | y)$", color="#56B4E9", linewidth=4)
    plt.xlabel("Minibatch")
    plt.ylabel("Ratio")

    # plot the 1+epsilon line
    plt.axhline(1 + epsilon, color="r", linestyle="--", label=r"1+$\epsilon$", linewidth=4)

    plt.legend(loc=(-0.3, 1), ncol=3, frameon=False, handlelength=1, handletextpad=0.2, columnspacing=1)
    # x-axis as integer ticks
    # plt.xticks(range(0, steps + 1))
    # show grid lines
    plt.grid(linestyle="dotted")

    plt.savefig(f"../outputs/rlconf-plotting/plots/toy_example_{suffix}.pdf")
    plt.show()
    return

In [None]:
figure5(-2, "neg")
figure5(3, "pos")

## Figure 6

In [None]:
def plot_metrics_by_keywords(df, df2, metrics, metric2, metric_name, keywords, save_name=""):
    """
    Plot box plots for given metrics, aggregating data by keywords found in 'config/working_dir'.

    :param df: DataFrame containing the dataset.
    :param metrics: List of metric keys to plot.
    :param keywords: List of keywords to filter configurations by their working directory.
    """
    key = [
        "Share actor and critic features",
        "Reset Adam",
        "Regularize all preactivations",
        "Regularize last preactivation",
        "No invervention",
    ]
    # Prepare the aggregated data storage
    keys = [
        "baselines",
        "ppo-kl",
        "ppo-early-stop",
        "all",
        "optimizer",
        "regularize",
        "regularize-all-layers",
        "shared-trunk",
    ]

    all_metrics = metrics + [metric2]
    aggregated_data = {keyword: {metric: [] for metric in all_metrics} for keyword in keywords}

    config_keys = [key for key in df.columns if key.startswith("config/")]
    unique_configs = df[config_keys].drop_duplicates()
    colors = ["#E69F00", "#56B4E9", "#009E73", "#0072B2", "#D55E00", "#CC79A7", "#F0E442"]
    return_nb = []
    N_nb = []
    return_steps = []
    ratio_nb = []
    ratio_N_nb = []
    ratio_steps = []
    return_info = []
    ratio_info = []
    ratio_res = []
    for _, config_row in tqdm(unique_configs.iterrows(), total=len(unique_configs)):
        mask = create_mask_for_config(df, config_row, config_keys)
        filtered_df = df.loc[mask].copy()
        for keyword in keywords:
            if "algo" in config_row["config/working_dir"]:
                continue
            else:
                if keyword in config_row["config/working_dir"]:
                    for metric in metrics:
                        if metric == "ProbRatio_":
                            metric_key1 = "batch/diff/avg_prob_ratio_above_epsilon"
                            metric_key2 = "batch/diff/avg_prob_ratio_below_epsilon"

                            # Remove rows where either metric_key1 or metric_key2 is NaN
                            df_no_nan = filtered_df.dropna(subset=[metric_key1, metric_key2])
                            df_sorted = df_no_nan.sort_values(by="global_step", ascending=True)  # Sort by 'global_step'
                            B_max = df_sorted.index.max()
                            # Get max 'global_step' before and after dropping NaNs for verification
                            unfiltered_max = filtered_df["global_step"].max()
                            filtered_max = df_sorted["global_step"].max()

                            df_max = df_no_nan.sort_values(by="global_step", ascending=False)["global_step"].unique()
                            idx_max = 0

                            five_percent_of_unfiltered_max = int(unfiltered_max * 0.05)

                            # Initial threshold calculation
                            threshold_step = filtered_max - five_percent_of_unfiltered_max
                            first_index_above = df_sorted[df_sorted["global_step"] >= threshold_step].index.min()
                            B_min = first_index_above
                            result_df = df_sorted.loc[B_min:]
                            while result_df[metric_key1].count() < 10:
                                idx_max += 1
                                threshold_step = df_max[idx_max] - five_percent_of_unfiltered_max
                                # Recalculate B_min based on the new threshold
                                first_index_above = df_sorted[df_sorted["global_step"] >= threshold_step].index.min()

                                B_min = first_index_above
                                B_max = df_sorted[
                                    df_sorted["global_step"] <= df_max[idx_max]
                                ].index.max()  # Redefine B_max to ensure it stays within the new filtered_max

                                # Update result_df with the new range
                                result_df = df_sorted.loc[B_min:B_max]

                                # If filtered_max reaches the minimum global_step or no further reduction is possible, break the loop
                                if filtered_max <= df_sorted["global_step"].min():
                                    break

                            N = B_max - B_min + 1

                            # Store additional info if needed
                            ratio_steps.append((result_df["global_step"].min(), result_df["global_step"].max()))

                            ratio_info.append(
                                (
                                    df_sorted["config/optim/num_epochs"].values[0],
                                    df_sorted["config/working_dir"].values[0],
                                )
                            )

                            # Prepare and calculate combined metric
                            f_1 = result_df[metric_key1]
                            f_2 = result_df[metric_key2]
                            ratio_nb.append((result_df[metric_key1].count(), result_df[metric_key2].count()))
                            ratio_N_nb.append(N)
                            combined_metric = (f_1 / f_2).mean()
                            # combined_metric = np.clip((f_1 / f_2).mean(), a_min=-np.inf,a_max=2.75)
                            ratio_res.append(combined_metric)
                            aggregated_data[keyword][metric].append(combined_metric)

                        else:
                            if metric == "batch/perf/avg_return_raw":
                                threshold_step = filtered_df["global_step"].max() * 0.95  # Calculate the 95% threshold

                                df_sorted = filtered_df.sort_values(by="global_step", ascending=True)

                                first_index_above = df_sorted[
                                    df_sorted["global_step"] >= threshold_step
                                ].index.min()  # Find the first index where 'global_step' exceeds the threshold
                                B_min = first_index_above
                                B_max = df_sorted.index.max()
                                N = B_max - B_min + 1

                                result_df = df_sorted.loc[B_min:]  # Select this and all subsequent batches
                                return_steps.append((result_df["global_step"].min(), result_df["global_step"].max()))
                                final_rows = []
                                removed_count = 0
                                last_timestep = float("inf")
                                return_info.append(
                                    (
                                        df_sorted["config/optim/num_epochs"].values[0],
                                        df_sorted["config/working_dir"].values[0],
                                    )
                                )
                                for _, row in result_df.iterrows():
                                    # Handling duplicates due to redundant in logging.
                                    if not final_rows:
                                        final_rows.append(row)  # Add the first row automatically
                                        last_timestep = row["batch/perf/max_timestep"]
                                    else:
                                        last_row = final_rows[-1]  # Check the last point added
                                        if (
                                            row["batch/perf/avg_return_raw"] == last_row["batch/perf/avg_return_raw"]
                                        ):  # if duplicate
                                            if (
                                                row["batch/perf/max_timestep"] <= last_timestep or removed_count >= 8
                                            ):  # if we have already removed 8 points (duplicates) or if the max time step is lower than the previous one
                                                removed_count = 0  # Reset the count
                                                final_rows.append(row)  # Add this row as a valid point after reset
                                            else:
                                                removed_count += 1  # Else Increment removed count for duplicates, we are still in the episode
                                        else:
                                            # Add to final rows if it's not a duplicate
                                            final_rows.append(row)
                                            removed_count = 0  # Reset the count since a valid point was added
                                        last_timestep = row["batch/perf/max_timestep"]
                                return_nb.append(len(final_rows))
                                N_nb.append(N)

                                final_df = pd.DataFrame(final_rows).sort_values(by="global_step", ascending=False)

                                mean_value = final_df[
                                    "batch/perf/avg_return_raw"
                                ].mean()  # Calculate the mean of the metric

                                # Storing the result
                                aggregated_data[keyword][metric].append(mean_value)

                            else:
                                threshold_step = filtered_df["global_step"].max() * 0.95  # Calculate the 95% threshold

                                df_sorted = filtered_df.sort_values(by="global_step", ascending=True)

                                first_index_above = df_sorted[
                                    df_sorted["global_step"] >= threshold_step
                                ].index.min()  # Find the first index where 'global_step' exceeds the threshold
                                B_min = first_index_above
                                B_max = df_sorted.index.max()
                                N = B_max - B_min + 1

                                result_df = df_sorted.loc[B_min:]  # Select this and all subsequent batches

                                mean_value = result_df[metric].mean()  # Calculate the mean of the metric

                                aggregated_data[keyword][metric].append(mean_value)  # Storing the result

    config_keys2 = [key for key in df2.columns if key.startswith("config/")]
    unique_configs2 = df2[config_keys2].drop_duplicates()
    i = 0

    for _, config_row2 in unique_configs2.iterrows():
        mask2 = create_mask_for_config(df2, config_row2, config_keys2)
        filtered_df2 = df2[mask2].copy()

        for keyword in keywords:
            if "algo" in config_row2["config/solve_dir"]:
                continue
            else:
                if keyword in config_row2["config/solve_dir"]:
                    metric = "capacity-checkpoint/last_capacity-epoch/last_capacity-minibatch/loss/policy"
                    threshold_step = (
                        filtered_df2["capacity-counters/env_steps"].max() * 0.95
                    )  # Calculate the 95% threshold

                    df_sorted = filtered_df2.sort_values(by="capacity-counters/env_steps", ascending=True)

                    first_index_above = df_sorted[
                        df_sorted["capacity-counters/env_steps"] > threshold_step
                    ].index.min()  # Find the first index where 'global_step' exceeds the threshold
                    B_min = first_index_above - 1
                    B_max = df_sorted.index.max()
                    N = B_max - B_min + 1

                    result_df = df_sorted.loc[B_min:]  # Select this and all subsequent batches

                    mean_value = result_df[metric].mean()  # Calculate the mean of the metric

                    aggregated_data[keyword][metric].append(mean_value)  # Storing the result

    # Uncomment if you want information about each point in the boxplot (return and prob ratio)

    #     print(save_name.split("-")[2])
    #     print(f"- Return : \n")
    #     for i in range(len(N_nb)):
    #         print(f"{return_steps[i][0]} -> {return_steps[i][1]}, {return_nb[i]}/{N_nb[i]}, {return_info[i][0]},{return_info[i][1]}")
    #     print(f"- Ratio : \n")
    #     for i in range(len(ratio_N_nb)):
    #         print(f"{ratio_steps[i][0]} -> {ratio_steps[i][1]}, {ratio_nb[i]}/{ratio_N_nb[i]}, {ratio_info[i][0]},{ratio_info[i][1]}, {ratio_res[i]}")

    meanlineprops = dict(linestyle="-.", linewidth=3, color="red")
    medianprops = dict(linestyle="-.", linewidth=3, color="black")

    # Assuming all_metrics, aggregated_data, keywords, medianprops, meanlineprops, flierprops, and colors are defined
    fig = plt.figure(figsize=(20, 2))
    gs = plt.GridSpec(1, len(all_metrics), figure=fig)  # GridSpec for the entire figure

    axes = []  # List to store either normal or broken axes

    for col_idx, metric in enumerate(all_metrics):
        m = [aggregated_data[keyword][metric] for keyword in keywords]
        prob_means = [np.mean(aggregated_data[keyword]["ProbRatio_"]) for keyword in keywords]
        g_min = min(np.min(data) for data in m)
        g_max = max(np.max(data) for data in m)
        print(f"prob means: {prob_means}")
        if metric == "ProbRatio_" and g_max > 2.5:
            bax = brokenaxes(
                xlims=[(g_min - 0.1, 2.5), (g_max - 0.1, g_max + 0.3)],
                subplot_spec=gs[0, col_idx],
                fig=fig,
                despine=False,
            )
            for s in bax.axs:  # axs is a list of the sub-axes of the broken axes
                s.tick_params(axis="y", which="both", left=False, labelleft=False)
                for tick in s.yaxis.get_major_ticks():
                    tick.tick1line.set_visible(False)  # Hides the inner tick lines
                    tick.tick2line.set_visible(False)  # Hides the outer tick lines
                # If you have minor ticks enabled
                for tick in s.yaxis.get_minor_ticks():
                    tick.tick1line.set_visible(False)
                    tick.tick2line.set_visible(False)
            reduction_factor = 0.3
            for line in bax.diag_handles:
                # Get the current data of the line
                xdata, ydata = line.get_data()

                # Calculate new data points by reducing the length by the specified factor
                # This moves the start and end points closer to the midpoint of the line
                mid_x = np.mean(xdata)
                mid_y = np.mean(ydata)
                new_xdata = mid_x + reduction_factor * (xdata - mid_x)
                new_ydata = mid_y + reduction_factor * (ydata - mid_y)

                # Set the new data for the line
                line.set_data(new_xdata, new_ydata)
                line.set_linewidth(1)
            ax = bax
            ax.set_yticklabels([])
        else:
            # Regular axes
            ax = fig.add_subplot(gs[0, col_idx])
            if metric in ["batch/end/feature_stats/norm_features_preactivation_policy_batch"]:
                ax.set_xscale("log")
            if metric == "ProbRatio_" and save_name.split("-")[2] == "Gravitar":
                ax.set_xscale("log")

        axes.append(ax)  # Append the created axis, either broken or regular, to the list

        # Create the boxplot
        bp = ax.boxplot(
            m,
            widths=0.6,
            patch_artist=True,
            vert=False,
            medianprops=medianprops,
            meanprops=meanlineprops,
            showmeans=True,
            meanline=True,
            flierprops=dict(
                marker="o", markerfacecolor="none", markeredgecolor="black", markersize=8, markeredgewidth=2
            ),
        )
        # Color the boxes and set titles
        if metric == "ProbRatio_" and g_max > 2.5:
            bp = bp[0]
        for idx, patch in enumerate(bp["boxes"]):
            patch.set_facecolor(colors[idx % len(colors)])  # Use modulo for color cycling

        ax.set_title(f"{metric_name[col_idx]}", fontsize=13.5)

        ax.grid(True, linestyle="--", alpha=0.5)
        if col_idx >= 0 and isinstance(ax, plt.Axes):
            ax.set_yticklabels([])
            ax.set_yticks([])
        if col_idx == 0:
            ax.set_yticks([y + 1 for y in range(len(m))], labels=key)

    mean_line = lines.Line2D([], [], color="red", linestyle="-.", linewidth=3, label="Mean")
    median_line = lines.Line2D([], [], color="black", linestyle="-.", linewidth=3, label="Median")
    outlier_marker = lines.Line2D(
        [0],
        [0],
        marker="o",
        color="w",
        label="Outliers",
        markerfacecolor="none",
        markeredgecolor="black",
        markersize=8,
        markeredgewidth=2,
    )
    axes[-1].legend(
        handles=[mean_line, median_line, outlier_marker],
        loc="upper center",  # This anchors the center of the legend at the provided coordinate
        bbox_to_anchor=(-2.8, 1.4),  # Anchors the legend above the plot
        borderaxespad=0.0,
        fontsize=12,
        ncol=3,
        frameon=False,
        handlelength=1,
        handletextpad=0.5,
        columnspacing=1,
    )
    fig.text(-0.015, 0.01, f'{save_name.split("-")[2]}', ha="left", va="bottom", fontsize=12, fontweight="bold")
    # plt.tight_layout()
    plt.savefig(f"../outputs/rlconf-plotting/plots/boxplot-{save_name}.pdf")
    plt.show()

In [None]:
def figure6_atari(name):
    criteria = {
        ("config", "env", "name"): [
            name,
        ],
        ("config", "optim", "anneal_linearly"): [False],
    }
    df_check = filter_df_by_criteria(atari_df, criteria)
    criteria2 = {
        ("config", "env", "name"): [
            name,
        ],
        ("config", "optim", "anneal_linearly"): [False],
    }
    f_name = f"6-{name}"
    f_name = f_name.replace("/", "-")
    df_check2 = filter_df_by_criteria(atari_df_capacity, criteria2)
    config_keywords = ["shared-trunk", "optimizer", "regularize-all-layers", "regularize", "baselines"]
    metrics = [
        "batch/perf/avg_return_raw",
        "ProbRatio_",
        "batch/end/dead_neurons/features_policy_batch",
        "batch/end/feature_stats/norm_features_preactivation_policy_batch",
        "batch/end/SVD/approximate_rank_pca/features_policy_batch",
    ]
    metric_name = [
        "Episode return",
        "Excess ratio",
        "Dead neurons policy",
        "Norm preactivation policy",
        "Feature rank policy (PCA)",
        "Plasticity loss policy",
    ]
    plot_metrics_by_keywords(
        df_check,
        df_check2,
        metrics,
        "capacity-checkpoint/last_capacity-epoch/last_capacity-minibatch/loss/policy",
        metric_name,
        config_keywords,
        save_name=f_name,
    )

In [None]:
figure6_atari("ALE/Gravitar-v5")
figure6_atari("ALE/Phoenix-v5")
figure6_atari("ALE/NameThisGame-v5")

## Correlation between the 5 ranks

In [None]:
def distance(y1, y2):
    """
    Calculate the discrete Fréchet distance between two curves represented by their y-values,
    assuming they share the same x-coordinates.
    """
    # Compute the L2 distance
    l2_distance = np.linalg.norm(y1 - y2)

    # Normalize the L2 distance
    normalized_distance = l2_distance / (np.sqrt(len(y1)) * 512)

    # The value at the bottom-right corner is the total distance
    return normalized_distance


def average_correlations_ranks(df, metrics_keys):
    config_keys = [key for key in df.columns if key.startswith("config/")]
    unique_configs = df[config_keys].drop_duplicates()

    kendall_sums = np.zeros((len(metrics_keys), len(metrics_keys)))
    spearman_sums = np.zeros((len(metrics_keys), len(metrics_keys)))
    pearson_sums = np.zeros((len(metrics_keys), len(metrics_keys)))
    distance_sums = np.zeros((len(metrics_keys), len(metrics_keys)))

    total_configs = len(unique_configs)
    processed_configs = 0

    for _, config_row in tqdm(unique_configs.iterrows(), total=len(unique_configs)):
        mask = create_mask_for_config(df, config_row, config_keys)
        filtered_df2 = df[mask]

        # Preliminary check for constant metrics
        has_constant_metric = False
        for metric in metrics_keys:
            y = filtered_df2[metric].dropna().values  # Drop NaN to get valid values
            if len(np.unique(y)) <= 1:  # Check if the metric is constant
                has_constant_metric = True
                break  # Exit the loop as we found a constant metric

        if has_constant_metric:
            continue  # Skip to the next configuration if a constant metric is found
        processed_configs += 1
        # If no constant metrics, proceed with calculations
        for i, metric1 in enumerate(metrics_keys):
            for j, metric2 in enumerate(metrics_keys):
                if i > j:  # Skip redundant calculations and self-comparison
                    kendall_sums[i, j] = kendall_sums[j, i]
                    spearman_sums[i, j] = spearman_sums[j, i]
                    pearson_sums[i, j] = pearson_sums[j, i]
                    distance_sums[i, j] = distance_sums[j, i]
                    continue
                y1 = filtered_df2[metric1].values
                y2 = filtered_df2[metric2].values
                if len(y1) != len(y2):
                    print(filtered_df2[metric1].values)
                    print(filtered_df2[metric2].values)
                valid_indices = ~np.isnan(y1) & ~np.isnan(y2)
                y1, y2 = y1[valid_indices], y2[valid_indices]

                if len(np.unique(y1)) > 1 and len(np.unique(y2)) > 1:
                    kendall_sums[i, j] += kendalltau(y1, y2)[0]
                    spearman_sums[i, j] += spearmanr(y1, y2)[0]
                    pearson_sums[i, j] += pearsonr(y1, y2)[0] if len(y1) > 1 else np.nan
                    # Assume frechet_dist function is defined
                    distance_sums[i, j] += 1 - distance(y1, y2)

        # Increment the count of processed configurations
        completion_percentage = (processed_configs / total_configs) * 100
    #         print(f"Processed {processed_configs}/{total_configs} configurations ({completion_percentage:.2f}%)")

    kendall_avg = kendall_sums / processed_configs
    spearman_avg = spearman_sums / processed_configs
    pearson_avg = pearson_sums / processed_configs
    distance_avg = distance_sums / processed_configs

    correlations = {
        "Kendall Tau": kendall_avg,
        "Spearman Rho": spearman_avg,
        "Pearson": pearson_avg,
        "1-distance": distance_avg,
    }

    return correlations

In [None]:
def plot_correlation_matrices(correlations, metrics_keys, name_figure=None):
    fig, axs = plt.subplots(2, 2, figsize=(40, 30))  # Adjust figsize as needed
    axs = axs.flatten()  # Flatten to easily iterate over
    plt.subplots_adjust(wspace=0.45, hspace=0.4)  # Adjust space between plots

    for i, (title, matrix) in enumerate(correlations.items()):
        df_matrix = pd.DataFrame(matrix, index=metrics_keys, columns=metrics_keys)
        row_averages = df_matrix.mean(axis=1)
        row_std_devs = df_matrix.std(axis=1)
        df_extended = df_matrix.copy()
        df_extended["Avg"] = row_averages
        df_extended["Std"] = np.nan

        pos = axs[i].get_position()
        cbar_width = 0.02
        cbar_height = 0.3
        cbar_xoffset = 0.03
        cbar_ax = fig.add_axes(
            [pos.x1 + cbar_xoffset, pos.y0 + (pos.height - cbar_height) / 2, cbar_width, cbar_height]
        )
        if i == 3:
            heatmap = sns.heatmap(
                df_extended,
                ax=axs[i],
                cmap="coolwarm",
                annot_kws={"size": 35},
                annot=True,
                fmt=".2f",
                cbar_ax=cbar_ax,
                square=False,
            )
        else:
            heatmap = sns.heatmap(
                df_extended,
                ax=axs[i],
                cmap="coolwarm",
                annot_kws={"size": 35},
                annot=True,
                fmt=".2f",
                cbar_ax=cbar_ax,
                square=False,
            )
        cbar = heatmap.collections[0].colorbar
        cbar.ax.tick_params(labelsize=25)
        facecolors = heatmap.collections[0].get_facecolors()
        avg_column_index = len(metrics_keys)  # 'Avg' is next to the last original metric column
        n_columns_incl_avg = len(metrics_keys) + 2

        for y in range(df_matrix.shape[0]):
            color_index = y * n_columns_incl_avg + avg_column_index
            color = facecolors[color_index]

            # Draw rectangle for 'Std' with the same color as 'Avg'
            axs[i].add_patch(plt.Rectangle((6, y), 1, 1, fill=True, facecolor=color, edgecolor="none"))
            axs[i].add_patch(plt.Rectangle((5, y), 1, 1, fill=True, facecolor="white", edgecolor="none"))
            if i == 3:
                # Annotate 'Std' value over the rectangle
                axs[i].text(
                    avg_column_index + 1.5,
                    y + 0.5,
                    f"{row_averages.iloc[y]:.2f}\n±{row_std_devs.iloc[y]:.2f}",
                    ha="center",
                    va="center",
                    color="black",
                    fontsize=35,
                )
            else:
                axs[i].text(
                    avg_column_index + 1.5,
                    y + 0.5,
                    f"{row_averages.iloc[y]:.2f}\n±{row_std_devs.iloc[y]:.2f}",
                    ha="center",
                    va="center",
                    color="black",
                    fontsize=35,
                )
            axs[i].text(
                5.5,
                y + 0.5,
                f"{row_averages.iloc[y]:.2f}",
                ha="center",
                va="center",
                color="white",
                fontsize=35,
                bbox=dict(facecolor="white", edgecolor="none", pad=10),
            )

        axs[i].set_title(title, fontsize=35)

        text = ["Vetterli", "PCA", "Kumar", "Lyle", "PyTorch"]
        tick_labels = text + ["", "Avg ± Std"]
        axs[i].set_xticks(np.arange(len(tick_labels)) + 0.5)
        axs[i].set_xticklabels(tick_labels, rotation=90, ha="center", fontsize=30)
        axs[i].set_yticks(np.arange(len(metrics_keys)) + 0.5)
        text = ["Vetterli", "PCA", "Kumar", "Lyle", "Pytorch"]
        axs[i].set_yticklabels(text, fontsize=30, va="center")
    plt.savefig(f"../outputs/rlconf-plotting/plots/correlation_{name_figure}.pdf", format="pdf", bbox_inches="tight")
    plt.show()

In [None]:
def figure_avg_ranks_atari():
    criteria = {
        ("config", "env", "name"): [
            "ALE/Phoenix-v5",
            "ALE/NameThisGame-v5",
            "ALE/DoubleDunk-v5",
            "ALE/Gravitar-v5",
            "ALE/BattleZone-v5",
            "ALE/Qbert-v5",
        ],
        ("config", "working_dir"): "baselines",
        ("config", "optim", "anneal_linearly"): [False],
    }

    filtered_df = filter_df_by_criteria(atari_df, criteria)

    metrics_keys2 = [
        "batch/end/SVD/effective_rank_vetterli/features_policy_batch",
        "batch/end/SVD/approximate_rank_pca/features_policy_batch",
        "batch/end/SVD/srank_kumar/features_policy_batch",
        "batch/end/SVD/feature_rank_lyle/features_policy_batch",
        "batch/end/SVD/pytorch_rank/features_policy_batch",
    ]
    correlations = average_correlations_ranks(filtered_df, metrics_keys2)
    plot_correlation_matrices(correlations, metrics_keys2, "avg_atari")
    return

In [None]:
figure_avg_ranks_atari()

In [None]:
def worst_correlations(df, metrics_keys):
    config_keys = [key for key in df.columns if key.startswith("config/")]
    unique_configs = df[config_keys].drop_duplicates()

    # Initialize with positive infinity to find minimum values
    kendall_worst = np.full((len(metrics_keys), len(metrics_keys)), 1.0)
    spearman_worst = np.full((len(metrics_keys), len(metrics_keys)), 1.0)
    pearson_worst = np.full((len(metrics_keys), len(metrics_keys)), 1.0)
    distance_worst = np.full((len(metrics_keys), len(metrics_keys)), 1.0)

    total_configs = len(unique_configs)
    processed_configs = 0
    for _, config_row in tqdm(unique_configs.iterrows(), total=total_configs):
        mask = create_mask_for_config(df, config_row, config_keys)
        filtered_df2 = df[mask]

        # Preliminary check for constant metrics
        has_constant_metric = False
        for metric in metrics_keys:
            y = filtered_df2[metric].dropna().values  # Drop NaN to get valid values
            if len(np.unique(y)) <= 1:  # Check if the metric is constant
                has_constant_metric = True
                break  # Exit the loop as we found a constant metric

        if has_constant_metric:
            continue  # Skip to the next configuration if a constant metric is found
        for i, metric1 in enumerate(metrics_keys):
            for j, metric2 in enumerate(metrics_keys):
                if i >= j:  # Skip redundant calculations and self-comparison
                    continue
                y1 = filtered_df2[metric1].values
                y2 = filtered_df2[metric2].values
                if len(y1) != len(y2):
                    print(filtered_df2[metric1].values)
                    print(filtered_df2[metric2].values)
                valid_indices = ~np.isnan(y1) & ~np.isnan(y2)
                y1, y2 = y1[valid_indices], y2[valid_indices]

                if len(y1) > 0 and len(y2) > 0:
                    kendall_corr = kendalltau(y1, y2)[0]
                    spearman_corr = spearmanr(y1, y2)[0]
                    pearson_corr = pearsonr(y1, y2)[0] if len(y1) > 1 else np.nan
                    # Assume frechet_dist function is defined and compatible
                    distance_corr = 1 - distance(y1, y2) if len(y1) > 1 else np.nan
                    # Update worst (min) correlation values
                    kendall_worst[i, j] = min(kendall_worst[i, j], kendall_corr)
                    spearman_worst[i, j] = min(spearman_worst[i, j], spearman_corr)
                    pearson_worst[i, j] = (
                        min(pearson_worst[i, j], pearson_corr) if not np.isnan(pearson_corr) else pearson_worst[i, j]
                    )
                    distance_worst[i, j] = (
                        min(distance_worst[i, j], distance_corr)
                        if not np.isnan(distance_corr)
                        else distance_worst[i, j]
                    )

    # Mirror the lower triangle to the upper for symmetric matrices
    for i in range(len(metrics_keys)):
        for j in range(i + 1, len(metrics_keys)):
            kendall_worst[j, i] = kendall_worst[i, j]
            spearman_worst[j, i] = spearman_worst[i, j]
            pearson_worst[j, i] = pearson_worst[i, j]
            distance_worst[j, i] = distance_worst[i, j]

    correlations = {
        "Kendall Tau": kendall_worst,
        "Spearman Rho": spearman_worst,
        "Pearson": pearson_worst,
        "1-distance": distance_worst,
    }

    return correlations


def figure_worst_ranks_atari():
    criteria = {
        ("config", "env", "name"): [
            "ALE/Phoenix-v5",
            "ALE/NameThisGame-v5",
            "ALE/DoubleDunk-v5",
            "ALE/Gravitar-v5",
            "ALE/BattleZone-v5",
            "ALE/Qbert-v5",
        ],
        ("config", "working_dir"): "baselines",
        ("config", "optim", "anneal_linearly"): [False],
    }

    filtered_df = filter_df_by_criteria(atari_df, criteria)
    metrics_keys2 = [
        "batch/end/SVD/effective_rank_vetterli/features_policy_batch",
        "batch/end/SVD/approximate_rank_pca/features_policy_batch",
        "batch/end/SVD/srank_kumar/features_policy_batch",
        "batch/end/SVD/feature_rank_lyle/features_policy_batch",
        "batch/end/SVD/pytorch_rank/features_policy_batch",
    ]
    worst_correlations_score = worst_correlations(filtered_df, metrics_keys2)
    plot_correlation_matrices(worst_correlations_score, metrics_keys2, "worst_atari")
    return

In [None]:
figure_worst_ranks_atari()

# MuJoCo

## Loading the runs from raw logs in a df and saving them combined .csv files
### Training
<div style="background-color: #F9F9F9; border-left: 5px solid #CC0000; padding: 10px; margin: 20px 0;">
    <strong> Skip to loading the combined raw logs .csv if you don't have the raw logs.
</div> 

In [None]:
# If the data is somewhere else create a symlink to it as your $PROJECT_DIR/outputs/rlconf
# ln -s <path-to-rlconf-containing-data> $PROJECT_DIR/outputs/rlconf

root_log_dir = f"../outputs/rlconf/solve/mujoco-ppo/control/shared-trunk/2024-03-01_15-56-01-512787"

all_keys = set()
log_subdirs = ["logs/models", "logs/minibatch", "logs/eval", "logs/epoch", "logs/batch"]
n = len(os.listdir(root_log_dir))

for subdir in log_subdirs:
    subdir_path = os.path.join(root_log_dir, subdir)
    if not os.path.exists(subdir_path):
        print(f"Subdirectory {subdir} does not exist.")
        continue
    for file_name in os.listdir(subdir_path):
        file_path = os.path.join(subdir_path, file_name)
        if file_path.endswith(".tar"):
            loaded_data = torch.load(file_path, map_location=torch.device("cpu"))
            all_keys.update(loaded_data.keys())
all_keys = list(all_keys)
all_keys.sort()


def get_nested_config_value(config, nested_key):
    """
    Retrieve a value from a nested dictionary using a list of keys.
    :param config: The configuration dictionary.
    :param nested_key: A list of keys representing the path to the desired value.
    :return: The value if found, None otherwise.
    """
    for key in nested_key:
        if isinstance(config, dict) and key in config:
            config = config[key]
        else:
            return None
    return config


def config_matches_criteria(config, criteria):
    """
    Check if the configuration matches the given criteria, supporting nested keys.
    :param config: The configuration dictionary from the YAML file.
    :param criteria: The criteria dictionary to match against, with nested keys as tuples.
    :return: True if the config matches the criteria, False otherwise.
    """
    for keys, value in criteria.items():
        # Support for nested keys represented as tuples in criteria
        nested_keys = keys if isinstance(keys, tuple) else (keys,)
        config_value = get_nested_config_value(config, nested_keys)
        if config_value != value:
            return False
    return True


def flatten_config(config, parent_key="", sep="/"):
    items = []
    for k, v in config.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_config(v, new_key, sep=sep).items())
        elif isinstance(v, list):
            for i, item in enumerate(v):
                items.extend(flatten_config({f"{new_key}{sep}{i}": item}, "", sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)


def find_matching_configs(root_path, criteria):
    """
    Find subfolders where the config matches the given criteria within each subfolder of root_path.
    """
    matching_paths = []

    # Iterate over each subfolder in the given root_path
    for category in tqdm(os.listdir(root_path), total=len(os.listdir(root_path)), desc="Find matching path"):
        category_path = os.path.join(root_path, category)
        if not os.path.isdir(category_path):
            continue  # Skip if not a directory
        if category == "baselines":
            # Iterate over each run folder within the category
            for run_folder in os.listdir(category_path):
                config_path = os.path.join(category_path, run_folder, "config", "config_resolved.yaml")

                # Check if the config_resolved.yaml file exists
                if os.path.exists(config_path):
                    with open(config_path, "r") as file:
                        config = yaml.safe_load(file)

                    # Check if the config matches the given criteria
                    if config_matches_criteria(config, criteria):
                        matching_paths.append(os.path.join(category_path, run_folder))
        else:
            for category2 in os.listdir(category_path):
                category_path2 = os.path.join(category_path, category2)
                for run_folder in os.listdir(category_path2):
                    config_path = os.path.join(category_path2, run_folder, "config", "config_resolved.yaml")

                    # Check if the config_resolved.yaml file exists
                    if os.path.exists(config_path):
                        with open(config_path, "r") as file:
                            config = yaml.safe_load(file)

                        # Check if the config matches the given criteria
                        if config_matches_criteria(config, criteria):
                            matching_paths.append(os.path.join(category_path2, run_folder))

    return matching_paths


def compile_data_debug(matching_paths, metric_keys):
    data = []
    nan_counter = 0  # Counter for NaN occurrences for the specific key

    for path in tqdm(matching_paths, total=len(matching_paths), desc="Filling df"):
        # Load and flatten the config
        config_path = os.path.join(path, "config", "config_resolved.yaml")
        with open(config_path, "r") as file:
            config = yaml.safe_load(file)
        flat_config = flatten_config(config)
        flat_config = {f"config/{key}": value for key, value in flat_config.items()}  # Prefix with "config/"

        # Navigate to the logs/batch folder and read metrics
        batch_folder = os.path.join(path, "logs", "batch")
        if os.path.exists(batch_folder):
            for log_file in os.listdir(batch_folder):
                log_path = os.path.join(batch_folder, log_file)
                if log_path.endswith(".tar"):
                    loaded_data = torch.load(log_path, map_location=torch.device("cpu"))
                    # Initialize a record with the flattened config
                    record = flat_config.copy()

                    for metric_key in metric_keys:
                        metric_value = loaded_data.get(metric_key, None)

                        record[metric_key] = metric_value

                    data.append(record)

    # Create a DataFrame from the compiled data
    df = pd.DataFrame(data)
    return df


# Example usage
path_folder = "../outputs/rlconf/solve/mujoco-ppo/"
# Define criteria with nested keys as tuples
criteria = {}
matching_paths = find_matching_configs(path_folder, criteria)

mujoco_df = compile_data_debug(matching_paths, all_keys)
print(mujoco_df.head())

mujoco_df.to_csv("../outputs/rlconf-plotting/combined-raw-logs/mujoco.csv")

In [None]:
# mujoco_df.to_csv("../outputs/rlconf-plotting/combined-raw-logs/mujoco.csv")

In [None]:
# At the time of submission to RLC
# Should expect 624 runs
# 240 baselines = 4 maps * 3 epochs * 2 activations * 2 lr schedules * 5 seeds
# 384 interventions = 4 interventions * 2 maps * 3 epochs * 2 activations * (2 lr schedules * 3 seeds + 1 lr schedule * 2 seeds)

### Plasticity
<div style="background-color: #F9F9F9; border-left: 5px solid #CC0000; padding: 10px; margin: 20px 0;">
    <strong> Skip to loading the combined raw logs .csv if you don't have the raw logs.
</div> 

In [None]:
root_log_dir = "../outputs/rlconf/capacity/mujoco-ppo/all/2024-04-08_12-52-56-917065/"

all_keys = set()
log_subdirs = ["logs/checkpoint", "logs/minibatch", "logs/model", "logs/epoch"]
n = len(os.listdir(root_log_dir))

for subdir in log_subdirs:
    subdir_path = os.path.join(root_log_dir, subdir)
    if not os.path.exists(subdir_path):
        print(f"Subdirectory {subdir} does not exist.")
        continue
    for file_name in os.listdir(subdir_path):
        file_path = os.path.join(subdir_path, file_name)
        if file_path.endswith(".tar"):
            loaded_data = torch.load(file_path, map_location=torch.device("cpu"))
            all_keys.update(loaded_data.keys())
all_keys = list(all_keys)
all_keys.sort()


def get_nested_config_value(config, nested_key):
    """
    Retrieve a value from a nested dictionary using a list of keys.
    :param config: The configuration dictionary.
    :param nested_key: A list of keys representing the path to the desired value.
    :return: The value if found, None otherwise.
    """
    for key in nested_key:
        if isinstance(config, dict) and key in config:
            config = config[key]
        else:
            return None
    return config


def config_matches_criteria(config, criteria):
    """
    Check if the configuration matches the given criteria, supporting nested keys.
    :param config: The configuration dictionary from the YAML file.
    :param criteria: The criteria dictionary to match against, with nested keys as tuples.
    :return: True if the config matches the criteria, False otherwise.
    """
    for keys, value in criteria.items():
        # Support for nested keys represented as tuples in criteria
        nested_keys = keys if isinstance(keys, tuple) else (keys,)
        config_value = get_nested_config_value(config, nested_keys)
        if config_value != value:
            return False
    return True


def flatten_config(config, parent_key="", sep="/"):
    items = []
    for k, v in config.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_config(v, new_key, sep=sep).items())
        elif isinstance(v, list):
            for i, item in enumerate(v):
                items.extend(flatten_config({f"{new_key}{sep}{i}": item}, "", sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)


def find_matching_configs(root_path, criteria):
    """
    Find subfolders where the config matches the given criteria within each subfolder of root_path.
    """
    matching_paths = []

    # Iterate over each subfolder in the given root_path
    for category in tqdm(os.listdir(root_path), total=len(os.listdir(root_path)), desc="Find matching path"):
        category_path = os.path.join(root_path, category)
        if not os.path.isdir(category_path):
            continue  # Skip if not a directory
        if category == "all":
            # Iterate over each run folder within the category
            for run_folder in os.listdir(category_path):
                config_path = os.path.join(category_path, run_folder, "config", "config_resolved.yaml")

                # Check if the config_resolved.yaml file exists
                if os.path.exists(config_path):
                    with open(config_path, "r") as file:
                        config = yaml.safe_load(file)

                    # Check if the config matches the given criteria
                    if config_matches_criteria(config, criteria):
                        matching_paths.append(os.path.join(category_path, run_folder))
        else:
            for category2 in os.listdir(category_path):
                category_path2 = os.path.join(category_path, category2)
                for run_folder in os.listdir(category_path2):
                    config_path = os.path.join(category_path2, run_folder, "config", "config_resolved.yaml")

                    # Check if the config_resolved.yaml file exists
                    if os.path.exists(config_path):
                        with open(config_path, "r") as file:
                            config = yaml.safe_load(file)

                        # Check if the config matches the given criteria
                        if config_matches_criteria(config, criteria):
                            matching_paths.append(os.path.join(category_path2, run_folder))

    return matching_paths


def compile_data_debug(matching_paths, metric_keys):
    data = []
    nan_counter = 0  # Counter for NaN occurrences for the specific key

    for path in tqdm(matching_paths, total=len(matching_paths), desc="Filling df"):
        # Load and flatten the config
        config_path = os.path.join(path, "config", "config_resolved.yaml")
        with open(config_path, "r") as file:
            config = yaml.safe_load(file)
        flat_config = flatten_config(config)
        flat_config = {f"config/{key}": value for key, value in flat_config.items()}  # Prefix with "config/"

        # Navigate to the logs/batch folder and read metrics
        batch_folder = os.path.join(path, "logs", "checkpoint")
        if os.path.exists(batch_folder):
            for log_file in os.listdir(batch_folder):
                log_path = os.path.join(batch_folder, log_file)
                if log_path.endswith(".tar"):
                    loaded_data = torch.load(log_path, map_location=torch.device("cpu"))
                    # Initialize a record with the flattened config
                    record = flat_config.copy()

                    for metric_key in metric_keys:
                        metric_value = loaded_data.get(metric_key, None)

                        record[metric_key] = metric_value

                    data.append(record)

    # Create a DataFrame from the compiled data
    df = pd.DataFrame(data)
    return df


# Example usage
path_folder = "../outputs/rlconf/capacity/mujoco-ppo/"
# Define criteria with nested keys as tuples
criteria = {}
matching_paths = find_matching_configs(path_folder, criteria)

mujoco_df_capacity = compile_data_debug(matching_paths, all_keys)
print(atari_df_capacity.head())
mujoco_df_capacity.to_csv("../outputs/rlconf-plotting/combined-raw-logs/mujoco-capacity.csv")

In [None]:
# mujoco_df_capacity.to_csv("../outputs/rlconf-plotting/combined-raw-logs/mujoco-capacity.csv")

In [None]:
# At the time of submission to RLC
# Should expect 624 runs
# 240 baselines = 4 maps * 3 epochs * 2 activations * 2 lr schedules * 5 seeds
# 384 interventions = 4 interventions * 2 maps * 3 epochs * 2 activations * (2 lr schedules * 3 seeds + 1 lr schedule * 2 seeds)

<div style="background-color: #F9F9F9; border-left: 5px solid #CC0000; padding: 10px; margin: 20px 0;">
    <strong> Load the combined raw logs .csv files here. If you did not generate them you can find the instructions to download them in $PROJECT_ROOT/outputs/README.md. 
</div> 

In [None]:
mujoco_df = pd.read_csv("../outputs/rlconf-plotting/combined-raw-logs/mujoco.csv")
mujoco_df.head()

In [None]:
mujoco_df_capacity = pd.read_csv("../outputs/rlconf-plotting/combined-raw-logs/mujoco-capacity.csv")
mujoco_df_capacity.head()

# Figure 1

In [None]:
def figure1_mujoco(name, activation):
    group_by_cols = ["config/optim/num_epochs"]

    metrics_keys = [
        "batch/perf/avg_return_raw",
        "batch/end/SVD/approximate_rank_pca/features_policy_batch",
        "batch/start/feature_stats/norm_features_preactivation_policy_batch",
        "batch/end/SVD/approximate_rank_pca/features_value_batch",
    ]
    criteria = {
        ("config", "env", "name"): [
            name,
        ],
        ("config", "optim", "anneal_linearly"): [False],
        ("config", "optim", "reset_state"): [False],
        ("config", "models", "share_features"): [False],
        ("config", "loss", "policy", "kwargs", "feature_trust_region_coef"): [0.0],
        ("config", "working_dir"): "baselines",
        ("config", "models", "activation"): activation,
    }
    fig1_df = filter_df_by_criteria(mujoco_df, criteria)
    criteria = {
        ("config", "env", "name"): [
            name,
        ],
        ("config", "optim", "anneal_linearly"): [False],
        ("config", "optim", "reset_state"): [False],
        ("config", "models", "share_features"): [False],
        ("config", "loss", "policy", "kwargs", "feature_trust_region_coef"): [0.0],
        ("config", "solve_dir"): "baselines",
        ("config", "models", "activation"): activation,
    }
    fig1_df_plasticity = filter_df_by_criteria(mujoco_df_capacity, criteria)
    f_name = f"1-{name}-{activation}"
    f_name = f_name.replace("/", "-")
    plot_shaded_metrics_side_by_side_plasticity_smooth(
        fig1_df,
        fig1_df_plasticity,
        group_by_cols,
        "global_step",
        metrics_keys,
        [
            "capacity-checkpoint/last_capacity-epoch/last_capacity-minibatch/loss/policy",
            "capacity-checkpoint/last_capacity-epoch/last_capacity-minibatch/loss/value",
        ],
        [
            "batch/start/feature_stats/norm_features_preactivation_policy_batch",
            "capacity-checkpoint/last_capacity-epoch/last_capacity-minibatch/loss/policy",
        ],
        [
            "Episode return",
            "Feature rank policy (PCA)",
            "Norm preactivation policy",
            "Feature rank critic (PCA)",
            "Plasticity loss policy",
            "Plasticity loss critic",
        ],
        title=name,
        name_=f_name,
    )
    return

## Tanh

In [None]:
figure1_mujoco("Ant-v4", "Tanh")
figure1_mujoco("Hopper-v4", "Tanh")
figure1_mujoco("Humanoid-v4", "Tanh")
figure1_mujoco("HalfCheetah-v4", "Tanh")

## ReLU

In [None]:
figure1_mujoco("Ant-v4", "ReLU")
figure1_mujoco("Hopper-v4", "ReLU")
figure1_mujoco("Humanoid-v4", "ReLU")
figure1_mujoco("HalfCheetah-v4", "ReLU")

In [None]:
# Function to plot in figure 1
def plot_shaded_metrics_side_by_side_plasticity_smooth_mujoco(
    df, df2, group_by_cols, x_col, metrics_keys, metrics_keys2, log_key="", subplot_titles=None, title="", name_=""
):
    if name_.split("-")[-1] == "bis":
        color_cycler = cycler(color=["#D55E00", "#56B4E9", "#0072B2", "#009E73", "#E69F00"])
        plt.rc("axes", prop_cycle=color_cycler)
    else:
        plt.rc("axes", prop_cycle=line_cycler)
    if isinstance(group_by_cols, list) and len(group_by_cols) == 1:
        group_by = group_by_cols[0]
    else:
        group_by = group_by_cols

    grouped = df.groupby(group_by)
    num_metrics = len(metrics_keys) + 2
    alpha = 0.05
    fig, axs = plt.subplots(1, num_metrics, figsize=(40, 8), sharex=True)
    special_keys = [
        "batch/perf/avg_return_raw",
        "batch/end/SVD/approximate_rank_pca/features_policy_batch",
        "batch/start/feature_stats/norm_features_preactivation_policy_batch",
        "batch/end/SVD/approximate_rank_pca/features_value_batch",
    ]
    for metric_idx, metric_key in enumerate(metrics_keys):
        for name, group in grouped:
            if metric_key in special_keys:
                # Calculate EMA for special keys
                ema_mean = group.groupby(x_col)[metric_key].mean().ewm(alpha=alpha, adjust=False).mean()
                ema_min = group.groupby(x_col)[metric_key].min().ewm(alpha=alpha, adjust=False).mean()
                ema_max = group.groupby(x_col)[metric_key].max().ewm(alpha=alpha, adjust=False).mean()

                if metric_idx == 3:
                    axs[metric_idx + 1].plot(
                        ema_mean.index, ema_mean, label=name if metric_idx == 0 else "", linewidth=4.0
                    )
                    axs[metric_idx + 1].fill_between(ema_mean.index, ema_min, ema_max, alpha=0.3)
                else:
                    axs[metric_idx].plot(ema_mean.index, ema_mean, label=name if metric_idx == 0 else "", linewidth=4.0)
                    axs[metric_idx].fill_between(ema_mean.index, ema_min, ema_max, alpha=0.3)
            else:
                # Here, adjust the groupby for mean, min, and max to directly use the column name
                mean = group.groupby(x_col)[metric_key].mean()
                min_val = group.groupby(x_col)[metric_key].min()
                max_val = group.groupby(x_col)[metric_key].max()
                if metric_idx == 3:
                    axs[metric_idx + 1].plot(mean.index, mean, label=name if metric_idx == 0 else "", linewidth=4.0)
                    axs[metric_idx + 1].fill_between(
                        mean.index, min_val, max_val, alpha=0.3
                    )  # Use min and max for shaded area
                else:
                    axs[metric_idx].plot(mean.index, mean, label=name if metric_idx == 0 else "", linewidth=4.0)
                    axs[metric_idx].fill_between(
                        mean.index, min_val, max_val, alpha=0.3
                    )  # Use min and max for shaded area
        if metric_idx == 3:
            axs[metric_idx + 1].set_xlabel("Environment steps", fontsize=25)
            axs[metric_idx + 1].set_ylabel(subplot_titles[metric_idx], fontsize=30)
            axs[metric_idx + 1].tick_params(axis="x", labelsize=30)
            axs[metric_idx + 1].tick_params(axis="y", labelsize=30)
            axs[metric_idx + 1].xaxis.get_offset_text().set_fontsize(25)
            axs[metric_idx + 1].yaxis.get_offset_text().set_fontsize(25)
        else:
            axs[metric_idx].set_xlabel("Environment steps", fontsize=25)
            axs[metric_idx].set_ylabel(subplot_titles[metric_idx], fontsize=30)
            axs[metric_idx].tick_params(axis="x", labelsize=30)
            axs[metric_idx].tick_params(axis="y", labelsize=30)
            axs[metric_idx].xaxis.get_offset_text().set_fontsize(25)
            axs[metric_idx].yaxis.get_offset_text().set_fontsize(25)

        if metric_key in log_key:
            axs[metric_idx].set_yscale("log")  # Apply log scale for the specific subplot

        axs[metric_idx].grid(linestyle="dotted")

    for metric_idx2, metric_key2 in enumerate(metrics_keys2):
        if metric_idx2 == 0:
            grouped = df2.groupby(group_by)
            for name, group in grouped:
                # Here, adjust the groupby for mean, min, and max to directly use the column name
                mean = group.groupby("capacity-counters/env_steps")[metric_key2].mean()
                min_val = group.groupby("capacity-counters/env_steps")[metric_key2].min()
                max_val = group.groupby("capacity-counters/env_steps")[metric_key2].max()

                axs[3].plot(mean.index, mean, label=name if metric_idx == 0 else "", linewidth=4.0)
                axs[3].fill_between(mean.index, min_val, max_val, alpha=0.3)  # Use min and max for shaded area
            axs[len(metrics_keys) - 1].set_xlabel("Environment steps", fontsize=25)
            axs[len(metrics_keys) - 1].set_ylabel(subplot_titles[len(metrics_keys)], fontsize=30)
            axs[len(metrics_keys) - 1].tick_params(axis="x", labelsize=30)
            axs[len(metrics_keys) - 1].tick_params(axis="y", labelsize=30)
            axs[len(metrics_keys) - 1].xaxis.get_offset_text().set_fontsize(25)
            axs[len(metrics_keys) - 1].yaxis.get_offset_text().set_fontsize(25)
            if metric_key2 in log_key:
                axs[len(metrics_keys) - 1].set_yscale("log")  # Apply log scale for the specific subplot
        else:
            grouped = df2.groupby(group_by)
            for name, group in grouped:
                # Here, adjust the groupby for mean, min, and max to directly use the column name
                mean = group.groupby("capacity-counters/env_steps")[metric_key2].mean()
                min_val = group.groupby("capacity-counters/env_steps")[metric_key2].min()
                max_val = group.groupby("capacity-counters/env_steps")[metric_key2].max()

                axs[len(metrics_keys) + 1].plot(mean.index, mean, label=name if metric_idx == 0 else "", linewidth=4.0)
                axs[len(metrics_keys) + 1].fill_between(
                    mean.index, min_val, max_val, alpha=0.3
                )  # Use min and max for shaded area
            axs[len(metrics_keys) + 1].set_xlabel("Environment steps", fontsize=25)
            axs[len(metrics_keys) + 1].set_ylabel(subplot_titles[len(metrics_keys) + 1], fontsize=30)
            axs[len(metrics_keys) + 1].tick_params(axis="x", labelsize=30)
            axs[len(metrics_keys) + 1].tick_params(axis="y", labelsize=30)
            axs[len(metrics_keys) + 1].xaxis.get_offset_text().set_fontsize(25)
            axs[len(metrics_keys) + 1].yaxis.get_offset_text().set_fontsize(25)
            if metric_key2 in log_key:
                axs[len(metrics_keys) + 1].set_yscale("log")  # Apply log scale for the specific subplot
        axs[len(metrics_keys) + metric_idx2].grid(linestyle="dotted")
    # fig.suptitle(title, fontsize=30)
    fig.tight_layout(pad=2.0)
    handles, labels = axs[0].get_legend_handles_labels()
    print(labels)
    # labels = [label + " epochs" for label in labels]
    lab = []
    for label in labels:
        if label.startswith("(") and label.endswith(")"):
            axs[0].text(
                0.02,
                0.90,
                f"{title}\n{df['config/optim/num_epochs'].values[0]} epochs",
                fontsize=20,
                fontweight="bold",
                transform=axs[0].transAxes,
            )
            label = label[1:-1]
            parts = label.split(",")
            # Assign and convert each part to the correct type
            share = parts[0]
            trust = float(parts[1])
            reset = parts[2]
            all_lay = parts[3]
            if share == "True":
                lab.append(f"Share actor and critic features")
            else:
                if reset == " True":
                    lab.append("Reset Adam")
                else:
                    if trust == 1:
                        if all_lay == "False":
                            lab.append(f"Regularize last preactivation")
                        else:
                            lab.append(f"Regularize all preactivations")
                    else:
                        lab.append("No intervention")
        else:
            lab.append(label + " epochs")
            axs[0].text(0.02, 0.95, f"{title}", fontsize=20, fontweight="bold", transform=axs[0].transAxes)
    legend_properties = {"weight": "bold"}

    if len(lab) == 4:
        box_to_anchor = (0.25, 1.05)

    elif len(lab) == 3:
        box_to_anchor = (0.4, 1.05)
    else:
        box_to_anchor = (0.15, 1.05)
    fig.legend(
        handles,
        lab,
        loc="upper left",
        bbox_to_anchor=box_to_anchor,
        borderaxespad=0.0,
        fontsize=30,
        ncol=len(lab),
        frameon=False,
        handlelength=1,
        handletextpad=0.5,
        columnspacing=1,
    )
    plt.subplots_adjust(right=1, top=0.90)  # Adjust subplot params to make room for the legend
    plt.savefig(f"../outputs/rlconf-plotting/plots/Figure-1-{name_}.pdf", format="pdf", bbox_inches="tight")
    plt.show()


def figure1_mujoco_bis(name, num_epochs, activation):
    group_by_cols = [
        "config/models/share_features",
        "config/loss/policy/kwargs/feature_trust_region_coef",
        "config/optim/reset_state",
        "config/loss/policy/kwargs/feature_trust_all_layers",
    ]

    metrics_keys = [
        "batch/perf/avg_return_raw",
        "batch/end/SVD/approximate_rank_pca/features_policy_batch",
        "batch/start/feature_stats/norm_features_preactivation_policy_batch",
        "batch/end/SVD/approximate_rank_pca/features_value_batch",
    ]
    criteria = {
        ("config", "env", "name"): [
            name,
        ],
        ("config", "optim", "num_epochs"): [num_epochs],
        ("config", "optim", "anneal_linearly"): [False],
        ("config", "algo"): ["ppo-clip"],
        ("config", "loss", "policy", "kwargs", "feature_trust_region_coef"): [0, 1, 10],
        ("config", "models", "activation"): activation,
        ("config", "working_dir"): [
            "baseline",
            "experiment",
            "optimizer",
            "regularize",
            "regularize-all-layers",
            "shared-trunk",
        ],
    }
    fig1_df = filter_df_by_criteria(mujoco_df, criteria)
    criteria = {
        ("config", "env", "name"): [
            name,
        ],
        ("config", "optim", "num_epochs"): [num_epochs],
        ("config", "optim", "anneal_linearly"): [False],
        ("config", "algo"): ["ppo-clip"],
        ("config", "loss", "policy", "kwargs", "feature_trust_region_coef"): [0, 1, 10],
        ("config", "models", "activation"): activation,
        ("config", "solve_dir"): [
            "baseline",
            "experiment",
            "optimizer",
            "regularize",
            "regularize-all-layers",
            "shared-trunk",
        ],
    }
    fig1_df_plasticity = filter_df_by_criteria(mujoco_df_capacity, criteria)
    f_name = f"2-{name}-{num_epochs}-{activation}-epochs-bis"
    f_name = f_name.replace("/", "-")
    plot_shaded_metrics_side_by_side_plasticity_smooth_mujoco(
        fig1_df,
        fig1_df_plasticity,
        group_by_cols,
        "global_step",
        metrics_keys,
        [
            "capacity-checkpoint/last_capacity-epoch/last_capacity-minibatch/loss/policy",
            "capacity-checkpoint/last_capacity-epoch/last_capacity-minibatch/loss/value",
        ],
        "batch/start/feature_stats/norm_features_preactivation_policy_batch",
        [
            "Episode return",
            "Feature rank policy (PCA)",
            "Norm preactivation policy",
            "Feature rank critic (PCA)",
            "Plasticity loss policy",
            "Plasticity loss critic",
        ],
        title=name,
        name_=f_name,
    )
    return

In [None]:
figure1_mujoco_bis("Hopper-v4", 10, "Tanh")
figure1_mujoco_bis("Hopper-v4", 15, "Tanh")
figure1_mujoco_bis("Hopper-v4", 20, "Tanh")
figure1_mujoco_bis("Humanoid-v4", 10, "Tanh")
figure1_mujoco_bis("Humanoid-v4", 15, "Tanh")
figure1_mujoco_bis("Humanoid-v4", 20, "Tanh")

In [None]:
figure1_mujoco_bis("Hopper-v4", 10, "ReLU")
figure1_mujoco_bis("Hopper-v4", 15, "ReLU")
figure1_mujoco_bis("Hopper-v4", 20, "ReLU")
figure1_mujoco_bis("Humanoid-v4", 10, "ReLU")
figure1_mujoco_bis("Humanoid-v4", 15, "ReLU")
figure1_mujoco_bis("Humanoid-v4", 20, "ReLU")

## Figure 2

In [None]:
def figure2_mujoco(name, activation):
    metrics_keys2 = [
        "batch/first_epoch/first_minibatch/loss/entropy",
        "batch/start/action_diversity/policy_variance",
        "batch/start/dead_neurons/features_policy_batch",
    ]
    y_axis = ["Entropy", "Policy variance", "Dead neurons policy"]
    f_name = f"1-{name}-{activation}"
    f_name = f_name.replace("/", "-")
    criteria = {
        ("config", "env", "name"): [
            name,
        ],
        ("config", "optim", "anneal_linearly"): [False],
        ("config", "optim", "reset_state"): [False],
        ("config", "models", "share_features"): [False],
        ("config", "loss", "policy", "kwargs", "feature_trust_region_coef"): [0.0],
        ("config", "working_dir"): "baselines",
        ("config", "models", "activation"): activation,
    }
    fig2_df = filter_df_by_criteria(mujoco_df, criteria)
    group_by_cols = ["config/optim/num_epochs"]
    plot_shaded_metrics_side_by_side_smooth(
        fig2_df,
        group_by_cols,
        "global_step",
        metrics_keys2,
        "batch/start/action_diversity/policy_variance",
        y_axis,
        title=name,
        name_=f_name,
    )
    return

## ReLU

In [None]:
figure2_mujoco("HalfCheetah-v4", "ReLU")
figure2_mujoco("Humanoid-v4", "ReLU")
figure2_mujoco("Ant-v4", "ReLU")
figure2_mujoco("Hopper-v4", "ReLU")

## Tanh

In [None]:
figure2_mujoco("HalfCheetah-v4", "Tanh")
figure2_mujoco("Humanoid-v4", "Tanh")
figure2_mujoco("Ant-v4", "Tanh")
figure2_mujoco("Hopper-v4", "Tanh")

## Figure 3

In [None]:
def plot_shaded_metrics_side_by_side_plasticity_fig3(
    df, df2, group_by_cols, x_col, metrics_keys, metrics_keys2, log_key="", subplot_titles=None, title="", name_=""
):
    plt.rc("axes", prop_cycle=line_cycler)
    if isinstance(group_by_cols, list) and len(group_by_cols) == 1:
        group_by = group_by_cols[0]
    else:
        group_by = group_by_cols

    grouped = df.groupby(group_by)
    num_metrics = len(metrics_keys) + 1

    fig, axs = plt.subplots(1, num_metrics, figsize=(40, 8), sharex=True)
    alpha = 0.05
    positions = {
        "batch/end/SVD/approximate_rank_pca/features_policy_batch": 0,
        "batch/perf/avg_return_raw": 2,
        "batch/diff/avg_prob_ratio_below_epsilon": 3,
        "batch/end/action_diversity/policy_variance": 4,
        "batch/last_epoch/last_minibatch/loss/loss_policy": 5,
        "capacity-checkpoint/last_capacity-epoch/last_capacity-minibatch/loss/policy": 1,
    }
    legend_info = {}  # Dictionary to track unique legends
    for metric_idx, metric_key in enumerate(metrics_keys):
        for name, group in grouped:
            if metric_key in [
                "batch/end/SVD/approximate_rank_pca/features_policy_batch",
                "batch/perf/avg_return_raw",
                "batch/diff/avg_prob_ratio_below_epsilon",
                "batch/end/action_diversity/policy_variance",
                "batch/last_epoch/last_minibatch/loss/loss_policy",
            ]:
                env_name = group["config/env/name"].iloc[0]  # Assuming uniform configuration within the group
                num_epochs = group["config/optim/num_epochs"].iloc[0]
                activation = group["config/models/activation"].iloc[0]
                seed = group["config/seed"].iloc[0]
                label = f"{env_name}, {num_epochs} epochs, seed={seed}"

                group = group.dropna(subset=[metric_key])
                if metric_key == "batch/last_epoch/last_minibatch/loss/loss_policy":
                    ema_mean = -group.groupby(x_col)[metric_key].mean().ewm(alpha=alpha, adjust=False).mean()
                    ema_min = -group.groupby(x_col)[metric_key].min().ewm(alpha=alpha, adjust=False).mean()
                    ema_max = -group.groupby(x_col)[metric_key].max().ewm(alpha=alpha, adjust=False).mean()
                else:
                    ema_mean = group.groupby(x_col)[metric_key].mean().ewm(alpha=alpha, adjust=False).mean()
                    ema_min = group.groupby(x_col)[metric_key].min().ewm(alpha=alpha, adjust=False).mean()
                    ema_max = group.groupby(x_col)[metric_key].max().ewm(alpha=alpha, adjust=False).mean()
                (line,) = axs[positions[metric_key]].plot(ema_mean.index, ema_mean, label=label, linewidth=4.0)
                axs[positions[metric_key]].fill_between(ema_mean.index, ema_min, ema_max, alpha=0.3)
                axs[positions[metric_key]].set_xlabel("Environment steps", fontsize=25)
                axs[positions[metric_key]].set_ylabel(subplot_titles[positions[metric_key]], fontsize=25)
                axs[positions[metric_key]].tick_params(axis="x", labelsize=25)
                axs[positions[metric_key]].tick_params(axis="y", labelsize=25)
                axs[positions[metric_key]].xaxis.get_offset_text().set_fontsize(25)
                axs[positions[metric_key]].yaxis.get_offset_text().set_fontsize(25)
                if label not in legend_info:
                    legend_info[label] = line

        if metric_key in log_key:
            axs[positions[metric_key]].set_yscale("log")  # Apply log scale for the specific subplot
        if metric_key == "batch/last_epoch/last_minibatch/loss/loss_policy":
            axs[positions[metric_key]].set_ylim(bottom=-0.1)
            axs[positions[metric_key]].set_ylim(top=0.1)

        axs[positions[metric_key]].grid(linestyle="dotted")

    for metric_idx2, metric_key2 in enumerate(metrics_keys2):
        if metric_idx2 == 0:
            grouped = df2.groupby(group_by)
            for name, group in grouped:
                env_name = group["config/env/name"].iloc[0]  # Assuming uniform configuration within the group
                num_epochs = group["config/optim/num_epochs"].iloc[0]
                activation = group["config/models/activation"].iloc[0]
                label = f"{env_name}, {num_epochs} epochs, {activation}"
                # Here, adjust the groupby for mean, min, and max to directly use the column name
                mean = group.groupby("capacity-counters/env_steps")[metric_key2].mean()
                min_val = group.groupby("capacity-counters/env_steps")[metric_key2].min()
                max_val = group.groupby("capacity-counters/env_steps")[metric_key2].max()

                (line,) = axs[positions[metric_key2]].plot(mean.index, mean, label=label, linewidth=4.0)
                axs[positions[metric_key2]].fill_between(
                    mean.index, min_val, max_val, alpha=0.3
                )  # Use min and max for shaded area
            axs[positions[metric_key2]].set_xlabel("Environment steps", fontsize=25)
            axs[positions[metric_key2]].set_ylabel(subplot_titles[positions[metric_key2]], fontsize=25)
            axs[positions[metric_key2]].tick_params(axis="x", labelsize=25)
            axs[positions[metric_key2]].tick_params(axis="y", labelsize=25)
            axs[positions[metric_key2]].xaxis.get_offset_text().set_fontsize(25)
            axs[positions[metric_key2]].yaxis.get_offset_text().set_fontsize(25)
        if metric_key2 in log_key:
            axs[positions[metric_key2]].set_yscale("log")  # Apply log scale for the specific subplot

        axs[positions[metric_key2]].grid(linestyle="dotted")
    fig.suptitle(title, fontsize=30)
    fig.tight_layout(pad=2.0)
    handles, labels = axs[0].get_legend_handles_labels()
    labels = [label + " epochs" for label in labels]
    # fig.legend(handles, labels, loc="upper left", bbox_to_anchor=(0.95, 1.00), borderaxespad=0.0,fontsize=20)
    # plt.subplots_adjust(right=1,top=0.8)  # Adjust subplot params to make room for the legend
    color_legend = [
        mlines.Line2D([], [], color="#E69F00", linestyle="solid", markersize=20, label="4 epochs", linewidth=4),
        mlines.Line2D([], [], color="#56B4E9", linestyle="solid", markersize=20, label="6 epochs", linewidth=4),
        mlines.Line2D([], [], color="#009E73", linestyle="solid", markersize=20, label="8 epochs", linewidth=4),
    ]
    plt.legend(
        handles=list(legend_info.values()),
        labels=list(legend_info.keys()),
        loc="upper center",  # This anchors the center of the legend at the provided coordinate
        bbox_to_anchor=(-2.8, 1.2),  # Anchors the legend above the plot
        borderaxespad=0.0,
        fontsize=25,
        ncol=4,
        frameon=False,
        handlelength=1,
        handletextpad=0.5,
        columnspacing=1,
    )
    plt.savefig(f"../outputs/rlconf-plotting/plots/Figure-3-{name_}.pdf", format="pdf", bbox_inches="tight")
    plt.show()

In [None]:
def figure3_mujoco():
    criteria = {
        ("config", "working_dir"): [
            "baselines/2024-02-27_19-45-33-114578",
            "baselines/2024-02-27_19-44-56-470648",
            "baselines/2024-02-27_19-44-15-056169",
            "baselines/2024-02-27_09-15-25-975421",
        ],
    }

    one_game_5_df = filter_df_by_criteria(mujoco_df, criteria)

    criteria = {
        ("config", "solve_dir"): [
            "baselines/2024-02-27_19-45-33-114578",
            "baselines/2024-02-27_19-44-56-470648",
            "baselines/2024-02-27_19-44-15-056169",
            "baselines/2024-02-27_09-15-25-975421",
        ],
    }

    one_game_5_df2 = filter_df_by_criteria(mujoco_df_capacity, criteria)

    group_by_cols = ["config/working_dir"]

    metrics_keys2 = [
        "batch/end/SVD/approximate_rank_pca/features_policy_batch",
        "batch/perf/avg_return_raw",
        "batch/diff/avg_prob_ratio_below_epsilon",
        "batch/end/action_diversity/policy_variance",
        "batch/last_epoch/last_minibatch/loss/loss_policy",
    ]

    plot_shaded_metrics_side_by_side_plasticity_fig3(
        one_game_5_df,
        one_game_5_df2,
        group_by_cols,
        "global_step",
        metrics_keys2,
        ["capacity-checkpoint/last_capacity-epoch/last_capacity-minibatch/loss/policy"],
        [
            "batch/start/feature_stats/norm_features_preactivation_policy_batch",
            "batch/end/action_diversity/policy_variance",
            "capacity-checkpoint/last_capacity-epoch/last_capacity-minibatch/loss/policy",
        ],
        [
            "Rank policy (PCA)",
            "Plasticity loss policy",
            "Episode return",
            r"Avg of prob ratios < 1 - $\epsilon$",
            "Policy variance",
            "PPO-Clip objective",
        ],
        title="",
        name_="5-mujoco",
    )
    return

In [None]:
figure3_mujoco()

## Figure 5

In [None]:
def figure5_mujoco(name, activation):
    criteria = {
        ("config", "env", "name"): name,
        ("config", "working_dir"): ["baselines"],
        ("config", "optim", "anneal_linearly"): False,
        ("config", "optim", "num_epochs"): [10, 15, 20],
        ("config", "models", "activation"): activation,
    }
    df_games = filter_df_by_criteria(mujoco_df, criteria)

    criteria = {
        ("config", "env", "name"): name,
        ("config", "solve_dir"): ["baselines"],
        ("config", "optim", "anneal_linearly"): False,
        ("config", "optim", "num_epochs"): [10, 15, 20],
        ("config", "models", "activation"): activation,
    }
    df_games_2 = filter_df_by_criteria(mujoco_df_capacity, criteria)

    plot_correlation(
        df_games,
        df_games_2,
        "batch/diff/avg_prob_ratio_below_epsilon",
        "batch/diff/min_prob_ratio_below_epsilon",
        "",
        [
            "batch/start/dead_neurons/features_policy_batch",
            "batch/start/SVD/approximate_rank_pca/features_policy_batch",
            "batch/start/feature_stats/norm_features_preactivation_policy_batch",
        ],
        r"Avg of prob ratios < 1 - $\epsilon$",
        ["Dead neurons policy", "Feature rank policy (PCA)", "Feature preactivation norm"],
        log=0,
    )
    return

## ReLU

In [None]:
figure5_mujoco(["Ant-v4"], "ReLU")
figure5_mujoco(["Humanoid-v4"], "ReLU")
figure5_mujoco(["Hopper-v4"], "ReLU")
figure5_mujoco(["HalfCheetah-v4"], "ReLU")

## Tanh

In [None]:
figure5_mujoco(["Ant-v4"], "Tanh")
figure5_mujoco(["Humanoid-v4"], "Tanh")
figure5_mujoco(["Hopper-v4"], "Tanh")
figure5_mujoco(["HalfCheetah-v4"], "Tanh")

## Figure 6

In [None]:
key = [
    "Share actor and critic features",
    "Reset Adam",
    "Regularize all preactivations",
    "Regularize last preactivation",
    "No invervention",
]


def plot_metrics_by_keywords_mujoco(df, df2, metrics, metric2, metric_name, keywords, steps_window=20000, save_name=""):
    """
    Plot box plots for given metrics, aggregating data by keywords found in 'config/working_dir'.

    :param df: DataFrame containing the dataset.
    :param metrics: List of metric keys to plot.
    :param keywords: List of keywords to filter configurations by their working directory.
    :param steps_window: Look at data within this window from the max global_step for averaging.
    """
    # Prepare the aggregated data storage
    keys = [
        "baselines",
        "ppo-kl",
        "ppo-early-stop",
        "all",
        "optimizer",
        "regularize",
        "regularize-all-layers",
        "shared-trunk",
    ]
    all_metrics = metrics + [metric2]
    aggregated_data = {keyword: {metric: [] for metric in all_metrics} for keyword in keywords}

    config_keys = [key for key in df.columns if key.startswith("config/")]
    unique_configs = df[config_keys].drop_duplicates()
    colors = ["#E69F00", "#56B4E9", "#009E73", "#0072B2", "#D55E00", "#CC79A7", "#F0E442"]
    return_nb = []
    N_nb = []
    return_steps = []
    ratio_nb = []
    ratio_N_nb = []
    ratio_steps = []
    return_info = []
    ratio_info = []
    ratio_res = []
    return_res = []
    for _, config_row in unique_configs.iterrows():
        mask = create_mask_for_config(df, config_row, config_keys)
        filtered_df = df.loc[mask].copy()
        for keyword in keywords:
            if "algo" in config_row["config/working_dir"]:
                continue
            else:
                if keyword in config_row["config/working_dir"]:
                    if config_row["config/loss/policy/kwargs/feature_trust_all_layers"] == True:
                        keyword_ = "regularize-all-layers"
                    else:
                        keyword_ = keyword
                    for metric in metrics:
                        if metric == "ProbRatio":
                            metric_key1 = "batch/diff/avg_prob_ratio_above_epsilon"
                            metric_key2 = "batch/diff/avg_prob_ratio_below_epsilon"

                            # Remove rows where either metric_key1 or metric_key2 is NaN
                            # filtered_df[metric_key1].replace([np.inf,float('inf')], 100, inplace=True)
                            # filtered_df[metric_key2].replace([0], 0.01, inplace=True)
                            df_no_nan = filtered_df.dropna(subset=[metric_key1, metric_key2])
                            df_sorted = df_no_nan.sort_values(by="global_step", ascending=True)  # Sort by 'global_step'
                            B_max = df_sorted.index.max()
                            # Get max 'global_step' before and after dropping NaNs for verification
                            unfiltered_max = filtered_df["global_step"].max()
                            filtered_max = df_sorted["global_step"].max()

                            df_steps_max = df_no_nan.sort_values(by="global_step", ascending=False)[
                                "global_step"
                            ].values.tolist()
                            idx_max = 0

                            five_percent_of_unfiltered_max = int(unfiltered_max * 0.05)

                            # Initial threshold calculation
                            threshold_step = filtered_max - five_percent_of_unfiltered_max
                            first_index_above = df_sorted[df_sorted["global_step"] >= threshold_step].index.min()
                            B_min = first_index_above
                            result_df = df_sorted.loc[B_min:]
                            while (
                                result_df[metric_key1].count() < 10
                            ):  # Minimum number of points needed to consider the window.
                                idx_max += 1
                                threshold_step = df_steps_max[idx_max] - five_percent_of_unfiltered_max
                                # Recalculate B_min based on the new threshold
                                first_index_above = df_sorted[df_sorted["global_step"] >= threshold_step].index.min()

                                B_min = first_index_above
                                B_max = df_sorted[
                                    df_sorted["global_step"] <= df_steps_max[idx_max]
                                ].index.max()  # Redefine B_max to ensure it stays within the new filtered_max

                                # Update result_df with the new range
                                result_df = df_sorted.loc[B_min:B_max]

                                # If filtered_max reaches the minimum global_step or no further reduction is possible, break the loop
                                if filtered_max <= df_sorted["global_step"].min():
                                    break

                            N = B_max - B_min + 1

                            # Store additional info if needed
                            ratio_steps.append((result_df["global_step"].min(), result_df["global_step"].max()))

                            ratio_info.append(
                                (
                                    df_sorted["config/optim/num_epochs"].values[0],
                                    df_sorted["config/working_dir"].values[0],
                                )
                            )

                            # Prepare and calculate combined metric
                            f_1 = np.clip(result_df[metric_key1], a_min=-np.inf, a_max=10e12)
                            f_2 = np.clip(result_df[metric_key2], a_max=np.inf, a_min=10e-12)
                            ratio_nb.append((result_df[metric_key1].count(), result_df[metric_key2].count()))
                            ratio_N_nb.append(N)
                            combined_metric = (f_1 / f_2).mean()
                            ratio_res.append(combined_metric)
                            aggregated_data[keyword_][metric].append(combined_metric)
                        else:
                            if metric == "batch/perf/avg_return_raw":
                                threshold_step = filtered_df["global_step"].max() * 0.95  # Calculate the 95% threshold

                                df_sorted = filtered_df.sort_values(by="global_step", ascending=True)

                                first_index_above = df_sorted[
                                    df_sorted["global_step"] >= threshold_step
                                ].index.min()  # Find the first index where 'global_step' exceeds the threshold
                                B_min = first_index_above
                                B_max = df_sorted.index.max()
                                N = B_max - B_min + 1

                                result_df = df_sorted.loc[B_min:]  # Select this and all subsequent batches
                                return_steps.append((result_df["global_step"].min(), result_df["global_step"].max()))
                                final_rows = []
                                removed_count = 0
                                last_timestep = float("inf")
                                return_info.append(
                                    (
                                        df_sorted["config/optim/num_epochs"].values[0],
                                        df_sorted["config/working_dir"].values[0],
                                    )
                                )
                                for _, row in result_df.iterrows():
                                    if not final_rows:
                                        final_rows.append(row)  # Add the first row automatically
                                        last_timestep = row["batch/perf/max_timestep"]
                                    else:
                                        last_row = final_rows[-1]  # Check the last point added
                                        if (
                                            row["batch/perf/avg_return_raw"] == last_row["batch/perf/avg_return_raw"]
                                        ):  # if duplicate
                                            if (
                                                row["batch/perf/max_timestep"] <= last_timestep or removed_count >= 8
                                            ):  # if we have already removed 8 points (duplicates) or if the max time step is lower than the previous one
                                                removed_count = 0  # Reset the count
                                                final_rows.append(row)  # Add this row as a valid point after reset
                                            else:
                                                removed_count += 1  # Else Increment removed count for duplicates, we are still in the episode
                                        else:
                                            # Add to final rows if it's not a duplicate
                                            final_rows.append(row)
                                            removed_count = 0  # Reset the count since a valid point was added
                                        last_timestep = row["batch/perf/max_timestep"]
                                return_nb.append(len(final_rows))
                                N_nb.append(N)

                                final_df = pd.DataFrame(final_rows).sort_values(by="global_step", ascending=False)

                                mean_value = final_df[
                                    "batch/perf/avg_return_raw"
                                ].mean()  # Calculate the mean of the metric
                                return_res.append(mean_value)
                                # Storing the result
                                aggregated_data[keyword_][metric].append(mean_value)

                            else:
                                threshold_step = filtered_df["global_step"].max() * 0.95  # Calculate the 95% threshold

                                df_sorted = filtered_df.sort_values(by="global_step", ascending=True)

                                first_index_above = df_sorted[
                                    df_sorted["global_step"] >= threshold_step
                                ].index.min()  # Find the first index where 'global_step' exceeds the threshold
                                B_min = first_index_above
                                B_max = df_sorted.index.max()
                                N = B_max - B_min + 1

                                result_df = df_sorted.loc[B_min:]  # Select this and all subsequent batches

                                mean_value = result_df[metric].mean()  # Calculate the mean of the metric

                                aggregated_data[keyword_][metric].append(mean_value)  # Storing the result

    config_keys2 = [key for key in df2.columns if key.startswith("config/")]
    unique_configs2 = df2[config_keys2].drop_duplicates()
    i = 0

    for _, config_row2 in unique_configs2.iterrows():
        mask2 = create_mask_for_config(df2, config_row2, config_keys2)
        filtered_df2 = df2[mask2].copy()
        for keyword in keywords:
            if "algo" in config_row2["config/solve_dir"]:
                continue
            else:
                if keyword in config_row2["config/solve_dir"]:
                    if config_row2["config/loss/policy/kwargs/feature_trust_all_layers"] == True:
                        keyword_ = "regularize-all-layers"
                    else:
                        keyword_ = keyword
                    metric = "capacity-checkpoint/last_capacity-epoch/last_capacity-minibatch/loss/policy"
                    threshold_step = (
                        filtered_df2["capacity-counters/env_steps"].max() * 0.95
                    )  # Calculate the 95% threshold

                    df_sorted = filtered_df2.sort_values(by="capacity-counters/env_steps", ascending=True)

                    first_index_above = df_sorted[
                        df_sorted["capacity-counters/env_steps"] >= threshold_step
                    ].index.min()  # Find the first index where 'global_step' exceeds the threshold
                    B_min = first_index_above
                    B_max = df_sorted.index.max()
                    N = B_max - B_min + 1

                    result_df = df_sorted.loc[B_min:]  # Select this and all subsequent batches

                    mean_value = result_df[metric].mean()  # Calculate the mean of the metric

                    aggregated_data[keyword_][metric].append(mean_value)  # Storing the result

    meanlineprops = dict(linestyle="-.", linewidth=3, color="red")
    fig, axes = plt.subplots(1, len(metrics) + 1, figsize=(20, 2))
    medianprops = dict(linestyle="-.", linewidth=3, color="black")
    metrics2 = metrics + [metric2]
    # Uncomment if you want information about each point in the boxplot (return and prob ratio)

    #     print(save_name.split("-")[2])
    #     print(f"- Return : \n")
    #     for i in range(len(N_nb)):
    #         print(f"{return_steps[i][0]} -> {return_steps[i][1]}, {return_nb[i]}/{N_nb[i]}, {return_info[i][0]},{return_info[i][1]},{return_res[i]}")
    #     print(f"- Ratio : \n")
    #     for i in range(len(ratio_N_nb)):
    #         print(f"{ratio_steps[i][0]} -> {ratio_steps[i][1]}, {ratio_nb[i]}/{ratio_N_nb[i]}, {ratio_info[i][0]},{ratio_info[i][1]},{ratio_res[i]}")

    for col_idx, metric in enumerate(metrics2):
        m = []
        ax = axes[col_idx]
        for row_idx, keyword in enumerate(keywords):
            if metric == "batch/end/feature_stats/norm_features_preactivation_policy_batch" or metric == "ProbRatio":
                data = aggregated_data[keyword][metric]
            else:
                data = aggregated_data[keyword][metric]
            m.append(data)
            # print(f"Metric: {metric}, Keyword: {keyword}, Data Points: {len(data)}")

        if metric in [
            "ProbRatio",
            "batch/end/feature_stats/norm_features_preactivation_policy_batch",
            "capacity-checkpoint/last_capacity-epoch/last_capacity-minibatch/loss/policy",
        ]:
            bp = ax.boxplot(
                m,
                widths=0.6,
                patch_artist=True,
                vert=False,
                medianprops=medianprops,
                meanprops=meanlineprops,
                showmeans=True,
                meanline=True,
                flierprops=dict(
                    marker="o", markerfacecolor="none", markeredgecolor="black", markersize=8, markeredgewidth=2
                ),
            )
            ax.set_xscale("log")
        else:
            bp = ax.boxplot(
                m,
                widths=0.6,
                patch_artist=True,
                vert=False,
                medianprops=medianprops,
                meanprops=meanlineprops,
                showmeans=True,
                meanline=True,
                flierprops=dict(
                    marker="o", markerfacecolor="none", markeredgecolor="black", markersize=8, markeredgewidth=2
                ),
            )
        # Set colors for each box
        for patch, color in zip(bp["boxes"], colors[: len(m)]):
            patch.set_facecolor(color)
        ax.set_title(f"{metric_name[col_idx]}", fontsize=13.5)
        if col_idx == 0:
            ax.set_yticks([y + 1 for y in range(len(m))], labels=key)
            # ax.set_xlabel(f"\n{config_row['config/env/name']}")
        if col_idx > 0 and isinstance(ax, plt.Axes):
            ax.set_yticklabels([])
            ax.set_yticks([])
        ax.grid(True, linestyle="--", alpha=0.5)
    mean_line = lines.Line2D([], [], color="red", linestyle="-.", linewidth=3, label="Mean")
    median_line = lines.Line2D([], [], color="black", linestyle="-.", linewidth=3, label="Median")
    outlier_marker = lines.Line2D(
        [0],
        [0],
        marker="o",
        color="w",
        label="Outliers",
        markerfacecolor="none",
        markeredgecolor="black",
        markersize=8,
        markeredgewidth=2,
    )

    axes[-1].legend(
        handles=[mean_line, median_line, outlier_marker],
        loc="upper center",  # This anchors the center of the legend at the provided coordinate
        bbox_to_anchor=(-2.8, 1.4),  # Anchors the legend above the plot
        borderaxespad=0.0,
        fontsize=12,
        ncol=3,
        frameon=False,
        handlelength=1,
        handletextpad=0.5,
        columnspacing=1,
    )
    fig.text(-0.015, 0.01, f'{save_name.split("-")[1]}', ha="left", va="bottom", fontsize=12, fontweight="bold")
    plt.savefig(f"../outputs/rlconf-plotting/plots/boxplot-{save_name}.pdf")

    plt.show()

## Tanh

In [None]:
def figure6_mujoco(name, activation):
    criteria = {
        ("config", "env", "name"): [
            name,
        ],
        ("config", "models", "activation"): activation,
        ("config", "optim", "anneal_linearly"): [False],
    }
    df_check = filter_df_by_criteria(mujoco_df, criteria)
    criteria2 = {
        ("config", "env", "name"): [
            name,
        ],
        ("config", "models", "activation"): activation,
        ("config", "optim", "anneal_linearly"): [False],
    }
    f_name = f"6-{name}-{activation}"
    f_name = f_name.replace("/", "-")
    df_check2 = filter_df_by_criteria(mujoco_df_capacity, criteria2)
    config_keywords = ["shared-trunk", "optimizer", "regularize-all-layers", "regularize", "baselines"]
    metrics = [
        "batch/perf/avg_return_raw",
        "ProbRatio",
        "batch/end/dead_neurons/features_policy_batch",
        "batch/end/feature_stats/norm_features_preactivation_policy_batch",
        "batch/end/SVD/approximate_rank_pca/features_policy_batch",
    ]
    metric_name = [
        "Episode return",
        "Excess ratio",
        "Dead neurons policy",
        "Norm preactivation policy",
        "Feature rank policy (PCA)",
        "Plasticity loss policy",
    ]
    plot_metrics_by_keywords_mujoco(
        df_check,
        df_check2,
        metrics,
        "capacity-checkpoint/last_capacity-epoch/last_capacity-minibatch/loss/policy",
        metric_name,
        config_keywords,
        save_name=f_name,
    )
    return

In [None]:
figure6_mujoco("Humanoid-v4", "Tanh")
figure6_mujoco("Hopper-v4", "Tanh")

## ReLU

In [None]:
figure6_mujoco("Humanoid-v4", "ReLU")
figure6_mujoco("Hopper-v4", "ReLU")

## Correlation between the 5 ranks

In [None]:
metrics_keys2 = [
    "batch/end/SVD/effective_rank_vetterli/features_policy_batch",
    "batch/end/SVD/approximate_rank_pca/features_policy_batch",
    "batch/end/SVD/srank_kumar/features_policy_batch",
    "batch/end/SVD/feature_rank_lyle/features_policy_batch",
    "batch/end/SVD/pytorch_rank/features_policy_batch",
]
# Example usage with multiple options for a criterion:
criteria_tanh = {
    ("config", "working_dir"): "baselines",
    ("config", "models", "activation"): "Tanh",
    ("config", "optim", "anneal_linearly"): [False],
}


criteria_relu = {
    ("config", "working_dir"): "baselines",
    ("config", "models", "activation"): "ReLU",
    ("config", "optim", "anneal_linearly"): [False],
}

filtered_df_tanh = filter_df_by_criteria(mujoco_df, criteria_tanh)
filtered_df_relu = filter_df_by_criteria(mujoco_df, criteria_relu)

correlations_tanh = average_correlations_ranks(filtered_df_tanh, metrics_keys2)
correlations_relu = average_correlations_ranks(filtered_df_relu, metrics_keys2)
plot_correlation_matrices(correlations_tanh, metrics_keys2, "avg_mujoco_tanh")
plot_correlation_matrices(correlations_relu, metrics_keys2, "avg_mujoco_relu")

In [None]:
worst_correlations_score_tanh = worst_correlations(filtered_df_tanh, metrics_keys2)
worst_correlations_score_relu = worst_correlations(filtered_df_relu, metrics_keys2)
plot_correlation_matrices(worst_correlations_score_tanh, metrics_keys2, "worst_mujoco_tanh")
plot_correlation_matrices(worst_correlations_score_relu, metrics_keys2, "worst_mujoco_relu")