<a href="https://colab.research.google.com/github/QasimWani/simple-transformer/blob/main/transformers/gpt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Barebones implementation of GPT

import torch
import torch.nn as nn
import numpy as np
from einops import rearrange

In [2]:
class SingleHeadedAttention(nn.Module):
  ''' Applies SHA with causal mask '''

  def __init__(self, max_seq_len, d_embed):
    super().__init__()

    self.q_proj = nn.Linear(d_embed, d_embed)
    self.k_proj = nn.Linear(d_embed, d_embed)
    self.v_proj = nn.Linear(d_embed, d_embed)

    self.w_out = nn.Sequential(
        nn.Linear(d_embed, d_embed),
        nn.Dropout(0.1)
    )

    self.attn_dropout = nn.Dropout(0.1)

    # Construct a one-time causal mask
    # max_seq_len, max_seq_len. Token at index i will only attend to tokens from 0 to i. This avoids learning from future positions
    mask = torch.triu(torch.ones(1, max_seq_len, max_seq_len), diagonal=1).bool()
    self.register_buffer('causal_mask', mask)

  def forward(self, x: torch.Tensor, pad_mask: torch.Tensor = None) -> torch.Tensor:
    batch_size, seq_len, d_embed = x.shape

    # Step 1. Project QKV
    Q = self.q_proj(x) # batch_size, seq_len, d_embed
    K = self.k_proj(x) # batch_size, seq_len, d_embed
    V = self.v_proj(x) # batch_size, seq_len, d_embed

    # Step 2. attention calculation
    scores = (Q @ K.transpose(-2, -1)) / (d_embed ** 0.5) # seq_len, d_embed x d_embed, seq_len -> batch_size, seq_len, seq_len (how different positions attend to each other)
    mask = self.causal_mask[:, :seq_len, :seq_len]

    if pad_mask is not None:
      # pad_mask.shape = batch_size, seq_len. 1 indicates real and 0 indicates <pad>
      pad_mask = pad_mask[:, None, :].bool() # batch_size, 1, seq_len
      combined_mask = mask | ~pad_mask
      scores = scores.masked_fill(combined_mask, float('-inf')) # no need to load to a particular device since model and x assumes same device
    else:
      scores = scores.masked_fill(mask, float('-inf')) # no need to load to a particular device since model and x assumes same device

    attention_weights = torch.softmax(scores, dim=-1) # batch_size, seq_len, seq_len. What happens if seq_len is very large? Lookup: Online Softmax
    attention_weights = self.attn_dropout(attention_weights)

    attention = attention_weights @ V # seq_len, seq_len x seq_len, d_embed -> batch_size, seq_len, d_embed relevance calculation

    out = self.w_out(attention) # batch_size, seq_len, d_embed
    return out # batch_size, seq_len, d_embed

