# Training GPT-2 Model with InfiniAttention Module

In [50]:
# zarr_file_path = './dataset_copy.zarr'
tokenizer_path = '../tokenizer/'
# config_path = '/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/configs/config.json'

In [51]:
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from typing import Optional, Tuple, Union

### Standard GPT2LMHeadModel structure

In [52]:
config = GPT2Config()
# model = GPT2LMHeadModel(config)

In [53]:
import torch.nn.functional as F
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, Conv1D


class InfiniAttentionGPT2(GPT2Attention):
    def __init__(
        self,
        config,
        is_cross_attention=False,
        layer_idx=None,
        n_segments=16,
        is_causal: Optional[bool] = True,
        update: Optional[str] = "linear",
    ):
        """
        Initialize the InfiniAttentionGPT2 module.

        Args:
            config: GPT2Config object containing configuration parameters.
            is_cross_attention (bool): Flag to indicate if this is cross attention.
            layer_idx (int, optional): Index of the layer.
            n_segments (int): Number of segments for memory processing.
            is_causal (bool, optional): Flag to indicate if attention is causal.
            update (str, optional): Update strategy for memory, either 'linear' or another strategy.
        """
        super().__init__(config, is_cross_attention, layer_idx)

        # Initializing memory state for compressive memory
        self.d_head = config.hidden_size // config.num_attention_heads
        self.n_head = config.num_attention_heads

        # Initialize the beta parameter for combining A_mem and A_dot
        self.beta = nn.Parameter(torch.zeros(1), requires_grad=True)

        self.elu = nn.ELU()

        # Sequence length
        self.seq_len = config.n_positions

        self.is_causal = is_causal
        self.register_buffer(
            "causal",
            torch.tril(
                torch.ones(self.seq_len // n_segments, self.seq_len // n_segments)
            ),
        )

        # Segment size
        self.n_segments = n_segments
        self.segment_size = self.seq_len // n_segments

        # Update strategy
        self.update = update

    def _retrieve_from_memory(self, query_states):
        # Retrieve context from compressive memory using linear attention (Eq. 3)
        if self.memory is None:
            return torch.zeros_like(query_states)
        query_states = F.elu(query_states) + 1  # ELU activation
        memory_output = torch.matmul(query_states, self.memory) / self.norm_term
        return memory_output

    def _update_memory(self, key_states, value_states):
        # Update compressive memory with new key-value states (Eq. 4)
        key_states = F.elu(key_states) + 1  # ELU activation
        if self.memory is not None:
            self.memory = self.memory + torch.matmul(
                key_states.transpose(-2, -1), value_states
            )
        else:
            self.memory = torch.matmul(key_states.transpose(-2, -1), value_states)
        if self.norm_term is not None:
            self.norm_term = self.norm_term + torch.unsqueeze(
                key_states.sum(dim=-2), -2
            )
        else:
            self.norm_term = torch.unsqueeze(key_states.sum(dim=-2), -2)

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
        debug: Optional[bool] = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:

        self.norm_term = None
        self.memory = None

        batch_size, _, _ = hidden_states.size()

        device = hidden_states.device

        qkv = self.c_attn(hidden_states)
        query, key, value = qkv.split(self.split_size, dim=2)

        # segments = torch.tensor_split(
        #     hidden_states,
        #     list(range(self.segment_size, hidden_states.size(1), self.segment_size)),
        #     dim=1,
        # )

        segments_q = torch.tensor_split(
            query,
            list(range(self.segment_size, query.size(1), self.segment_size)),
            dim=1,
        )

        segments_k = torch.tensor_split(
            key,
            list(range(self.segment_size, key.size(1), self.segment_size)),
            dim=1,
        )

        segments_v = torch.tensor_split(
            value,
            list(range(self.segment_size, value.size(1), self.segment_size)),
            dim=1,
        )

        final_outputs = []
        final_k = []
        final_attn = []

        # print(f"{hidden_states.shape=}")
        # print(f"{self.c_attn(hidden_states).shape=}")
        # print(f"{segments_q[0].shape=}")

        for i, segment in enumerate(segments_q):
            # print(f"{segment.shape=}")
            # qkv_segment = self.c_attn(segment)
            # query, key, value = qkv_segment.split(self.split_size, dim=2)

            query = segments_q[i]
            key = segments_k[i]
            value = segments_v[i]

            query = self._split_heads(
                query, num_heads=self.n_head, attn_head_size=self.d_head
            )
            key = self._split_heads(
                key, num_heads=self.n_head, attn_head_size=self.d_head
            )
            final_k.append(key)
            value = self._split_heads(
                value, num_heads=self.n_head, attn_head_size=self.d_head
            )

            bsz, q_len, _ = segment.size()

            # print(f"{query.shape=}")
            memory_output = self._retrieve_from_memory(query)
            self._update_memory(key, value)

            # print(f"{attention_mask.shape=}")
            if attention_mask is not None:
                attention_mask_segment = attention_mask[:, :, :, : self.segment_size]
            # print(f"{attention_mask_segment.shape=}")
            attn_outputs = self._attn(
                query, key, value, attention_mask_segment, head_mask
            )
            a_dot = attn_outputs[0]
            final_attn.append(attn_outputs[1])

            # attn_output = torch.nn.functional.scaled_dot_product_attention(
            #     query_states,
            #     key_states,
            #     value_states,
            #     attn_mask=causal_mask,
            #     dropout_p=self.attention_dropout if self.training else 0.0,
            # )

            combined_output = (
                F.sigmoid(self.beta) * memory_output
                + (1 - F.sigmoid(self.beta)) * a_dot
            )

            combined_output = self._merge_heads(
                combined_output, self.n_head, self.d_head
            )
            combined_output = self.c_proj(combined_output)
            combined_output = self.resid_dropout(combined_output)

            final_outputs.append(combined_output)

        final_outputs = torch.cat(final_outputs, dim=1)
        final_k = torch.cat(final_k, dim=1)
        final_attn = torch.cat(final_attn, dim=1)
        outputs = (final_outputs, final_k)

        if output_attentions:
            outputs = outputs + (final_attn,)

        return outputs

### Training Model

In [54]:
model_type = "gpt2-infini"  # "gpt2" or "gpt2-infini"

In [55]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0")

### Trainer

In [56]:
# gpt-2 original
# model = GPT2LMHeadModel(config).to(device)

# gpt-2 infini
model = GPT2LMHeadModel(config)

for i, layer in enumerate(model.transformer.h):
    model.transformer.h[i].attn = InfiniAttentionGPT2(
        config, layer_idx=i, n_segments=128
    )

model = model.to(device)

In [57]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 124,439,820 trainable parameters


In [58]:
from datasets import Dataset
import torch

path_dataset = '../datasets/split_10/'

train_dataset = Dataset.load_from_disk(path_dataset + "train_dataset")
test_dataset = Dataset.load_from_disk(path_dataset + "test_dataset")

# 50% train data and 5% of test data of 50% train data.
train_dataset = train_dataset.select(range(int(len(train_dataset))))
# train_dataset = train_dataset.select(range(160))
test_dataset = test_dataset.select(range(4 * 16)) # 6 * 16

In [59]:
train_dataset

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 53930
})

