## Define Results

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import pandas as pd
from matplotlib.path import Path
from matplotlib.spines import Spine
from matplotlib.transforms import Affine2D

In [None]:
# Model names with proper LaTeX formatting
model_audio = "e$_{\\mathrm{audio}}$"    # Using \mathrm instead of \text
model_text = "e$_{\\mathrm{text}}$"
# model_fusion = "e$_{\\mathrm{fusion}}$"
model_fusion = "VoxATtack"
e_audio_only = "ECAPA$_{\\mathrm{ours}}$"

results_dev = {
    # model_audio: {'B3': 23.53, 'B4': 24.35, 'B5': 29.57, 'T8-5': 32.34, 'T10-2': 25.91, 
    #               'T12-5': 30.61, 'T25-1': 33.77, 'LibriSpeech': 2.83},
    # model_text: {'B3': 37.00, 'B4': 37.00, 'B5': 37.00, 'T8-5': 37.00, 'T10-2': 37.00, 
    #              'T12-5': 37.00, 'T25-1': 37.00, 'LibriSpeech': 37.00},
    model_fusion: {'B3': 22.95, 'B4': 24.75, 'B5': 28.17, 'T8-5': 26.00, 'T10-2': 23.01, 
                   'T12-5': 28.69, 'T25-1': 30.21, 'LibriSpeech': 4.93},
    e_audio_only: {'B3': 23.17, 'B4': 23.26, 'B5': 28.74, 'T8-5': 26.79, 'T10-2': 23.67,
                   'T12-5': 29.09, 'T25-1': 31.59, 'LibriSpeech': 3.47},
}


results_test = {
    # model_audio: {'B3': 21.17, 'B4': 20.32, 'B5': 24.79, 'T8-5': 33.85, 'T10-2': 23.03, 
    #               'T12-5': 24.95, 'T25-1': 28.85, 'LibriSpeech': 3.81},
    # model_text: {'B3': 34.57, 'B4': 34.57, 'B5': 34.57, 'T8-5': 34.57, 'T10-2': 34.57, 
    #              'T12-5': 34.57, 'T25-1': 34.57, 'LibriSpeech': 34.57},
    model_fusion: {'B3': 20.05 , 'B4': 19.13 , 'B5': 24.02, 'T8-5': 23.18, 'T10-2': 19.83,
                   'T12-5': 24.71, 'T25-1': 27.51, 'LibriSpeech': 3.39},
    e_audio_only: {'B3': 19.94, 'B4': 20.26, 'B5': 24.80, 'T8-5': 27.65, 'T10-2': 19.62,
                   'T12-5': 25.46, 'T25-1': 28.64, 'LibriSpeech': 4.2},
}  


In [None]:
def scale_results(results: dict, factor: float = 0.01) -> dict:
    """
    Multiply every numeric value in a nested dict by `factor`.

    Args:
        results: A dict of the form {category: {key: value, ...}, ...}.
        factor: The multiplier (default 0.01).

    Returns:
        A new dict with the same structure, but all values scaled.
    """
    return {
        category: {k: v * factor for k, v in subdict.items()}
        for category, subdict in results.items()
    }


def average_results(results_dev: dict, results_test: dict) -> dict:
    """
    Compute the element-wise average of two nested dicts with identical structure.

    Args:
        results_dev: First dict of the form {category: {key: value, ...}, ...}.
        results_test: Second dict (must have the same keys and subkeys).

    Returns:
        A new dict where each value is (results_dev + results_test) / 2.
    """
    averaged = {}
    for category, dev_sub in results_dev.items():
        test_sub = results_test.get(category, {})
        averaged[category] = {}
        for k, dev_val in dev_sub.items():
            test_val = test_sub.get(k)
            if test_val is None:
                raise KeyError(f"Key {k!r} not found in results_test[{category!r}]")
            averaged[category][k] = (dev_val + test_val) / 2.0
    return averaged

## Average EER Plot