In [3]:
class MultiHeadedAttention(nn.Module):
  ''' Applies MHA with causal mask '''
  def __init__(self, max_seq_len, d_embed, num_heads):
    super().__init__()

    assert d_embed % num_heads == 0, f"d_embed needs to be divisible by num_heads. Recommend parameters: d_embed = 768, num_heads = 12."

    self.qkv = nn.Linear(d_embed, d_embed * 3)
    self.w_out = nn.Sequential(
        nn.Linear(d_embed, d_embed),
        nn.Dropout(0.1)
    )

    self.attn_dropout = nn.Dropout(0.1)

    self.num_heads = num_heads
    self.head_dim = d_embed // num_heads

    # Register causal mask
    mask = torch.triu(torch.ones(1, 1, max_seq_len, max_seq_len, dtype=torch.bool), diagonal=1) # batch_size, num_heads, seq_len, seq_len
    self.register_buffer('causal_mask', mask)

  def forward(self, x: torch.Tensor, pad_mask: torch.Tensor = None) -> torch.Tensor:
    batch_size, seq_len, d_embed = x.shape

    # Step 1. Get QKV matrices, each has dimension of batch_size, seq_len, num_heads, head_dim
    qkv = self.qkv(x) # batch_size, seq_len, 3 * d_embed
    Q, K, V = rearrange(qkv,
                        'batch_size seq_len (three num_heads head_dim) -> three batch_size seq_len num_heads head_dim',
                        three=3,
                        num_heads=self.num_heads,
                        head_dim=self.head_dim).unbind(0) # Note: We do not need to explicitly unbind it since PyTorch uses the iterator to split across Q, K, V

    # Step 2. Compute Attention
    scores = torch.einsum('b q h d, b k h d -> b h q k', Q, K) / (self.head_dim ** 0.5) # batch_size, num_heads, seq_len, seq_len
    mask = self.causal_mask[..., :seq_len, :seq_len] # use elipses operator to keep the leading dimensions without defining them explicitly

    if pad_mask is not None:
        # Mask out pad tokens so real tokens don't attend to pad tokens.
        # This is called key pad. We don't want to pad query tokens because then we'll get an nan due to entire row being -inf
        # Pad rows still attend, but it's irrelevant because during loss we'll ignore them
        # For example, if we have an input: [cat sat on <pad> <pad>]
        # The causal mask would be:
        # [0 1 1 1 1] cat
        # [0 0 1 1 1] sat
        # [0 0 0 1 1] on
        # [0 0 0 0 1] <pad>
        # [0 0 0 0 0] <pad>
        # In this case, we are still attending to the last two columns. We need to explicitly set <pad> tokens to 1 so we mask it out.
        # Desired causal mask that includes masking pad tokens:
        # [0 1 1 1 1] cat
        # [0 0 1 1 1] sat
        # [0 0 0 1 1] on
        # [0 0 0 1 1] <pad>
        # [0 0 0 1 1] <pad>
        # NOTE: For right-padded attention masks, you technically don't need to inject a pad mask since the causal mask will take care of the rest
        pad_mask = pad_mask[:, None, None, :].bool() # # batch_size, 1, 1, seq_len
        combined_mask = mask | ~pad_mask
        scores = scores.masked_fill(combined_mask, float('-inf')) # broadcasts to batch_size, num_heads, seq_len, seq_len
    else:
        scores = scores.masked_fill(mask, float('-inf')) # broadcasts to batch_size, num_heads, seq_len, seq_len

    attention_weights = torch.softmax(scores, dim=-1) # batch_size, num_heads, seq_len, seq_len
    attention_weights = self.attn_dropout(attention_weights)

    attention = torch.einsum('b h q k, b k h d -> b q h d', attention_weights, V) # batch_size, seq_len, d_embed
    attention = rearrange(attention, 'b q h d -> b q (h d)')

    out = self.w_out(attention) # batch_size, seq_len, d_embed

    return out


In [4]:
class FFN(nn.Module):

  def __init__(self, in_channels, hidden_size, out_channels):
    super().__init__()
    self.network = nn.Sequential(
        nn.Linear(in_channels, hidden_size),
        nn.ReLU(), # Replace with GELU
        nn.Dropout(0.1),
        nn.Linear(hidden_size, out_channels),
        nn.Dropout(0.1)
    )

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.network(x)

In [5]:
class LayerNorm(nn.Module):

  def __init__(self, num_features, eps=1e-5):
    super().__init__()

    self.eps = eps
    self.gamma = nn.Parameter(torch.ones(num_features)) # d_embed
    self.beta = nn.Parameter(torch.zeros(num_features)) # d_embed


  def forward(self, x: torch.Tensor) -> torch.Tensor:
    batch_size, seq_len, d_embed = x.shape
    mu = x.mean(dim=-1, keepdim=True) # batch_size, seq_len, 1
    var = x.var(dim=-1, unbiased=False, keepdim=True) # batch_size, seq_len, 1. Note: unbiased = False divides by true population (instead of N - 1).
    # But for large enough d_embed, doesn't matter

    x = (x - mu) / torch.sqrt(var + self.eps) # batch_size, seq_len, d_embed
    return x * self.gamma + self.beta # automatic broadcasting: batch_size, seq_len, d_embed


