set cuda id

In [None]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=4

Install dependencies for computing metrics and plots:

In [None]:
#!pip3 install numpy scipy pandas seaborn matplotlib sklearn

## Basic imports

In [None]:
import jax
import jax.numpy as jnp
import flax
from entmax_jax import sparsemax
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
from IPython.display import display, HTML
from functools import partial
import json
from entmax_jax.activations import sparsemax, entmax15
from sklearn.metrics import roc_auc_score, average_precision_score

from meta_expl.explainers import load_explainer
from meta_expl.models import load_model
from meta_expl.data.imdb import load_data, dataloader
from meta_expl.data.movies_rationales import dataloader as movie_dataloader
from meta_expl.data.movies_rationales import load_data as movie_load_data

from evaluate_explanations import evaluate_word_level, evaluate_sentence_level, aggregate_pieces

In [None]:
# data utils
def unroll(list_of_lists):
    return [e for ell in list_of_lists for e in ell]

## Define args and load stuff

In [None]:
# arguments
arch = "google/electra-small-discriminator"
num_classes = 2
task_type = "classification"
max_len = 256
batch_size = 32
seed = 1
setup = "static_teacher"

teacher_dir = 'data/imdb-electra-models/teacher_dir'
explainer_dir = 'data/imdb-electra-models/teacher_expl_dir'

In [None]:
# create dummy inputs for model instantiation
input_ids = jnp.ones((batch_size, max_len), jnp.int32)
dummy_inputs = {
    "input_ids": input_ids,
    "attention_mask": jnp.ones_like(input_ids),
    "token_type_ids": jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),
    "position_ids": jnp.ones_like(input_ids),
}
dummy_inputs['input_ids'].shape

### load data

In [None]:
# load data
test_data = movie_load_data(setup, "test")

In [None]:
test_data[0]

### load tokenizer

In [None]:
from transformers import ElectraTokenizerFast
tokenizer = ElectraTokenizerFast.from_pretrained(arch)
vocab_size = len(tokenizer)
cls_id = tokenizer.cls_token_id
sep_id = tokenizer.sep_token_id
pad_id = tokenizer.pad_token_id

### load model and explainer

In [None]:
teacher, teacher_params, dummy_state = load_model(teacher_dir, dummy_inputs, batch_size, max_len)
teacher_explainer, teacher_explainer_params = load_explainer(explainer_dir, dummy_inputs, state=dummy_state)

In [None]:
from meta_expl.utils import PRNGSequence
from meta_expl.explainers import create_explainer
keyseq = PRNGSequence(11)
teacher_explainer_params_non_trained={
    'normalize_head_coeffs': 'sparsemax',
    'normalizer_fn': 'softmax',
    'aggregator_idx': 'mean',
    'aggregator_dim': 'row',
    'init_fn': 'uniform',
    'layer_idx': None,
    'head_idx': None
}
teacher_explainer_non_trained, teacher_explainer_params_non_trained = create_explainer(
    key=next(keyseq),
    inputs=dummy_inputs,
    state=dummy_state,
    explainer_type='attention_explainer',
    explainer_args=teacher_explainer_params_non_trained,
)

In [None]:
best_head_teacher_explainer_params={
    'normalize_head_coeffs': 'sparsemax',
    'normalizer_fn': 'softmax',
    'aggregator_idx': 'mean',
    'aggregator_dim': 'row',
    'init_fn': 'uniform',
    'layer_idx': 9,  #9, None
    'head_idx': 5,  #5, None
}
best_head_teacher_explainer, best_head_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='attention_explainer', 
    explainer_args=best_head_teacher_explainer_params
)

In [None]:
best_layer_teacher_explainer_params={
    'normalize_head_coeffs': 'sparsemax',
    'normalizer_fn': 'softmax',
    'aggregator_idx': 'mean',
    'aggregator_dim': 'row',
    'init_fn': 'uniform',
    'layer_idx': 10,  #9, None
    'head_idx': None,  #5, None
}
best_layer_teacher_explainer, best_layer_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='attention_explainer', 
    explainer_args=best_layer_teacher_explainer_params
)