In [None]:
def plot_spider_chart(results_dict, output_path='model_comparison.pdf', figsize=(10, 8), dpi=300, use_latex=True):
    """
    Generate a publication-ready spider plot for model comparison with custom colors.

    Parameters:
    -----------
    results_dict : dict
        Dictionary with model names as keys and results as values.
        Format: {'model_name_1': {'B3': X.X, 'T8-5': X.X, ...}, ...}
    output_path : str, optional
        Path for saving the output figure.
    figsize : tuple, optional
        Figure size in inches.
    dpi : int, optional
        Resolution for the output figure.
    use_latex : bool, optional
        Whether to use LaTeX for text rendering. Default is False.
    """
    # Configure LaTeX if requested
    if use_latex:
        try:
            plt.rcParams.update({
                "text.usetex": True,
                "font.family": "serif",
                "text.latex.preamble": r"\usepackage{amsmath}",  # Add amsmath package for \text command
                'font.family': 'serif',
                'font.serif': ['Times New Roman', 'DejaVu Serif'],
                'font.size': 26,
                'axes.labelsize': 26,
                'axes.titlesize': 28,
                'xtick.labelsize': 24,
                'ytick.labelsize': 24,
                'legend.fontsize': 24,
                'figure.figsize': (21, 12),
                'figure.dpi': 300,
                'savefig.dpi': 300,
                'savefig.bbox': 'tight',
                'savefig.pad_inches': 0.05
                
            })
        except Exception as e:
            print(f"Warning: LaTeX setup failed with error: {e}")
            plt.rcParams.update({"text.usetex": False})
    
    # Extract categories and prepare data
    categories = list(next(iter(results_dict.values())).keys())
    n_categories = len(categories)
    df = pd.DataFrame(
        {model: [results[cat] for cat in categories]
         for model, results in results_dict.items()},
        index=categories
    )
    
    # Base angles for each axis
    base_angles = np.linspace(0, 2 * np.pi, n_categories, endpoint=False).tolist()

    # If 'LibriSpeech' exists, rotate so it sits at the bottom (3π/2)
    if 'LibriSpeech' in categories:
        idx = categories.index('LibriSpeech')
        target_angle = 3 * np.pi / 2  # bottom center
        angle_shift = target_angle - base_angles[idx]
        angles = [(angle + angle_shift) % (2 * np.pi) for angle in base_angles]
    else:
        angles = base_angles
    angles += angles[:1]

    # Create polar plot
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, polar=True)

    # Set axis labels
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(categories)
    for label, angle in zip(ax.get_xticklabels(), angles[:-1]):
        if angle < np.pi/2 or angle > 3 * np.pi/2:
            label.set_horizontalalignment('left')
        else:
            label.set_horizontalalignment('right')
    ax.set_rlabel_position(67.5)
    # ax.set_ylim(0, 0.37)
    # ax.set_rticks([0, 0.1, 0.2, 0.3, 0.4])
    ax.set_ylim(0, 0.32)
    ax.set_rticks([0, 0.1, 0.2, 0.3])
    ax.grid(True)

    # Define custom colors: light orange, darker green, darker lavender
    # base_custom = ['lightsalmon', 'forestgreen', 'mediumpurple', 'black']
    base_custom = ['mediumpurple', 'darkorange']
    fallback = sns.color_palette("tab10", len(results_dict))
    colors = [base_custom[i] if i < len(base_custom) else fallback[i] for i in range(len(results_dict))]
    linestyles = ['-', '--', '-.', ':']
    markers = ['o', 's', 'D', '^', 'v']
    
    # Plot each model
    for i, (model_name, _) in enumerate(results_dict.items()):
        values = df[model_name].tolist()
        values += values[:1]
        ax.plot(
            angles, values,
            linestyle=linestyles[i % len(linestyles)],
            marker=markers[i % len(markers)],
            color=colors[i],
            linewidth=1,
            markersize=8,
            label=model_name  # Pass the model name directly
        )
        ax.fill(angles, values, color=colors[i], alpha=0.1)

    # Legend and layout
    plt.legend(loc='lower right', bbox_to_anchor=(0.0, 0.0))
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=dpi, bbox_inches='tight')
    print(f"Spider plot saved to {output_path}")
    return fig, ax

In [None]:
results = average_results(results_dev, results_test)

# fig, ax = plot_spider_chart(scale_results(results_dev), output_path='model_comparison_dev.pdf', figsize=(10, 8))
# fig, ax = plot_spider_chart(scale_results(results_test), output_path='model_comparison_test.pdf', figsize=(10, 8))
fig, ax = plot_spider_chart(scale_results(results), output_path='model_comparison.pdf', figsize=(10, 8))