In [6]:
class TransformerBlock(nn.Module):
  def __init__(self, max_seq_len, d_embed, num_heads):
    super().__init__()
    self.mha = MultiHeadedAttention(max_seq_len, d_embed, num_heads)
    self.ffn = FFN(d_embed, d_embed * 4, d_embed)

    self.ln1 = LayerNorm(d_embed)
    self.ln2 = LayerNorm(d_embed)

  def forward(self, x: torch.Tensor, pad_mask: torch.Tensor = None) -> torch.Tensor:
    x = x + self.mha(self.ln1(x), pad_mask) # Pre-norm
    x = x + self.ffn(self.ln2(x))
    return x


In [7]:
# Using Byte Tokenizer for simplicity, but I refer the reader to my in-depth implementation
# of several tokenizers here: https://github.com/QasimWani/simple-transformer/blob/main/transformers/tokenization.ipynb

class ByteTokenizer():
  def encode(self, text: str):
    '''
    The advantage of using a byte tokenizer is that the vocab_size
    is bounded to 256. Each character can be represented as a list
    of bytes, where each byte is in range(0, 256).
    The advantage of this is we have a tiny vocab_size, but the
    major disadvantage is that our sequence length will be very
    long. As you can tell, the upper bound of `ByteTokenizer.compression_factor`
    is 1.0 because a single character can be composed of multiple
    tokens.

    A better alternative to this is using BPE.
    '''
    return list(text.encode('utf-8'))

  def decode(self, tokens):
    return bytes(tokens).decode('utf-8')


In [8]:
class FixedPositionalEncodings(nn.Module):

  def __init__(self, max_seq_len, d_embed):
    super().__init__()
    assert d_embed % 2 == 0
    positional_encodings = torch.zeros(max_seq_len, d_embed) # max_seq_len, d_embed
    positions = rearrange(torch.arange(0, max_seq_len), 'seq_len->seq_len 1') # max_seq_len, 1
    # 1 / 10_000 ** (-index / d_embed) -or- exp(index * (-math.log(1e4) / d_embed))
    div_term = 1 / 10_000 ** (torch.arange(0, d_embed, 2) / d_embed) # d_embed // 2

    positional_encodings[..., 0::2] = torch.sin(positions * div_term) # max_seq_len, d_embed
    positional_encodings[..., 1::2] = torch.cos(positions * div_term) # max_seq_len, d_embed

    self.register_buffer('positions', positional_encodings)


  def forward(self, x: torch.Tensor) -> torch.Tensor:
    batch_size, seq_len, d_embed = x.shape
    return x + self.positions[:seq_len, :][None, ...] # batch_size, seq_len, d_embed

