In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../")
sys.path.append("../experiments/")

In [3]:
import os

import dill
import pandas as pd
from pathlib import Path
from transformers import RobertaTokenizer, RobertaForSequenceClassification

from xbert.engine import calculate_correlation
from mnli import (read_mnli_dataset, dataset_to_input_instances, get_labels, predict,
                  MNLI_IDX2LABEL, MNLI_LABEL2IDX, OCCLUSION_STRATEGIES, GRAD_STRATEGIES, ALL_STRATEGIES)

In [33]:
from typing import List, Dict, Any


def experiment_load_relevances(experiment_dir: str,
                               relevance_filename: str = "relevances.pkl"):
    path = Path(experiment_dir)
    
    experiment_relevances = {}
    for relevance_file in path.glob(f"**/{relevance_filename}"):
        name = relevance_file.parents[0].name
        with relevance_file.open("rb") as f:
            relevances = dill.load(f)
            experiment_relevances[name] = relevances
            
    return experiment_relevances


def experiment_relevance_correlation(relevances: Dict[str, Any],
                                     strategies: List[str] = None,
                                     strategy_name_map: Dict[str, str] = None):
    strategies = strategies or ALL_STRATEGIES
    #strategies = set(strategies) & set(relevances.keys())
    strategy_name_map = strategy_name_map or {}
    
    correlations = []
    for strategy_a in strategies:
        strategy_a_mapped = strategy_name_map.get(strategy_a, strategy_a)
        
        correlation = {"Method": strategy_a_mapped}
        for strategy_b in strategies:
            strategy_b_mapped = strategy_name_map.get(strategy_b, strategy_b)
            corr = calculate_correlation(relevances[strategy_a], relevances[strategy_b])
            correlation[strategy_b_mapped] = corr
        correlations.append(correlation)
            
    df = pd.DataFrame(correlations)
    df = df.set_index("Method")
    
    #df.columns = pd.MultiIndex.from_tuples([("Occlusion" if strategy in OCCLUSION_STRATEGIES else "Gradient",
    #                                         strategy_name_map.get(strategy, strategy))
    #                                        for strategy in strategies])
    
    return df


def combined_table(task_correlations, strategies):
    
    indices = []
    column_dfs = []
    for task, correlations in task_correlations:
        for strategy in strategies:
            indices.append((task, strategy))
            column_dfs.append(correlations.loc[strategy])
            
    df = pd.DataFrame(column_dfs).T
    df.columns = pd.MultiIndex.from_tuples(indices)
    
    return df

In [28]:
RESULTS_DIR = "/home/christoph/Downloads/xbert_results/"

STRATEGY_NAME_MAPPING = {
    "unk": "Unk",
    "delete": "Delete",
    "resampling": "OLM",
    "resampling_std": "OLM-S",
    "grad": "Grad.",
    "gradxinput": "Grad*Input",
    "saliency": "Sensitivity",
    "integratedgrad": "Integr. grad"
}

# MNLI

In [29]:
MNLI_RESULTS_PATH = os.path.join(RESULTS_DIR, "mnli")
mnli_experiment_relevances = experiment_load_relevances(MNLI_RESULTS_PATH)

mnli_correlation = experiment_relevance_correlation(mnli_experiment_relevances,
                                                    strategies=["unk", "delete", "resampling", "resampling_std", "grad", "gradxinput", "saliency", "integratedgrad"],
                                                    strategy_name_map=STRATEGY_NAME_MAPPING)
print(mnli_correlation.to_latex(float_format="{:0.2f}".format, multicolumn_format="c"))
mnli_correlation

  c /= stddev[:, None]
  c /= stddev[None, :]


\begin{tabular}{lrrrrrrrr}
\toprule
{} &   Unk &  Delete &   OLM &  OLM-S &  Grad. &  Grad*Input &  Sensitivity &  Integr. grad \\
Method       &       &         &       &        &        &             &              &               \\
\midrule
Unk          &  1.00 &    0.73 &  0.58 &   0.32 &   0.00 &       -0.03 &         0.22 &          0.32 \\
Delete       &  0.73 &    1.00 &  0.60 &   0.32 &   0.01 &       -0.05 &         0.23 &          0.34 \\
OLM          &  0.58 &    0.60 &  1.00 &   0.61 &   0.00 &       -0.03 &         0.27 &          0.28 \\
OLM-S        &  0.32 &    0.32 &  0.61 &   1.00 &  -0.00 &       -0.01 &         0.35 &          0.20 \\
Grad.        &  0.00 &    0.01 &  0.00 &  -0.00 &   1.00 &       -0.00 &         0.00 &          0.00 \\
Grad*Input   & -0.03 &   -0.05 & -0.03 &  -0.01 &  -0.00 &        1.00 &         0.03 &          0.00 \\
Sensitivity  &  0.22 &    0.23 &  0.27 &   0.35 &   0.00 &        0.03 &         1.00 &          0.17 \\
Integr. grad &  0.32

