In [3]:
from gpt2_attentions_aproximations import GPT2Wrapper
import torch
from transformers import GPT2Tokenizer, GPT2Model

# Load the pretrained GPT2 model and tokenizer
model_name = "gpt2"
model = GPT2Model.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

# Input sequence
text = "Hello, how are you?"

# Tokenize the input sequence
input_ids = tokenizer.encode(text, return_tensors="pt")

# Get the attention mask
attention_mask = torch.ones_like(input_ids)

# Forward pass through the original GPT2 model
with torch.no_grad():
    original_outputs = model(input_ids, attention_mask=attention_mask, output_attentions=True)

# Print the original attention values for the first layer and first head
original_attention = original_outputs.attentions[0][0]

# Modify the GPT2 model with MeanGPT2Attention
modified_model = GPT2Wrapper(model)

# Forward pass through the modified GPT2 model
with torch.no_grad():
    modified_outputs = modified_model(input_ids, attention_mask=attention_mask, output_attentions=True)

# Print the modified attention values for the first layer
modified_attention = modified_outputs.attentions[0][0].mean(dim=-1).unsqueeze(1)
print("Modified Attention:")
print(modified_attention)

# Calculate the mean of the original attention values
mean_original_attention = original_attention.mean(dim=-1).unsqueeze(1)
print("Mean Original Attention:")
print(mean_original_attention)

# Check if mean of the modified attention values match the mean of the original attention values
mean_equal = torch.allclose(modified_attention, mean_original_attention)
print("Mean Attention Equal:", mean_equal)


Modified Attention:
tensor([[[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

        [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

        [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

        [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

        [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

        [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

        [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

        [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

        [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

        [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

        [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

        [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]]])
Mean Original Attention:
tensor([[[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

        [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

        [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

        [[0.1667, 0.1667, 0.1667, 0.1667