# Generate explanations

In this notebook, we generate explanations for all samples of an example dataset and cache it for visualization in SemLa.

This notebook depends on the outputs from `embed-dataset.ipynb`.

In [1]:
import torch
import math
from curses import raw
from lib2to3.pgen2 import token
from flask import Flask, send_from_directory, request
from numba import jit
import numpy as np
from scipy.special import softmax
from transformers import BertModel
from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions


MAX_LENGTH = 30

In [2]:
DEVICE = "cpu"
DATASET = "banking"
if DATASET == "banking":
    num_labels = 77

@jit
def inner_product_distance(a,b, tau=15):
    return np.exp(-np.sum(a * b) / tau)**2


class BertForImportanceAttribution(BertModel):

    def setMode(self, mode):
        self.mode = mode

    def forward(self, 
                input_ids, 
                precomputed_encoding=None,
                attention_mask=None,
                token_type_ids=None,
                output_hidden_states=False,
                output_attentions=False):

        if self.mode == "integrad_from_similarity":
            encoding = super().forward(input_ids).last_hidden_state[:,0]
            
            if precomputed_encoding is not None:
                import torch

                similarity = torch.inner(encoding, precomputed_encoding)
                similarity = similarity.sum(dim=-1)
                return similarity
            else:
                return encoding

        elif self.mode == "integrad":
            encoding = super().forward(input_ids).last_hidden_state[:,0]
            output = encoding.sum(dim=-1)
            return output
        
        elif self.mode == "vanilla_grad":
            extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_ids.size(), device=DEVICE)
            embedding_output = self.embeddings(
                input_ids=input_ids,
                position_ids=None,
                token_type_ids=token_type_ids,
                inputs_embeds=None,
                past_key_values_length=0,
            )

        
            embedding_output.retain_grad()

            encoder_outputs = self.encoder(
                embedding_output,
                attention_mask=extended_attention_mask,
                head_mask=None,
                encoder_hidden_states=None,
                encoder_attention_mask=None,
                past_key_values=None,
                use_cache=False,
                output_attentions=False,
                output_hidden_states=output_hidden_states
            )
            sequence_output = encoder_outputs[0]
            pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

            return BaseModelOutputWithPoolingAndCrossAttentions(
                last_hidden_state=sequence_output,
                pooler_output=pooled_output,
            ), embedding_output
        elif self.mode is None:
            return super().forward(input_ids, 
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids,
                    output_hidden_states=output_hidden_states,
                    output_attentions=output_attentions
                )


In [3]:
def get_tokens_with_matching_case(tokens, text):
    to_return = []
    for token in tokens:
        text_lowercase = text.lower()
        
        token = token.lower()
        pos = text_lowercase.find(token)
        to_return.append(text[pos:len(token)])
        text = text[len(token):].strip()
        
    return to_return

In [4]:
def attention_importance(tokenizer, model, text, device="cuda"):
    with torch.no_grad():
        tokenized_input = tokenizer(text,
                                    max_length=MAX_LENGTH, 
                                    truncation=True, 
                                    return_tensors="pt")
        tokenized_input.to(device)
        model.to(device)
        outputs = model(**tokenized_input, output_attentions=True)
        attentions = torch.stack(outputs.attentions)
        attentions_aggregated = attentions.squeeze().sum(0).sum(0).detach().cpu()
        attentions_importance = attentions_aggregated.sum(0) / 144
        tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"][0])
        tokens = tokens[1:-1]

        importance = attentions_importance.tolist()[1:-1]
    return importance, tokens