Unnamed: 0_level_0,Unk,Delete,OLM,OLM-S,Grad.,Grad*Input,Sensitivity,Integr. grad
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Unk,1.0,0.734117,0.580645,0.316609,0.002294,-0.034539,0.22277,0.323415
Delete,0.734117,1.0,0.601821,0.324969,0.005007,-0.045025,0.227364,0.337978
OLM,0.580645,0.601821,1.0,0.61044,0.003561,-0.027769,0.272469,0.275429
OLM-S,0.316609,0.324969,0.61044,1.0,-0.002372,-0.005132,0.346814,0.197998
Grad.,0.002294,0.005007,0.003561,-0.002372,1.0,-0.000932,0.003702,0.002953
Grad*Input,-0.034539,-0.045025,-0.027769,-0.005132,-0.000932,1.0,0.029389,0.002813
Sensitivity,0.22277,0.227364,0.272469,0.346814,0.003702,0.029389,1.0,0.169672
Integr. grad,0.323415,0.337978,0.275429,0.197998,0.002953,0.002813,0.169672,1.0


# SST-2

In [30]:
SST2_RESULTS_PATH = os.path.join(RESULTS_DIR, "sst2")
sst2_experiment_relevances = experiment_load_relevances(SST2_RESULTS_PATH)

sst2_correlation = experiment_relevance_correlation(sst2_experiment_relevances,
                                                    strategies=["unk", "delete", "resampling", "resampling_std", "grad", "gradxinput", "saliency", "integratedgrad"],
                                                    strategy_name_map=STRATEGY_NAME_MAPPING)
print(sst2_correlation.to_latex(float_format="{:0.2f}".format, multicolumn_format="c"))
sst2_correlation

\begin{tabular}{lrrrrrrrr}
\toprule
{} &   Unk &  Delete &  OLM &  OLM-S &  Grad. &  Grad*Input &  Sensitivity &  Integr. grad \\
Method       &       &         &      &        &        &             &              &               \\
\midrule
Unk          &  1.00 &    0.64 & 0.47 &   0.38 &  -0.00 &        0.03 &         0.18 &          0.36 \\
Delete       &  0.64 &    1.00 & 0.52 &   0.39 &  -0.00 &        0.01 &         0.21 &          0.37 \\
OLM          &  0.47 &    0.52 & 1.00 &   0.78 &   0.01 &        0.02 &         0.30 &          0.35 \\
OLM-S        &  0.38 &    0.39 & 0.78 &   1.00 &   0.01 &        0.01 &         0.37 &          0.30 \\
Grad.        & -0.00 &   -0.00 & 0.01 &   0.01 &   1.00 &        0.00 &         0.03 &          0.01 \\
Grad*Input   &  0.03 &    0.01 & 0.02 &   0.01 &   0.00 &        1.00 &         0.03 &          0.04 \\
Sensitivity  &  0.18 &    0.21 & 0.30 &   0.37 &   0.03 &        0.03 &         1.00 &          0.13 \\
Integr. grad &  0.36 &    0.3

Unnamed: 0_level_0,Unk,Delete,OLM,OLM-S,Grad.,Grad*Input,Sensitivity,Integr. grad
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Unk,1.0,0.636484,0.469002,0.376973,-0.003888,0.027681,0.184303,0.362277
Delete,0.636484,1.0,0.51958,0.386166,-0.002355,0.00775,0.210423,0.370621
OLM,0.469002,0.51958,1.0,0.778646,0.005716,0.017771,0.298129,0.350541
OLM-S,0.376973,0.386166,0.778646,1.0,0.005622,0.012634,0.374481,0.301163
Grad.,-0.003888,-0.002355,0.005716,0.005622,1.0,0.001884,0.030472,0.012021
Grad*Input,0.027681,0.00775,0.017771,0.012634,0.001884,1.0,0.032011,0.036633
Sensitivity,0.184303,0.210423,0.298129,0.374481,0.030472,0.032011,1.0,0.134606
Integr. grad,0.362277,0.370621,0.350541,0.301163,0.012021,0.036633,0.134606,1.0