In [60]:
test_dataset

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 64
})

In [61]:
# train with trainer
from transformers import GPT2Tokenizer
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from tokenizers import ByteLevelBPETokenizer
import numpy as np
from datasets import Dataset

output_dir = '../models/output_dir'
logging_dir = '../models/logs'
model_save_dir = '../models/'

# batch_size = 16
# num_epochs = 1

# num_steps = len(train_dataset) * num_epochs // batch_size

training_args = TrainingArguments(
    learning_rate=1e-4,
    output_dir=output_dir,
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    logging_dir=logging_dir,
    logging_steps=1000, #5000~10000
    eval_steps=1000,
    save_steps=1000,
    save_total_limit = 1,
    logging_strategy="steps",
    save_strategy="steps",
    eval_strategy="steps",
    seed=42,
    eval_accumulation_steps = 4,
    logging_first_step=True
    # fp16=True, -> Train with FP16 generate zeros/nan values in loss
    # fp16_full_eval = True,
)

vocab_file = tokenizer_path + "vocab.json"
merges_file = tokenizer_path + "merges.txt"

tokenizer = GPT2Tokenizer(vocab_file, merges_file)
tokenizer.model_max_length = model.config.n_positions
tokenizer.pad_token = tokenizer.eos_token

bos_id = tokenizer.bos_token_id
eos_id = tokenizer.eos_token_id
pad_id = tokenizer.pad_token_id