def lime_importance(tokenizer, model, text, support_set, device=None):

    def one_sentence_tokenize(text):
        tokens = tokenizer.tokenize(text)
        tokens = [token.replace("#", "") for token in tokens]
        tokens = get_tokens_with_matching_case(tokens, text)
        return tokens

    def encode(texts):
        with torch.no_grad():
            tokenized_xs = tokenizer.batch_encode_plus(texts, max_length=50, 
                                                    truncation=True, padding=True, return_tensors="pt")
            outputs = []
            num_batches = math.ceil(len(texts) / BATCH_LIMIT)
            
            for batch_idx in range(num_batches):
                tokenized_xs_batched = dict(
                    input_ids=tokenized_xs["input_ids"][batch_idx*BATCH_LIMIT: (batch_idx + 1) * BATCH_LIMIT].to(device),
                    token_type_ids=tokenized_xs["token_type_ids"][batch_idx*BATCH_LIMIT: (batch_idx + 1) * BATCH_LIMIT].to(device),
                    attention_mask=tokenized_xs["attention_mask"][batch_idx*BATCH_LIMIT: (batch_idx + 1) * BATCH_LIMIT].to(device),
                )
                outputs_batched = model(**tokenized_xs_batched)
                outputs.append( outputs_batched.last_hidden_state[:,0,:] )
        
            outputs = torch.cat(outputs)
            return outputs

    def classify(texts):
        import torch
        text_encoding = encode(texts)
        similarities = torch.inner(text_encoding, support_encodings)
        probs = torch.softmax(similarities/TAU, dim=-1)
        return probs.detach().cpu().numpy()

    from lime.lime_text import LimeTextExplainer

    if device is None:
        device = "cuda" if len(text.split(" ")) < 10 else "cpu"

    model.to(device)
    BATCH_LIMIT = 1
    TAU = 15

    support_encodings = encode(support_set["text"])
    probs = classify([text])
    label = probs[0].argmax()

    LIME_explainer_1sent = LimeTextExplainer(
                        class_names=support_set["label"], 
                        bow=False, 
                        split_expression=one_sentence_tokenize, 
                        mask_string=tokenizer.mask_token)
    exp = LIME_explainer_1sent.explain_instance(text,
                        classify, top_labels=5, num_samples=100, 
                        num_features=MAX_LENGTH)
    
    tokens = one_sentence_tokenize(text)
    importance = sorted(exp.as_map()[label], key=lambda x: x[0])
    importance = [weight for pos, weight in importance]
    return importance, tokens


def integrad_importance(tokenizer, model, text, txt2=None, device="cuda"):
    from captum.attr import LayerIntegratedGradients

    model.to(device)
    model.setMode("integrad_from_similarity")
    with torch.no_grad():
        tokenized_inputs2 = tokenizer(txt2, 
                                max_length=MAX_LENGTH,
                                truncation=True,
                                return_tensors="pt")
        tokenized_inputs2.to(device)
        encoding2 = model(tokenized_inputs2["input_ids"])

    lig = LayerIntegratedGradients(model, model.embeddings)
    tokenized_inputs = tokenizer(text, 
                                max_length=MAX_LENGTH, 
                                truncation=True, 
                                return_tensors="pt")
    tokenized_inputs.to(device)
    input = tokenized_inputs["input_ids"]
    
    attributions_ig, delta = lig.attribute(
        (input, encoding2), 
        return_convergence_delta=True,
        attribute_to_layer_input=False
    )
    tokens = tokenizer.tokenize(text)
    importance = attributions_ig.sum(-1).squeeze().tolist()[1:-1]
    model.setMode(None)

    return importance, tokens


def gradient_importance(tokenizer, model, text, txt2=None, device="cuda"):
    def encode(text, tokenizer, model, device="cuda", output_last_hiddens=False):
        tokenized_inputs = tokenizer(text, 
                                    max_length=MAX_LENGTH, 
                                    truncation=True, 
                                    return_tensors="pt")
        tokenized_inputs.to(device)
        model.to(device)
        
        outputs, embeddings = model(**tokenized_inputs,
                            output_hidden_states=True)
        encoding = outputs.last_hidden_state[:,0,:]

        if output_last_hiddens:
            return encoding, embeddings, outputs.last_hidden_state.squeeze()[1:-1,:]
        return encoding, embeddings

    model.to(device)
    model.setMode(None)
    with torch.no_grad():
        tokenized_inputs2 = tokenizer(txt2, 
                                max_length=MAX_LENGTH,
                                truncation=True,
                                return_tensors="pt")
        tokenized_inputs2.to(device)
        encoding2 = model(tokenized_inputs2["input_ids"]).last_hidden_state[:,0]

    model.setMode("vanilla_grad")
    
    encoding, embeddings = encode(text, tokenizer, model, device)
    similarity = torch.inner(encoding, encoding2)
    similarity = similarity.sum(dim=-1).backward()
    
    importance = embeddings.grad.abs().sum(-1).squeeze().tolist()[1:-1]
    tokens = tokenizer.tokenize(text)
    model.setMode(None)
    return importance, tokens
    

