# Training GPT-2 Model with InfiniAttention Module

In [None]:
!pip install datasets
!pip install accelerate -U
!pip install transformers -U

Collecting datasets
  Downloading datasets-2.19.2-py3-none-any.whl (542 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.1/542.1 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m
Collecting requests>=2.32.1 (from datasets)
  Downloading requests-2.32.3-py3-none-any.whl (64 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.9/64.9 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m23.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)


In [None]:
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from typing import Optional, Tuple, Union

### Standard GPT2LMHeadModel structure

In [None]:
config = GPT2Config()
model = GPT2LMHeadModel(config)

In [None]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

### Testing GPT-2 Original Attention Module

In [None]:
# Instance GPT-Attention Module
gpt2_att = GPT2Attention(config=config)

# Dummy data
batch_size = 1
seq_length = config.n_positions
hidden_size = config.hidden_size

hidden_states = torch.rand(batch_size, seq_length, hidden_size)
attention_mask = torch.ones(batch_size, seq_length)

# Forward
outputs = gpt2_att(hidden_states=hidden_states, attention_mask=attention_mask)

# Output
print("Output GPT-2 att:")
print(f"{outputs[0].shape=}")

Output GPT-2 att:
outputs[0].shape=torch.Size([1, 1024, 768])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, Conv1D
from transformers import GPT2Model, GPT2Config
from typing import Optional, Tuple, Union


class InfiniAttentionGPT2(GPT2Attention):
    def __init__(
        self,
        config,
        is_cross_attention=False,
        layer_idx=None,
        n_segments=16,
        is_causal: Optional[bool] = True,
        update: Optional[str] = "linear",
    ):
        """
        Initialize the InfiniAttentionGPT2 module.

        Args:
            config: GPT2Config object containing configuration parameters.
            is_cross_attention (bool): Flag to indicate if this is cross attention.
            layer_idx (int, optional): Index of the layer.
            n_segments (int): Number of segments for memory processing.
            is_causal (bool, optional): Flag to indicate if attention is causal.
            update (str, optional): Update strategy for memory, either 'linear' or another strategy.
        """
        super().__init__(config, is_cross_attention, layer_idx)

        # Initializing memory state for compressive memory
        self.d_head = config.hidden_size // config.num_attention_heads
        self.n_head = config.num_attention_heads

        # Initialize the beta parameter for combining A_mem and A_dot
        self.beta = nn.Parameter(torch.zeros(1), requires_grad=True)

        self.elu = nn.ELU()

        # Sequence length
        self.seq_len = config.n_positions

        self.is_causal = is_causal
        self.register_buffer(
            "causal",
            torch.tril(
                torch.ones(self.seq_len // n_segments, self.seq_len // n_segments)
            ),
        )

        # Segment size
        self.n_segments = n_segments
        self.segment_size = self.seq_len // n_segments

        # Update strategy
        self.update = update

    def _retrieve_from_memory(self, query_states):
        # Retrieve context from compressive memory using linear attention (Eq. 3)
        if self.memory is None:
            return torch.zeros_like(query_states)
        query_states = F.elu(query_states) + 1  # ELU activation
        memory_output = torch.matmul(query_states, self.memory) / self.norm_term
        return memory_output

    def _update_memory(self, key_states, value_states):
        # Update compressive memory with new key-value states (Eq. 4)
        key_states = F.elu(key_states) + 1  # ELU activation
        if self.memory is not None:
            self.memory = self.memory + torch.matmul(
                key_states.transpose(-2, -1), value_states
            )
        else:
            self.memory = torch.matmul(key_states.transpose(-2, -1), value_states)
        if self.norm_term is not None:
            self.norm_term = self.norm_term + torch.unsqueeze(
                key_states.sum(dim=-2), -2
            )
        else:
            self.norm_term = torch.unsqueeze(key_states.sum(dim=-2), -2)

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
        debug: Optional[bool] = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:

        self.norm_term = None
        self.memory = None

        batch_size, _, _ = hidden_states.size()

        device = hidden_states.device

        qkv = self.c_attn(hidden_states)
        query, key, value = qkv.split(self.split_size, dim=2)

        # segments = torch.tensor_split(
        #     hidden_states,
        #     list(range(self.segment_size, hidden_states.size(1), self.segment_size)),
        #     dim=1,
        # )

        segments_q = torch.tensor_split(
            query,
            list(range(self.segment_size, query.size(1), self.segment_size)),
            dim=1,
        )

        segments_k = torch.tensor_split(
            key,
            list(range(self.segment_size, key.size(1), self.segment_size)),
            dim=1,
        )

        segments_v = torch.tensor_split(
            value,
            list(range(self.segment_size, value.size(1), self.segment_size)),
            dim=1,
        )

        final_outputs = []
        final_k = []
        final_attn = []

        # print(f"{hidden_states.shape=}")
        # print(f"{self.c_attn(hidden_states).shape=}")
        # print(f"{segments_q[0].shape=}")

        for i, segment in enumerate(segments_q):
            # print(f"{segment.shape=}")
            # qkv_segment = self.c_attn(segment)
            # query, key, value = qkv_segment.split(self.split_size, dim=2)

            query = segments_q[i]
            key = segments_k[i]
            value = segments_v[i]

            query = self._split_heads(
                query, num_heads=self.n_head, attn_head_size=self.d_head
            )
            key = self._split_heads(
                key, num_heads=self.n_head, attn_head_size=self.d_head
            )
            final_k.append(key)
            value = self._split_heads(
                value, num_heads=self.n_head, attn_head_size=self.d_head
            )

            bsz, q_len, _ = segment.size()

            # print(f"{query.shape=}")
            memory_output = self._retrieve_from_memory(query)
            self._update_memory(key, value)

            # print(f"{attention_mask.shape=}")
            if attention_mask is not None:
                attention_mask_segment = attention_mask[:, :, :, : self.segment_size]
            # print(f"{attention_mask_segment.shape=}")
            attn_outputs = self._attn(
                query, key, value, attention_mask_segment, head_mask
            )
            a_dot = attn_outputs[0]
            final_attn.append(attn_outputs[1])

            # attn_output = torch.nn.functional.scaled_dot_product_attention(
            #     query_states,
            #     key_states,
            #     value_states,
            #     attn_mask=causal_mask,
            #     dropout_p=self.attention_dropout if self.training else 0.0,
            # )

            combined_output = (
                F.sigmoid(self.beta) * memory_output
                + (1 - F.sigmoid(self.beta)) * a_dot
            )

            combined_output = self._merge_heads(
                combined_output, self.n_head, self.d_head
            )
            combined_output = self.c_proj(combined_output)
            combined_output = self.resid_dropout(combined_output)

            final_outputs.append(combined_output)

        final_outputs = torch.cat(final_outputs, dim=1)
        final_k = torch.cat(final_k, dim=1)
        final_attn = torch.cat(final_attn, dim=1)
        outputs = (final_outputs, final_k)

        if output_attentions:
            outputs = outputs + (final_attn,)

        return outputs


# Supondo que o `config`, `hidden_states`, e `attention_mask` já tenham sido definidos.
config = GPT2Config()
config.n_positions = 1024  # Definindo um exemplo de tamanho de sequência
hidden_states = torch.randn(2, 1024, config.hidden_size)  # Exemplo de estados ocultos
attention_mask = torch.ones(2, 1024)  # Exemplo de máscara de atenção
attention_mask[:, :512] = 0  # Máscara de atenção para metade da sequência


infini_att_gpt2 = InfiniAttentionGPT2(config=config, is_causal=False, n_segments=16)

# Forward
# outputs = infini_att_gpt2(
#     hidden_states=hidden_states, attention_mask=attention_mask, debug=True
# )

# Output
print("Output InfiniAttention GPT-2 att:")
# print(f"{outputs[0].shape=}")

Output InfiniAttention GPT-2 att:


### Training Model

In [None]:
model_type = "gpt2-infini"  # "gpt2" or "gpt2-infini"

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Training

dataset = ["AZUL", "AMARELO", "VERMELHO", "VERDE", "ROSA", "ROXO", "LARANJA", "BRANCO"] * 10

### Trainer

In [None]:
# gpt-2 original
# model = GPT2LMHeadModel(config).to(device)

# gpt-2 infini
model = GPT2LMHeadModel(config)

for i, layer in enumerate(model.transformer.h):
    model.transformer.h[i].attn = InfiniAttentionGPT2(
        config, layer_idx=i
    ).to(device)

model = model.to(device)

In [None]:
# train with trainer
from transformers import GPT2Tokenizer
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from datasets import Dataset

training_args = TrainingArguments(
    output_dir="./results",  # output directory
    num_train_epochs=3,  # total number of training epochs
    per_device_train_batch_size=8,  # batch size per device during training
    logging_dir="./logs",  # directory for storing logs
    logging_steps=5,  # log every X updates steps
    logging_strategy="steps",
)

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# tokenizer.add_special_tokens({"pad_token": "<PAD>"})

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)


inputs = tokenizer(
    dataset,
    padding="max_length",
    truncation=True,
    return_tensors="pt",
)


train_dataset = Dataset.from_dict(inputs)
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])


trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    train_dataset=train_dataset,
)

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

In [None]:
trainer.train()

Step,Training Loss
5,8.098
10,4.2736
15,2.0526
20,1.2788
25,0.7147
30,0.5808


TrainOutput(global_step=30, training_loss=2.833089065551758, metrics={'train_runtime': 39.4508, 'train_samples_per_second': 6.084, 'train_steps_per_second': 0.76, 'total_flos': 125420193054720.0, 'train_loss': 2.833089065551758, 'epoch': 3.0})

### Inference trainer

In [None]:

def generate_infini(model_infini, tokenizer, text="Este é um carro", tokens_gen=10):

    model_infini.eval()

    previous_token_id = None

    for _ in range(tokens_gen):

        inputs = tokenizer(text, return_tensors="pt", truncation=True)
        input_ids = inputs.input_ids.to(device)
        attention_mask = inputs.attention_mask.to(device)

        outputs = model_infini(input_ids, attention_mask=attention_mask)

        # get next token

        next_token_logits = outputs[0][:, -1, :]
        next_token_id = torch.argmax(next_token_logits, dim=-1)

        if previous_token_id == next_token_id:
            break
        else:
            previous_token_id = next_token_id

        # add to input_ids

        input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)
        text = tokenizer.decode(input_ids[0], skip_special_tokens=True)

    return text



In [None]:
print(generate_infini(model, tokenizer, text="VERM", tokens_gen=10))
print(generate_infini(model, tokenizer, text="ROX", tokens_gen=10))
print(generate_infini(model, tokenizer, text="AZ", tokens_gen=10))

VERMELHO
ROXO
AZUL


In [None]:
# DEBUG GENERATE
# inputs = tokenizer("LAR",return_tensors="pt", truncation=True)
# input_ids = inputs.input_ids.to(device)
# attention_mask = inputs.attention_mask.to(device)

# model.generate(input_ids, max_new_tokens=10, attention_mask=attention_mask)

In [None]:
model = GPT2LMHeadModel.from_config(config)
y = model(input_ids, attention_mask=attention_mask)

In [None]:
y.past_key_values[0].shape

torch.Size([1, 12, 2, 64])