In [None]:
import joblib
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

In [None]:
title_map = {
    'BS': 'Baseline',
    'ER-A': 'ER + Att',
    'ER-R': 'ER + AttR',
    'ER-IxG': 'ER + IxG',
    'ER-C-A': 'ER-C + Att',
    'ER-C-R': 'ER-C + AttR',
    'ER-C-IxG': 'ER-C + IxG'
}

techniques_map = {
    'attentions': 'Att',
    'rollout': 'AttR',
    'IxG': 'IxG',
    'alti_aggregated': 'ALTI',
    'decompx': 'DX',
    'decompx_classifier': 'DX-C',
}

### Approaches

In [None]:
def load_data_approaches_averages(corr_fn, ds_name, techniques_to_plot):
    
    data = joblib.load(f"correlations_{ds_name}_{corr_fn}.joblib")
    data_approaches = data['approaches']
    techniques = data['techniques'].keys()

    average_correlations = {}
    for k, v in data_approaches.items():
        average_correlations[k] = {}
        for i, technique_1 in enumerate(techniques_to_plot):
            average_correlations[k][technique_1] = {}
            for technique_2 in techniques:
                if technique_2 in data_approaches[k][technique_1]:
                    average_correlations[k][technique_1][technique_2] = np.mean(data_approaches[k][technique_1][technique_2])

    return average_correlations

In [None]:
for corr_fn, ds_name in [('kendall', 'sst_dev'), ('kendall', 'movies'), ('kendall', 'yelp-50')]:

    # Load data
    techniques_to_plot = ['attentions', 'rollout', 'IxG']
    
    approaches_correlations = load_data_approaches_averages(corr_fn, ds_name, techniques_to_plot)

    techniques = list(approaches_correlations['BS'].keys())
    approaches = list(approaches_correlations.keys())
    
    # Plot
    sns.set_context("paper", rc={"font.size": 16, "axes.titlesize": 14, "axes.labelsize": 14, "xtick.labelsize": 14, "ytick.labelsize": 14})
    
    fig, axs = plt.subplots(1, len(approaches), figsize=(10 if ds_name == "movies" else 17, 4 if ds_name == 'movies' else 8), sharey=False)
    for i, (k, v) in enumerate(approaches_correlations.items()):
        df = pd.DataFrame(v)
        df = 100 * df
        df = df.T
        hm = sns.heatmap(
            df,
            annot=True,
            cmap='Blues',
            square=True,
            ax=axs[i],
            vmin=-10,
            vmax=100,
            cbar=False,
            fmt=".0f",
            linewidth=0.01,
            annot_kws={"size": 12}
        )
        axs[i].set_title(title_map[k])
    
        hm.set_xticklabels([techniques_map[label.get_text()] for label in hm.get_xticklabels()])
        if i == 0:
            hm.set_yticklabels([techniques_map[label.get_text()] for label in hm.get_yticklabels()])
        else:
            hm.set_yticklabels([])
    
    #fig.tight_layout()
    #fig.savefig(f"../figures/sa_attributions_correlations_techniques_{corr_fn}_{ds_name}.pdf", dpi=500, bbox_inches='tight')

### Techniques

In [None]:
corr_fn = 'kendall'

In [None]:
hatches = {'ER-A': '//', 'ER-R': '\\\\', 'ER-IxG': '--', 'ER-C-A': '..', 'ER-C-R': 'xx', 'ER-C-IxG': 'oo'}
colors = sns.color_palette("colorblind")[:10]

In [None]:
def load_data_techniques(corr_fn, ds_name):
    data = joblib.load(f"correlations_{ds_name}_{corr_fn}.joblib")

    data_techniques = data['techniques']

    return data_techniques

**Main Text**

In [None]:
sns.set_style('darkgrid')

approaches_to_plot = ['ER-A', 'ER-C-A', 'ER-C-R', 'ER-C-IxG']
techniques_to_plot_summarized = ['attentions', 'alti_aggregated', 'decompx_classifier']