In [None]:
input_gradient_teacher_explainer, input_gradient_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='gradient_input_explainer', 
    model_extras={
        "grad_fn": teacher.apply(
            teacher_params, dummy_inputs, method=teacher.embeddings_grad_fn
        )
    }
)

In [None]:
int_gradient_teacher_explainer, int_gradient_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='integrated_gradients_explainer', 
    model_extras={
        "grad_fn": teacher.apply(
            teacher_params, dummy_inputs, method=teacher.embeddings_grad_fn
        )
    }
)

### look at the coefficients

In [None]:
hc = sparsemax(teacher_explainer_params['params']['head_coeffs']).reshape(12, 4)
hc

In [None]:
for a, b in zip(*hc.nonzero()):
    print(a+1, b+1)

In [None]:
# check the layers with the highest coefficients
layer_coeffs = hc.mean(-1).tolist()
sorted(list(zip(list(range(1, len(layer_coeffs)+1)), layer_coeffs)), key=lambda k: k[1])

In [None]:
coeffs = np.asarray(hc)
fig, ax = plt.subplots(figsize=(4, 4))
ax.imshow(coeffs, cmap='Greens')
ax.set_xticks(list(range(12)))
ax.set_yticks(list(range(12)))
ax.set_xlabel('Head')
ax.set_ylabel('Layer')
ax.set_title('Head coefficients')

## Get explanations

