# Train InfiniTransformer

### Install Libs

In [None]:
!pip install datasets
!pip install transformers -U
!pip install tiktoken

In [8]:
from google.colab import drive

drive.mount("/content/drive")

ModuleNotFoundError: No module named 'google.colab'

In [2]:
# Select the colab_notebooks folder of project
# %cd "/content/drive/MyDrive/final_project/Infini-attention-Transformer"
%cd "/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/Infini-attention-Transformer"

/content/drive/MyDrive/Colab Notebooks/nlp_unicamp/final_project/Infini-attention-Transformer


In [3]:
# !pip install datasets
# # !pip install accelerate -U
# !pip install transformers -U
# !pip install tiktoken

In [3]:
import sys

sys.path.append("../")

In [7]:
import os
import requests
import math
import tiktoken
import torch
import torch.nn as nn
from torch.nn import functional as F
from infini_attention_transformer.transformer_language_model import (
    TransformerLanguageModel,
)
from datasets import load_dataset, load_from_disk

In [5]:
# Hyperparameters
batch_size = 16  # How many batches per training step
context_length = 1024  # Length of the token chunk each batch
dim_input = 128  # The size of our model token embeddings
num_blocks = 4  # Number of transformer blocks
num_heads = 4  # Number of heads in Multi-head attention
learning_rate = 1e-4  # 0.001
dropout = 0.1  # Dropout rate
device = "cuda"  # Use Mac GPU if it's available.

infni = True
segment_len = 64

TORCH_SEED = 1337

# torch.set_default_device(device)

In [6]:
train_dataset = load_from_disk("data/")
train_dataset.set_format(type="torch", columns=["input_ids", "labels"])

In [7]:
train_dataset

Dataset({
    features: ['input_ids', 'labels'],
    num_rows: 99039
})

In [8]:
# verify if all batchs has same shape
# print(set(list(map(lambda i: (i['input_ids'].shape, i['labels'].shape), train_dataset))))

In [9]:
# get 10% of train_dataset
# train_dataset = train_dataset.train_test_split(test_size=0.1)
# train_dataset = train_dataset["test"]

In [10]:
# encoding = tiktoken.get_encoding("cl100k_base")
# tokenized_text = encoding.encode(text)
# # tokenized_text = torch.load("corpus.pt")
# max_token_value = max(tokenized_text) + 1  # the maximum value of the tokenized numbers
# tokenized_text = torch.tensor(
#     tokenized_text, dtype=torch.long, device = device
# )

encoding = tiktoken.get_encoding("cl100k_base")
max_token_value = encoding.n_vocab + 1  #

In [11]:
# Split train and validation
# split_idx = int(len(tokenized_text) * 0.8)
# split_val_text = int(len(tokenized_text) * 0.9)
# print(split_val_text)
# train_data = tokenized_text[:split_idx]
# val_data = tokenized_text[split_idx:split_val_text]
# test_data = tokenized_text[split_val_text:]


from torch.utils.data import DataLoader

# 80% for training and 10% for validation
train_dataset = train_dataset.train_test_split(test_size=0.2)
val_dataset = train_dataset["test"]
train_dataset = train_dataset["train"]

# split 10% for test
val_dataset = val_dataset.train_test_split(test_size=0.5)
test_dataset = val_dataset["test"]
val_dataset = val_dataset["train"]


train_loader = DataLoader(train_dataset, batch_size=batch_size)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [12]:
print(train_dataset)
print(val_dataset)
print(test_dataset)

Dataset({
    features: ['input_ids', 'labels'],
    num_rows: 79231
})
Dataset({
    features: ['input_ids', 'labels'],
    num_rows: 9904
})
Dataset({
    features: ['input_ids', 'labels'],
    num_rows: 9904
})


In [13]:
next(iter(test_loader))["input_ids"].shape

torch.Size([16, 1024])

In [14]:
model = TransformerLanguageModel(
    dim_input=dim_input,
    num_heads=num_heads,
    num_blocks=num_blocks,
    context_length=context_length,
    max_token_value=max_token_value,
    dropout=dropout,
    infini=infni,
    segment_len=segment_len,
)


model = model.to(device)


print(model)