def token_encoding_relation(tokenizer, model, txt1, txt2, device="cuda"):
    def encode(text, tokenizer, model, device="cuda"):
        tokenized_inputs = tokenizer(text, 
                                    max_length=MAX_LENGTH, 
                                    truncation=True, 
                                    return_tensors="pt")
        tokenized_inputs.to(device)
        model.to(device)
        
        outputs = model(**tokenized_inputs,
                            output_hidden_states=True)
        return outputs.last_hidden_state.squeeze()[1:-1,:]
    
    encodings1 = encode(txt1, tokenizer, model, device)
    encodings2 = encode(txt2, tokenizer, model, device)

    with torch.no_grad():
        token_similarities = torch.inner(encodings1, encodings2)
    
    return {"links": token_similarities.tolist(),
            "tokens1": tokenizer.tokenize(txt1),
            "tokens2": tokenizer.tokenize(txt2)}


def integrad_relation(tokenizer, model, txt1, txt2, device="cuda"):
    from captum.attr import LayerIntegratedGradients

    model.setMode("integrad_from_similarity")
    model.to(device)
    with torch.no_grad():
        tokenized_inputs2 = tokenizer(txt2, 
                                max_length=MAX_LENGTH,
                                truncation=True,
                                return_tensors="pt")
        tokenized_inputs2.to(device)
        encoding2 = model(tokenized_inputs2["input_ids"])

    lig = LayerIntegratedGradients(model, model.embeddings)
    tokenized_inputs = tokenizer(txt1, 
                                    max_length=MAX_LENGTH,
                                    truncation=True,
                                    return_tensors="pt")
    tokenized_inputs.to(device)
    input = tokenized_inputs["input_ids"]

    attributions_ig, delta = lig.attribute(
        (input, encoding2), 
        return_convergence_delta=True,
        attribute_to_layer_input=False
    )
    tokens = tokenizer.tokenize(txt1)
    importance = attributions_ig.sum(-1).squeeze().tolist()[1:-1]
    model.setMode(None)

    return tokens, importance

In [5]:
from datasets import load_dataset


