set cuda id

In [1]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=4

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=4


Install dependencies for computing metrics and plots:

In [2]:
#!pip3 install numpy scipy pandas seaborn matplotlib sklearn

## Basic imports

In [3]:
import jax
import jax.numpy as jnp
import flax
from entmax_jax import sparsemax
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
from IPython.display import display, HTML
from functools import partial
import json
from entmax_jax.activations import sparsemax, entmax15
from sklearn.metrics import roc_auc_score, average_precision_score

from meta_expl.explainers import load_explainer
from meta_expl.models import load_model
from meta_expl.data.imdb import load_data, dataloader
from meta_expl.data.movies_rationales import dataloader as movie_dataloader
from meta_expl.data.movies_rationales import load_data as movie_load_data

from evaluate_explanations import evaluate_word_level, evaluate_sentence_level, aggregate_pieces

In [4]:
# data utils
def unroll(list_of_lists):
    return [e for ell in list_of_lists for e in ell]

## Define args and load stuff

In [5]:
# arguments
arch = "google/electra-small-discriminator"
num_classes = 2
task_type = "classification"
max_len = 256
batch_size = 32
seed = 1
setup = "static_teacher"

teacher_dir = 'data/imdb-electra-models/teacher_dir'
explainer_dir = 'data/imdb-electra-models/teacher_expl_dir'

In [6]:
# create dummy inputs for model instantiation
input_ids = jnp.ones((batch_size, max_len), jnp.int32)
dummy_inputs = {
    "input_ids": input_ids,
    "attention_mask": jnp.ones_like(input_ids),
    "token_type_ids": jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),
    "position_ids": jnp.ones_like(input_ids),
}
dummy_inputs['input_ids'].shape

(32, 256)

### load data

In [7]:
# load data
test_data = movie_load_data(setup, "test")



  0%|          | 0/3 [00:00<?, ?it/s]

In [8]:
test_data[0]

{'review': 'plot : two teen couples go to a church party , drink and then drive .\nthey get into an accident .\none of the guys dies , but his girlfriend continues to see him in her life , and has nightmares .\nwhat \'s the deal ?\nwatch the movie and " sorta " find out . . .\ncritique : a mind - fuck movie for the teen generation that touches on a very cool idea , but presents it in a very bad package .\nwhich is what makes this review an even harder one to write , since i generally applaud films which attempt to break the mold , mess with your head and such ( lost highway & memento ) , but there are good and bad ways of making all types of films , and these folks just did n\'t snag this one correctly\n.\nthey seem to have taken this pretty neat concept , but executed it terribly .\nso what are the problems with the movie ?\nwell , its main problem is that it \'s simply too jumbled\n.\nit starts off " normal " but then downshifts into this " fantasy " world in which you , as an audien

### load tokenizer

In [9]:
from transformers import ElectraTokenizerFast
tokenizer = ElectraTokenizerFast.from_pretrained(arch)
vocab_size = len(tokenizer)
cls_id = tokenizer.cls_token_id
sep_id = tokenizer.sep_token_id
pad_id = tokenizer.pad_token_id

### load model and explainer

In [10]:
teacher, teacher_params, dummy_state = load_model(teacher_dir, dummy_inputs, batch_size, max_len)
teacher_explainer, teacher_explainer_params = load_explainer(explainer_dir, dummy_inputs, state=dummy_state)

In [11]:
from meta_expl.utils import PRNGSequence
from meta_expl.explainers import create_explainer
keyseq = PRNGSequence(11)
teacher_explainer_params_non_trained={
    'normalize_head_coeffs': 'sparsemax',
    'normalizer_fn': 'softmax',
    'aggregator_idx': 'mean',
    'aggregator_dim': 'row',
    'init_fn': 'uniform',
    'layer_idx': None,
    'head_idx': None
}
teacher_explainer_non_trained, teacher_explainer_params_non_trained = create_explainer(
    key=next(keyseq),
    inputs=dummy_inputs,
    state=dummy_state,
    explainer_type='attention_explainer',
    explainer_args=teacher_explainer_params_non_trained,
)

In [12]:
best_head_teacher_explainer_params={
    'normalize_head_coeffs': 'sparsemax',
    'normalizer_fn': 'softmax',
    'aggregator_idx': 'mean',
    'aggregator_dim': 'row',
    'init_fn': 'uniform',
    'layer_idx': 9,  #9, None
    'head_idx': 5,  #5, None
}
best_head_teacher_explainer, best_head_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='attention_explainer', 
    explainer_args=best_head_teacher_explainer_params
)

