<a href="https://colab.research.google.com/github/JayThibs/gpt-experiments/blob/main/notebooks/minGPT_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# End-to-End minGPT

This is an implementation of minGPT in order to quickly get a feel for the end-to-end training of a GPT model.

## Imports

In [2]:
import math
import torch
from torch import nn
import torch.nn.functional as F

## Model configuration

We are now going to create a class where we can initialize all the parameters of the model. This is where we include all the hyperparameters for the model. Since we are doing an implementation of minGPT, we don't have the same model as GPT-2 or GPT-3. However, to get those models, we can simple add more layers, increase maximum sequence length, and embedding dimension.

Those bigger models have some additional tricks for training, but the general idea is just a bigger model and more data.

In [3]:
class GPTConfig:
    attn_dropout = 0.1
    embed_dropout = 0.1
    ff_dropout = 0.1

    def __init__(self, vocab_size, max_len, **kwargs):
        self.vocab_size = vocab_size
        self.max_len = max_len
        for key, value in kwargs.items():
            setattr(self, key, value)

class GPT1Config(GPTConfig):
    num_heads = 12
    num_blocks = 12
    embed_dim = 768

In [4]:
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config.embed_dim
        self.max_len = config.max_len
        self.tok_embed = nn.Embedding(
            config.vocab_size, embed_dim
        )
        self.pos_embed = nn.Parameter(
            torch.zeros(1, self.max_len, embed_dim)
        )
        self.dropout = nn.Dropout(config.embed_dropout)
        self.blocks = nn.Sequential(
            *[Decoder(config) for _ in range(config.num_blocks)]
        ) # Decoder() is the transformer decoder block, we are stacking blocks
        self.ln = nn.LayerNorm(embed_dim)
        self.fc = nn.Linear(embed_dim, config.vocab_size)

    def forward(self, x, target=None):
        # batch_size = x.size(0) # (batch_size, sequence_length, embedding_dimension)
        seq_len = x.size(1)
        assert seq_len <= self.max_len, "sequence longer than model capacity"

        tok_embedding = self.tok_embed(x)
        # tok_embedding.shape == (batch_size, seq_len, embed_dim)
        pos_embedding = self.pos_embed[:, :seq_len, :] # cuts pos_embed shorter based on seq_len passed (no trainable positional embedding layer)
        # pos_embedding.shape == (1, seq_len, embed_dim)
        # more elegant than calling torch.range() on each forward pass
        x = self.dropout(tok_embedding + pos_embedding)
        x = self.blocks(x)
        x = self.ln(x)
        x = self.fc(x) # logits
        # x.shape == (batch_size, seq_len, vocab_size)
        return x