model.resize_token_embeddings(len(tokenizer))

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)



In [62]:
# import perplexity
import torch.nn.functional as F

def compute_metrics(eval_pred):
    logits, labels = eval_pred

    if not isinstance(logits, torch.Tensor):
        logits = torch.tensor(logits)

    if not isinstance(labels, torch.Tensor):
        labels = torch.tensor(labels)

    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
    perplexity = torch.exp(loss).item()  # Ensure perplexity is a scalar

    return {
        "eval_loss": loss.item(),
        "eval_perplexity": perplexity
    }

In [63]:
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)

In [64]:
trainer.args._n_gpu = 1

In [65]:
trainer.train()

Step,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
# save trainer
# trainer.save_model(model_save_dir + 'trainer/')

torch.save(model.state_dict(), "../models/infini.pt")

In [None]:
# trainer.model.save_pretrained(model_save_dir)
tokenizer.save_pretrained(model_save_dir)

('/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/models/tokenizer_config.json',
 '/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/models/special_tokens_map.json',
 '/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/models/vocab.json',
 '/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/models/merges.txt',
 '/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/models/added_tokens.json')

In [None]:
import json

with open("../models/metric.json", "w") as f: 
    json.dump(
        trainer.state.log_history,
        f,
        indent=2
    )

### Load saved model

In [None]:
model = GPT2LMHeadModel(config)

for i, layer in enumerate(model.transformer.h):
    model.transformer.h[i].attn = InfiniAttentionGPT2(
        config, layer_idx=i
    )

model.load_state_dict(torch.load("../models/infini.pt"))

Parameter containing:
tensor([0.0005], requires_grad=True)
Parameter containing:
tensor([-0.0007], requires_grad=True)
Parameter containing:
tensor([-0.0003], requires_grad=True)
Parameter containing:
tensor([-0.0005], requires_grad=True)
Parameter containing:
tensor([-0.0003], requires_grad=True)
Parameter containing:
tensor([0.0002], requires_grad=True)
Parameter containing:
tensor([0.0002], requires_grad=True)
Parameter containing:
tensor([0.0004], requires_grad=True)
Parameter containing:
tensor([0.0004], requires_grad=True)
Parameter containing:
tensor([0.0004], requires_grad=True)
Parameter containing:
tensor([0.0003], requires_grad=True)
Parameter containing:
tensor([0.0005], requires_grad=True)


### Inference trainer

In [None]:

def generate_infini(model_infini, tokenizer, text="Este é um carro", tokens_gen=10):

    model_infini.eval()

    previous_token_id = None

    for _ in range(tokens_gen):

        inputs = tokenizer(text, return_tensors="pt", truncation=True)
        input_ids = inputs.input_ids.to(device)
        attention_mask = inputs.attention_mask.to(device)

        outputs = model_infini(input_ids, attention_mask=attention_mask)

        # get next token

        next_token_logits = outputs[0][:, -1, :]
        next_token_id = torch.argmax(next_token_logits, dim=-1)

        if previous_token_id == next_token_id:
            break
        else:
            previous_token_id = next_token_id

        # add to input_ids

        input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)
        text = tokenizer.decode(input_ids[0], skip_special_tokens=True)

    return text



In [None]:
print(generate_infini(model, tokenizer, text="Meu nome é Pe", tokens_gen=10))
print(generate_infini(model, tokenizer, text="Um carro pass", tokens_gen=10))
print(generate_infini(model, tokenizer, text="Música", tokens_gen=10))