In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from matplotlib.lines import Line2D

from climatebenchpress.compressor.plotting.plot_metrics import (
    _rename_compressors,
    _get_legend_name,
    _COMPRESSOR_ORDER,
)

# Process results

In [2]:
results_file = "metrics/all_results.csv"
df = pd.read_csv(results_file)

In [3]:
def create_data_matrix(
    df: pd.DataFrame,
    error_bound: str,
    metrics: list[str] = [
        "DSSIM",
        "MAE",
        "Max Absolute Error",
        "Spectral Error",
        "Compression Ratio [raw B / enc B]",
        "Satisfies Bound (Value)",
    ],
):
    df_filtered = df[df["Error Bound Name"] == error_bound].copy()
    df_filtered["Satisfies Bound (Value)"] = (
        df_filtered["Satisfies Bound (Value)"] * 100
    )  # Convert to percentage

    # Get unique variables and compressors
    # dataset_variables = sorted(df_filtered[['Dataset', 'Variable']].drop_duplicates().apply(lambda x: "/".join(x), axis=1).unique())
    dataset_variables = sorted(df_filtered["Variable"].unique())
    compressors = sorted(
        df_filtered["Compressor"].unique(),
        key=lambda k: _COMPRESSOR_ORDER.index(_get_legend_name(k)),
    )

    column_labels = []
    for metric in metrics:
        for dataset_variable in dataset_variables:
            column_labels.append(f"{dataset_variable}\n{metric}")

    # Initialize the data matrix
    data_matrix = np.full((len(compressors), len(column_labels)), np.nan)

    # Fill the matrix with data
    for i, compressor in enumerate(compressors):
        for j, metric in enumerate(metrics):
            for k, dataset_variable in enumerate(dataset_variables):
                # Get data for this compressor-variable combination
                # dataset, variable = dataset_variable.split('/')
                variable = dataset_variable
                subset = df_filtered[
                    (df_filtered["Compressor"] == compressor)
                    & (df_filtered["Variable"] == variable)  # &
                    # (df_filtered['Dataset'] == dataset)
                ]
                if subset.empty:
                    print(f"No data for Compressor: {compressor}, Variable: {variable}")
                    continue

                if metric in ["DSSIM", "Spectral Error"] and variable in ["ta", "tos"]:
                    # These variables have large regions of NaN values which makes the
                    # DSSIM and Spectral Error values unreliable.
                    continue

                col_idx = j * len(dataset_variables) + k
                if metric in subset.columns:
                    values = subset[metric]
                    if len(values) == 1:
                        data_matrix[i, col_idx] = values.iloc[0]

    return data_matrix, compressors, dataset_variables

In [4]:
df = df[
    ~df["Compressor"].isin(
        [
            "bitround",
            "jpeg2000-conservative-abs",
            "stochround-conservative-abs",
            "stochround-pco-conservative-abs",
            "zfp-conservative-abs",
            "bitround-conservative-rel",
            "stochround-pco",
            "stochround",
            "zfp",
            "jpeg2000",
        ]
    )
]
df = df[~df["Dataset"].str.contains("-tiny")]
df = df[~df["Dataset"].str.contains("-chunked")]
df = _rename_compressors(df)

In [5]:
metrics = [
    "DSSIM",
    "MAE",
    "Max Absolute Error",
    "Spectral Error",
    "Compression Ratio [raw B / enc B]",
    "Satisfies Bound (Value)",
]
scorecard_data = {}
for error_bound in ["low", "mid", "high"]:
    scorecard_data[error_bound] = create_data_matrix(df, error_bound, metrics)