In [5]:
# The Transformer Decoder Block
# Includes multi-head attention, layer normalization, 
# and a point-wise feedforward network
class Decoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config.embed_dim
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(config)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(config.ff_dropout)
        )

    def forward(self, x):
        # x is the residual connection in between each component
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config.embed_dim
        self.num_heads = config.num_heads
        assert embed_dim % self.num_heads == 0, "invalid heads and embedding dimension configuration, embed_dim must be divisible by num_heads with no remainder"
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.query = nn.Linear(embed_dim, embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.attn_dropout = nn.Dropout(config.attn_dropout)
        self.proj_dropout = nn.Dropout(config.ff_dropout)
        self.register_buffer(
            "mask",
            torch.tril(torch.ones(config.max_len, config.max_len)).unsqueeze(0).unsqueeze(0)
        ) # torch.tril gives us a triangle matrix so each row masks the next tokens
    
    def forward(self, x):
        batch_size = x.size(0)
        seq_len = x.size(1)
        # x.shape == (batch_size, seq_len, embed_dim)
        # k_t == transposed key
        # reshape splits the data across each attention head
        k_t = self.key(x).reshape(batch_size, seq_len, self.num_heads, -1).permute(0, 2, 3, 1)
        v = self.value(x).reshape(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
        q = self.query(x).reshape(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
        # shape == (batch_size, num_heads, seq_len, head_dim)

        attn = torch.matmul(q, k_t) / math.sqrt(q.size(-1)) # Q * K^T / √d
        # attn.shape == (batch_size, num_heads, seq_len, seq_len)
        mask = self.mask[:, :, :seq_len, :seq_len] # across batch, select all words yet to be generated
        attn = attn.masked_fill(mask == 0, float("-inf")) # put -inf to mask tokens
        attn = self.attn_dropout(attn)
        # attn.shape == (batch_size, num_heads, seq_len, seq_len)
        attn = F.softmax(attn, dim=-1)
        y = torch.matmul(attn, v) # second matmul, (seq_len, seq_len) x (seq_len, heads_dim)
        # y.shape == (batch_size, num_heads, seq_len, heads_dim)
        y = y.transpose(1, 2)
        # y.shape == (batch_size, seq_len, num_heads, heads_dim)
        y = y.reshape(batch_size, seq_len, -1) # reconnecting data passed through each head
        # y.shape == (batch_size, seq_len, embed_dim)
        y = self.proj_dropout(self.proj(y))
        return y

Here's a concrete example of the mask matrix, assuming that we have a decoder whose maximum sequence length is 5.

In [7]:
max_len = 5

mask = torch.tril(torch.ones(max_len, max_len)).unsqueeze(0).unsqueeze(0)
mask

tensor([[[[1., 0., 0., 0., 0.],
          [1., 1., 0., 0., 0.],
          [1., 1., 1., 0., 0.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.]]]])

Now, to mask each example in the batch, we mask in the following way:

In [10]:
seq_len = 3

mask = mask[:, :, :seq_len, :seq_len]
mask

tensor([[[[1., 0., 0.],
          [1., 1., 0.],
          [1., 1., 1.]]]])

In [11]:
# attn.shape == (batch_size, num_heads, seq_len, seq_len)
batch_size = 3
num_heads = 2

attn = torch.randn(batch_size, num_heads, seq_len, seq_len)
attn.shape

torch.Size([3, 2, 3, 3])

As you can see below, we have 3 batches that are separated into 2 attention heads, where the matrix is (seq_len, seq_len) where each row adds a new token.

In [12]:
attn = attn.masked_fill(mask == 0, float("-inf"))
attn

tensor([[[[-1.4991,    -inf,    -inf],
          [-0.6522, -0.2875,    -inf],
          [-0.3859, -0.2676,  0.3449]],

         [[-0.5417,    -inf,    -inf],
          [ 0.4578,  1.1385,    -inf],
          [-0.1292,  0.3617,  1.6502]]],


        [[[-0.5823,    -inf,    -inf],
          [ 1.6902, -1.3269,    -inf],
          [ 0.3759,  0.7223,  1.1238]],

         [[-0.2738,    -inf,    -inf],
          [-0.7617, -0.7891,    -inf],
          [ 0.1969,  1.7865,  0.6715]]],


        [[[-1.0717,    -inf,    -inf],
          [ 1.7455,  0.4107,    -inf],
          [-1.5003,  2.2838, -0.3378]],

         [[-1.7366,    -inf,    -inf],
          [-0.2629,  1.0714,    -inf],
          [-0.4392,  0.3253,  0.5141]]]])

Now, we apply a softmax:

In [13]:
F.softmax(attn, dim=-1) # applies the softmax over each row

tensor([[[[1.0000, 0.0000, 0.0000],
          [0.4098, 0.5902, 0.0000],
          [0.2380, 0.2678, 0.4942]],

         [[1.0000, 0.0000, 0.0000],
          [0.3361, 0.6639, 0.0000],
          [0.1168, 0.1909, 0.6923]]],


        [[[1.0000, 0.0000, 0.0000],
          [0.9533, 0.0467, 0.0000],
          [0.2209, 0.3124, 0.4667]],

         [[1.0000, 0.0000, 0.0000],
          [0.5069, 0.4931, 0.0000],
          [0.1332, 0.6528, 0.2141]]],


        [[[1.0000, 0.0000, 0.0000],
          [0.7916, 0.2084, 0.0000],
          [0.0207, 0.9129, 0.0664]],

         [[1.0000, 0.0000, 0.0000],
          [0.2084, 0.7916, 0.0000],
          [0.1742, 0.3741, 0.4518]]]])

## Model

We can now put everything together and create a basic model.

In [14]:
vocab_size = 10
max_len = 12

config = GPT1Config(vocab_size, max_len)
model = GPT(config)

This is a simple 12-layer decoder network.

In [15]:
model

GPT(
  (tok_embed): Embedding(10, 768)
  (dropout): Dropout(p=0.1, inplace=False)
  (blocks): Sequential(
    (0): Decoder(
      (ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (key): Linear(in_features=768, out_features=768, bias=True)
        (value): Linear(in_features=768, out_features=768, bias=True)
        (query): Linear(in_features=768, out_features=768, bias=True)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (proj_dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): Sequential(
        (0): Linear(in_features=768, out_features=3072, bias=True)
        (1): GELU(approximate=none)
        (2): Linear(in_features=3072, out_features=768, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
    )
    (1): Decoder(
      (ln1): LayerNorm((768,), eps=1e-05, elem

Let's test a dummy input and see if the model successfully acts the way we expect to. First, let's try passing in a degenerate input whose length is beyond model capacity.

In [17]:
seq_len = 15

test_input = torch.randint(high=vocab_size, size=(batch_size, seq_len))
try:
    model(test_input).shape
except AssertionError as e:
    print(e)

sequence longer than model capacity


We get an appropriate assertion error, saying that the sequence is longer than the maximum sequence length that the model can process. Let’s see what happens if we pass in a valid input.

In [18]:
model(test_input[:, :max_len]).shape

torch.Size([3, 12, 10])

As expected, we get a valid output of shape (batch_size, seq_len, vocab_size).