In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
from path_explain import utils
utils.set_up_environment(visible_devices='1')

In [4]:
import tensorflow as tf
import tensorflow_datasets
import numpy as np
import pandas as pd
import altair as alt
import scipy
from bert_explainer import BertExplainerTF
from path_explain.path_explainer_tf import PathExplainerTF
from transformers import *
from plot.text import text_plot
import transformers
from tqdm import tqdm
from functools import reduce

## Data and Model Loading

In [5]:
task = 'sts-b'
num_labels = len(glue_processors[task]().get_labels())

In [6]:
config = BertConfig.from_pretrained('.', num_labels=num_labels)
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = TFBertForSequenceClassification.from_pretrained('.', config=config)

In [7]:
data, info = tensorflow_datasets.load('glue/stsb', with_info=True)

INFO:absl:Overwrite dataset info from restored data version.
INFO:absl:Reusing dataset glue (/homes/gws/psturm/tensorflow_datasets/glue/stsb/0.0.2)
INFO:absl:Constructing tf.data.Dataset for split None, from /homes/gws/psturm/tensorflow_datasets/glue/stsb/0.0.2


In [8]:
train_dataset = glue_convert_examples_to_features(data['train'], tokenizer, max_length=128, task=task)
valid_dataset = glue_convert_examples_to_features(data['validation'], tokenizer, max_length=128, task=task)
valid_dataset = valid_dataset.batch(16)

## Model Evaluation

In [9]:
valid_pred = model.predict(valid_dataset)

In [10]:
valid_input = []
valid_labels = []
for batch in valid_dataset:
    valid_input.append(batch[0])
    valid_labels.append(batch[1].numpy())
valid_labels_np = np.concatenate(valid_labels, axis=0)

In [38]:
concat_valid_ids = [batch['input_ids'] for batch in valid_input]
concat_valid_ids = tf.concat(concat_valid_ids, axis=0)

In [11]:
valid_mse = np.mean(np.square(valid_pred[:, 0] - valid_labels_np))

print('Validation MSE: {:.4f} (variance explained: {:.4f})'.format(valid_mse,
                                                                   1.0 - valid_mse / np.var(valid_labels_np)))

Validation MSE: 0.5685 (variance explained: 0.7440)


In [12]:
label_df = pd.DataFrame({
    'Predicted Label': valid_pred[:, 0],
    'True Label': valid_labels_np
})

In [13]:
alt.Chart(label_df).transform_density(
    'Predicted Label',
    as_=['Predicted Label', 'density'],
    extent=[-0.5, 5.5],
    groupby=['True Label']
).mark_area(orient='horizontal').encode(
    y='Predicted Label:Q',
    color='True Label:N',
    x=alt.X(
        'density:Q',
        stack='center',
        impute=None,
        title=None,
        axis=alt.Axis(labels=False, values=[0],grid=False, ticks=True),
    ),
    column=alt.Column(
        'True Label:N',
        header=alt.Header(
            titleOrient='bottom',
            labelOrient='bottom',
            labelPadding=0,
        ),
    )
).properties(
    width=100
).configure_facet(
    spacing=0
).configure_view(
    stroke=None
)

In [42]:
batch_pred = valid_pred[:16]
batch_labels = valid_labels[0]
batch_input = valid_input[0]
batch_ids = batch_input['input_ids']
batch_baseline = concat_valid_ids

In [43]:
explainer = BertExplainerTF(model)

In [46]:
try:
    attributions = np.load('attributions.npy')
except FileNotFoundError as e:
    attributions = explainer.attributions(inputs=batch_ids,
                                          baseline=batch_baseline,
                                          batch_size=30,
                                          num_samples=1000,
                                          use_expectation=True,
                                          output_indices=0,
                                          verbose=True)
    np.save('attributions.npy', attributions)

100%|██████████| 16/16 [06:38<00:00, 24.93s/it]


In [47]:
def check_completeness(index):
    current_input = {
        'input_ids': batch_input['input_ids'][index:index+1],
        'attention_mask': batch_input['attention_mask'][index:index+1],
        'token_type_ids': batch_input['token_type_ids'][index:index+1],
    }

    current_baseline = {
        'input_ids': np.zeros((1, 128)).astype(int),
        'attention_mask': batch_input['attention_mask'][index:index+1],
        'token_type_ids': batch_input['token_type_ids'][index:index+1],
    }

    current_output = model(current_input)[0]
    baseline_output = model(current_baseline)[0]
    output_difference = current_output - baseline_output
    output_difference = output_difference[0, 0]
    sum_attr = np.sum(attributions[index, :])

    encoded_sentence = batch_input['input_ids'].numpy()[index]
    encoded_sentence = encoded_sentence[encoded_sentence != 0]
    decoded_sentence = tokenizer.decode(encoded_sentence)
    
    sentences = decoded_sentence.split('[SEP]')
    first_sentence = sentences[0][6:]
    second_sentence = sentences[1][1:]
    label = batch_labels[index]
    print('1) ' + first_sentence)
    print('2) ' + second_sentence)
    print('True similarity: {} - Predicted similarity: {:.4f})'.format(label, current_output[0, 0]))
    print('Output difference:\t{:.4f} ({:.4f} - {:.4})'.format(output_difference,
                                                               current_output[0, 0],
                                                               baseline_output[0, 0]))
    print('Sum of attributions:\t{:.4f}'.format(sum_attr))
    print('-------------------------')

In [48]:
for i in range(16):
    check_completeness(i)

1) Representatives for Puretunes could not immediately be reached for comment Wednesday. 
2) Puretunes representatives could not be located Thursday to comment on the suit. 
True similarity: 3 - Predicted similarity: 3.7143)
Output difference:	1.4649 (3.7143 - 2.249)
Sum of attributions:	-0.0377
-------------------------
1) North Korea Nuclear Test Sparks Worry 
2) North Korea nuclear test 
True similarity: 2 - Predicted similarity: 2.8583)
Output difference:	0.9098 (2.8583 - 1.949)
Sum of attributions:	-0.7320
-------------------------
1) A man in red swim trunks playing volleyball. 
2) A man and woman sitting on a motorcycle. 
True similarity: 0 - Predicted similarity: -0.0216)
Output difference:	-2.1308 (-0.0216 - 2.109)
Sum of attributions:	0.0436
-------------------------
1) A group of people are at a convention waving American flags. 
2) A crowd of people at an outdoor event 
True similarity: 2 - Predicted similarity: 1.6829)
Output difference:	-0.4440 (1.6829 - 2.127)
Sum of att