class TextProcessor:
    def __init__(self, dataset, num_labels):
        self.tau = 15

        # Load dataset for its metadata
        from datasets import Dataset, DatasetDict
        from sklearn import preprocessing

        self.datasets = dict()
        self.prediction_data = dict()

        for dataset in ["banking77"]:
            dataset = "banking77"
            raw_datasets = load_dataset(dataset)            
            le = preprocessing.LabelEncoder()
            le.fit(raw_datasets["train"]["label"] + raw_datasets["test"]["label"])
            self.datasets[dataset] = raw_datasets

            import json
            prediction_data_file = "./results/banking77-with-BERT-Banking77/banking77-viz_data-12-clusters-label_cluster_chosen_by_majority_in-predicted-label-with-BERT-Banking77.json"
            with open(prediction_data_file, "r") as f:
                self.prediction_data[dataset] = json.load(f)

        # Load model
        from transformers import AutoModel, AutoTokenizer
        checkpoint = "philschmid/BERT-Banking77" 

        self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        self.model = BertForImportanceAttribution.from_pretrained(checkpoint)
        self.model.setMode(None)
        self.model.to(DEVICE)

    def process(self, text):
        encoding = self.encode(text)
        return {"encoding": encoding.tolist()}

    def encode(self, text):
        tokenized = self.tokenizer(text, return_tensors="pt")
        tokenized.to(DEVICE)
        outputs = self.model(**tokenized)
        return outputs.last_hidden_state.squeeze()[0].detach().cpu().numpy()
    
    def importance(self, dataset, index, method):
        text = self.datasets[dataset]["test"]["text"][index]

        pred_data = self.prediction_data[dataset][index]
        support_set_idxs = pred_data["support_set"]
        support_set = self.datasets[dataset]["test"][support_set_idxs]
        distances = np.array(pred_data["distances"]).squeeze()
        closest_dp_idx = distances.argmax()
        closest_text = support_set["text"][closest_dp_idx]

        if method == "attention":
            importance, tokens = attention_importance(self.tokenizer, self.model, text)
            return importance, tokens
        elif method == "lime":
            importance = lime_importance(self.tokenizer, self.model, text, support_set)
            return importance
        elif method == "integrad":
            importance = integrad_importance(self.tokenizer, self.model, text, txt2=closest_text)
            return importance
        elif method == "gradient":
            importance = gradient_importance(self.tokenizer, self.model, text, txt2=closest_text)
            return importance

    def importances_all(self, dataset, index):
        attn_importance, tokens = self.importance(dataset, index, "attention")
        lime_importance, tokens = self.importance(dataset, index, "lime")
        integrad_importance, tokens = self.importance(dataset, index, "integrad")
        grad_importance, tokens = self.importance(dataset, index, "gradient")

        return {"tokens": tokens, 
                "attn_importance": attn_importance,
                "lime_importance": lime_importance,
                "grad_importance": grad_importance, 
                "integrad_importance": integrad_importance}
    
    def relation(self, dataset, index1, index2, reltype):
        txt1 = self.datasets[dataset]["test"]["text"][index1]
        txt2 = self.datasets[dataset]["test"]["text"][index2]

        if reltype == "token2token":
            return token_encoding_relation(
                self.tokenizer,
                self.model,
                txt1,
                txt2,
            )
        else:
            tokens1, importance1 = integrad_relation(
                self.tokenizer,
                self.model,
                txt1,
                txt2
            )
            tokens2, importance2 = integrad_relation(
                self.tokenizer,
                self.model,
                txt2,
                txt1
            )
            return {"tokens1": tokens1,
                    "tokens2": tokens2,
                    "importance1": importance1,
                    "importance2": importance2,}

In [6]:
text_processor = TextProcessor(DATASET, num_labels)

Using custom data configuration default
Reusing dataset banking77 (/home/mojo/.cache/huggingface/datasets/banking77/default/1.1.0/aec0289529599d4572d76ab00c8944cb84f88410ad0c9e7da26189d31f62a55b)
Some weights of the model checkpoint at philschmid/BERT-Banking77 were not used when initializing BertForImportanceAttribution: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertForImportanceAttribution from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForImportanceAttribution from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
from datasets import Dataset, DatasetDict
from sklearn import preprocessing

all_datasets = dict()
prediction_data = dict()
for dataset in ["banking77"]:
    raw_datasets = load_dataset(dataset)            
    le = preprocessing.LabelEncoder()
    le.fit(raw_datasets["train"]["label"] + raw_datasets["test"]["label"])

    all_datasets[dataset] = raw_datasets

    import json
    prediction_data_file = "results/banking77-with-BERT-Banking77/banking77-viz_data-12-clusters-label_cluster_chosen_by_majority_in-predicted-label-with-BERT-Banking77.json"
    with open(prediction_data_file, "r") as f:
        prediction_data[dataset] = json.load(f)
        
    prediction_data

Using custom data configuration default
Reusing dataset banking77 (/home/mojo/.cache/huggingface/datasets/banking77/default/1.1.0/aec0289529599d4572d76ab00c8944cb84f88410ad0c9e7da26189d31f62a55b)


