<a href="https://colab.research.google.com/github/1pawn0/Transformers-Playground/blob/main/Notebooks/toy_transformer_pretrain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
from transformers import AutoTokenizer
from torch.utils.data import TensorDataset, DataLoader
from tqdm.notebook import tqdm

tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")

In [None]:
!wget https://raw.githubusercontent.com/karpathy/ng-video-lecture/refs/heads/master/input.txt

In [None]:
from torch.utils.data import TensorDataset, DataLoader


def tokenize_large_text(text, tokenizer, max_length=512, stride=256):
    encodings = tokenizer(
        text, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length, stride=stride, return_overflowing_tokens=True
    )

    return encodings.input_ids.to(torch.int16)


with open("input.txt", "r", encoding="utf-8") as f:
    input_ids_ds = TensorDataset(tokenize_large_text(f.read(), tokenizer)[:-1, :])
input_ids_loader = DataLoader(input_ids_ds, batch_size=16, shuffle=True)


In [None]:
torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float32
embedding_dim = 64
vocab_size = 30000
embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0, dtype=dtype, device=device)


W_q = torch.randn(embedding_dim, embedding_dim, requires_grad=True, dtype=dtype, device=device)
W_k = torch.randn(embedding_dim, embedding_dim, requires_grad=True, dtype=dtype, device=device)
W_v = torch.randn(embedding_dim, embedding_dim, requires_grad=True, dtype=dtype, device=device)

b_q = torch.randn(embedding_dim, requires_grad=True, dtype=dtype, device=device)
b_k = torch.randn(embedding_dim, requires_grad=True, dtype=dtype, device=device)
b_v = torch.randn(embedding_dim, requires_grad=True, dtype=dtype, device=device)

W_out = torch.randn(embedding_dim, vocab_size, requires_grad=True, dtype=dtype, device=device)
b_out = torch.randn(vocab_size, requires_grad=True, dtype=dtype, device=device)


In [None]:
learning_rate = torch.tensor(1e-2, dtype=dtype, device=device)

for batch_idx, input_id_batch in tqdm(enumerate(input_ids_loader)):
    x = embedding(input_id_batch[0].to(torch.int32))
    Q = x @ W_q + b_q
    K = x @ W_k + b_k
    V = x @ W_v + b_v
    scores = torch.matmul(Q, K.transpose(-2, -1)).div(embedding_dim**0.5)
    causal_mask = torch.triu(torch.ones(scores.shape[-2:], dtype=torch.bool, device=device), diagonal=1)
    scores = scores.masked_fill(causal_mask, float("-inf"))
    probabilities = torch.softmax(scores, dim=-1)
    attention_output = probabilities @ V
    x = x + attention_output
    x = torch.nn.functional.layer_norm(x, [embedding_dim])
    logits = attention_output @ W_out + b_out
    targets = input_id_batch[0][:, 1:]
    logits = logits[:, :-1, :]

    predicted_tokens = torch.argmax(logits, dim=-1)

    loss = torch.nn.functional.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1).to(torch.long), ignore_index=0)
    print(f"{batch_idx}. Cross-entropy Loss: {loss.item():.2f}")
    loss.backward()
    with torch.no_grad():
        embedding.weight -= learning_rate * embedding.weight.grad
        W_q -= learning_rate * W_q.grad
        W_k -= learning_rate * W_k.grad
        W_v -= learning_rate * W_v.grad
        b_q -= learning_rate * b_q.grad
        b_k -= learning_rate * b_k.grad
        b_v -= learning_rate * b_v.grad
        W_out -= learning_rate * W_out.grad
        b_out -= learning_rate * b_out.grad
        embedding.weight.grad.zero_()
        W_q.grad.zero_()
        W_k.grad.zero_()
        W_v.grad.zero_()
        b_q.grad.zero_()
        b_k.grad.zero_()
        b_v.grad.zero_()
        W_out.grad.zero_()
        b_out.grad.zero_()
