# Purpose of this file:

This notebook is just an implementation of LLMs Head Importance Calculation method. This method is described in the following paper: https://www.arxiv.org/pdf/2407.14679

In [13]:
import torch
import itertools
from typing import List
from transformers import BertModel, BertTokenizer

In [14]:
class LLMHeadImportance:
    def __init__(seld, model_name: str = "bert-base-uncased"):
        seld.model = BertModel.from_pretrained(model_name, output_attentions=True)
        seld.tokenizer = BertTokenizer.from_pretrained(model_name)
        seld.model.eval()
    
    def compute_head_importance(self, text: str) -> List[float]:

        # Tokenize the input fox
        inputs = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            return_attention_mask=True,
            return_tensors='pt',
            max_length=512,
            truncation=True,
            padding='max_length'
        )

        # Forward pass through the model
        with torch.no_grad():
            outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])

        # Get attention weights
        attentions = outputs.attentions # This is a tuple of attention weights for each layer

        head_importance = []
        for layer_attentions in attentions:

            # Layer_attentions shape: (batch_size, num_heads, sequence_length, sequence_length)
            layer_importance = []

            for head_idx in range(layer_attentions.size(1)):

                head_attention = layer_attentions[0, head_idx, :, :]

                # Compute the frobenius norm of the attention matrix
                importance = torch.norm(head_attention, p='fro')
                layer_importance.append(importance.item())
            
            head_importance.append(layer_importance)

        return head_importance
    
    def aggregate_importance(self, scores: List[float], method: str = "mean") -> float:
        if method == "mean":
            return sum(scores) / len(scores)
        elif method == "l2_norm":
            return (sum([s**2 for s in scores])**0.5)
        elif method == "variance":
            mean = sum(scores) / len(scores) 
            return sum([(s - mean)**2 for s in scores]) / len(scores)
        else:
            raise ValueError(f"Invalid aggregation method: {method}")
        

In [15]:
if __name__ == "__main__":
    # create an importance caluculator
    importance_calculator = LLMHeadImportance("bert-base-uncased")

    # input text
    text = "The quick brown fox jumps over the lazy dog."

    # compute the importance of each head
    head_importance = importance_calculator.compute_head_importance(text)

    print("Head Importance Scores: ")
    for i, layer_importance in enumerate(head_importance):
        for j, score in enumerate(layer_importance):
            print(f"Layer {i}, Head {j}: {score:.4f}")

    # aggregate the importance scores
    print(f"Mean: {sum(itertools.chain(*head_importance)) / len(head_importance):.4f}")
    print(f"L2 Norm: {(sum([s**2 for s in itertools.chain(*head_importance)])**(0.5)):.4f}")
    print(f"Variance: {(sum([(s - sum(itertools.chain(*head_importance)) / len(head_importance))**2 for s in itertools.chain(*head_importance)]) / len(head_importance)):.4f}")



Head Importance Scores: 
Layer 0, Head 0: 8.1735
Layer 0, Head 1: 8.8312
Layer 0, Head 2: 21.6100
Layer 0, Head 3: 21.7919
Layer 0, Head 4: 16.5113
Layer 0, Head 5: 12.6093
Layer 0, Head 6: 16.9935
Layer 0, Head 7: 8.6542
Layer 0, Head 8: 11.5274
Layer 0, Head 9: 11.0889
Layer 0, Head 10: 22.2144
Layer 0, Head 11: 18.9474
Layer 1, Head 0: 19.1049
Layer 1, Head 1: 13.4629
Layer 1, Head 2: 11.1345
Layer 1, Head 3: 15.2564
Layer 1, Head 4: 13.5162
Layer 1, Head 5: 11.4204
Layer 1, Head 6: 22.2157
Layer 1, Head 7: 12.3090
Layer 1, Head 8: 10.7550
Layer 1, Head 9: 7.6027
Layer 1, Head 10: 15.6025
Layer 1, Head 11: 7.3570
Layer 2, Head 0: 19.0868
Layer 2, Head 1: 18.1804
Layer 2, Head 2: 16.5020
Layer 2, Head 3: 10.1512
Layer 2, Head 4: 15.5794
Layer 2, Head 5: 15.0479
Layer 2, Head 6: 10.2510
Layer 2, Head 7: 14.1339
Layer 2, Head 8: 11.7780
Layer 2, Head 9: 19.9995
Layer 2, Head 10: 9.6312
Layer 2, Head 11: 15.1023
Layer 3, Head 0: 15.7930
Layer 3, Head 1: 12.4936
Layer 3, Head 2: 12.7143
