# Training GPT-2 Model with InfiniAttention Module

In [1]:
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from typing import Optional, Tuple, Union

### Standard GPT2LMHeadModel structure

In [2]:
config = GPT2Config()
model = GPT2LMHeadModel(config)

In [3]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

### Testing GPT-2 Original Attention Module

In [4]:
# Instance GPT-Attention Module
gpt2_att = GPT2Attention(config=config)

# Dummy data
batch_size = 1
seq_length = config.n_positions
hidden_size = config.hidden_size

hidden_states = torch.rand(batch_size, seq_length, hidden_size)
attention_mask = torch.ones(batch_size, seq_length)

# Forward
outputs = gpt2_att(hidden_states=hidden_states, attention_mask=attention_mask)

# Output
print("Output GPT-2 att:")
print(f"{outputs[0].shape=}")

Output GPT-2 att:
outputs[0].shape=torch.Size([1, 1024, 768])


In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, Conv1D
from transformers import GPT2Model, GPT2Config
from typing import Optional, Tuple, Union


class InfiniAttentionGPT2(GPT2Attention):
    def __init__(
        self,
        config,
        is_cross_attention=False,
        layer_idx=None,
        n_segments=16,
        is_causal: Optional[bool] = True,
        update: Optional[str] = "linear",
    ):
        """
        Initialize the InfiniAttentionGPT2 module.

        Args:
            config: GPT2Config object containing configuration parameters.
            is_cross_attention (bool): Flag to indicate if this is cross attention.
            layer_idx (int, optional): Index of the layer.
            n_segments (int): Number of segments for memory processing.
            is_causal (bool, optional): Flag to indicate if attention is causal.
            update (str, optional): Update strategy for memory, either 'linear' or another strategy.
        """
        super().__init__(config, is_cross_attention, layer_idx)

        # Initializing memory state for compressive memory
        self.d_head = config.hidden_size // config.num_attention_heads
        self.n_head = config.num_attention_heads

        # Initialize the beta parameter for combining A_mem and A_dot
        self.beta = nn.Parameter(torch.zeros(1), requires_grad=True)

        self.elu = nn.ELU()

        # Sequence length
        self.seq_len = config.n_positions

        self.is_causal = is_causal
        self.register_buffer(
            "causal",
            torch.tril(
                torch.ones(self.seq_len // n_segments, self.seq_len // n_segments)
            ),
        )

        # Segment size
        self.n_segments = n_segments
        self.segment_size = self.seq_len // n_segments

        # Update strategy
        self.update = update

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:

        batch_size, _, _ = hidden_states.size()

        device = hidden_states.device

        memory = torch.zeros((self.n_head, self.d_head, self.d_head), device=device)
        z = torch.zeros((self.n_head, self.d_head, 1), device=device)

        # Project hidden states to query, key, value
        query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

        # Reshape query, key, value for segmented processing
        query = query.reshape(
            batch_size,
            self.n_head,
            self.n_segments,
            self.segment_size,
            self.d_head,
        )
        key = key.reshape(
            batch_size,
            self.n_head,
            self.n_segments,
            self.segment_size,
            self.d_head,
        )
        value = value.reshape(
            batch_size,
            self.n_head,
            self.n_segments,
            self.segment_size,
            self.d_head,
        )

        outputs = []

        for idx in range(self.n_segments):

            attention_mask_segment = attention_mask[
                :, idx * self.segment_size : (idx + 1) * self.segment_size
            ]

            sigma_q = (
                self.elu(query[:, :, idx, :, :]) + 1.0
            )  # [bsz, n_head, segment, seq_len, head_dim]
            sigma_k = (
                self.elu(key[:, :, idx, :, :]) + 1.0
            )  # [bsz, n_head, segment, seq_len, head_dim]

            A_mem = (sigma_q @ memory) / (
                (sigma_q @ z) + 1e-6
            )  # [bsz, n_head, segment, seq_len, head_dim]

            A_dot = query[:, :, idx, :, :] @ key[:, :, idx, :, :].transpose(-2, -1)

            if self.is_causal:
                A_dot.masked_fill_(self.causal == 0, float("-inf"))

            A_dot = F.softmax(
                A_dot / torch.tensor(self.d_head, device=device) ** 0.5, dim=-1
            )
            A_dot = A_dot @ value[:, :, idx, :, :]

            attention = (F.sigmoid(self.beta) * A_mem) + (
                (1 - F.sigmoid(self.beta)) * A_dot
            )

            # Update memory
            if self.update == "linear":
                memory = memory + (sigma_k.transpose(-2, -1) @ value[:, :, idx, :, :])
            else:
                delta = (sigma_k @ memory) / ((sigma_k @ z) + 1e-6)
                memory = memory + (
                    sigma_k.transpose(-2, -1) @ (value[:, :, idx, :, :] - delta)
                )

            z = z + sigma_k.sum(dim=-2, keepdim=True)

            outputs.append(attention)

        # Concatenate outputs from all segments
        final_output = torch.cat(outputs, dim=2).view(
            batch_size, self.seq_len, self.embed_dim
        )

        return (final_output, None)


# Supondo que o `config`, `hidden_states`, e `attention_mask` já tenham sido definidos.
config = GPT2Config()
config.n_positions = 1024  # Definindo um exemplo de tamanho de sequência
hidden_states = torch.randn(2, 1024, config.hidden_size)  # Exemplo de estados ocultos
attention_mask = torch.ones(2, 1024)  # Máscara de atenção de exemplo

infini_att_gpt2 = InfiniAttentionGPT2(config=config, is_causal=True)

# Forward
outputs = infini_att_gpt2(hidden_states=hidden_states, attention_mask=attention_mask)

# Output
print("Output InfiniAttention GPT-2 att:")
print(f"{outputs[0].shape=}")

Output InfiniAttention GPT-2 att:
outputs[0].shape=torch.Size([2, 1024, 768])


### Training Model

In [33]:
model_type = "gpt2"  # "gpt2" or "gpt2-infini"

In [34]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [35]:
# Training

dataset = ["Este é um carro azul.", "Esta é uma casa vermelha."] * 10

In [38]:
from transformers import GPT2Tokenizer

if model_type == "gpt2":
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    inputs = tokenizer(dataset, padding=True, truncation=True, return_tensors="pt")
    model = GPT2LMHeadModel(config).to(device)
else:
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2", padding_side="left")
    tokenizer.pad_token = tokenizer.eos_token
    inputs = tokenizer(
        dataset, padding="max_length", truncation=True, return_tensors="pt"
    )

    model = GPT2LMHeadModel(config)

    # Replace the attention module with InfiniAttention
    for i, layer in enumerate(model.transformer.h):
        model.transformer.h[i].attn = InfiniAttentionGPT2(
            config, layer_idx=i, is_causal=False
        )

    model = model.to(device)

In [39]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [40]:
from torch.utils.data import DataLoader
from datasets import Dataset

train_dataset = Dataset.from_dict(inputs)
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

### Train loop

In [41]:
from tqdm.auto import tqdm

epoch = 1
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

for epoch_idx in range(epoch):
    print(f"Epoch: {epoch_idx+1}/{epoch}")
    epoch_loss = 0.0

    for batch in tqdm(
        train_loader, desc=f"Training Epoch {epoch_idx+1}/{epoch}", leave=False
    ):
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
        loss = outputs.loss
        loss.backward()

        optimizer.step()

        # Acumula a perda para monitoramento
        epoch_loss += loss.item()

    # Calcula a perda média da época
    avg_epoch_loss = epoch_loss / len(train_loader)
    print(f"Loss: {avg_epoch_loss:.4f}")
    print(f"Perplexity: {torch.exp(torch.tensor(avg_epoch_loss)):.4f}")

Epoch: 1/1


Training Epoch 1/1:   0%|          | 0/10 [00:00<?, ?it/s]

Loss: 4.2348
Perplexity: 69.0465


In [42]:
# Inference

text = "Este é um carro"

input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)

output = model.generate(
    input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id
)

for out in output:
    print(tokenizer.decode(out, skip_special_tokens=True))

Este é um carro azul.
