# Training GPT-2 Model with InfiniAttention Module

In [None]:
# !pip install datasets
# !pip install accelerate -U
# !pip install transformers -U
# !pip install zarr

In [None]:
from google.colab import drive

drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
zarr_file_path = (
    "/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/dataset_copy.zarr"
)
tokenizer_path = (
    "/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/tokenizer/"
)
# config_path = '/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/configs/config.json'

In [None]:
# import zarr

# zarr_store = zarr.load(zarr_file_path)

In [1]:
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from typing import Optional, Tuple, Union

### Standard GPT2LMHeadModel structure

In [None]:
config = GPT2Config()
# model = GPT2LMHeadModel(config)

In [None]:
import torch.nn.functional as F
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, Conv1D


class InfiniAttentionGPT2(GPT2Attention):
    def __init__(
        self,
        config,
        is_cross_attention=False,
        layer_idx=None,
        n_segments=16,
        is_causal: Optional[bool] = True,
        update: Optional[str] = "linear",
    ):
        """
        Initialize the InfiniAttentionGPT2 module.

        Args:
            config: GPT2Config object containing configuration parameters.
            is_cross_attention (bool): Flag to indicate if this is cross attention.
            layer_idx (int, optional): Index of the layer.
            n_segments (int): Number of segments for memory processing.
            is_causal (bool, optional): Flag to indicate if attention is causal.
            update (str, optional): Update strategy for memory, either 'linear' or another strategy.
        """
        super().__init__(config, is_cross_attention, layer_idx)

        # Initializing memory state for compressive memory
        self.d_head = config.hidden_size // config.num_attention_heads
        self.n_head = config.num_attention_heads

        # Initialize the beta parameter for combining A_mem and A_dot
        self.beta = nn.Parameter(torch.zeros(1), requires_grad=True)

        self.elu = nn.ELU()

        # Sequence length
        self.seq_len = config.n_positions

        self.is_causal = is_causal
        self.register_buffer(
            "causal",
            torch.tril(
                torch.ones(self.seq_len // n_segments, self.seq_len // n_segments)
            ),
        )

        # Segment size
        self.n_segments = n_segments
        self.segment_size = self.seq_len // n_segments

        # Update strategy
        self.update = update

    def _retrieve_from_memory(self, query_states):
        # Retrieve context from compressive memory using linear attention (Eq. 3)
        if self.memory is None:
            return torch.zeros_like(query_states)
        query_states = F.elu(query_states) + 1  # ELU activation
        memory_output = torch.matmul(query_states, self.memory) / self.norm_term
        return memory_output

    def _update_memory(self, key_states, value_states):
        # Update compressive memory with new key-value states (Eq. 4)
        key_states = F.elu(key_states) + 1  # ELU activation
        if self.memory is not None:
            self.memory = self.memory + torch.matmul(
                key_states.transpose(-2, -1), value_states
            )
        else:
            self.memory = torch.matmul(key_states.transpose(-2, -1), value_states)
        if self.norm_term is not None:
            self.norm_term = self.norm_term + torch.unsqueeze(
                key_states.sum(dim=-2), -2
            )
        else:
            self.norm_term = torch.unsqueeze(key_states.sum(dim=-2), -2)

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
        debug: Optional[bool] = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:

        self.norm_term = None
        self.memory = None

        batch_size, _, _ = hidden_states.size()

        device = hidden_states.device

        qkv = self.c_attn(hidden_states)
        query, key, value = qkv.split(self.split_size, dim=2)

        # segments = torch.tensor_split(
        #     hidden_states,
        #     list(range(self.segment_size, hidden_states.size(1), self.segment_size)),
        #     dim=1,
        # )

        segments_q = torch.tensor_split(
            query,
            list(range(self.segment_size, query.size(1), self.segment_size)),
            dim=1,
        )

        segments_k = torch.tensor_split(
            key,
            list(range(self.segment_size, key.size(1), self.segment_size)),
            dim=1,
        )

        segments_v = torch.tensor_split(
            value,
            list(range(self.segment_size, value.size(1), self.segment_size)),
            dim=1,
        )

        final_outputs = []
        final_k = []
        final_attn = []

        # print(f"{hidden_states.shape=}")
        # print(f"{self.c_attn(hidden_states).shape=}")
        # print(f"{segments_q[0].shape=}")

        for i, segment in enumerate(segments_q):
            # print(f"{segment.shape=}")
            # qkv_segment = self.c_attn(segment)
            # query, key, value = qkv_segment.split(self.split_size, dim=2)

            query = segments_q[i]
            key = segments_k[i]
            value = segments_v[i]

            query = self._split_heads(
                query, num_heads=self.n_head, attn_head_size=self.d_head
            )
            key = self._split_heads(
                key, num_heads=self.n_head, attn_head_size=self.d_head
            )
            final_k.append(key)
            value = self._split_heads(
                value, num_heads=self.n_head, attn_head_size=self.d_head
            )

            bsz, q_len, _ = segment.size()

            # print(f"{query.shape=}")
            memory_output = self._retrieve_from_memory(query)
            self._update_memory(key, value)

            # print(f"{attention_mask.shape=}")
            if attention_mask is not None:
                attention_mask_segment = attention_mask[:, :, :, : self.segment_size]
            # print(f"{attention_mask_segment.shape=}")
            attn_outputs = self._attn(
                query, key, value, attention_mask_segment, head_mask
            )
            a_dot = attn_outputs[0]
            final_attn.append(attn_outputs[1])

            # attn_output = torch.nn.functional.scaled_dot_product_attention(
            #     query_states,
            #     key_states,
            #     value_states,
            #     attn_mask=causal_mask,
            #     dropout_p=self.attention_dropout if self.training else 0.0,
            # )

            combined_output = (
                F.sigmoid(self.beta) * memory_output
                + (1 - F.sigmoid(self.beta)) * a_dot
            )

            combined_output = self._merge_heads(
                combined_output, self.n_head, self.d_head
            )
            combined_output = self.c_proj(combined_output)
            combined_output = self.resid_dropout(combined_output)

            final_outputs.append(combined_output)

        final_outputs = torch.cat(final_outputs, dim=1)
        final_k = torch.cat(final_k, dim=1)
        final_attn = torch.cat(final_attn, dim=1)
        outputs = (final_outputs, final_k)

        if output_attentions:
            outputs = outputs + (final_attn,)

        return outputs


# # Supondo que o `config`, `hidden_states`, e `attention_mask` já tenham sido definidos.
# config = GPT2Config()
# config.n_positions = 1024  # Definindo um exemplo de tamanho de sequência
# hidden_states = torch.randn(2, 1024, config.hidden_size)  # Exemplo de estados ocultos
# attention_mask = torch.ones(2, 1024)  # Exemplo de máscara de atenção
# attention_mask[:, :512] = 0  # Máscara de atenção para metade da sequência


# infini_att_gpt2 = InfiniAttentionGPT2(config=config, is_causal=False, n_segments=16)

# # Forward
# # outputs = infini_att_gpt2(
# #     hidden_states=hidden_states, attention_mask=attention_mask, debug=True
# # )

# # Output
# print("Output InfiniAttention GPT-2 att:")
# # print(f"{outputs[0].shape=}")

### Training Model

In [None]:
model_type = "gpt2-infini"  # "gpt2" or "gpt2-infini"

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Trainer

In [None]:
# gpt-2 original
# model = GPT2LMHeadModel(config).to(device)

# gpt-2 infini
model = GPT2LMHeadModel(config)

for i, layer in enumerate(model.transformer.h):
    model.transformer.h[i].attn = InfiniAttentionGPT2(config, layer_idx=i)

model = model.to(device)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f"The model has {count_parameters(model):,} trainable parameters")

