# Training GPT-2 Model with InfiniAttention Module

In [128]:
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from typing import Optional, Tuple, Union

### Standard GPT2LMHeadModel structure

In [129]:
config = GPT2Config()
model = GPT2LMHeadModel(config)

In [130]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

### Testing GPT-2 Original Attention Module

In [134]:
# Instance GPT-Attention Module
gpt2_att = GPT2Attention(config=config)

# Dummy data
batch_size = 1
seq_length = 5
hidden_size = config.hidden_size

hidden_states = torch.rand(batch_size, seq_length, hidden_size)
attention_mask = torch.ones(batch_size, seq_length)

# Forward
outputs = gpt2_att(hidden_states=hidden_states, attention_mask=attention_mask)

# Output
print("Output GPT-2 att:")
print(f"{outputs[0].shape=}")

Output GPT-2 att:
outputs[0].shape=torch.Size([1, 5, 768])


In [136]:
class InfiniAttentionGPT2(GPT2Attention):
    def __init__(self, config, is_cross_attention=False, layer_idx=None):
        super().__init__(config, is_cross_attention, layer_idx)

        # Initializing memory state for compressive memory
        self.memory_dim = config.hidden_size // config.num_attention_heads
        self.memory = nn.Parameter(
            torch.zeros(
                (1, config.num_attention_heads, self.memory_dim, self.memory_dim)
            ),
            requires_grad=False,
        )  # Memory dimension: (1, heads, d_k, d_v)

        # Initialize the beta parameter for combining A_mem and A_dot
        self.beta = nn.Parameter(torch.zeros(1))

        self.norm_term = nn.Parameter(
            torch.ones((1, config.num_attention_heads, self.memory_dim, 1)),
            requires_grad=False,
        )  # Normalization term for memory: (1, heads, d_k, 1)

    def _mem_attention(self, query, prev_memory):
        """
        Compute the attention over the compressive memory.
        """
        # Ensure query and prev_memory have compatible dimensions
        bsz, num_heads, q_len, head_dim = query.size()
        memory_output = torch.zeros(
            bsz, num_heads, q_len, head_dim, device=query.device
        )

        for i in range(num_heads):
            sigma_Q = torch.sigmoid(
                query[:, i, :, :]
            )  # query: (batch, head, seq_len, head_dim)
            memory_output[:, i, :, :] = torch.matmul(
                sigma_Q, prev_memory[0, i, :, :]
            ) / torch.matmul(
                sigma_Q, self.norm_term[0, i, :, :]
            )  # prev_memory: (1, head, d_k, d_v)

        return memory_output

    def _combine_attention(self, A_dot, A_mem):
        """
        Combine local attention A_dot with memory attention A_mem using a weighted combination.
        """
        sigmoid_beta = torch.sigmoid(self.beta)
        return (
            sigmoid_beta * A_mem + (1 - sigmoid_beta) * A_dot
        )  # Combine A_dot and A_mem: (batch, head, seq_len, head_dim)

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
        if encoder_hidden_states is not None:
            if not hasattr(self, "q_attn"):
                raise ValueError(
                    "If class is used as cross attention, the weights `q_attn` have to be defined. "
                    "Please make sure to instantiate class with `InfiniAttentionGPT2(..., is_cross_attention=True)`."
                )

            query = self.q_attn(hidden_states)
            key, value = self.c_attn(encoder_hidden_states).split(
                self.split_size, dim=2
            )
            attention_mask = encoder_attention_mask
        else:
            query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

        query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)

        if layer_past is not None:
            past_key, past_value = layer_past
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)

        if use_cache is True:
            present = (key, value)
        else:
            present = None

        # Compute the attention weights
        A_dot, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

        A_mem = self._mem_attention(query, self.memory)

        # InfiniAttention: Combine local attention with memory attention
        attn_output = self._combine_attention(A_dot, A_mem)

        # GPT-2 Post-processing
        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)

        outputs = (attn_output, present)
        if output_attentions:
            outputs += (attn_weights,)

        # Infini Attention: Update memory
        sigma_K = torch.sigmoid(key)
        self.memory.data = self.memory.data + torch.matmul(
            sigma_K.transpose(-1, -2), value
        )

        # Infinit Attention: Update normalization term
        self.norm_term.data = self.norm_term.data + sigma_K.sum(dim=-2, keepdim=True)

        return outputs  # attn_output, present, (attentions)


infini_att = InfiniAttentionGPT2(config)

# Forward
outputs = infini_att(hidden_states=hidden_states, attention_mask=attention_mask)

# Output
print("Output GPT-2 infini-att:")
print(outputs)
print(outputs[0].shape)