# SST-2 (LSTM)

In [31]:
SST2_LSTM_RESULTS_PATH = os.path.join(RESULTS_DIR, "sst2_lstm")
sst2_lstm_experiment_relevances = experiment_load_relevances(SST2_LSTM_RESULTS_PATH)

sst2_lstm_correlation = experiment_relevance_correlation(sst2_lstm_experiment_relevances,
                                                         strategies=["unk", "delete", "resampling", "resampling_std", "grad", "gradxinput", "saliency", "integratedgrad"],
                                                         strategy_name_map=STRATEGY_NAME_MAPPING)
print(sst2_lstm_correlation.to_latex(float_format="{:0.2f}".format, multicolumn_format="c"))
sst2_lstm_correlation

\begin{tabular}{lrrrrrrrr}
\toprule
{} &   Unk &  Delete &  OLM &  OLM-S &  Grad. &  Grad*Input &  Sensitivity &  Integr. grad \\
Method       &       &         &      &        &        &             &              &               \\
\midrule
Unk          &  1.00 &    0.70 & 0.60 &   0.36 &  -0.11 &        0.56 &         0.17 &          0.44 \\
Delete       &  0.70 &    1.00 & 0.68 &   0.41 &   0.06 &        0.57 &         0.26 &          0.49 \\
OLM          &  0.60 &    0.68 & 1.00 &   0.62 &   0.07 &        0.49 &         0.27 &          0.44 \\
OLM-S        &  0.36 &    0.41 & 0.62 &   1.00 &   0.10 &        0.28 &         0.45 &          0.27 \\
Grad.        & -0.11 &    0.06 & 0.07 &   0.10 &   1.00 &        0.09 &         0.19 &          0.03 \\
Grad*Input   &  0.56 &    0.57 & 0.49 &   0.28 &   0.09 &        1.00 &         0.16 &          0.44 \\
Sensitivity  &  0.17 &    0.26 & 0.27 &   0.45 &   0.19 &        0.16 &         1.00 &          0.15 \\
Integr. grad &  0.44 &    0.4

Unnamed: 0_level_0,Unk,Delete,OLM,OLM-S,Grad.,Grad*Input,Sensitivity,Integr. grad
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Unk,1.0,0.701938,0.601099,0.355563,-0.105752,0.564114,0.170479,0.437578
Delete,0.701938,1.0,0.679008,0.408872,0.059867,0.569427,0.255421,0.494065
OLM,0.601099,0.679008,1.0,0.621485,0.065762,0.491497,0.272031,0.437985
OLM-S,0.355563,0.408872,0.621485,1.0,0.096932,0.283404,0.447648,0.273982
Grad.,-0.105752,0.059867,0.065762,0.096932,1.0,0.089549,0.190294,0.03155
Grad*Input,0.564114,0.569427,0.491497,0.283404,0.089549,1.0,0.160369,0.444666
Sensitivity,0.170479,0.255421,0.272031,0.447648,0.190294,0.160369,1.0,0.153623
Integr. grad,0.437578,0.494065,0.437985,0.273982,0.03155,0.444666,0.153623,1.0


In [32]:
COLA_RESULTS_PATH = os.path.join(RESULTS_DIR, "cola")
cola_experiment_relevances = experiment_load_relevances(COLA_RESULTS_PATH)

cola_correlation = experiment_relevance_correlation(cola_experiment_relevances,
                                                    strategies=["unk", "delete", "resampling", "resampling_std", "grad", "gradxinput", "saliency", "integratedgrad"],
                                                    strategy_name_map=STRATEGY_NAME_MAPPING)
print(cola_correlation.to_latex(float_format="{:0.2f}".format, multicolumn_format="c"))
cola_correlation