The model has 124,439,820 trainable parameters


In [None]:
from datasets import Dataset
import torch

path_dataset = "/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/tokenizer/datasets/split_10/"

train_dataset = Dataset.load_from_disk(path_dataset + "train_dataset")
test_dataset = Dataset.load_from_disk(path_dataset + "test_dataset")

# 50% train data and 5% of test data of 50% train data.
train_dataset = train_dataset.select(range(int(len(train_dataset))))
test_dataset = test_dataset.select(range(4 * 16))  # 6 * 16

In [None]:
train_dataset

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 53930
})

In [None]:
test_dataset

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 64
})

In [None]:
# train with trainer
from transformers import GPT2Tokenizer
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from tokenizers import ByteLevelBPETokenizer
import numpy as np
from datasets import Dataset

output_dir = (
    "/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/models/output_dir"
)
logging_dir = (
    "/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/models/logs"
)
model_save_dir = (
    "/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/models/"
)

# batch_size = 16
# num_epochs = 1

# num_steps = len(train_dataset) * num_epochs // batch_size

training_args = TrainingArguments(
    learning_rate=2e-5,
    output_dir=output_dir,
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    logging_dir=logging_dir,
    logging_steps=1000,  # 5000~10000
    eval_steps=500,
    save_steps=1000,
    save_total_limit=1,
    logging_strategy="steps",
    save_strategy="steps",
    eval_strategy="steps",
    seed=42,
    eval_accumulation_steps=4,
    # fp16=True, -> Train with FP16 generate zeros/nan values in loss
    # fp16_full_eval = True,
)