ds_name = 'yelp-50'
data_techniques = load_data_techniques(corr_fn, ds_name)

techniques = list(data_techniques['attentions'].keys())
approaches = list(data_techniques.keys())

fig, axs = plt.subplots(len(approaches_to_plot), len(techniques_to_plot_summarized), figsize=(9, 6), sharey=True, sharex=True)
for i, approach in enumerate(approaches_to_plot):
    for j, technique in enumerate(techniques_to_plot_summarized):
        kwargs = {'ax': axs[i,j], 'alpha': 0.6, 'bins': 10}
        sns.histplot(data_techniques[technique]['BS']['BS'], color='gray', **kwargs)
        sns.histplot(data_techniques[technique]['BS'][approach], color=colors[i], **kwargs)

        # Get patches
        patches = axs[i, j].patches
        
        # Apply the approach hatch pattern for the second histogram
        for bar in patches[len(patches)//2:]:
            bar.set_hatch(hatches[approach])

        # Set title and labels
        axs[0, j].set_title(techniques_map[technique])
        axs[-1, j].set_xlabel("Correlation")


legend_entries = [
    mpatches.Patch(facecolor='gray', edgecolor='white', label='BS vs BS'),
    mpatches.Patch(facecolor=colors[0], edgecolor='white', hatch=hatches['ER-A'], label='BS vs ER + Att'),
    mpatches.Patch(facecolor=colors[1], edgecolor='white', hatch=hatches['ER-C-A'], label='BS vs ER-C + Att'),
    mpatches.Patch(facecolor=colors[2], edgecolor='white', hatch=hatches['ER-C-R'], label='BS vs ER-C + AttR'),
    mpatches.Patch(facecolor=colors[3], edgecolor='white', hatch=hatches['ER-C-IxG'], label='BS vs ER-C + IxG')
]

plt.legend(handles=legend_entries, loc='upper center', bbox_to_anchor=(-0.8, -0.6), ncol=3, fontsize=14, frameon=False)

#fig.savefig(f"../figures/sa_attributions_correlations_approaches_{ds_name}_{corr_fn}.pdf", dpi=72, bbox_inches='tight')

**Appendix**

In [None]:
sns.set_style('darkgrid')

approaches_to_plot = ['ER-A', 'ER-R', 'ER-IxG', 'ER-C-A', 'ER-C-R', 'ER-C-IxG']
techniques_to_plot = ['attentions', 'rollout', 'IxG', 'alti_aggregated', 'decompx', 'decompx_classifier']

# Load data
corr_fn = 'kendall'
ds_name = 'sst-dev'
data_techniques = load_data_techniques(corr_fn, 'sst_dev')

# Plot
fig, axs = plt.subplots(len(approaches_to_plot), len(techniques_to_plot), figsize=(2.5*len(techniques_to_plot), 8), sharey=True, sharex=True)
for i, approach in enumerate(approaches_to_plot):
    for j, technique in enumerate(techniques_to_plot):
        kwargs = {'ax': axs[i,j], 'alpha': 0.6, 'bins': 10}
        sns.histplot(data_techniques[technique]['BS']['BS'], color='gray', **kwargs)
        sns.histplot(data_techniques[technique]['BS'][approach], color=colors[i], **kwargs)

        # Access the bars for the current axis
        patches = axs[i, j].patches
        
        # Apply the approach specific pattern
        for bar in patches[len(patches)//2:]:
            bar.set_hatch(hatches[approach])

        # Set title and labels
        axs[0, j].set_title(techniques_map[technique])
        axs[-1, j].set_xlabel("Correlation")

        # Add correlation text
        corr_text_kwargs = {'transform': axs[i,j].transAxes, 'fontsize': 8, 'verticalalignment': 'top'}
        axs[i,j].text(0.025, 0.95, f"BS: {np.mean(data_techniques[technique]['BS']['BS']):.2f}", bbox=dict(facecolor='white', alpha=0.0), **corr_text_kwargs)
        axs[i,j].text(0.025, 0.75, f"ER: {np.mean(data_techniques[technique]['BS'][approach]):.2f}", bbox=dict(facecolor='white', alpha=0.0), **corr_text_kwargs)

#fig.savefig(f"../figures/sa_attributions_correlations_approaches_{ds_name}_{corr_fn}_all.pdf", dpi=72, bbox_inches='tight')

In [None]:
approaches_to_plot = ['ER-A', 'ER-R', 'ER-IxG', 'ER-C-A', 'ER-C-R', 'ER-C-IxG']
corr_fn = 'kendall'

for ds_name in ['movies', 'yelp-50']:
    print(ds_name)

    if ds_name == 'movies':
        techniques_to_plot = ['attentions', 'rollout', 'IxG']
    elif ds_name == 'yelp-50':
        techniques_to_plot = ['attentions', 'rollout', 'IxG', 'alti_aggregated', 'decompx', 'decompx_classifier']
    
    # Load data
    data_techniques = load_data_techniques(corr_fn, ds_name)

    # Plot
    fig, axs = plt.subplots(len(approaches_to_plot), len(techniques_to_plot), figsize=(2.5*len(techniques_to_plot), 8), sharey=True, sharex=True)
    for i, approach in enumerate(approaches_to_plot):
        for j, technique in enumerate(techniques_to_plot):
            kwargs = {'ax': axs[i,j], 'alpha': 0.6, 'bins': 10}
            sns.histplot(data_techniques[technique]['BS']['BS'], color='gray', **kwargs)
            sns.histplot(data_techniques[technique]['BS'][approach], color=colors[i], **kwargs)
    
            # Get patches
            patches = axs[i, j].patches
            
            # Apply the approach hatch pattern for the second histogram
            for bar in patches[len(patches)//2:]:
                bar.set_hatch(hatches[approach])
    
            # Set title and labels
            axs[0, j].set_title(techniques_map[technique])
            axs[-1, j].set_xlabel("Correlation")
    
            # Add correlation text
            corr_text_kwargs = {'transform': axs[i,j].transAxes, 'fontsize': 8, 'verticalalignment': 'top'}
            axs[i,j].text(0.025, 0.95, f"BS: {np.mean(data_techniques[technique]['BS']['BS']):.2f}", bbox=dict(facecolor='white', alpha=0.0), **corr_text_kwargs)
            axs[i,j].text(0.025, 0.75, f"ER: {np.mean(data_techniques[technique]['BS'][approach]):.2f}", bbox=dict(facecolor='white', alpha=0.0), **corr_text_kwargs)

    # Add legend
    if ds_name == 'yelp-50':
        legend_entries = [
            mpatches.Patch(facecolor='gray', edgecolor='white', label='BS vs BS'),
            mpatches.Patch(facecolor=colors[0], edgecolor='white', hatch=hatches['ER-A'], label='BS vs ER + Att'),
            mpatches.Patch(facecolor=colors[1], edgecolor='white', hatch=hatches['ER-R'], label='BS vs ER + AttR'),
            mpatches.Patch(facecolor=colors[2], edgecolor='white', hatch=hatches['ER-IxG'], label='BS vs ER + IxG'),
            mpatches.Patch(facecolor=colors[3], edgecolor='white', hatch=hatches['ER-C-A'], label='BS vs ER-C + Att'),
            mpatches.Patch(facecolor=colors[4], edgecolor='white', hatch=hatches['ER-C-R'], label='BS vs ER-C + AttR'),
            mpatches.Patch(facecolor=colors[5], edgecolor='white', hatch=hatches['ER-C-IxG'], label='BS vs ER-C + IxG')
        ]
    
        plt.legend(handles=legend_entries, loc='upper center', bbox_to_anchor=(-2.5, -0.6), ncol=4, fontsize=14, frameon=False)
            
    
    #fig.savefig(f"../figures/sa_attributions_correlations_approaches_{ds_name}_{corr_fn}_all.pdf", dpi=72, bbox_inches='tight')