<a href="https://colab.research.google.com/github/1pawn0/Transformers-Playground/blob/main/Notebooks/toy_transformer_model_via_pytorch_nn_module.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
from torch import nn
from torch.optim import SGD, AdamW
from transformers import AutoTokenizer
from datasets import load_dataset
from torch.utils.data import TensorDataset, DataLoader
from tqdm.notebook import tqdm
from google.colab import userdata
os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')
tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
torch.manual_seed(166320)
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")
dtype = torch.float32

In [None]:
ds = load_dataset("Salesforce/wikitext", "wikitext-2-v1")
train_corpus = ''.join(ds['train']['text'])
val_corpus = ''.join(ds['validation']['text'])
test_corpus = ''.join(ds['test']['text'])
del ds

In [None]:
batch_size = 16

def tokenize_large_text_corpus(text_corpus, tokenizer, max_length=512, stride=128):
    encodings = tokenizer(
        text_corpus,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=max_length,
        stride=stride,
        return_overflowing_tokens=True
    )

    return encodings.input_ids

train_input_ids = TensorDataset(tokenize_large_text_corpus(train_corpus, tokenizer))
train_input_ids_loader = DataLoader(train_input_ids, batch_size=batch_size, shuffle=True, pin_memory=True)
del train_input_ids


In [None]:
class ToyTransformerModel(nn.Module):
    def __init__(self, d_model=128, num_embeddings=30000, max_sequence_len=512, tokenizer: AutoTokenizer = tokenizer):
        super().__init__()
        self.d_model = d_model
        self.num_embeddings = num_embeddings
        self.tokenizer = tokenizer
        self.token_embeddings = nn.Embedding(num_embeddings, d_model, 0, dtype=dtype, device=device)
        self.position_embeddings = nn.Embedding(max_sequence_len, d_model, 0, dtype=dtype, device=device)
        self.W_q = nn.Linear(d_model, d_model, dtype=dtype, device=device)
        self.W_k = nn.Linear(d_model, d_model, dtype=dtype, device=device)
        self.W_v = nn.Linear(d_model, d_model, dtype=dtype, device=device)
        self.W_out = nn.Linear(d_model, d_model, dtype=dtype, device=device)
        self.ff1 = nn.Linear(d_model, 4 * d_model, dtype=dtype, device=device)
        self.ff2 = nn.Linear(4 * d_model, d_model, dtype=dtype, device=device)
        self.ln1 = nn.LayerNorm(d_model, dtype=dtype, device=device)
        self.ln2 = nn.LayerNorm(d_model, dtype=dtype, device=device)
        self.lm_head = nn.Linear(d_model, num_embeddings, bias=False, dtype=dtype, device=device)
        self.gelu = nn.GELU()
        self.softmax = nn.Softmax(dim=-1)


    def forward(self, input_ids):
        batch_size, seq_len = input_ids.shape

        token_emb = self.token_embeddings(input_ids)
        positions = torch.arange(seq_len, device=device)
        pos_emb = self.position_embeddings(positions).unsqueeze(0).expand(batch_size, seq_len, -1)
        x = token_emb + pos_emb

        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_model ** 0.5)
        mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
        attn_scores = attn_scores.masked_fill(mask, float('-inf'))
        attn_weights = self.softmax(attn_scores)
        attn_output = torch.matmul(attn_weights, v)
        attn_output = self.W_out(attn_output)

        x = self.ln1(x + attn_output)

        ff_output = self.ff2(self.gelu(self.ff1(x)))

        x = self.ln2(x + ff_output)

        logits = self.lm_head(x)

        return logits

model = ToyTransformerModel(
    d_model=128,
    num_embeddings=tokenizer.vocab_size,
    max_sequence_len=512,
    tokenizer=tokenizer,
).to(device)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float32

model = ToyTransformerModel(
    d_model=128,
    num_embeddings=tokenizer.vocab_size,
    max_sequence_len=512,
    tokenizer=tokenizer
).to(device)

num_epochs = 10
learning_rate = 1e-3
weight_decay = 0.01

optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
loss_fn = nn.CrossEntropyLoss()

from torch.optim.lr_scheduler import CosineAnnealingLR
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs * len(train_input_ids_loader))

model.train()
for epoch in range(num_epochs):
    total_loss = 0
    num_batches = 0

    progress_bar = tqdm(train_input_ids_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for batch in progress_bar:
        input_ids = batch[0].to(device)

        inputs = input_ids[:, :-1]
        targets = input_ids[:, 1:]

        optimizer.zero_grad()

        logits = model(inputs)

        loss = loss_fn(
            logits.reshape(-1, model.num_embeddings),
            targets.reshape(-1)
        )

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        num_batches += 1

        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'avg_loss': f'{total_loss/num_batches:.4f}'
        })

    avg_loss = total_loss / num_batches
    print(f"\nEpoch {epoch+1}/{num_epochs} - Average Loss: {avg_loss:.4f}")

    if (epoch + 1) % 2 == 0:
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }
        torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pt')
        print(f"Checkpoint saved: checkpoint_epoch_{epoch+1}.pt")

torch.save(model.state_dict(), 'toy_transformer_final.pt')
print("\nTraining complete! Final model saved as 'toy_transformer_final.pt'")