vocab_file = tokenizer_path + "vocab.json"
merges_file = tokenizer_path + "merges.txt"

tokenizer = GPT2Tokenizer(vocab_file, merges_file)
tokenizer.model_max_length = model.config.n_positions
tokenizer.pad_token = tokenizer.eos_token

bos_id = tokenizer.bos_token_id
eos_id = tokenizer.eos_token_id
pad_id = tokenizer.pad_token_id

model.resize_token_embeddings(len(tokenizer))

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)


# special tokens

# add an first column of bos value
# bos_array = np.zeros((zarr_store.shape[0], 1), dtype=np.int32)
# bos_array[:, 0] = bos_id

# add an last column of eos value
# eos_array = np.zeros((zarr_store.shape[0], 1), dtype=np.int32)
# eos_array[:, 0] = eos_id

# zarr_store = np.concatenate((bos_array, zarr_store), axis=1)
# zarr_store = np.concatenate((zarr_store, eos_array), axis=1)

# zarr_store[:, 0] = bos_id
# zarr_store[:, -1] = eos_id

# attention mask same dimension zarr_store
# attention_mask = np.ones(zarr_store.shape)

# train 95%
# train_size = int(zarr_store.shape[0] * 0.95)

# train_input_ids = zarr_store[:train_size]
# train_attention_mask = attention_mask[:train_size]

# test 5%
# test_input_ids = zarr_store[train_size:]
# test_attention_mask = attention_mask[train_size:]

# inputs_train = {"input_ids": torch.from_numpy(train_input_ids), "attention_mask": torch.from_numpy(train_attention_mask)}
# inputs_test = {"input_ids": torch.from_numpy(test_input_ids), "attention_mask": torch.from_numpy(test_attention_mask)}

# import torch
# inputs_train = torch.load(tokenizer_path + "inputs_train.pt")
# # replace bos
# inputs_train['input_ids'][:, 0] = bos_id
# # replace eos
# inputs_train['input_ids'][:, -1] = eos_id

# inputs_test = torch.load(tokenizer_path + "inputs_test.pt")
# # replace bos
# inputs_test['input_ids'][:, 0] = bos_id
# # replace eos
# inputs_test['input_ids'][:, -1] = eos_id


# # save inputs_train, inputs_test
# torch.save(inputs_train, tokenizer_path + "inputs_train.pt")
# torch.save(inputs_test, tokenizer_path + "inputs_test.pt")

In [None]:
# test data_collator

data_collator(
    train_dataset[:2],
)

In [None]:
# ### SCRIPT TO SAVE DATASET WITHOUT USING RAM
# import os
# from datasets import Dataset
# import torch

# tokenizer_path = '/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/tokenizer/'

# # Create dataset directories
# train_ds_path = os.path.join(tokenizer_path, "datasets/train")
# test_ds_path = os.path.join(tokenizer_path, "datasets/test")

# os.makedirs(train_ds_path, exist_ok=True)
# os.makedirs(test_ds_path, exist_ok=True)

# # Load tensors
# train_ids = torch.load(tokenizer_path + "inputs_train.pt")
# test_ids = torch.load(tokenizer_path + "inputs_test.pt")

