[![Github](https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social)](https://github.com/labmlai/annotated_deep_learning_paper_implementations)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/basic/autoregressive_experiment.ipynb)

## Transformer Experiment

This trains a simple transformer with
[multi headed attention](https://nn.labml.ai/transformers/mha.html)
introduced in [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
on an NLP auto-regression task (with Tiny Shakespeare dataset).

### Install the packages

In [1]:
!pip install labml-nn --quiet

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/266.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.3/266.3 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m461.9/461.9 kB[0m [31m21.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m94.6/94.6 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m50.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for fairscale (pyproject.toml) ... [?25l[?25hdone


In [2]:
!pip install tiktoken



### Imports

In [2]:
from labml import experiment
from labml_nn.transformers.basic.autoregressive_experiment import Configs

### Create an experiment

In [3]:
experiment.create(name="transformer", writers={'screen'})

### Configurations

In [4]:
conf = Configs()

Set experiment configurations and assign a configurations dictionary to override configurations

In [5]:
experiment.configs(conf, {
    # Use character level tokenizer
    'tokenizer': 'character',
    # Prompt separator is blank
    'prompt_separator': '',
    # Starting prompt for sampling
    'prompt': 'It is ',
    # Use Tiny Shakespeare dataset
    'text': 'tiny_shakespeare',

    # Use a context size of $256$
    'seq_len': 512,
    # Train for 32 epochs
    'epochs': 32,
    # Batch size $32$
    'batch_size': 16,
    # Switch between training and validation for $10$ times
    # per epoch
    'inner_iterations': 10,

    # Model size
    'd_model': 256,
    'transformer.n_heads': 16,
    'transformer.ffn.d_ff': 1024,

    # Use [Noam optimizer](../../optimizers/noam.html)
    'optimizer.optimizer': 'Noam',
    'optimizer.learning_rate': 1.,
})

Set PyTorch models for loading and saving

In [None]:
experiment.add_pytorch_models({'model': conf.model})

### Start the experiment and run the training loop.

In [6]:
# Start the experiment
with experiment.start():
    conf.run()

AttributeError: module 'labml.tracker' has no attribute 'set_text'

In [7]:
!pip install tiktoken



In [8]:
from tiktoken._educational import *

# Train a BPE tokeniser on a small amount of text
enc = train_simple_encoding()

# Visualise how the GPT-4 encoder encodes text
enc = SimpleBytePairEncoding.from_tiktoken("cl100k_base")
enc.encode("hello world aaaaaaaaaaaa")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[48;5;134mi[48;5;167mm[48;5;179mp[48;5;185mor[48;5;80mt[48;5;68m [48;5;134mcol[48;5;185mle[48;5;80mc[48;5;68mt[48;5;134mi[48;5;167mon[48;5;185ms[48;5;77m
[48;5;80m
[48;5;68mi[48;5;134mm[48;5;167mp[48;5;179mor[48;5;77mt[48;5;80m re[48;5;167mge[48;5;185mx[48;5;77m
[48;5;80m
[48;5;68mi[48;5;134mm[48;5;167mp[48;5;179mor[48;5;77mt[48;5;80m t[48;5;134mi[48;5;167mk[48;5;179mtoken[48;5;134m
[48;5;167m
[48;5;179m
[48;5;185mc[48;5;77ml[48;5;80ma[48;5;68ms[48;5;134ms[48;5;167m [48;5;179mS[48;5;185mi[48;5;77mm[48;5;80mp[48;5;68mle[48;5;167mB[48;5;179myte[48;5;80mP[48;5;68ma[48;5;134mir[48;5;179mE[48;5;185mn[48;5;77mco[48;5;68md[48;5;134ming[48;5;185m:[48;5;77m
   [48;5;167m [48;5;179mde[48;5;77mf[48;5;80m [48;5;68m_[48;5;134m_[48;5;167min[48;5;185mi[48;5;77mt[48;5;80m_[48;5;68m_[48;5;134m([48;5;167mse[48;5;185ml[48;5;77mf[48;5;80m,[48;5;68m [48;5;134m*[48;

[15339, 1917, 264, 70540, 33746]

In [9]:
import importlib
import tiktoken

print("tiktoken version:", importlib.metadata.version("tiktoken"))

tiktoken version: 0.12.0


In [10]:
tokenizer = tiktoken.get_encoding("gpt2")

In [11]:
text = (
    "Hello, do you like tea? <|endoftext|> In the sunlit terraces"
     "of someunknownPlace."
)

integers = tokenizer.encode(text, allowed_special={"<|endoftext|>"})

print(integers)

[15496, 11, 466, 345, 588, 8887, 30, 220, 50256, 554, 262, 4252, 18250, 8812, 2114, 1659, 617, 34680, 27271, 13]


In [12]:
strings = tokenizer.decode(integers)

print(strings)

Hello, do you like tea? <|endoftext|> In the sunlit terracesof someunknownPlace.


### Data Loader

In [3]:
from torch.utils.data import Dataset, DataLoader


class GPTDatasetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self.input_ids = []
        self.target_ids = []

        # Tokenize the entire text
        token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
        assert len(token_ids) > max_length, "Number of tokenized inputs must at least be equal to max_length+1"

        # Use a sliding window to chunk the book into overlapping sequences of max_length
        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i:i + max_length]
            target_chunk = token_ids[i + 1: i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

In [4]:
def create_dataloader_v1(txt, batch_size=4, max_length=256,
                         stride=128, shuffle=True, drop_last=True,
                         num_workers=0):

    # Initialize the tokenizer
    tokenizer = tiktoken.get_encoding("gpt2")
    print(tokenizer.n_vocab)

    # Create dataset
    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)

    # Create dataloader
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers
    )

    return dataloader

In [5]:
!pip install -q datasets tiktoken torch

from datasets import load_dataset

# use the community mirror on Hugging Face Hub
ds = load_dataset("roneneldan/tinystories", split="train[:1%]")

txt = "<|endoftext|>".join(ex["text"] for ex in ds)

from torch.utils.data import DataLoader
import tiktoken, torch

# your helper function here
dataloader = create_dataloader_v1(txt, batch_size=4, max_length=256, stride=128)

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00004-2d5a1467fff108(…):   0%|          | 0.00/249M [00:00<?, ?B/s]

data/train-00001-of-00004-5852b56a2bd28f(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/train-00002-of-00004-a26307300439e9(…):   0%|          | 0.00/246M [00:00<?, ?B/s]

data/train-00003-of-00004-d243063613e5a0(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/validation-00000-of-00001-869c898b5(…):   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

50257


In [9]:
dataloader = create_dataloader_v1(txt, batch_size=4, max_length=256, stride=128)

data_iter = iter(dataloader)
input, target = next(data_iter)
print(input.shape, target.shape)

50257
torch.Size([4, 256]) torch.Size([4, 256])


### Token Embedding

In [10]:
# input_ids = torch.tensor([2, 3, 5, 1])

In [11]:
# vocab_size = 6
# output_dim = 3

# torch.manual_seed(123)
# embedding_layer = torch.nn.Embedding(vocab_size, output_dim)

In [12]:
# print(embedding_layer.weight)

# Positonal Embedding

In [13]:
vocab_size = 50257
output_dim = 768

token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)

In [14]:
dataloader = create_dataloader_v1(txt, batch_size=4, max_length=256, stride=128)
data_iter = iter(dataloader)
inputs, targets = next(data_iter)

50257


In [15]:
print("Token IDs:\n", inputs)
print("\nInputs shape:\n", inputs.shape)

Token IDs:
 tensor([[  719,  2952,   290,  ...,  2651,   284,   511],
        [41130,   607,   290,  ...,   290, 10620,    13],
        [  526,   383,  3234,  ...,  8187,   286,   465],
        [  673,  4030,   477,  ...,   284,  4859,   284]])

Inputs shape:
 torch.Size([4, 256])


In [16]:
token_embeddings = token_embedding_layer(inputs)


In [17]:
print(token_embeddings.shape)

torch.Size([4, 256, 768])


In [18]:
context_length = 256
output_dim = 768
pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)

In [19]:
pos_embeddings = pos_embedding_layer(torch.arange(context_length))
print(pos_embeddings.shape)

# uncomment & execute the following line to see how the embeddings look like
# print(pos_embeddings)

torch.Size([256, 768])


In [20]:
input_embeddings = token_embeddings + pos_embeddings
print(input_embeddings.shape)

# uncomment & execute the following line to see how the embeddings look like
# print(input_embeddings)

torch.Size([4, 256, 768])


# Attention Layer

In [40]:
attn_mask = torch.full((1024, 1024), float('-inf'))

In [41]:
attn_mask = torch.triu(attn_mask, 1)  # upper triangle = -inf


In [42]:
attn_mask.shape

torch.Size([1024, 1024])

In [45]:
attn_mask[2]

tensor([0., -inf, -inf,  ..., -inf, -inf, -inf])

In [46]:
attn_mask2 = torch.tril(torch.ones(1024, 1024, dtype=torch.bool), diagonal=1)

In [47]:
attn_mask2[0]

tensor([ True,  True, False,  ..., False, False, False])

In [48]:
attn_mask = torch.triu(torch.full((context_length, context_length), -float("inf")), diagonal=1)

In [51]:
attn_mask[:10][:10]

tensor([[0., -inf, -inf,  ..., -inf, -inf, -inf],
        [0., 0., -inf,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        ...,
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf]])

In [52]:
attn_mask.bool()[:10][:10]

tensor([[False,  True,  True,  ...,  True,  True,  True],
        [False, False,  True,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True]])

In [82]:
import torch.nn as nn
import torch.nn.functional as F


class AttentionBlock(nn.Module):
  def __init__(self, d_in=768, d_out=768, num_heads=4, dropout=0.1, is_causal = True):
    super().__init__()
    assert d_out % num_heads == 0, "d_out is indivisible by num_heads"

    self.num_heads = num_heads
    self.head_dim = d_out // self.num_heads
    self.w_qkv = nn.Linear(d_in, 3 * d_out)
    # self.w_k = nn.Linear(input_dim, output_dim)
    # self.w_v = nn.Linear(input_dim, output_dim)

    self.dropout = nn.Dropout(p=dropout)

    self.out_project = nn.Linear(d_out, d_out)

    if is_causal:
      self.register_buffer("mask", torch.triu(torch.ones(1024, 1024), diagonal=1))

  def forward(self, x):
    b, seq_len, vocab_dim = x.shape
    # [b, seq_len, vocab_dim] -> [b, seq_len, 3 * d_out]
    qkv = self.w_qkv(x).view(b, seq_len, 3, self.num_heads, self.head_dim)

    # [3, b, num_heads, seq_len, head_dim]
    qkv = qkv.permute(2, 0, 3, 1, 4)

    # [b, num_heads, seq_len, head_dim]
    q, k, v = qkv.unbind(0)

    # [b, num_heads, seq_len, seq_len]
    attn_scores = q @ k.transpose(2, 3)

    attn_scores = attn_scores.masked_fill(self.mask.bool()[:seq_len, :seq_len] , -torch.inf)

    attn_weights = torch.softmax(attn_scores / self.head_dim ** 0.5, dim=-1)

    # [b, num_heads, seq_len, head_dim]
    gathered_output = attn_weights @ v

    return self.out_project(gathered_output.view(b, seq_len, self.num_heads * self.head_dim))


In [83]:
if torch.backends.mps.is_available():
    device = torch.device("mps")   # Apple Silicon GPU (Metal)
elif torch.cuda.is_available():
    device = torch.device("cuda")  # NVIDIA GPU
else:
    device = torch.device("cpu")   # CPU fallback

print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

batch_size = 8
context_len = 1024
embed_dim = 768
embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)

Using device: cuda
PyTorch version: 2.8.0+cu126


In [84]:
mha_combined_qkv = AttentionBlock(d_in=embed_dim, d_out=embed_dim, num_heads=12, dropout=0.0).to(device)

out = mha_combined_qkv(embeddings)

In [85]:
out.shape

torch.Size([8, 1024, 768])

In [86]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionBlock(nn.Module):
    def __init__(self, d_in=768, d_out=768, num_heads=4, dropout=0.1, is_causal=True, max_len=1024):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.scale = 1.0 / math.sqrt(self.head_dim)
        self.is_causal = is_causal
        self.max_len = max_len

        self.w_qkv = nn.Linear(d_in, 3 * d_out, bias=True)
        self.out_project = nn.Linear(d_out, d_out, bias=True)
        self.dropout = nn.Dropout(p=dropout)

        if is_causal:
            # Bool upper-tri mask where True=to-be-masked (for masked_fill)
            # store as [1,1,L,L] so it broadcasts to [B,H,L,L]
            mask = torch.triu(torch.ones(max_len, max_len, dtype=torch.bool), diagonal=1)
            mask = mask.unsqueeze(0).unsqueeze(0)  # [1,1,L,L]
            self.register_buffer("causal_mask", mask, persistent=False)
        else:
            self.register_buffer("causal_mask", None, persistent=False)

    def forward(self, x):
        b, seq_len, _ = x.shape
        if self.is_causal and seq_len > self.max_len:
            raise ValueError(f"seq_len={seq_len} exceeds max_len={self.max_len} of the prebuilt mask")

        # Project and split heads
        qkv = self.w_qkv(x)  # [B, L, 3*d_out]
        qkv = qkv.view(b, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)  # each: [B, H, L, D]

        # Scaled dot product scores
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale  # [B, H, L, L]

        # Apply causal mask if needed (True => set to -inf)
        if self.is_causal:
            # Slice and broadcast: [1,1,L,L] -> [B,H,L,L]
            m = self.causal_mask[..., :seq_len, :seq_len]
            scores = scores.masked_fill(m, torch.finfo(scores.dtype).min)

        # Softmax over keys and dropout
        attn = scores.softmax(dim=-1)
        attn = self.dropout(attn)

        # Weighted sum of values
        out = torch.matmul(attn, v)  # [B, H, L, D]
        out = out.transpose(1, 2).contiguous().view(b, seq_len, self.num_heads * self.head_dim)
        return self.out_project(out)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MHABlock(nn.Module):
  def __init__(self, d_in=768, d_out=768, n_heads = 12, context_len=1024, drop_out=0.0, qkv_bias=False, is_causal=True):
    assert d_out % n_heads == 0, f"{d_out} and {n_heads} are not divisable"

    self.d_out = d_out
    self.n_heads = n_heads
    self.head_dim = self.d_out // self.n_heads

    self.scale = 1.0 / math.sqrt(self.head_dim)

    self.w_qkv = nn.Linear(d_in, 3 * d_out)

    self.out_proj = nn.Lineaar(d_out, d_out)

    if qkv_bias:
      self.bias_q = nn.Parameter(torch.zeros(d_out))
      self.bias_k = nn.Parameter(torch.zeros(d_out))
      self.bias_v = nn.Parameter(torch.zeros(d_out))
    else:
      self.register_parameter("bias_q", None)
      self.register_parameter("bias_k", None)
      self.register_parameter("bias_v", None)

    if is_causal:
      mask = torch.triu(torch.ones(context_len, context_len), diagonal=1).unsqueeze(0).unsqueeze(0)
      self.register_buffer("causal_mask", self.mask)

    else:
      self.register_buffer("causal_mask", None)

    nn.init.xavier_uniform_(self.w_qkv.weight)
    nn.init_xavier_uniform_(self.out_proj.weight)

    nn.init.zeros_(self.w_qkv.bias)
    nn.init.zeros_(self.w_qkv.bias)


  def forward(self, x):
    b, L, _ = x.shape

    #[b, L, 3*d_out]        [b, L,3, n_heads, head_dim]. -> [3, b, n_heads, L, head_dim]
    qkv = self.w_qkv(x).view(b, L, 3, self.n_heads, -1).permute(2, 0, 3, 1, 4)
    q, k, v = qkv.unbind(0)
    q = q / self.scale


    # Q @ K -> [b, n_heads, L, L]
    scores = torch.matmul(q, k.transpose(2, -1))

    if self.is_causal:
      m = self.causal_mask[..., :L, :L]
      scores = scores.masked_fill(m, -torch.inf)


    attn_weights = torch.softmax(scores, dim=-1)

    # [b, n_heads, L, head_dim]
    out = torch.matmul(attn_weights, v)

    out = out.transpose(1, 2).contiguous().view(b, L, -1)

    return self.out_proj(out)