\begin{tabular}{lrrrrrrrr}
\toprule
{} &   Unk &  Delete &  OLM &  OLM-S &  Grad. &  Grad*Input &  Sensitivity &  Integr. grad \\
Method       &       &         &      &        &        &             &              &               \\
\midrule
Unk          &  1.00 &    0.35 & 0.21 &   0.12 &  -0.00 &        0.03 &         0.03 &          0.14 \\
Delete       &  0.35 &    1.00 & 0.25 &   0.15 &  -0.00 &        0.04 &         0.02 &          0.18 \\
OLM          &  0.21 &    0.25 & 1.00 &   0.56 &   0.02 &        0.02 &         0.20 &          0.15 \\
OLM-S        &  0.12 &    0.15 & 0.56 &   1.00 &   0.03 &        0.03 &         0.29 &          0.09 \\
Grad.        & -0.00 &   -0.00 & 0.02 &   0.03 &   1.00 &        0.01 &        -0.00 &         -0.01 \\
Grad*Input   &  0.03 &    0.04 & 0.02 &   0.03 &   0.01 &        1.00 &        -0.00 &          0.12 \\
Sensitivity  &  0.03 &    0.02 & 0.20 &   0.29 &  -0.00 &       -0.00 &         1.00 &          0.07 \\
Integr. grad &  0.14 &    0.1

Unnamed: 0_level_0,Unk,Delete,OLM,OLM-S,Grad.,Grad*Input,Sensitivity,Integr. grad
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Unk,1.0,0.351959,0.212183,0.115955,-0.002829,0.026229,0.026221,0.136655
Delete,0.351959,1.0,0.249682,0.147698,-0.002535,0.03633,0.018538,0.175759
OLM,0.212183,0.249682,1.0,0.564801,0.017446,0.020401,0.203886,0.151937
OLM-S,0.115955,0.147698,0.564801,1.0,0.030685,0.025178,0.287795,0.09272
Grad.,-0.002829,-0.002535,0.017446,0.030685,1.0,0.008709,-0.000928,-0.007255
Grad*Input,0.026229,0.03633,0.020401,0.025178,0.008709,1.0,-0.00107,0.116754
Sensitivity,0.026221,0.018538,0.203886,0.287795,-0.000928,-0.00107,1.0,0.068817
Integr. grad,0.136655,0.175759,0.151937,0.09272,-0.007255,0.116754,0.068817,1.0


In [35]:
combined_df = combined_table(task_correlations=[("MNLI", mnli_correlation), ("SST-2", sst2_correlation), ("CoLA", cola_correlation)],
                             strategies=["OLM", "OLM-S"])
print(combined_df.to_latex(float_format="{:0.2f}".format, multicolumn_format="c"))
combined_df

\begin{tabular}{lrrrrrr}
\toprule
{} & \multicolumn{2}{c}{MNLI} & \multicolumn{2}{c}{SST-2} & \multicolumn{2}{c}{CoLA} \\
{} &   OLM & OLM-S &   OLM & OLM-S &  OLM & OLM-S \\
\midrule
Unk          &  0.58 &  0.32 &  0.47 &  0.38 & 0.21 &  0.12 \\
Delete       &  0.60 &  0.32 &  0.52 &  0.39 & 0.25 &  0.15 \\
OLM          &  1.00 &  0.61 &  1.00 &  0.78 & 1.00 &  0.56 \\
OLM-S        &  0.61 &  1.00 &  0.78 &  1.00 & 0.56 &  1.00 \\
Grad.        &  0.00 & -0.00 &  0.01 &  0.01 & 0.02 &  0.03 \\
Grad*Input   & -0.03 & -0.01 &  0.02 &  0.01 & 0.02 &  0.03 \\
Sensitivity  &  0.27 &  0.35 &  0.30 &  0.37 & 0.20 &  0.29 \\
Integr. grad &  0.28 &  0.20 &  0.35 &  0.30 & 0.15 &  0.09 \\
\bottomrule
\end{tabular}



Unnamed: 0_level_0,MNLI,MNLI,SST-2,SST-2,CoLA,CoLA
Unnamed: 0_level_1,OLM,OLM-S,OLM,OLM-S,OLM,OLM-S
Unk,0.580645,0.316609,0.469002,0.376973,0.212183,0.115955
Delete,0.601821,0.324969,0.51958,0.386166,0.249682,0.147698
OLM,1.0,0.61044,1.0,0.778646,1.0,0.564801
OLM-S,0.61044,1.0,0.778646,1.0,0.564801,1.0
Grad.,0.003561,-0.002372,0.005716,0.005622,0.017446,0.030685
Grad*Input,-0.027769,-0.005132,0.017771,0.012634,0.020401,0.025178
Sensitivity,0.272469,0.346814,0.298129,0.374481,0.203886,0.287795
Integr. grad,0.275429,0.197998,0.350541,0.301163,0.151937,0.09272
