In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("../")

In [3]:
from IPython.core.display import HTML

In [4]:
import numpy as np

from xbert_tasks.predictor_utils import load_predictor

from xbert_tasks.classification.models.text_classifier import TextClassifier
from xbert_tasks.classification.predictors.text_classifier_predictor import TextClassifierPredictor
from xbert_tasks.classification.dataset_readers.sst2_dataset_reader import Sst2DatasetReader

In [5]:
from xbert.engine import Engine, weight_of_evidence, difference_of_log_probabilities, calculate_correlation
from xbert import InputInstance, Config
from xbert.visualization import visualize_relevances
from xbert.occlusion.explainer_allennlp import AllenNLPIntegrateGradExplainer

In [6]:
# import torch
# from torch import Tensor
# from torch.utils.hooks import RemovableHandle
# from allennlp.data.dataset import Batch
# from allennlp.nn import util


# class AllenNLPVanillaGradExplainer:
#     def __init__(self, predictor, output_getter=None):
#         self.predictor = predictor
#         self.output_getter = output_getter

#     def _backprop(self, instances, ind, register_forward_hooks: bool = True):
#         embedding_gradients: List[Tensor] = []
#         grad_hooks: List[RemovableHandle] = self._register_embedding_gradient_hooks(embedding_gradients)
            
#         embedding_values: List[Tensor] = []
            
#         if register_forward_hooks:
#             val_hooks: List[RemovableHandle] = self._register_embedding_value_hooks(embedding_values)
        
#         model = self.predictor._model
        
#         cuda_device = model._get_prediction_device()
    
#         dataset = Batch(instances)
#         dataset.index_instances(model.vocab)
        
#         model_input = util.move_to_device(dataset.as_tensor_dict(), cuda_device)
        
#         output = model.decode(
#             model.forward(**model_input)  # type: ignore
#         )

#         if self.output_getter is not None:
#             output = self.output_getter(output)
        
#         grad_out = output.data.clone()
#         grad_out.fill_(0.0)
#         grad_out.scatter_(1, ind.unsqueeze(0).t(), 1.0)
        
#         model.zero_grad()
        
#         output.backward(grad_out)
        
#         for hook in grad_hooks:
#             hook.remove()
            
#         if register_forward_hooks:
#             for hook in val_hooks:
#                 hook.remove()
        
#         return embedding_values, embedding_gradients

#     def explain(self, instances, ind):
#         return self._backprop(instances, ind)[1]
    
#     def _register_embedding_gradient_hooks(self, embedding_gradients):

#         def backward_hook_layers(module, grad_in, grad_out):
#             embedding_gradients.append(grad_out[0])

#         backward_hooks = []
#         embedding_layer = util.find_embedding_layer(self.predictor._model)
#         embedding_layer.weight.requires_grad = True
#         backward_hooks.append(embedding_layer.register_backward_hook(backward_hook_layers))
#         return backward_hooks
    
#     def _register_embedding_value_hooks(self, embedding_values):

#         def forward_hook_layers(module, input, output):
#             embedding_values.append(output)

#         forward_hooks = []
#         embedding_layer = util.find_embedding_layer(self.predictor._model)
#         embedding_layer.weight.requires_grad = True
#         forward_hooks.append(embedding_layer.register_forward_hook(forward_hook_layers))
#         return forward_hooks
    
    
# class AllenNLPGradxInputExplainer(AllenNLPVanillaGradExplainer):
#     def __init__(self, predictor, output_getter=None):
#         super().__init__(predictor=predictor,
#                          output_getter=output_getter)

#     def explain(self, instances, ind):
#         inputs, grads = self._backprop(instances, ind)
#         return [input * grad for input, grad in zip(inputs, grads)]
    
    
# class AllenNLPSaliencyExplainer(AllenNLPVanillaGradExplainer):
#     def __init__(self, predictor, output_getter=None):
#         super().__init__(predictor=predictor,
#                          output_getter=output_getter)

#     def explain(self, instances, ind):
#         _, grads = self._backprop(instances, ind)
#         return [grad.abs() for grad in grads]