In [None]:
def get_expls(data, t, t_p, t_e, t_e_p, s=None, s_p=None, s_e=None, s_e_p=None, is_grad_based=False):
    all_tokens = []
    all_masks = []
    all_explanations = []
    all_outputs = []
    for i, (x, y) in enumerate(movie_dataloader(data, tokenizer, batch_size=batch_size, max_len=max_len, shuffle=False)):
        print('{} of {}'.format(i, len(data)//batch_size), end='\r')
        y_teacher, teacher_attn = t.apply(t_p, **x, deterministic=True)
        y_teacher = jnp.argmax(y_teacher, axis=-1) if task_type == "classification" else y_teacher
        if is_grad_based:
            teacher_extras = {
                "grad_fn": t.apply(t_p, x, method=t.embeddings_grad_fn)
            }
            teacher_expl, _ = t_e.apply(t_e_p, x, teacher_attn, **teacher_extras)
        else:
            teacher_expl, _ = t_e.apply(t_e_p, x, teacher_attn)
        # teacher_rep = teacher_attn['hidden_states'][0][0]
        # teacher_attn = np.asarray(jnp.stack(teacher_attn['attentions']).transpose([1, 0, 2, 3, 4]))
        
        if s is not None:
            y_student, student_attn = s.apply(s_p, **x)
            y_student = jnp.argmax(y_student, axis=-1) if task_type == "classification" else y_student
            student_expl, _ = student_explainer.apply(s_e_p, x, student_attn)
            # student_attn = np.asarray(jnp.stack(student_attn['attentions']).transpose([1, 0, 2, 3, 4]))
        
        # convert everything to lists
        batch_ids = x['input_ids'].tolist()
        batch_tokens = [tokenizer.convert_ids_to_tokens(ids) for ids in batch_ids]
        batch_masks = [[not tk.startswith('##') for tk in tokens] for tokens in batch_tokens]
        batch_expls = teacher_expl.tolist()
        
        # filter out pad
        batch_valid_len = x['attention_mask'].sum(-1).tolist()
        batch_z = []
        for i in range(len(batch_valid_len)):
            n = batch_valid_len[i]
            batch_ids[i] = batch_ids[i][:n]
            batch_tokens[i] = batch_tokens[i][:n]
            batch_masks[i] = batch_masks[i][:n]
            batch_expls[i] = batch_expls[i][:n]
        
        all_tokens.extend(batch_tokens)
        all_masks.extend(batch_masks)
        all_explanations.extend(batch_expls)
        all_outputs.extend(y_teacher.tolist())
        
    return all_tokens, all_masks, all_explanations, all_outputs

In [None]:
total_ok = 0
total_er = 0

def find_index_sublist(v, u_pad):
    global total_ok, total_er
    m = u_pad.index(102)
    u = u_pad[1:m]
    n = len(u)
    for i in range(len(v)-n):
        if v[i:i+n] == u:
            total_ok += 1
            return i, i+n
    total_er += 1
    return None, None

def convert_evidences_to_mask(x, e):
    start_end_idxs = [find_index_sublist(x, e_) for e_ in e]
    mask = [0] * len(x)
    for a, b in start_end_idxs:
        if a is not None and b is not None:
            for j in range(a, b):
                mask[j] = 1
    return mask

all_gold_explanations = []
for i, sample in enumerate(test_data):
    if len(sample['evidences']) == 0:
        sample['evidences'] = ['justarandomwordhere']
    sample['review'] = sample['review'].replace('\n', ' ')
    x = tokenizer(
        sample['review'],
        padding="max_length",
        truncation=True,
        return_tensors="jax",
        max_length=max_len,
    )
    e = tokenizer(
        sample['evidences'],
        padding="max_length",
        truncation=True,
        return_tensors="jax",
        max_length=max_len,
    )
    # z = convert_evidences_to_mask(sample['review'], sample['evidences'])
    z = convert_evidences_to_mask(x['input_ids'][0].tolist(), e['input_ids'].tolist())
    n = x['attention_mask'].sum(-1).tolist()[0]
    z = z[:n]
    all_gold_explanations.append(z)

In [None]:
total_ok, total_er, total_ok / (total_ok+total_er)

In [None]:
valid_tokens, valid_masks, valid_explanations, valid_outputs = get_expls(
    test_data, teacher, teacher_params, teacher_explainer, teacher_explainer_params 
)
list(map(len, [valid_tokens, valid_masks, valid_explanations, valid_outputs]))

### Aggregate scores for word pieces in SRC and MT independently

In [None]:
import torch
from utils import aggregate_pieces

def get_piece_explanations(all_tokens, all_fp_masks, all_explanations, reduction):
    all_pieces = []
    for tokens, expl, fp_mask in zip(all_tokens, all_explanations, all_fp_masks):
        # aggregate word pieces scores (use my old good torch function)
        agg_expl = aggregate_pieces(torch.tensor(expl), torch.tensor(fp_mask), reduction)
        # remove <s> and </s>
        agg_expl = agg_expl.tolist()[1:-1]
        all_pieces.append(agg_expl)
    return all_pieces

In [None]:
reduction = 'sum'  # first, sum, mean, max
valid_masks = [[not tk.startswith('##') for tk in tokens] for tokens in valid_tokens]
all_expls = get_piece_explanations(
    valid_tokens, valid_masks, valid_explanations, reduction=reduction
)
all_gold_expls = get_piece_explanations(
    valid_tokens, valid_masks, all_gold_explanations, reduction='max'
)

## Evaluating explanations by comparing explanations with word-level QE tags

In [None]:
sum(map(sum, all_gold_expls))

In [None]:
gold_scores = [inp['label'] for inp in test_data]
gold_expls = all_gold_expls
pred_scores = valid_outputs
pred_expls = all_expls

In [None]:
print('Acc:', np.mean(np.array(gold_scores) == np.array(pred_scores)))

In [None]:
_ = evaluate_word_level(gold_expls, pred_expls)

In [None]:
def eval_plausibility(data, t, t_p, t_e, t_e_p, is_grad_based=False):
    valid_tokens, valid_masks, valid_explanations, valid_outputs = get_expls(
        data, t, t_p, t_e, t_e_p, is_grad_based=is_grad_based
    )
    pred_expls = get_piece_explanations(
        valid_tokens, valid_masks, valid_explanations, reduction='sum'
    )
    gold_scores = [inp['label'] for inp in data]
    gold_expls = all_gold_expls
    pred_scores = valid_outputs
    print('Acc:', np.mean(np.array(gold_scores) == np.array(pred_scores)))
    evaluate_word_level(gold_expls, pred_expls)
    return pred_expls
    

### meta-learned explainer

In [None]:
expls_mtl = eval_plausibility(
    test_data,
    teacher, 
    teacher_params, 
    teacher_explainer, 
    teacher_explainer_params, 
)

### all attention layers and heads

In [None]:
expls_all_attn = eval_plausibility(
    test_data,
    teacher, 
    teacher_params, 
    teacher_explainer_non_trained, 
    teacher_explainer_params_non_trained,

)

### gradient x input

In [None]:
eval_plausibility(
    test_data,
    teacher, 
    teacher_params, 
    input_gradient_teacher_explainer, 
    input_gradient_teacher_explainer_params,
    is_grad_based=True
)

### integrated gradients

In [None]:
eval_plausibility(
    test_data,
    teacher, 
    teacher_params, 
    int_gradient_teacher_explainer, 
    int_gradient_teacher_explainer_params,
    is_grad_based=True
)

### best attention layer

In [None]:
eval_plausibility(
    test_data,
    teacher, 
    teacher_params, 
    best_layer_teacher_explainer, 
    best_layer_teacher_explainer_params,
)

### best attention head

In [None]:
eval_plausibility(
    test_data,
    teacher, 
    teacher_params, 
    best_head_teacher_explainer, 
    best_head_teacher_explainer_params,
)

### last layer attention

In [None]:
flax.linen.softmax(teacher_params['params']['scalarmix']['coeffs'])  # first item is the embedding layer

In [None]:
flax.linen.softmax(teacher_params['params']['scalarmix']['coeffs']).argmax()

In [None]:
best_layer_teacher_explainer_params={
    'normalize_head_coeffs': 'sparsemax',
    'normalizer_fn': 'softmax',
    'aggregator_idx': 'mean',
    'aggregator_dim': 'row',
    'init_fn': 'uniform',
    'layer_idx': 11,
    'head_idx': None,
}
best_layer_teacher_explainer, best_layer_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='attention_explainer', 
    explainer_args=best_layer_teacher_explainer_params
)
eval_plausibility(
    test_data,
    teacher, 
    teacher_params, 
    best_layer_teacher_explainer, 
    best_layer_teacher_explainer_params,
)

In [None]:
teacher.num_encoder_layers, teacher.num_heads

In [None]:
for layer_id in range(12):
    print('-------------')
    print('layer:', layer_id)
    best_layer_teacher_explainer_params={
        'normalize_head_coeffs': 'sparsemax',
        'normalizer_fn': 'softmax',
        'aggregator_idx': 'mean',
        'aggregator_dim': 'row',
        'init_fn': 'uniform',
        'layer_idx': layer_id,
        'head_idx': None,
    }
    best_layer_teacher_explainer, best_layer_teacher_explainer_params = create_explainer(
        key=next(keyseq), 
        inputs=dummy_inputs, 
        state=dummy_state, 
        explainer_type='attention_explainer', 
        explainer_args=best_layer_teacher_explainer_params
    )
    eval_plausibility(
        test_data,
        teacher, 
        teacher_params, 
        best_layer_teacher_explainer, 
        best_layer_teacher_explainer_params,
    )

In [None]:
jax.random.bernoulli(next(keyseq), p=0.5, shape=(10,)).astype(int).tolist()

In [None]:
t_1 = sum(map(sum, gold_expls))
t_n = sum(map(len, gold_expls))

print(t_1/t_n)

fake_pred_expls = [
    [1]*len(ge) for ge in gold_expls 
]
_ = evaluate_word_level(gold_expls, fake_pred_expls)

fake_pred_expls2 = [
    jax.random.bernoulli(next(keyseq), p=t_1/t_n, shape=(len(ge),)).astype(int).tolist() for ge in gold_expls 
]
_ = evaluate_word_level(gold_expls, fake_pred_expls2)

In [None]:
for layer_id in range(12):
    for head_id in range(4):
        print('---------------------')
        print('layer {} | head {}'.format(layer_id, head_id))
        best_layer_teacher_explainer_params={
            'normalize_head_coeffs': 'sparsemax',
            'normalizer_fn': 'softmax',
            'aggregator_idx': 'mean',
            'aggregator_dim': 'row',
            'init_fn': 'uniform',
            'layer_idx': layer_id,
            'head_idx': head_id,
        }
        best_layer_teacher_explainer, best_layer_teacher_explainer_params = create_explainer(
            key=next(keyseq), 
            inputs=dummy_inputs, 
            state=dummy_state, 
            explainer_type='attention_explainer', 
            explainer_args=best_layer_teacher_explainer_params
        )
        eval_plausibility(
            test_data,
            teacher, 
            teacher_params, 
            best_layer_teacher_explainer, 
            best_layer_teacher_explainer_params,
        )