In [8]:
def get_contrastive_idxs(entry, data):
    support_set = entry["support_set"]
    support_set_distances = entry["distances"]
    support_set_topk = torch.topk(torch.tensor(support_set_distances), len(support_set))
    support_set_sorted = [support_set[idx] 
                              for idx in support_set_topk.indices.squeeze().tolist()]
    support_set_labels = [data[idx]["ground_truth_label_idx"] for idx in support_set]
    fact_idx = support_set_sorted[0]

    is_error = entry["prediction_label_idx"] != entry["ground_truth_label_idx"]
    if is_error:
        # if it's an error, the contrast instance is the ground_truth instance
        support_set_labels = [data[idx]["ground_truth_label_idx"] for idx in support_set]
        same_label_support = [i for i, label in enumerate(support_set_labels) 
                             if label == entry["ground_truth_label_idx"]][0]
        contrast_idx = support_set[same_label_support]
    else:
        # if it's not an error, the contrast instance is the second closest
        contrast_idx = support_set_sorted[1]
    return fact_idx, contrast_idx

In [None]:
from tqdm.auto import tqdm
import json
import torch
import os

demo_data_dir = "./results/banking77-with-BERT-Banking77/explanations"

if not os.path.isdir(demo_data_dir):
    os.mkdir(demo_data_dir)

for dataset_name, data in prediction_data.items():
    if dataset_name in ["clinc", "banking"]:
        continue
    print(dataset_name, len(data))
    importances = []
    tok2tok_relations = []
    tok2sim_relations = []

    for i in tqdm(range(len(data))):
        entry = data[i]
        fact_idx, contrast_idx = get_contrastive_idxs(entry, data)
        is_error = entry["prediction_label_idx"] != entry["ground_truth_label_idx"]
        
        print(i, fact_idx, contrast_idx, "is_error:", is_error, "|", entry["text"], )
        print("\t\t\t", data[fact_idx]["text"])
        print("\t\t\t", data[contrast_idx]["text"])
        
        importance_all = text_processor.importances_all(dataset_name, i)
        tok2tok_rel_left = text_processor.relation(
                                        dataset_name, 
                                        i, 
                                        contrast_idx,
                                        "token2token")
        tok2tok_rel_right = text_processor.relation(
                                        dataset_name,
                                        i, 
                                        fact_idx,
                                        "token2token")        
        tok2sim_rel_left = text_processor.relation(
                                        dataset_name, 
                                        i, 
                                        contrast_idx, 
                                        "token2sim")
        tok2sim_rel_right = text_processor.relation(
                                        dataset_name, 
                                        i, 
                                        fact_idx, 
                                        "token2sim")
        importances.append(importance_all)
        tok2tok_relations.append(dict(left=tok2tok_rel_left, 
                                      right=tok2tok_rel_right))
        tok2sim_relations.append(dict(left=tok2sim_rel_left,
                                      right=tok2sim_rel_right))
        print(tok2tok_rel_left["tokens1"], tok2tok_rel_left["tokens2"])
        print(tok2sim_rel_left["tokens1"], tok2sim_rel_right["tokens2"])
        
        if (i % 500 == 0):
            with open(f"{demo_data_dir}/importances.json", "w") as f:
                json.dump(importances, f)
            with open(f"{demo_data_dir}/token2token_relations.json", "w") as f:
                json.dump(tok2tok_relations, f)
            with open(f"{demo_data_dir}/token2similarity_relations.json", "w") as f:
                json.dump(tok2sim_relations, f)
            with open(f"{demo_data_dir}/num_examples.json", "w") as f:
                json.dump({"count": i}, f)
    
    with open(f"{demo_data_dir}/importances.json", "w") as f:
        json.dump(importances, f)
    with open(f"{demo_data_dir}/token2token_relations.json", "w") as f:
        json.dump(tok2tok_relations, f)
    with open(f"{demo_data_dir}/token2similarity_relations.json", "w") as f:
        json.dump(tok2sim_relations, f)
    with open(f"{demo_data_dir}/num_examples.json", "w") as f:
        json.dump({"count": i}, f)