# class AllenNLPIntegrateGradExplainer(AllenNLPVanillaGradExplainer):
#     def __init__(self, predictor, steps=100, output_getter=None):
#         super().__init__(predictor=predictor,
#                          output_getter=output_getter)
#         self.steps = steps

#     def explain(self, instances, ind):
#         grads = [0 for _ in instances]
#         inputs = []

#         for alpha in np.linspace(0, 1.0, num=self.steps, endpoint=False):
#             embedding_values: List[Tensor] = []
#             handle: RemovableHandle = self._register_embedding_value_hook(alpha, inputs)
            
#             _, grads_current = self._backprop(instances, ind, register_forward_hooks=False)
            
#             handle.remove()
            
#             grads = [grad + grad_c for grad, grad_c in zip(grads, grads_current)]

#         return [input * grad / self.steps for input, grad in zip(inputs, grads)]
    
#     def _register_embedding_value_hook(self, alpha: float, embedding_values):

#         def forward_hook(module, inputs, output):
#             # Save the input for later use. Only do so on first call.
#             if alpha == 0:
#                 embedding_values.append(output.squeeze(0).clone().detach())

#             # Scale the embedding by alpha
#             output.mul_(alpha)

#         embedding_layer = util.find_embedding_layer(self.predictor._model)
#         embedding_layer.weight.requires_grad = True
#         return embedding_layer.register_forward_hook(forward_hook)

In [7]:
CUDA_DEVICE = 0 # or -1 if no GPU is available

MODEL_DIR = "../models/sst2/"
PREDICTOR_NAME = "sst_text_classifier"

In [8]:
from allennlp.predictors import Predictor
from allennlp.models.archival import load_archive

archive = load_archive(MODEL_DIR + "model.tar.gz")
predictor = load_predictor(MODEL_DIR, PREDICTOR_NAME, CUDA_DEVICE, archive_filename="model.tar.gz", weights_file=None)

In [9]:
SST_DATASET_PATH = "~/Downloads/SST-2/"

dataset_instances = predictor._dataset_reader.read(SST_DATASET_PATH + "dev.tsv")

872it [00:00, 47441.22it/s]


In [10]:
len(dataset_instances)

872

In [19]:
from collections import defaultdict

import torch
import torch.nn.functional as F
torch.backends.cudnn.enabled = False


def batcher(batch_instances):
    label2idx = predictor._model.vocab.get_token_to_index_vocabulary("labels")
    
    true_label_indices = []
    batch_dicts = []
    for instance in batch_instances:
        idx = instance.id
        true_label_idx = label2idx[dataset_instances[idx].fields["label"].label]
        true_label_indices.append(true_label_idx)
        batch_dicts.append(dict(text=instance.text.tokens))
    
    results = predictor.predict_batch_json(batch_dicts)
    
    return [result["class_probabilities"][tl_idx] for (result, tl_idx) in zip(results, true_label_indices)]
    
    
def batcher_gradient(batch_instances):
    label2idx = predictor._model.vocab.get_token_to_index_vocabulary("labels")
    
    explainer = AllenNLPIntegrateGradExplainer(predictor=predictor,
                                               output_getter=lambda x: F.softmax(x["logits"], dim=-1))
    
    relevances = []
    for instance in batch_instances:
        idx = instance.id
        true_label_idx = label2idx[dataset_instances[idx].fields["label"].label]
        
        inst = predictor._json_to_instance(dict(text=instance.text.tokens))
        expl = explainer.explain([inst], ind=torch.tensor([true_label_idx], dtype=torch.long, device=CUDA_DEVICE))[0]
        expl_np = expl.sum(dim=-1).squeeze().detach().cpu().numpy().tolist()
        
        relevance_dict = defaultdict(float)
        for token_idx, relevance in enumerate(expl_np):
            relevance_dict[("text", token_idx)] = relevance
        relevances.append(relevance_dict)

    return relevances
    

