### Train the transformer!

##### Three important modules
1. `MultiHeadSelfAttention`: A self-attention implementation
2. `Block`: A transformer block which is repeated n_layer times in a GPT model
3. `GPT`: The full GPT model itself, including initial embeddings, the GPT blocks, and the token decoding logic.

The `GPT` module uses the `Block` module, which in turn uses the `MultiHeadSelfAttention` module.
```
    ┌────────────────────────┐
    │             GPT               │
    └────────────────────────┘
                   ▲
    ┌───────────┴────────────┐
    │            Block              │
    └────────────────────────┘
                   ▲
    ┌───────────┴────────────┐
    │    MultiHeadSelfAttention     │
    └────────────────────────┘
```

In [3]:
from common import GPTConfig, MultiHeadSelfAttention

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Let's use a placeholder config to show how the attention layer works
config = GPTConfig(
    vocab_size=10,
    n_layer=3,
    n_embd=12,
    n_head=4,
    block_size=5,
)

In [5]:
attention = MultiHeadSelfAttention(config)

In [6]:
print(attention)

MultiHeadSelfAttention(
  (key): Linear(in_features=12, out_features=12, bias=True)
  (query): Linear(in_features=12, out_features=12, bias=True)
  (value): Linear(in_features=12, out_features=12, bias=True)
  (attn_drop): Dropout(p=0.1, inplace=False)
  (resid_drop): Dropout(p=0.1, inplace=False)
  (proj): Linear(in_features=12, out_features=12, bias=True)
)


In [7]:
import torch.nn as nn
from common import MultiHeadSelfAttention

In [8]:
class Block(nn.Module):
    """ An unassuming Transformer block """

    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)

        self.attn = MultiHeadSelfAttention(config)
        # MLP is a feed-forward neural network
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(0.1),
        )
    
    def forward(self, x):
        x = x + self.attn(self.ln1(x)) + self.mlp(self.ln2(x))
        return x

In [9]:
block = Block(config)
block

Block(
  (ln1): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
  (ln2): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
  (attn): MultiHeadSelfAttention(
    (key): Linear(in_features=12, out_features=12, bias=True)
    (query): Linear(in_features=12, out_features=12, bias=True)
    (value): Linear(in_features=12, out_features=12, bias=True)
    (attn_drop): Dropout(p=0.1, inplace=False)
    (resid_drop): Dropout(p=0.1, inplace=False)
    (proj): Linear(in_features=12, out_features=12, bias=True)
  )
  (mlp): Sequential(
    (0): Linear(in_features=12, out_features=48, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=48, out_features=12, bias=True)
    (3): Dropout(p=0.1, inplace=False)
  )
)

In [10]:
assert isinstance(block.attn, MultiHeadSelfAttention)

In [11]:
import torch
import torch.nn.functional as F

In [None]:
class GPT(nn.Module):
    """ The full GPT language model, with a context size of block_size """

    def __init__(self, config):
        super().__init__()

        # Input embedding stem
        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
        self.drop = nn.Dropout(config.embd_pdrop)

        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
        
        # Decoder head
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.block_size = config.block_size
        self.apply(self._init_weights)
        
        print("number of parameters: {}".format(sum(p.numel() for p in self.parameters())))

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
    
    def get_block_size(self):
        return self.block_size

    def forward(self, idx, targets=None):
        b, t = idx.size()
        assert t <= self.block_size, "Cannot forward, model block size is exhausted."

        # Create token embeddings and add positional embeddings
        # Each index maps to a (learnable) vector
        token_embeddings = self.tok_emb(idx)
        # Each position maps to a (learnable) vector
        position_embeddings = self.pos_emb[:, :t, :]
        
        x = self.drop(token_embeddings + position_embeddings)
        x = self.blocks(x)

        # Decode the output of the transformer blocks
        x = self.ln_f(x)
        logits = self.head(x)

        # If we are given some desired targets also calculate the loss, e.g. during training
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            loss = None

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, stop_tokens=None):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """

        # === EXERCISE PART 3 START: COMPLETE THE GENERATION LOGIC ===
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = (
                idx if idx.size(1) <= self.block_size else idx[:, -self.block_size :]
            )

            # forward the model to get the logits for the index in the sequence
            logits, _ = self(idx_cond)

            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature

            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float("Inf")

            # TODO: apply softmax to convert logits to (normalized) probabilities
            # using F.softmax. Remember the dim=-1 parameter.
            # probs = <TODO>
            probs = F.softmax(logits, dim=-1)  # Solution

            # TODO: sample from the distribution (if top_k=1 this is equivalent to greedy sampling)
            # using torch.multinomial. You only need to sample a single token.
            # idx_next = <TODO>
            idx_next = torch.multinomial(probs, num_samples=1)  # Solution

            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

            # stop prediction if we produced a stop token
            if stop_tokens is not None and idx_next.item() in stop_tokens:
                return idx
        # === EXERCISE PART 3 END: COMPLETE THE GENERATION LOGIC ===

        return idx

In [23]:
model = GPT(config)

number of parameters: 5976