Output GPT-2 infini-att:
(tensor([[[-0.1790, -0.0998,  0.0127,  ...,  0.0549, -0.0000,  0.0000],
         [-0.1459, -0.0298, -0.0354,  ...,  0.0625, -0.1075,  0.0449],
         [-0.1223, -0.0444, -0.0838,  ...,  0.0181, -0.0000,  0.0318],
         [-0.0000, -0.0550, -0.0654,  ...,  0.0000, -0.1357,  0.0188],
         [-0.1222, -0.0350, -0.0432,  ...,  0.0046, -0.1037,  0.0123]]],
       grad_fn=<MulBackward0>), None)
torch.Size([1, 5, 768])


### Testing accumulation in sequential forwards

In [138]:
iterations = 5

for _ in range(iterations):
    outputs = infini_att(hidden_states=hidden_states, attention_mask=attention_mask)

    print("Output GPT-2 infini-att:")
    print(outputs)

Output GPT-2 infini-att:
(tensor([[[-0.2837, -0.0251, -0.0000,  ...,  0.1189, -0.0000,  0.0151],
         [-0.2035, -0.1411, -0.0979,  ..., -0.0282, -0.1847, -0.0063],
         [-0.2418, -0.0000, -0.0882,  ..., -0.0358, -0.0000,  0.0410],
         [-0.2166, -0.0862, -0.1134,  ...,  0.0203, -0.2318,  0.0272],
         [-0.1767, -0.0512, -0.1114,  ..., -0.0530, -0.1729,  0.0075]]],
       grad_fn=<MulBackward0>), None)
Output GPT-2 infini-att:
(tensor([[[-0.2807, -0.0000, -0.0022,  ...,  0.0329, -0.1628,  0.0323],
         [-0.2505, -0.0927, -0.0610,  ..., -0.0168, -0.2041,  0.0142],
         [-0.2274, -0.1152, -0.0938,  ...,  0.0080, -0.2222,  0.0188],
         [-0.1876, -0.1098, -0.1280,  ..., -0.0021, -0.2356,  0.0268],
         [-0.2040, -0.0000, -0.1116,  ..., -0.0170, -0.2422,  0.0707]]],
       grad_fn=<MulBackward0>), None)
Output GPT-2 infini-att:
(tensor([[[-0.2542, -0.1438, -0.0895,  ...,  0.0378, -0.2081,  0.0186],
         [-0.2486, -0.1124, -0.0000,  ...,  0.0110, -0.0000, 

### Training Model

In [110]:
model = GPT2LMHeadModel(config)

# Replace the attention module with InfiniAttention
for i, layer in enumerate(model.transformer.h):
    model.transformer.h[i].attn = InfiniAttentionGPT2(config, layer_idx=i)

In [139]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [119]:
# Training

dataset = ["Este é um carro azul", "Esta é uma casa vermelha"] * 100

In [120]:
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(dataset, padding=True, truncation=True, return_tensors="pt")

In [121]:
# training

from torch.utils.data import DataLoader, Dataset


class CustomDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)


train_dataset = CustomDataset(inputs)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

In [122]:
epoch = 10
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

for _ in range(epoch):
    for batch in train_loader:
        optimizer.zero_grad()
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]

        outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
        loss = outputs.loss
        loss.backward()

        print(f"Loss: {loss.item()}")
        print("Perplexity: ", torch.exp(loss).item())
        optimizer.step()

  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


Loss: 6.882748603820801
Perplexity:  975.3034057617188
Loss: 6.571261882781982
Perplexity:  714.2706298828125
Loss: 5.873733997344971
Perplexity:  355.57421875
Loss: 5.2701029777526855
Perplexity:  194.4359893798828
Loss: 4.578592300415039
Perplexity:  97.3772201538086
Loss: 4.25078010559082
Perplexity:  70.16012573242188
Loss: 3.6458864212036133
Perplexity:  38.31672286987305
Loss: 3.053598403930664
Perplexity:  21.191463470458984
Loss: 3.4408464431762695
Perplexity:  31.213367462158203
Loss: 3.0555572509765625
Perplexity:  21.233015060424805
Loss: 2.0906982421875
Perplexity:  8.09056282043457
Loss: 1.765692114830017
Perplexity:  5.845616817474365
Loss: 1.4354183673858643
Perplexity:  4.201402187347412
Loss: 1.1778771877288818
Perplexity:  3.2474730014801025
Loss: 0.7602536678314209
Perplexity:  2.1388187408447266
Loss: 0.8171280026435852
Perplexity:  2.2639882564544678
Loss: 0.5947479009628296
Perplexity:  1.8125739097595215
Loss: 0.44478729367256165
Perplexity:  1.5601582527160645
L

KeyboardInterrupt: 

In [124]:
# Inference

text = "Este é um carro"

input_ids = tokenizer(text, return_tensors="pt")["input_ids"]

output = model.generate(input_ids, max_length=100)

for out in output:
    print(tokenizer.decode(out, skip_special_tokens=True))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Este é um carro azul
