Introduction here

In [1]:
# importing necessary documents

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import os
import contextlib
from tqdm import tqdm
from IPython.core.display import display, HTML
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm
  from IPython.core.display import display, HTML


Initializing the bert tokenizer and model

In [2]:
# Loading the model and tokenizer

model_name = "facebook/bart-large-mnli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, output_attentions=True)

# Setting the model to evaluation mode
with open(os.devnull, 'w') as fnull:
    with contextlib.redirect_stdout(fnull):
        model.eval()

In [3]:
# Defining the input texts

premise = "It rains heavily in april"
hypothesis = "Crops have lack of water in april"

In [4]:
# Finding the class labels of the model

print(model.config.id2label)

{0: 'contradiction', 1: 'neutral', 2: 'entailment'}


In [5]:
# Since our hypothesis is a contradiction, setting the target class accordingly

target_class = 0

In [6]:
# Vectorizing the text

input_vectors = tokenizer(premise, hypothesis, return_tensors="pt")
input_vectors = input_vectors["input_ids"]

In [7]:
# defining the function to compute the integrated gradients

def _compute_integrated_gradients(attention_matrix, embeddings, steps=3):
    baseline = torch.zeros_like(attention_matrix)  # No attention baseline
    scaled_inputs = [(baseline + (float(i) / steps) * (attention_matrix - baseline)) for i in range(steps + 1)]
    gradients = []
    # print(scaled_inputs[0].shape)
    for scaled_input in scaled_inputs:
        scaled_input_ = scaled_input.clone()
        embeddings.requires_grad_()
        embeddings.retain_grad()
        embeddings.retain_graph = True
        # print("--------------------------------------")
        # print(scaled_input_.shape)
        # print("--------------------------------------")
        output = model.model.encoder(inputs_embeds=embeddings, attention_mask = scaled_input_, return_dict=True)
        class_logits = model.classification_head(output.last_hidden_state[:, 0, :])
        loss = class_logits[0, target_class]
        # model.zero_grad()
        loss.backward(retain_graph=True)
        # print("00000000000000000000000000000000000000000000000000")
        # print(embeddings.shape)
        # print(embeddings.grad.mean(dim=2).shape)
        gradients.append(embeddings.grad.mean(dim=2))
    # print(gradients)
    avg_gradients = torch.mean(torch.stack(gradients), dim=0)
    attributions = (attention_matrix - baseline) * avg_gradients
    return attributions  # Aggregate per-token attributions

In [8]:
# defining the function to compute the attention attributions
inputs = []

def compute_attention_attributions():

    # generating embeddings of the inputs for forward and backward pass
    embedding_layer = model.model.shared
    embeddings = embedding_layer(input_vectors)
    embeddings.requires_grad_()
    embeddings.retain_grad()
    # return inputs
    output = model.model.encoder(inputs_embeds=embeddings, return_dict=True)
    attentions = output.attentions
    # print(len(attentions))  # List of attention tensors (layer-wise)
    # print(attentions[0].size())  # First layer attention tensor
        
    # Compute attribution scores using Integrated Gradients
    attributions = []
    for layer_attention in attentions:
        for at_head in tqdm(layer_attention[0]):
            # at_head is the attention matrix for the head
            # converting it to linear
            at_head = at_head.mean(dim=0) 
            at_head = at_head.unsqueeze(0)  # Add batch dimension
            # print(at_head.shape)
            head_attr = _compute_integrated_gradients(at_head, embeddings)
            attributions.append(head_attr)

    return attributions
        

In [9]:
text_input = "The father punished his son very badly because he has consumed a lot of alcohol"

attributions = compute_attention_attributions()

100%|██████████| 16/16 [00:20<00:00,  1.27s/it]
100%|██████████| 16/16 [00:21<00:00,  1.34s/it]
100%|██████████| 16/16 [00:23<00:00,  1.44s/it]
100%|██████████| 16/16 [00:23<00:00,  1.49s/it]
100%|██████████| 16/16 [00:24<00:00,  1.52s/it]
100%|██████████| 16/16 [00:25<00:00,  1.58s/it]
100%|██████████| 16/16 [00:26<00:00,  1.67s/it]
100%|██████████| 16/16 [00:28<00:00,  1.78s/it]
100%|██████████| 16/16 [00:28<00:00,  1.79s/it]
100%|██████████| 16/16 [00:29<00:00,  1.85s/it]
100%|██████████| 16/16 [00:30<00:00,  1.93s/it]
100%|██████████| 16/16 [00:32<00:00,  2.05s/it]


In [10]:
def normalize_attributions(attributions):
    attributions = np.array(attributions)  # Convert to NumPy array
    min_attr = attributions.min()
    max_attr = attributions.max()

    # Normalize to [0, 255]
    normalized = 255 * (attributions - min_attr) / (max_attr - min_attr + 1e-8)  # Avoid division by zero
    return normalized.astype(int)


In [11]:
for idx, layer in enumerate(attributions):
    if idx%16 != 0:
        continue
    # layer = layer.sum(dim=1)
    layer = layer.detach().clone().numpy()
    
    norm_layer = normalize_attributions(layer)
    # print(norm_layer)
    html_string = ""
    for index, val in enumerate(norm_layer[0]):
        str_print = tokenizer.decode(input_vectors.numpy()[0][index])
        if str_print[0] == "<":
            continue
        html_string += f"<span style='color:rgba(0,{val}, 0)'> {str_print} </span>"
    # print(html_string)
    display(HTML(html_string)) 

In [14]:
attributions = [attributions[i * 16:(i + 1) * 16] for i in range(12)]

In [150]:
mask = []
for layer in attributions:
    layer_mask = []
    for head in layer:
        head_copy = head.detach().clone().numpy()
        head_importance = max(normalize_attributions(head_copy)[0])
        head_mask = head_importance > 50
        # print(normalize_attributions(head_copy))
        # print(head_importance)
        head_mask = int(head_mask)
        layer_mask.append(head_mask)
    mask.append(layer_mask)
print(mask)

[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0], [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]


In [174]:
import types

def prune_tree(model, mask):
    encoder = model.model.encoder
    layers = encoder.layers
    for idx, layer in enumerate(layers):
        attn_weights = layer.self_attn
        original_forward = attn_weights.forward 

        def new_forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
            outputs = original_forward(hidden_states, attention_mask, layer_head_mask=None)
            attention_scores = outputs[0]

            # Apply pruning mask
            attention_scores = attention_scores * mask[idx].unsqueeze(0).unsqueeze(0)

            return (attention_scores,) + outputs[1:]
        layer.self_attn.forward = types.MethodType(new_forward, layer.self_attn)
    return model

In [175]:
pruned_model = prune_tree(model, mask)

In [176]:
pruned_model(input_vectors)

TypeError: prune_tree.<locals>.new_forward() got an unexpected keyword argument 'layer_head_mask'