In [16]:
class GPT(nn.Module):

  def __init__(self, max_seq_len=1024, d_embed=64, vocab_size=256, num_heads=4, num_transformer_blocks=8):
    super().__init__()

    self.config = {'max_seq_len': max_seq_len, 'd_embed': d_embed, 'vocab_size': vocab_size, 'num_heads': num_heads, 'num_transformer_blocks': num_transformer_blocks}

    self.transformer = nn.Sequential(*[TransformerBlock(max_seq_len, d_embed, num_heads) for _ in range(num_transformer_blocks)])
    self.add_positional_encodings = FixedPositionalEncodings(max_seq_len, d_embed)

    self.token_embeddings = nn.Embedding(vocab_size, d_embed) # We learn the representation to d_embed. Note: this is a giant lookup table.
    # You may ask yourself, learning a representation from large vocab_size of say 50,256 to small d_embed (766) is pretty challenging.
    # This is where the transformer network comes into play, whos main job is to learn this representation. So we keep this dead-simple weight preserving look-up table

    self.ln_final = LayerNorm(d_embed)
    self.lm_head = nn.Linear(d_embed, vocab_size, bias=False)
    self.lm_head.weight = self.token_embeddings.weight # NOTE: we do not need to transpose it because Linear layer will internally call: x @ W.T + b

  def forward(self, token_ids: torch.Tensor, pad_mask: torch.Tensor = None):
    # Step 1. tokenization
    token_embeddings = self.token_embeddings(token_ids) # batch_size, seq_len, d_embed

    # Step 2. apply positional embeddings
    embeddings = self.add_positional_encodings(token_embeddings) # batch_size, seq_len, d_embed (Fixed sinusoidal positions)

    # Step 3. Pass it through a transformer block
    for block in self.transformer:
      embeddings = block(embeddings, pad_mask) # you can't pass in two inputs to Sequential. So need to do it iteratively

    # Step 4. Pass it through final linear layer to project down to vocab size (seq_len, d_embed) -> (seq_len, vocab_size)
    out = self.lm_head(self.ln_final(embeddings)) # batch_size, seq_len, vocab_size

    return out

  def get_parameter_count(self):
    get_parameter_count = lambda param_name: sum(p.numel() for p in param_name.parameters())

    num_transformer_blocks = get_parameter_count(self.transformer)
    num_token_embeddings = get_parameter_count(self.token_embeddings)
    num_lm_head = get_parameter_count(self.lm_head)

    num_qkv_block = get_parameter_count(self.transformer[0].mha)
    num_ffn = get_parameter_count(self.transformer[0].ffn)
    num_layernorm = get_parameter_count(self.transformer[0].ln1) + get_parameter_count(self.transformer[0].ln2)

    num_final_ln = get_parameter_count(self.ln_final)

    # Get parameter count for transformer block
    return {
        'config': self.config,
        'transformer_blocks': num_transformer_blocks,
        'per_block_mha': num_qkv_block,
        'per_block_ffn': num_ffn,
        'per_block_layernorm': num_layernorm,
        'token_embeddings': num_token_embeddings,
        'final_ln': num_final_ln,
        'lm_head': num_lm_head, # NOTE: weight sharing, so we don't double count the parameters
        'total': num_transformer_blocks + num_token_embeddings + num_final_ln
    }


In [17]:
model = GPT(max_seq_len=256, d_embed=128, vocab_size=256, num_heads=4, num_transformer_blocks=2).to('cuda')

In [19]:
# Nice illustrations showcasing scaling laws:
# 1. Parameter count: https://claude.ai/public/artifacts/af13cc77-c008-4436-b61c-129d4e9c66f2
# 2. FLOPs: screenshot at https://youtu.be/SQ3fZ1sAqXI?si=dnzKwbDe9ArzvfFz&t=334

model.get_parameter_count()

{'config': {'max_seq_len': 256,
  'd_embed': 128,
  'vocab_size': 256,
  'num_heads': 4,
  'num_transformer_blocks': 2},
 'transformer_blocks': 396544,
 'per_block_mha': 66048,
 'per_block_ffn': 131712,
 'per_block_layernorm': 512,
 'token_embeddings': 32768,
 'final_ln': 256,
 'lm_head': 32768,
 'total': 429568}

In [20]:
# Now, let's train the model

def train_one_epoch(dataloader, model, optimizer):
  model.train()
  criterion = nn.CrossEntropyLoss(ignore_index=-100) # we will make the padding token id to be -100 s.t. we don't take that into account for loss calculation

  for batch in dataloader:
    optimizer.zero_grad()
    token_ids, attention_mask = batch


    logits = model(token_ids, attention_mask) # batch_size, seq_len, vocab_size

    # Loss calculation. Main idea is to shift by one.
    # Suppose current tokens that we feed into the model are [the cat sat]
    # The targets (just the tokens extended by one) would be: [cat sat on]. assume the original sequence is [the cat sat on the mat]
    # So at each step, we're just shifting the index by one so we explicitly teach the model to predict the next token
    targets = token_ids[:, 1:].clone() # batch_size, seq_len - 1 (start with the next token to ensure that the first token our model learns to predict is the second token in true sequence, first token in target)
    logits = logits[:, :-1, :] # batch_size, seq_len - 1, vocab_size
    mask_shifted = attention_mask[:, 1:] # batch_size, seq_len - 1
    targets[mask_shifted == 0] = -100 # make padded tokens -100

    loss = criterion(rearrange(logits, 'B T V -> (B T) V'), rearrange(targets, 'B T -> (B T)'))

    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # MISSING
    optimizer.step()