No data for Compressor: sperr, Variable: pr
No data for Compressor: sperr, Variable: ta
No data for Compressor: sperr, Variable: tos
No data for Compressor: sperr, Variable: pr
No data for Compressor: sperr, Variable: ta
No data for Compressor: sperr, Variable: tos
No data for Compressor: sperr, Variable: pr
No data for Compressor: sperr, Variable: ta
No data for Compressor: sperr, Variable: tos
No data for Compressor: sperr, Variable: pr
No data for Compressor: sperr, Variable: ta
No data for Compressor: sperr, Variable: tos
No data for Compressor: sperr, Variable: pr
No data for Compressor: sperr, Variable: ta
No data for Compressor: sperr, Variable: tos
No data for Compressor: sperr, Variable: pr
No data for Compressor: sperr, Variable: ta
No data for Compressor: sperr, Variable: tos
No data for Compressor: safeguarded-sperr, Variable: pr
No data for Compressor: safeguarded-sperr, Variable: pr
No data for Compressor: safeguarded-sperr, Variable: pr
No data for Compressor: safeguarde

# Scorecard

In [6]:
METRICS2NAME = {
    "DSSIM": "dSSIM",
    "MAE": "Mean Absolute Error",
    "Compression Ratio [raw B / enc B]": "Compression Ratio",
    "Satisfies Bound (Value)": r"% of Data Points Violating the Error Bound",
}

VARIABLE2NAME = {
    "10m_u_component_of_wind": "10u",
    "10m_v_component_of_wind": "10v",
    "mean_sea_level_pressure": "msl",
}

DATASET2PREFIX = {
    "era5-hurricane": "h-",
}


def get_variable_label(variable):
    dataset, var_name = variable.split("/")
    prefix = DATASET2PREFIX.get(dataset, "")
    var_name = VARIABLE2NAME.get(var_name, var_name)
    return f"{prefix}{var_name}"