TransformerLanguageModel(
  (token_embedding_lookup_table): Embedding(100279, 128)
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (multi_head_attention_layer): InfiniMultiHeadAttention(
        (proj_k): Linear(in_features=128, out_features=128, bias=False)
        (proj_v): Linear(in_features=128, out_features=128, bias=False)
        (proj_q): Linear(in_features=128, out_features=128, bias=False)
        (proj_out): Linear(in_features=128, out_features=128, bias=False)
        (dropout_layer): Dropout(p=0.1, inplace=False)
      )
      (feed_forward_layer): FeedForward(
        (ffn): Sequential(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): ReLU()
          (2): Linear(in_features=512, out_features=128, bias=True)
          (3): Dropout(p=0.1, inplace=False)
        )
      )
      (layer_norm_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (layer_norm_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True

In [None]:
from tqdm.notebook import tqdm


# get num of trainable parameters model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def estimate_loss():

    model.eval()

    with torch.no_grad():
        out = {}
        losses = [0 for i in range(eval_iters)]
        steps = 0
        for batch in train_loader:
            if steps == eval_iters:
                break
            x_batch = batch["input_ids"].to(device)
            y_batch = batch["labels"].to(device)

            # x_batch = x_batch
            # y_batch = y_batch

            logits, loss = model(x_batch, targets=y_batch)

            losses[steps] = loss.item()
            steps += 1
        out["train"] = sum(losses) / eval_iters

        losses = [0 for i in range(eval_iters)]
        steps = 0
        for batch in val_loader:
            if steps == eval_iters:
                break
            x_batch = batch["input_ids"].to(device)
            y_batch = batch["labels"].to(device)

            # x_batch = x_batch
            # y_batch = y_batch

            logits, loss = model(x_batch, targets=y_batch)

            losses[steps] = loss.item()
            steps += 1

        out["valid"] = sum(losses) / eval_iters

    model.train()
    return out


epochs = 1
max_iters = len(train_loader) * epochs
eval_interval = max_iters // 10  # Number of steps between evaluations
eval_iters = 20  # Number of iterations to average for evaluation

# Use AdamW optimizer
optimizer = torch.optim.AdamW(params=model.parameters(), lr=learning_rate)
# SGD
# optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate)
tracked_losses = list()
for epoch in tqdm(range(epochs), total=epochs, desc="Epochs", leave=False):
    step = 0
    for batch in tqdm(
        train_loader, total=len(train_loader), desc="Training Steps", leave=False
    ):
        optimizer.zero_grad()

        # print gpu consumed
        # print(f'total GPU usage:{torch.cuda.memory_allocated()/1024**3:.2f} GB')
        if step % eval_interval == 0 or step == max_iters - 1:
            losses = estimate_loss()
            tracked_losses.append(losses)
            # print(losses)
            print(
                f"Step: {step:>6} | "
                f"Training Loss: {losses['train']:.3f} | "
                f"Validation Loss: {losses['valid']:.3f} | "
                f"Training Perplexity: {math.exp(losses['train']):.3f} | "
                f"Validation Perplexity: {math.exp(losses['valid']):.3f}"
            )

        x_batch = batch["input_ids"].to(device)
        y_batch = batch["labels"].to(device)
        logits, loss = model(x_batch, y_batch)
        loss.backward()
        optimizer.step()
        step += 1


# Save the model state dictionary
os.makedirs("ckpt", exist_ok=True)
torch.save(model.state_dict(), "ckpt/model-infini-4096-64.pt")
# write tracked_losses
with open("ckpt/tracked_losses.txt", "w") as f:
    f.write(str(tracked_losses))

Epochs:   0%|          | 0/1 [00:00<?, ?it/s]

Training Steps:   0%|          | 0/4952 [00:00<?, ?it/s]

Step:      0 | Training Loss: 11.664 | Validation Loss: 11.670 | Training Perplexity: 116332.261 | Validation Perplexity: 117016.203
Step:    495 | Training Loss: 7.409 | Validation Loss: 7.434 | Training Perplexity: 1650.952 | Validation Perplexity: 1692.064
Step:    990 | Training Loss: 6.925 | Validation Loss: 6.958 | Training Perplexity: 1017.872 | Validation Perplexity: 1051.694
Step:   1485 | Training Loss: 6.372 | Validation Loss: 6.411 | Training Perplexity: 585.007 | Validation Perplexity: 608.333
Step:   1980 | Training Loss: 6.002 | Validation Loss: 6.048 | Training Perplexity: 404.057 | Validation Perplexity: 423.203
Step:   2475 | Training Loss: 5.742 | Validation Loss: 5.793 | Training Perplexity: 311.672 | Validation Perplexity: 328.153
Step:   2970 | Training Loss: 5.555 | Validation Loss: 5.611 | Training Perplexity: 258.600 | Validation Perplexity: 273.372
Step:   3465 | Training Loss: 5.419 | Validation Loss: 5.476 | Training Perplexity: 225.697 | Validation Perplexi

In [None]:
import time

time.sleep(10)

from google.colab import runtime

runtime.unassign()

In [None]:
# # Generate
# model.eval()
# # start = ''
# # start_ids = encoding.encode(start)
# idx = 4
# start_ids = test_data[idx * context_length : (idx + 1) * context_length]
# x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]
# # print(x.shape)
# y = model.generate(x, max_new_tokens=int(10))
# print("------original text---------")
# print(encoding.decode(start_ids.tolist()))
# print("------only new tokens-------")
# print(encoding.decode(y[0][len(start_ids) :].tolist()))
# print("----------------------------")

In [None]:
# # prompt: Plot the train and validation loss stored in tracked_losses

# import matplotlib.pyplot as plt

# # Extract train and validation losses from tracked_losses
# train_losses = [loss["train"].cpu() for loss in tracked_losses]
# val_losses = [loss["valid"].cpu() for loss in tracked_losses]

# # Create the plot
# plt.plot(
#     list(range(0, eval_interval * len(train_losses), eval_interval)),
#     train_losses,
#     label="Train Loss",
# )
# plt.plot(
#     list(range(0, eval_interval * len(train_losses), eval_interval)),
#     val_losses,
#     label="Validation Loss",
# )

# # Add labels and title
# plt.xlabel("Step")
# plt.ylabel("Loss")
# plt.title("Train and Validation Loss over Iterations")

# # Add legend and show the plot
# plt.legend()
# plt.show()