# # Function to save datasets in batches
# def save_dataset_in_batches(ids, path, batch_size=100000):
#     total_batches = (len(ids['input_ids']) + batch_size - 1) // batch_size  # Compute number of batches
#     for i in range(total_batches):
#         start_idx = i * batch_size
#         end_idx = min((i + 1) * batch_size, len(ids['input_ids']))
#         batch = {key: value[start_idx:end_idx] for key, value in ids.items()}
#         dataset = Dataset.from_dict(batch)
#         dataset.save_to_disk(os.path.join(path, f"batch_{i:03d}"))

# # Save datasets
# save_dataset_in_batches(train_ids, train_ds_path)
# save_dataset_in_batches(test_ids, test_ds_path)

In [None]:
# import os
# from datasets import Dataset, load_from_disk
# from tqdm import tqdm
# import datasets

# # Function to iteratively concatenate batch datasets in a directory into a single dataset
# def concatenate_batches_iteratively(directory):
#     batch_files = sorted([os.path.join(directory, f) for f in os.listdir(directory) if f.startswith("batch_")])
#     cumulative_dataset = None

#     # Use tqdm to display a progress bar
#     for batch_file in tqdm(batch_files, desc="Loading and concatenating batches"):
#         current_batch = load_from_disk(batch_file)
#         # current_batch.set_format(type="torch", columns=["input_ids", "attention_mask"])
#         if cumulative_dataset is None:
#             cumulative_dataset = current_batch
#         else:
#             # Concatenate the current batch with the cumulative dataset
#             cumulative_dataset = datasets.concatenate_datasets([cumulative_dataset, current_batch])

#     return cumulative_dataset

# # Path to the directories containing the batch files
# # tokenizer_path = "path/to/your/tokenizer/"  # Set this to the correct path
# train_ds_path = os.path.join(tokenizer_path, "datasets/train")
# test_ds_path = os.path.join(tokenizer_path, "datasets/test")

# # Concatenate batches into a single dataset for both train and test, using the iterative function
# train_dataset = concatenate_batches_iteratively(train_ds_path)
# test_dataset = concatenate_batches_iteratively(test_ds_path)

# # Example of usage
# print("Combined train dataset:", train_dataset)
# print("Combined test dataset:", test_dataset)

In [None]:
# save datasets
# train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
# train_dataset.save_to_disk(tokenizer_path + "train_dataset")

# test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
# test_dataset.save_to_disk(tokenizer_path + "test_dataset")

In [None]:
# # from datasets import Dataset
# # import torch
# # # load dataset

# path_dataset = '/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/tokenizer/datasets/split_10/'

# train_dataset = Dataset.load_from_disk(path_dataset + "train_dataset")
# test_dataset = Dataset.load_from_disk(path_dataset + "test_dataset")

# test_dataset = test_dataset.select(range(int(len(test_dataset) * 0.25)))


# # train_dataset = Dataset.from_dict(torch.load(tokenizer_path + "inputs_train.pt"))
# # train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
# # train_dataset.save_to_disk(tokenizer_path + "train_dataset")

# # train_dataset = Dataset.from_dict(torch.load(tokenizer_path + "inputs_test.pt"))
# # train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
# # train_dataset.save_to_disk(tokenizer_path + "test_dataset")

# # # replace column bos
# # train_dataset['input_ids'][:, 0] = bos_id
# # # replace column eos
# # train_dataset['input_ids'][:, -1] = eos_id

# # test_dataset = Dataset.from_dict(torch.load(tokenizer_path + "inputs_test.pt"))
# # test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

# # # replace column bos
# # test_dataset['input_ids'][:, 0] = bos_id
# # # replace column eos
# # test_dataset['input_ids'][:, -1] = eos_id

# # # Save datasets
# # # train_dataset.save_to_disk(tokenizer_path + "train_dataset")
# # # test_dataset.save_to_disk(tokenizer_path + "test_dataset")

In [None]:
# # import datasets

# dataset = datasets.concatenate_datasets([train_dataset, test_dataset])

# # reduce 50%
# dataset = dataset.select(range(int(len(dataset) * 0.5)))

# # 98% train, 2% test
# train_dataset = dataset.train_test_split(test_size=0.02)

# test_dataset = train_dataset.pop("test")
# train_dataset = train_dataset["train"]