def create_compression_scorecard(
    data_matrix,
    compressors,
    variables,
    metrics,
    cbar=True,
    ref_compressor="sz3",
    higher_better_metrics=["DSSIM", "Compression Ratio [raw B / enc B]"],
    save_fn=None,
    compare_against_0=False,
    highlight_bigger_than_one=False,
):
    """
    Create a scorecard plot similar to the weather forecasting example

    Parameters:
    - data_matrix: 2D array with compressors as rows, metric-variable combinations as columns
    - compressors: list of compressor names
    - variables: list of variable names
    - metrics: list of metric names
    - ref_compressor: reference compressor for relative calculations
    - save_fn: filename to save plot (optional)
    """

    # Calculate relative differences vs reference compressor
    ref_idx = compressors.index(ref_compressor)
    ref_values = data_matrix[ref_idx, :]
    if compare_against_0:
        ref_values = np.zeros_like(data_matrix[ref_idx, :])

    relative_matrix = np.full_like(data_matrix, np.nan)
    if highlight_bigger_than_one:
        relative_matrix = np.sign(data_matrix) * 101
        for j in range(data_matrix.shape[1]):
            if metrics[j // len(variables)] == "Satisfies Bound (Value)":
                # For bound satisfication lower is better (less number of pixels exceeding error bound).
                relative_matrix[:, j] = -1 * relative_matrix[:, j]
    else:
        for i in range(len(compressors)):
            for j in range(data_matrix.shape[1]):
                if not np.isnan(data_matrix[i, j]) and not np.isnan(ref_values[j]):
                    ref_val = np.abs(ref_values[j])
                    if ref_val == 0.0:
                        ref_val = 1e-10  # Avoid division by zero
                    if metrics[j // len(variables)] in higher_better_metrics:
                        # Higher is better metrics
                        relative_matrix[i, j] = (
                            (ref_values[j] - data_matrix[i, j]) / ref_val * 100
                        )
                    elif metrics[j // len(variables)] == "Satisfies Bound (Value)":
                        relative_matrix[i, j] = 100 if data_matrix[i, j] != 0 else 0
                    else:
                        relative_matrix[i, j] = (
                            (data_matrix[i, j] - ref_values[j]) / ref_val * 100
                        )

    # Set up colormap - similar to original
    reds = sns.color_palette("Reds", 6)
    blues = sns.color_palette("Blues_r", 6)
    cmap = mpl.colors.ListedColormap(blues + [(0.95, 0.95, 0.95)] + reds)
    # cb_levels = [-50, -20, -10, -5, -2, -1, 1, 2, 5, 10, 20, 50]
    # cb_levels = [-75, -50, -25, -10, -5, -1, 1, 5, 10, 25, 50, 75]
    cb_levels = [-100, -75, -50, -25, -10, -1, 1, 10, 25, 50, 75, 100]

    norm = mpl.colors.BoundaryNorm(cb_levels, cmap.N, extend="both")

    # Calculate figure dimensions
    ncompressors = len(compressors)
    nvariables = len(variables)
    nmetrics = len(metrics)

    panel_width = (2.5 / 5) * nvariables
    label_width = 1.5 * panel_width
    padding_right = 0.1
    panel_height = panel_width / nvariables

    title_height = panel_height * 1.25
    cbar_height = panel_height * 2
    spacing_height = panel_height * 0.1
    spacing_width = panel_height * 0.2

    total_width = (
        label_width
        + nmetrics * panel_width
        + (nmetrics - 1) * spacing_width
        + padding_right
    )
    total_height = (
        title_height
        + cbar_height
        + ncompressors * panel_height
        + (ncompressors - 1) * spacing_height
    )

    # Create figure and gridspec
    fig = plt.figure(figsize=(total_width, total_height))
    gs = mpl.gridspec.GridSpec(
        ncompressors,
        nmetrics,
        figure=fig,
        left=label_width / total_width,
        right=1 - padding_right / total_width,
        top=1 - (title_height / total_height),
        bottom=cbar_height / total_height,
        hspace=spacing_height / panel_height,
        wspace=spacing_width / panel_width,
    )

    # Plot each panel
    for row, compressor in enumerate(compressors):
        for col, metric in enumerate(metrics):
            ax = fig.add_subplot(gs[row, col])

            # Get data for this metric (all variables)
            start_col = col * nvariables
            end_col = start_col + nvariables

            rel_values = relative_matrix[row, start_col:end_col].reshape(1, -1)
            abs_values = data_matrix[row, start_col:end_col]

            # Create heatmap
            img = ax.imshow(rel_values, aspect="auto", cmap=cmap, norm=norm)

            # Customize axes
            ax.set_xticks([])
            ax.set_xticklabels([])
            ax.set_yticks([])
            ax.set_yticklabels([])

            # Add white grid lines
            for i in range(1, nvariables):
                rect = mpl.patches.Rectangle(
                    (i - 0.5, -0.5),
                    1,
                    1,
                    linewidth=1,
                    edgecolor="lightgrey"
                    if np.isnan(abs_values[i]) and np.isnan(abs_values[i - 1])
                    else "white",
                    facecolor="none",
                )
                ax.add_patch(rect)

            # Add absolute values as text
            for i, val in enumerate(abs_values):
                # Ensure we don't have black text on dark background
                color = "black" if abs(rel_values[0, i]) < 75 else "white"
                fontsize = 10
                # Format numbers appropriately
                if metric in ["DSSIM", "Spectral Error"] and variables[i] in [
                    "ta",
                    "tos",
                ]:
                    # These variables have large regions of NaN values which makes the
                    # DSSIM and Spectral Error values unreliable.
                    text = "N/A"
                    color = "black"
                elif np.isnan(val):
                    text = "Crash"
                    color = "black"
                elif abs(val) > 10_000:
                    text = f"{val:.1e}"
                    fontsize = 8
                elif abs(val) > 10:
                    text = f"{val:.0f}"
                elif abs(val) > 1:
                    text = f"{val:.1f}"
                elif val == 1 and metric == "DSSIM":
                    text = "1"
                elif val == 0:
                    text = "0"
                elif abs(val) < 0.01:
                    text = f"{val:.1e}"
                    fontsize = 8
                else:
                    text = f"{val:.2f}"
                ax.text(
                    i, 0, text, ha="center", va="center", fontsize=fontsize, color=color
                )

                if (
                    row > 0
                    and np.isnan(val)
                    and np.isnan(data_matrix[row - 1, col * nvariables + i])
                    and compressor == f"safeguarded-{compressors[row - 1]}"
                    and not (
                        metric in ["DSSIM", "Spectral Error"]
                        and variables[i]
                        in [
                            "ta",
                            "tos",
                        ]
                    )
                ):
                    ax.annotate(
                        "",
                        xy=(i, -0.15),
                        xytext=(i, -0.9),
                        arrowprops=dict(arrowstyle="->", lw=2, color="lightgrey"),
                    )

            # Add row labels (compressor names)
            if col == 0:
                ax.set_ylabel(
                    _get_legend_name(compressor),
                    rotation=0,
                    ha="right",
                    va="center",
                    labelpad=10,
                    fontsize=14,
                )

            # Add column titles (variable names)
            if row == 0:
                # ax.set_title(VARIABLE2NAME.get(variable, variable), fontsize=10, pad=10)
                ax.set_title(METRICS2NAME.get(metric, metric), fontsize=16, pad=10)

            # Add metric labels at the top on the top row
            if row == 0:
                # ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                # ax.set_xticks(range(nmetrics))
                # ax.set_xticklabels(
                #     [METRICS2NAME.get(m, m) for m in metrics],
                #     rotation=45,
                #     ha='left', fontsize=8)
                ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                ax.set_xticks(range(nvariables))
                ax.set_xticklabels(
                    [VARIABLE2NAME.get(v, v) for v in variables],
                    rotation=45,
                    ha="left",
                    fontsize=12,
                )

            # Style spines
            for spine in ax.spines.values():
                spine.set_color("0.7")

    # Add colorbar
    if cbar and not highlight_bigger_than_one:
        rel_cbar_height = cbar_height / total_height
        cax = fig.add_axes((0.4, rel_cbar_height * 0.3, 0.5, rel_cbar_height * 0.2))
        cb = fig.colorbar(img, cax=cax, orientation="horizontal")
        cb.ax.set_xticks(cb_levels)
        if highlight_bigger_than_one:
            cb.ax.set_xlabel("Better ← |non-chunked - chunked| → Worse", fontsize=16)
        else:
            cb.ax.set_xlabel(
                f"Better ← % difference vs {_get_legend_name(ref_compressor)} → Worse",
                fontsize=16,
            )

    if highlight_bigger_than_one:
        chunking_handles = [
            Line2D(
                [],
                [],
                marker="s",
                color=cmap(101),
                linestyle="None",
                markersize=10,
                label="Not Chunked Better",
            ),
            Line2D(
                [],
                [],
                marker="s",
                color=cmap(-101),
                linestyle="None",
                markersize=10,
                label="Chunked Better",
            ),
        ]

        ax.legend(
            handles=chunking_handles,
            loc="upper left",
            ncol=2,
            bbox_to_anchor=(-0.5, -0.05),
            fontsize=16,
        )

    # plt.tight_layout()

    if save_fn:
        plt.savefig(save_fn, dpi=300, bbox_inches="tight")
        plt.close()
    else:
        plt.show()

In [7]:
for bound_name, (data_matrix, compressors, variables) in scorecard_data.items():
    print(f"Creating scorecard for {bound_name} bound...")
    # Split into two rows for better readability.
    create_compression_scorecard(
        data_matrix[:, : 3 * len(variables)],
        compressors,
        variables,
        metrics[:3],
        ref_compressor="bitround-pco",
        cbar=False,
        save_fn=f"scorecards/{bound_name}_scorecard_row1.pdf",
    )

    create_compression_scorecard(
        data_matrix[:, 3 * len(variables) :],
        compressors,
        variables,
        metrics[3:],
        ref_compressor="bitround-pco",
        save_fn=f"scorecards/{bound_name}_scorecard_row2.pdf",
    )

Creating scorecard for low bound...
Creating scorecard for mid bound...
Creating scorecard for high bound...