In [13]:
best_layer_teacher_explainer_params={
    'normalize_head_coeffs': 'sparsemax',
    'normalizer_fn': 'softmax',
    'aggregator_idx': 'mean',
    'aggregator_dim': 'row',
    'init_fn': 'uniform',
    'layer_idx': 10,  #9, None
    'head_idx': None,  #5, None
}
best_layer_teacher_explainer, best_layer_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='attention_explainer', 
    explainer_args=best_layer_teacher_explainer_params
)

In [14]:
input_gradient_teacher_explainer, input_gradient_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='gradient_input_explainer', 
    model_extras={
        "grad_fn": teacher.apply(
            teacher_params, dummy_inputs, method=teacher.embeddings_grad_fn
        )
    }
)

In [15]:
int_gradient_teacher_explainer, int_gradient_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='integrated_gradients_explainer', 
    model_extras={
        "grad_fn": teacher.apply(
            teacher_params, dummy_inputs, method=teacher.embeddings_grad_fn
        )
    }
)

### look at the coefficients

In [16]:
hc = sparsemax(teacher_explainer_params['params']['head_coeffs']).reshape(12, 4)
hc

DeviceArray([[0.        , 0.        , 0.        , 0.        ],
             [0.        , 0.        , 0.        , 0.        ],
             [0.        , 0.19831145, 0.        , 0.        ],
             [0.        , 0.        , 0.        , 0.        ],
             [0.        , 0.        , 0.        , 0.        ],
             [0.        , 0.        , 0.        , 0.        ],
             [0.12434827, 0.25941166, 0.        , 0.        ],
             [0.        , 0.27521744, 0.        , 0.0986265 ],
             [0.        , 0.        , 0.0440846 , 0.        ],
             [0.        , 0.        , 0.        , 0.        ],
             [0.        , 0.        , 0.        , 0.        ],
             [0.        , 0.        , 0.        , 0.        ]],            dtype=float32)

In [17]:
for a, b in zip(*hc.nonzero()):
    print(a+1, b+1)

3 2
7 1
7 2
8 2
8 4
9 3


In [18]:
# check the layers with the highest coefficients
layer_coeffs = hc.mean(-1).tolist()
sorted(list(zip(list(range(1, len(layer_coeffs)+1)), layer_coeffs)), key=lambda k: k[1])

[(1, 0.0),
 (2, 0.0),
 (4, 0.0),
 (5, 0.0),
 (6, 0.0),
 (10, 0.0),
 (11, 0.0),
 (12, 0.0),
 (9, 0.011021151207387447),
 (3, 0.04957786202430725),
 (8, 0.09346098452806473),
 (7, 0.09593997895717621)]

In [None]:
coeffs = np.asarray(hc)
fig, ax = plt.subplots(figsize=(4, 4))
ax.imshow(coeffs, cmap='Greens')
ax.set_xticks(list(range(12)))
ax.set_yticks(list(range(12)))
ax.set_xlabel('Head')
ax.set_ylabel('Layer')
ax.set_title('Head coefficients')

## Get explanations