In [None]:
# train_dataset.save_to_disk('/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/tokenizer/datasets/split_50/train_dataset')
# test_dataset.save_to_disk('/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/tokenizer/datasets/split_50/test_dataset')

In [None]:
# import perplexity
import torch.nn.functional as F


def compute_metrics(eval_pred):
    logits, labels = eval_pred

    if not isinstance(logits, torch.Tensor):
        logits = torch.tensor(logits)

    if not isinstance(labels, torch.Tensor):
        labels = torch.tensor(labels)

    # GPT-2 LM Head Cross Entropy Loss
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))

    # GPT-2 LM Head Perplexity
    perplexity = torch.exp(loss)
    return {"eval_loss": loss.item(), "eval_perplexity": perplexity}

In [None]:
# # # DEBUG
# # # get 10% of train_dataset and 1% of test_dataset
# train_dataset = train_dataset.select(range(int(len(train_dataset) * 0.1)))
# test_dataset = test_dataset.select(range(int(len(test_dataset) * 0.5)))

In [None]:
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

Step,Training Loss,Validation Loss,Perplexity
500,No log,7.997762,2953.017822
1000,8.077600,7.707036,6190.458008
1500,8.077600,7.433079,6643.856445
2000,7.435400,7.227094,6549.385742
2500,7.435400,7.056933,8408.976562
3000,7.044600,6.926828,9471.450195
3500,7.044600,6.817666,9934.422852
4000,6.795400,6.716465,10777.864258
4500,6.795400,6.633573,10195.775391
5000,6.607000,6.567185,11136.291016


TrainOutput(global_step=16855, training_loss=6.425810335397508, metrics={'train_runtime': 23512.7091, 'train_samples_per_second': 11.468, 'train_steps_per_second': 0.717, 'total_flos': 1.409148127383552e+17, 'train_loss': 6.425810335397508, 'epoch': 5.0})

In [None]:
# get all params model

for layer in model.transformer.h:
    print(layer.attn.beta)

In [None]:
# save model with torch
torch.save(model.state_dict(), model_save_dir + "model_gpt2_infini.pt")

# load
torch.load(model_save_dir + "model_gpt2_infini.pt")

In [None]:
# save trainer
trainer.save_model(model_save_dir + "trainer/")

In [None]:
trainer.model.save_pretrained(model_save_dir)
tokenizer.save_pretrained(model_save_dir)

('/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/models/tokenizer_config.json',
 '/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/models/special_tokens_map.json',
 '/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/models/vocab.json',
 '/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/models/merges.txt',
 '/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/models/added_tokens.json')

In [None]:
import time

time.sleep(5)

from google.colab import runtime

runtime.unassign()

### Inference trainer

In [None]:
def generate_infini(model_infini, tokenizer, text="Este é um carro", tokens_gen=10):

    model_infini.eval()

    previous_token_id = None

    for _ in range(tokens_gen):

        inputs = tokenizer(text, return_tensors="pt", truncation=True)
        input_ids = inputs.input_ids.to(device)
        attention_mask = inputs.attention_mask.to(device)

        outputs = model_infini(input_ids, attention_mask=attention_mask)

        # get next token

        next_token_logits = outputs[0][:, -1, :]
        next_token_id = torch.argmax(next_token_logits, dim=-1)

        if previous_token_id == next_token_id:
            break
        else:
            previous_token_id = next_token_id

        # add to input_ids

        input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)
        text = tokenizer.decode(input_ids[0], skip_special_tokens=True)

    return text

In [None]:
print(generate_infini(model, tokenizer, text="Meu nome é Pe", tokens_gen=10))
print(generate_infini(model, tokenizer, text="Um carro pass", tokens_gen=10))
print(generate_infini(model, tokenizer, text="Música", tokens_gen=10))

In [None]:
# DEBUG GENERATE
# inputs = tokenizer("LAR",return_tensors="pt", truncation=True)
# input_ids = inputs.input_ids.to(device)
# attention_mask = inputs.attention_mask.to(device)

# model.generate(input_ids, max_new_tokens=10, attention_mask=attention_mask)

In [None]:
model = GPT2LMHeadModel.from_config(config)
y = model(input_ids, attention_mask=attention_mask)

In [None]:
y.past_key_values[0].shape