### Train the transformer!

##### Three important modules
1. `MultiHeadSelfAttention`: A self-attention implementation
2. `Block`: A transformer block which is repeated n_layer times in a GPT model
3. `GPT`: The full GPT model itself, including initial embeddings, the GPT blocks, and the token decoding logic.

The `GPT` module uses the `Block` module, which in turn uses the `MultiHeadSelfAttention` module.
```
    ┌────────────────────────┐
    │             GPT               │
    └────────────────────────┘
                   ▲
    ┌───────────┴────────────┐
    │            Block              │
    └────────────────────────┘
                   ▲
    ┌───────────┴────────────┐
    │    MultiHeadSelfAttention     │
    └────────────────────────┘
```

In [3]:
from common import GPTConfig, MultiHeadSelfAttention

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Let's use a placeholder config to show how the attention layer works
config = GPTConfig(
    vocab_size=10,
    n_layer=3,
    n_embd=12,
    n_head=4,
    block_size=5,
)

In [5]:
attention = MultiHeadSelfAttention(config)

In [6]:
print(attention)

MultiHeadSelfAttention(
  (key): Linear(in_features=12, out_features=12, bias=True)
  (query): Linear(in_features=12, out_features=12, bias=True)
  (value): Linear(in_features=12, out_features=12, bias=True)
  (attn_drop): Dropout(p=0.1, inplace=False)
  (resid_drop): Dropout(p=0.1, inplace=False)
  (proj): Linear(in_features=12, out_features=12, bias=True)
)


In [7]:
import torch.nn as nn
from common import MultiHeadSelfAttention

In [8]:
class Block(nn.Module):
    """ An unassuming Transformer block """

    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)

        self.attn = MultiHeadSelfAttention(config)
        # MLP is a feed-forward neural network
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(0.1),
        )
    
    def forward(self, x):
        x = x + self.attn(self.ln1(x)) + self.mlp(self.ln2(x))
        return x

In [9]:
block = Block(config)
block

Block(
  (ln1): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
  (ln2): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
  (attn): MultiHeadSelfAttention(
    (key): Linear(in_features=12, out_features=12, bias=True)
    (query): Linear(in_features=12, out_features=12, bias=True)
    (value): Linear(in_features=12, out_features=12, bias=True)
    (attn_drop): Dropout(p=0.1, inplace=False)
    (resid_drop): Dropout(p=0.1, inplace=False)
    (proj): Linear(in_features=12, out_features=12, bias=True)
  )
  (mlp): Sequential(
    (0): Linear(in_features=12, out_features=48, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=48, out_features=12, bias=True)
    (3): Dropout(p=0.1, inplace=False)
  )
)

In [10]:
assert isinstance(block.attn, MultiHeadSelfAttention)

In [11]:
import torch
import torch.nn.functional as F

In [None]:
class GPT(nn.Module):
    """ The full GPT language model, with a context size of block_size """

    def __init__(self, config):
        super().__init__()

        # Input embedding stem
        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
        self.drop = nn.Dropout(config.embd_pdrop)

        self.blocks = nn.Sequential(
            *[Block(config) for _ in range(config.n_layer)]
        )