config_unk = Config.from_dict({
    "strategy": "unk_replacement",
    "batch_size": 128,
    "unk_token": "<unk>"
})

config_del = Config.from_dict({
    "strategy": "delete",
    "batch_size": 128,
})

config_gradient = Config.from_dict({
    "strategy": "gradient",
    "batch_size": 2
})

config_resample = Config.from_dict({
    "strategy": "bert_lm_sampling",
    "std": True,
    "cuda_device": 0,
    "bert_model": "bert-base-uncased",
    "batch_size": 256,
    "n_samples": 100,
    "verbose": False
})

unknown_engine = Engine(config_unk, batcher)
delete_engine = Engine(config_del, batcher)
resample_engine = Engine(config_resample, batcher)
gradient_engine = Engine(config_gradient, batcher_gradient)

In [20]:
instance_idx = 0
n = 5
input_instances = [InputInstance(id_=idx, text=[t.text for t in dataset_instance.fields["tokens"].tokens])
                   for idx, dataset_instance
                   in zip(range(instance_idx, instance_idx+n), dataset_instances[instance_idx: instance_idx+n])]

In [21]:
unk_candidate_instances, unk_candidate_results = unknown_engine.run(input_instances[instance_idx: instance_idx+n])
del_candidate_instances, del_candidate_results = delete_engine.run(input_instances[instance_idx: instance_idx+n])
res_candidate_instances, res_candidate_results = resample_engine.run(input_instances[instance_idx: instance_idx+n])
grad_candidate_instances, grad_candidate_results = gradient_engine.run(input_instances[instance_idx: instance_idx+n])

100%|██████████| 5/5 [00:00<00:00, 28263.50it/s]
1it [00:00, 120.94it/s]
100%|██████████| 5/5 [00:00<00:00, 25236.49it/s]
1it [00:00, 141.87it/s]
100%|██████████| 5/5 [00:00<00:00,  8.66it/s]
5it [00:00, 64.86it/s]               
100%|██████████| 5/5 [00:00<00:00, 101311.69it/s]
3it [00:04,  1.57s/it]                       


In [22]:
unk_relevances = unknown_engine.relevances(unk_candidate_instances, unk_candidate_results)
del_relevances = delete_engine.relevances(del_candidate_instances, del_candidate_results)
res_relevances = resample_engine.relevances(res_candidate_instances, res_candidate_results)
grad_relevances = gradient_engine.relevances(grad_candidate_instances, grad_candidate_results)

In [23]:
labels_true = [instance.fields["label"].label for instance in dataset_instances[instance_idx: instance_idx+n]]
labels_pred = [predictor.predict_instance(instance)["label"] for instance in dataset_instances[instance_idx: instance_idx+n]]

In [24]:
HTML(visualize_relevances(input_instances[instance_idx: instance_idx+n], res_relevances, labels_true, labels_pred))

In [18]:
HTML(visualize_relevances(input_instances[instance_idx: instance_idx+n], res_relevances, labels_true, labels_pred))

In [18]:
HTML(visualize_relevances(input_instances[instance_idx: instance_idx+n], unk_relevances, labels_true, labels_pred))

In [19]:
HTML(visualize_relevances(input_instances[instance_idx: instance_idx+n], del_relevances, labels_true, labels_pred))

In [20]:
HTML(visualize_relevances(input_instances[instance_idx: instance_idx+n], res_relevances, labels_true, labels_pred))

In [21]:
HTML(visualize_relevances(input_instances[instance_idx: instance_idx+n], grad_relevances, labels_true, labels_pred))

In [25]:
print(calculate_correlation(unk_relevances, res_relevances))
print(calculate_correlation(unk_relevances, grad_relevances))
print(calculate_correlation(unk_relevances, del_relevances))
print(calculate_correlation(del_relevances, res_relevances))
print(calculate_correlation(del_relevances, grad_relevances))
print(calculate_correlation(res_relevances, grad_relevances))

0.3632763330793503
0.3827816202907428
0.5675043686385648
0.49687429693119034
0.5988491751629121
0.5616808325624799
