In [24]:
import torch
import torch.nn as nn
from transformers import GPT2Model, GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from typing import Optional, Tuple, Union

config = GPT2Config()

### Testing GPT-2 Original Attention Module

In [25]:
# Instance GPT-Attention Module
gpt2_att = GPT2Attention(config=config)

# Dummy data
batch_size = 1
seq_length = 5
hidden_size = config.hidden_size

hidden_states = torch.rand(batch_size, seq_length, hidden_size)
attention_mask = torch.ones(batch_size, seq_length)

# Forward
outputs = gpt2_att(hidden_states=hidden_states, attention_mask=attention_mask)

# Output
print("Output GPT-2:")
print(outputs)
print(outputs[0].shape)

Output GPT-2:
(tensor([[[ 0.0533, -0.0408, -0.0861,  ..., -0.1164,  0.0557,  0.0489],
         [ 0.0451, -0.0000, -0.0000,  ..., -0.0208,  0.0768,  0.1233],
         [ 0.0403, -0.0626, -0.0000,  ..., -0.0879,  0.2348,  0.0582],
         [ 0.1063, -0.1927, -0.2302,  ..., -0.0375,  0.1769,  0.0764],
         [-0.0045, -0.1324, -0.1975,  ..., -0.0925,  0.2451,  0.0033]]],
       grad_fn=<MulBackward0>), None)
torch.Size([1, 5, 768])


In [15]:
class InfiniAttentionGPT2(GPT2Attention):
    def __init__(self, config, is_cross_attention=False, layer_idx=None):
        super().__init__(config, is_cross_attention, layer_idx)

        # Initializing memory state for compressive memory
        self.memory_dim = config.hidden_size // config.num_attention_heads
        self.memory = nn.Parameter(
            torch.zeros((self.memory_dim, self.memory_dim))
        )  # Memory dimension: (d_k, d_v)

        # Initialize the beta parameter for combining A_mem and A_dot
        self.beta = nn.Parameter(torch.zeros(1))

        self.norm_term = nn.Parameter(
            torch.ones((self.memory_dim, 1)), requires_grad=False
        )

    def _mem_attention(self, query, prev_memory):
        """
        Compute the attention over the compressive memory.
        """
        # Ensure query and prev_memory have compatible dimensions
        bsz, num_heads, q_len, head_dim = query.size()
        memory_output = torch.zeros(
            bsz, num_heads, q_len, head_dim, device=query.device
        )

        for i in range(num_heads):
            sigma_Q = torch.sigmoid(query[:, i, :, :])
            memory_output[:, i, :, :] = (
                torch.matmul(sigma_Q, prev_memory) / self.norm_term
            )

        return memory_output

    def _combine_attention(self, A_dot, A_mem):
        """
        Combine local attention A_dot with memory attention A_mem using a weighted combination.
        """
        sigmoid_beta = torch.sigmoid(self.beta)
        return sigmoid_beta * A_mem + (1 - sigmoid_beta) * A_dot

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
        if encoder_hidden_states is not None:
            if not hasattr(self, "q_attn"):
                raise ValueError(
                    "If class is used as cross attention, the weights `q_attn` have to be defined. "
                    "Please make sure to instantiate class with `InfiniAttentionGPT2(..., is_cross_attention=True)`."
                )

            query = self.q_attn(hidden_states)
            key, value = self.c_attn(encoder_hidden_states).split(
                self.split_size, dim=2
            )
            attention_mask = encoder_attention_mask
        else:
            query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

        query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)

        if layer_past is not None:
            past_key, past_value = layer_past
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)

        if use_cache is True:
            present = (key, value)
        else:
            present = None

        # Compute the attention weights
        A_dot, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
        A_mem = self._mem_attention(query, self.memory)

        # DEBUG
        if A_dot.size() != A_mem.size():
            A_mem = A_mem.expand_as(A_dot)

        # InfiniAttention: Combine local attention with memory attention
        attn_output = self._combine_attention(A_dot, A_mem)
        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)

        outputs = (attn_output, present)
        if output_attentions:
            outputs += (attn_weights,)

        # Update the memory state
        sigma_K = torch.sigmoid(key)
        self.memory.data = self.memory.data + torch.matmul(
            sigma_K.transpose(-1, -2), value
        )
        self.norm_term.data = self.norm_term.data + torch.unsqueeze(
            sigma_K.sum(dim=-2), -2
        )

        return outputs  # attn_output, present, (attentions)


# Load config
config = GPT2Config()
model = GPT2Model(config)

# Replace the attention module with InfiniAttention
for i, layer in enumerate(model.h):
    model.h[i].attn = InfiniAttentionGPT2(config, layer_idx=i)

# Dummy test
input_ids = torch.tensor([[1, 2, 3, 4, 5]])
outputs = model(input_ids)
print("Output com InfiniAttention do GPT-2:")
print(outputs)

RuntimeError: The size of tensor a (5) must match the size of tensor b (64) at non-singleton dimension 1