In [30]:
def get_expls(data, t, t_p, t_e, t_e_p, s=None, s_p=None, s_e=None, s_e_p=None, is_grad_based=False):
    all_tokens = []
    all_masks = []
    all_explanations = []
    all_outputs = []
    for i, (x, y) in enumerate(movie_dataloader(data, tokenizer, batch_size=batch_size, max_len=max_len, shuffle=False)):
        print('{} of {}'.format(i, len(data)//batch_size), end='\r')
        y_teacher, teacher_attn = t.apply(t_p, **x, deterministic=True)
        y_teacher = jnp.argmax(y_teacher, axis=-1) if task_type == "classification" else y_teacher
        if is_grad_based:
            teacher_extras = {
                "grad_fn": t.apply(t_p, x, method=t.embeddings_grad_fn)
            }
            teacher_expl, _ = t_e.apply(t_e_p, x, teacher_attn, **teacher_extras)
        else:
            teacher_expl, _ = t_e.apply(t_e_p, x, teacher_attn)
        # teacher_rep = teacher_attn['hidden_states'][0][0]
        # teacher_attn = np.asarray(jnp.stack(teacher_attn['attentions']).transpose([1, 0, 2, 3, 4]))
        
        if s is not None:
            y_student, student_attn = s.apply(s_p, **x)
            y_student = jnp.argmax(y_student, axis=-1) if task_type == "classification" else y_student
            student_expl, _ = student_explainer.apply(s_e_p, x, student_attn)
            # student_attn = np.asarray(jnp.stack(student_attn['attentions']).transpose([1, 0, 2, 3, 4]))
        
        # convert everything to lists
        batch_ids = x['input_ids'].tolist()
        batch_tokens = [tokenizer.convert_ids_to_tokens(ids) for ids in batch_ids]
        batch_masks = [[not tk.startswith('##') for tk in tokens] for tokens in batch_tokens]
        batch_expls = teacher_expl.tolist()
        
        # filter out pad
        batch_valid_len = x['attention_mask'].sum(-1).tolist()
        batch_z = []
        for i in range(len(batch_valid_len)):
            n = batch_valid_len[i]
            batch_ids[i] = batch_ids[i][:n]
            batch_tokens[i] = batch_tokens[i][:n]
            batch_masks[i] = batch_masks[i][:n]
            batch_expls[i] = batch_expls[i][:n]
        
        all_tokens.extend(batch_tokens)
        all_masks.extend(batch_masks)
        all_explanations.extend(batch_expls)
        all_outputs.extend(y_teacher.tolist())
        
    return all_tokens, all_masks, all_explanations, all_outputs

In [20]:
total_ok = 0
total_er = 0

def find_index_sublist(v, u_pad):
    global total_ok, total_er
    m = u_pad.index(102)
    u = u_pad[1:m]
    n = len(u)
    for i in range(len(v)-n):
        if v[i:i+n] == u:
            total_ok += 1
            return i, i+n
    total_er += 1
    return None, None

def convert_evidences_to_mask(x, e):
    start_end_idxs = [find_index_sublist(x, e_) for e_ in e]
    mask = [0] * len(x)
    for a, b in start_end_idxs:
        if a is not None and b is not None:
            for j in range(a, b):
                mask[j] = 1
    return mask

all_gold_explanations = []
for i, sample in enumerate(test_data):
    if len(sample['evidences']) == 0:
        sample['evidences'] = ['justarandomwordhere']
    sample['review'] = sample['review'].replace('\n', ' ')
    x = tokenizer(
        sample['review'],
        padding="max_length",
        truncation=True,
        return_tensors="jax",
        max_length=max_len,
    )
    e = tokenizer(
        sample['evidences'],
        padding="max_length",
        truncation=True,
        return_tensors="jax",
        max_length=max_len,
    )
    # z = convert_evidences_to_mask(sample['review'], sample['evidences'])
    z = convert_evidences_to_mask(x['input_ids'][0].tolist(), e['input_ids'].tolist())
    n = x['attention_mask'].sum(-1).tolist()[0]
    z = z[:n]
    all_gold_explanations.append(z)

In [21]:
total_ok, total_er, total_ok / (total_ok+total_er)

(4138, 13325, 0.23695814006757143)

In [27]:
valid_tokens, valid_masks, valid_explanations, valid_outputs = get_expls(
    test_data, teacher, teacher_params, teacher_explainer, teacher_explainer_params 
)
list(map(len, [valid_tokens, valid_masks, valid_explanations, valid_outputs]))

63 of 62

[1999, 1999, 1999, 1999]

### Aggregate scores for word pieces in SRC and MT independently

In [28]:
import torch
from utils import aggregate_pieces

def get_piece_explanations(all_tokens, all_fp_masks, all_explanations, reduction):
    all_pieces = []
    for tokens, expl, fp_mask in zip(all_tokens, all_explanations, all_fp_masks):
        # aggregate word pieces scores (use my old good torch function)
        agg_expl = aggregate_pieces(torch.tensor(expl), torch.tensor(fp_mask), reduction)
        # remove <s> and </s>
        agg_expl = agg_expl.tolist()[1:-1]
        all_pieces.append(agg_expl)
    return all_pieces

In [51]:
reduction = 'sum'  # first, sum, mean, max
valid_masks = [[not tk.startswith('##') for tk in tokens] for tokens in valid_tokens]
all_expls = get_piece_explanations(
    valid_tokens, valid_masks, valid_explanations, reduction=reduction
)
all_gold_expls = get_piece_explanations(
    valid_tokens, valid_masks, all_gold_explanations, reduction='max'
)

## Evaluating explanations by comparing explanations with word-level QE tags

In [55]:
sum(map(sum, all_gold_expls))

38983.0

In [31]:
gold_scores = [inp['label'] for inp in test_data]
gold_expls = all_gold_expls
pred_scores = valid_outputs
pred_expls = all_expls

In [32]:
print('Acc:', np.mean(np.array(gold_scores) == np.array(pred_scores)))

Acc: 0.8049024512256128


In [33]:
_ = evaluate_word_level(gold_expls, pred_expls)

AUC score: 0.7315
AP score: 0.3060
Recall at top-K: 0.2849


In [34]:
def eval_plausibility(data, t, t_p, t_e, t_e_p, is_grad_based=False):
    valid_tokens, valid_masks, valid_explanations, valid_outputs = get_expls(
        data, t, t_p, t_e, t_e_p, is_grad_based=is_grad_based
    )
    pred_expls = get_piece_explanations(
        valid_tokens, valid_masks, valid_explanations, reduction='sum'
    )
    gold_scores = [inp['label'] for inp in data]
    gold_expls = all_gold_expls
    pred_scores = valid_outputs
    print('Acc:', np.mean(np.array(gold_scores) == np.array(pred_scores)))
    evaluate_word_level(gold_expls, pred_expls)
    return pred_expls
    

### meta-learned explainer

In [35]:
expls_mtl = eval_plausibility(
    test_data,
    teacher, 
    teacher_params, 
    teacher_explainer, 
    teacher_explainer_params, 
)

Acc: 0.8049024512256128
AUC score: 0.7315
AP score: 0.3060
Recall at top-K: 0.2849


### all attention layers and heads

In [38]:
expls_all_attn = eval_plausibility(
    test_data,
    teacher, 
    teacher_params, 
    teacher_explainer_non_trained, 
    teacher_explainer_params_non_trained,

)

Acc: 0.8049024512256128
AUC score: 0.6824
AP score: 0.2114
Recall at top-K: 0.1899


### gradient x input

In [39]:
eval_plausibility(
    test_data,
    teacher, 
    teacher_params, 
    input_gradient_teacher_explainer, 
    input_gradient_teacher_explainer_params,
    is_grad_based=True
)

Acc: 0.8049024512256128
AUC score: 0.5092
AP score: 0.1385
Recall at top-K: 0.1258


### integrated gradients

In [40]:
eval_plausibility(
    test_data,
    teacher, 
    teacher_params, 
    int_gradient_teacher_explainer, 
    int_gradient_teacher_explainer_params,
    is_grad_based=True
)

Acc: 0.8049024512256128
AUC score: 0.5287
AP score: 0.1355
Recall at top-K: 0.1201


### best attention layer

In [56]:
eval_plausibility(
    test_data,
    teacher, 
    teacher_params, 
    best_layer_teacher_explainer, 
    best_layer_teacher_explainer_params,
)

Acc: 0.8049024512256128
AUC score: 0.5898
AP score: 0.1589
Recall at top-K: 0.1243


### best attention head

In [57]:
eval_plausibility(
    test_data,
    teacher, 
    teacher_params, 
    best_head_teacher_explainer, 
    best_head_teacher_explainer_params,
)

Acc: 0.8049024512256128
AUC score: 0.6038
AP score: 0.1796
Recall at top-K: 0.1475


### last layer attention

In [58]:
flax.linen.softmax(teacher_params['params']['scalarmix']['coeffs'])  # first item is the embedding layer

DeviceArray([0.0760358 , 0.07595138, 0.07588542, 0.07590534, 0.07589095,
             0.07576945, 0.07577606, 0.07599818, 0.07814347, 0.07859489,
             0.07850471, 0.07855242, 0.07899193], dtype=float32)

In [59]:
flax.linen.softmax(teacher_params['params']['scalarmix']['coeffs']).argmax()

DeviceArray(12, dtype=int32)

In [60]:
best_layer_teacher_explainer_params={
    'normalize_head_coeffs': 'sparsemax',
    'normalizer_fn': 'softmax',
    'aggregator_idx': 'mean',
    'aggregator_dim': 'row',
    'init_fn': 'uniform',
    'layer_idx': 11,
    'head_idx': None,
}
best_layer_teacher_explainer, best_layer_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='attention_explainer', 
    explainer_args=best_layer_teacher_explainer_params
)
eval_plausibility(
    test_data,
    teacher, 
    teacher_params, 
    best_layer_teacher_explainer, 
    best_layer_teacher_explainer_params,
)

Acc: 0.8049024512256128
AUC score: 0.6067
AP score: 0.1599
Recall at top-K: 0.1299


In [62]:
teacher.num_encoder_layers, teacher.num_heads

(12, 4)

In [63]:
for layer_id in range(12):
    print('-------------')
    print('layer:', layer_id)
    best_layer_teacher_explainer_params={
        'normalize_head_coeffs': 'sparsemax',
        'normalizer_fn': 'softmax',
        'aggregator_idx': 'mean',
        'aggregator_dim': 'row',
        'init_fn': 'uniform',
        'layer_idx': layer_id,
        'head_idx': None,
    }
    best_layer_teacher_explainer, best_layer_teacher_explainer_params = create_explainer(
        key=next(keyseq), 
        inputs=dummy_inputs, 
        state=dummy_state, 
        explainer_type='attention_explainer', 
        explainer_args=best_layer_teacher_explainer_params
    )
    eval_plausibility(
        test_data,
        teacher, 
        teacher_params, 
        best_layer_teacher_explainer, 
        best_layer_teacher_explainer_params,
    )

-------------
layer: 0
Acc: 0.8049024512256128
AUC score: 0.5189
AP score: 0.1374
Recall at top-K: 0.1215
-------------
layer: 1
Acc: 0.8049024512256128
AUC score: 0.4608
AP score: 0.1036
Recall at top-K: 0.0725
-------------
layer: 2
Acc: 0.8049024512256128
AUC score: 0.5433
AP score: 0.1431
Recall at top-K: 0.1316
-------------
layer: 3
Acc: 0.8049024512256128
AUC score: 0.4959
AP score: 0.1183
Recall at top-K: 0.0942
-------------
layer: 4
Acc: 0.8049024512256128
AUC score: 0.4962
AP score: 0.1222
Recall at top-K: 0.0954
-------------
layer: 5
Acc: 0.8049024512256128
AUC score: 0.5049
AP score: 0.1291
Recall at top-K: 0.1097
-------------
layer: 6
Acc: 0.8049024512256128
AUC score: 0.6512
AP score: 0.2028
Recall at top-K: 0.1884
-------------
layer: 7
Acc: 0.8049024512256128
AUC score: 0.7463
AP score: 0.3267
Recall at top-K: 0.2943
-------------
layer: 8
Acc: 0.8049024512256128
AUC score: 0.7510
AP score: 0.3144
Recall at top-K: 0.2850
-------------
layer: 9
Acc: 0.8049024512256128

In [72]:
jax.random.bernoulli(next(keyseq), p=0.5, shape=(10,)).astype(int).tolist()

[0, 1, 0, 1, 1, 0, 0, 0, 1, 1]

In [74]:
t_1 = sum(map(sum, gold_expls))
t_n = sum(map(len, gold_expls))

print(t_1/t_n)

fake_pred_expls = [
    [1]*len(ge) for ge in gold_expls 
]
_ = evaluate_word_level(gold_expls, fake_pred_expls)

fake_pred_expls2 = [
    jax.random.bernoulli(next(keyseq), p=t_1/t_n, shape=(len(ge),)).astype(int).tolist() for ge in gold_expls 
]
_ = evaluate_word_level(gold_expls, fake_pred_expls2)

0.08240203093326118
AUC score: 0.5000
AP score: 0.1024
Recall at top-K: 0.1059
AUC score: 0.4989
AP score: 0.1056
Recall at top-K: 0.1055


In [75]:
for layer_id in range(12):
    for head_id in range(4):
        print('---------------------')
        print('layer {} | head {}'.format(layer_id, head_id))
        best_layer_teacher_explainer_params={
            'normalize_head_coeffs': 'sparsemax',
            'normalizer_fn': 'softmax',
            'aggregator_idx': 'mean',
            'aggregator_dim': 'row',
            'init_fn': 'uniform',
            'layer_idx': layer_id,
            'head_idx': head_id,
        }
        best_layer_teacher_explainer, best_layer_teacher_explainer_params = create_explainer(
            key=next(keyseq), 
            inputs=dummy_inputs, 
            state=dummy_state, 
            explainer_type='attention_explainer', 
            explainer_args=best_layer_teacher_explainer_params
        )
        eval_plausibility(
            test_data,
            teacher, 
            teacher_params, 
            best_layer_teacher_explainer, 
            best_layer_teacher_explainer_params,
        )

---------------------
layer 0 | head 0
Acc: 0.8049024512256128
AUC score: 0.5453
AP score: 0.1451
Recall at top-K: 0.1274
---------------------
layer 0 | head 1
Acc: 0.8049024512256128
AUC score: 0.5217
AP score: 0.1387
Recall at top-K: 0.1198
---------------------
layer 0 | head 2
Acc: 0.8049024512256128
AUC score: 0.4895
AP score: 0.1204
Recall at top-K: 0.0965
---------------------
layer 0 | head 3
Acc: 0.8049024512256128
AUC score: 0.5033
AP score: 0.1382
Recall at top-K: 0.1227
---------------------
layer 1 | head 0
Acc: 0.8049024512256128
AUC score: 0.5290
AP score: 0.1260
Recall at top-K: 0.1046
---------------------
layer 1 | head 1
Acc: 0.8049024512256128
AUC score: 0.4627
AP score: 0.1108
Recall at top-K: 0.0889
---------------------
layer 1 | head 2
Acc: 0.8049024512256128
AUC score: 0.4661
AP score: 0.1039
Recall at top-K: 0.0716
---------------------
layer 1 | head 3
Acc: 0.8049024512256128
AUC score: 0.4578
AP score: 0.1068
Recall at top-K: 0.0774
---------------------
la