In [6]:
import torch
import torch.nn as nn

GPT_CONFIG_124M = {
    "vocab_size": 50257,
    "context_length": 256,
    "emb_dim": 768,
    "n_heads": 12,
    "n_layers": 12,
    "drop_rate": 0.1,
    "qkv_bias": False 
}


In [9]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
        self.d_out = d_out # 4
        self.num_heads = num_heads # 2
        self.head_dim = d_out // num_heads # 2
        
        # bigger weight matrices
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) # (3,4)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) # (3,4)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) # (3,4)
        
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        
        # compute big kqv matrices
        keys = self.W_key(x) # (b,n,3)=>(b,n_token,4)
        queries = self.W_query(x) # (b,n,3)=>(b,n_token,4)
        values = self.W_value(x) # (b,n,3)=>(b,n_token,4)
        # ... then splits
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) # (b,n_token,n_head=2,head_dim=2)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) # (b,n_token,n_head,head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim) # (b,n_token,n_head,head_dim)
        # swap <num_heads> to after-batch location
        keys = keys.transpose(1,2) # (b,n_head=2,n_token,head_dim=2)
        queries = queries.transpose(1,2) # (b,n_head=2,n_token,head_dim=2)
        values = values.transpose(1,2) # (b,n_head=2,n_token,head_dim=2)

        attn_scores = queries @ keys.transpose(2,3) # (b,n_head,n_token,head_dim)@(b,n_head,head_dim,n_token)=>(b,n_head,n_token,n_token)
        mask_bool = self.mask.bool()[:num_tokens,:num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf) # (n,n) with upper-right is -inf
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights) # (b,n_head,n_token,n_token)

        context_vec = (attn_weights @ values).transpose(1,2) # (b,n_head,n_token,head_dim=2)=>transpose to (b,n_token,n_head,head_dim=2)
        context_vec = context_vec.contiguous().view(b,num_tokens,self.d_out) # (b,n_token,n_head*head_dim)=(b,n_token,4)
        context_vec = self.out_proj(context_vec) # (b,n_token,4)

        return context_vec # (b,n_token,4)

class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"],
            dropout=cfg["drop_rate"],
            qkv_bias=cfg["qkv_bias"]
        )
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)
        x = self.drop_shortcut(x)
        x = x + shortcut 

        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_shortcut(x)
        x = x + shortcut 
        return x 

class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])
        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)

    def forward(self, in_idx):
        # in_idx: (batch,seq_len), each element is a token-index (integer shows location)
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_emb(in_idx) # (batch,seq_len)=>(batch,seq_len,emb_dim)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) # (1,seq_len)=>(1,seq_len,emb_dim)
        x = tok_embeds + pos_embeds # (batch,seq_len,emb_dim)
        x = self.drop_emb(x) # (batch,seq_len,emb_dim)
        x = self.trf_blocks(x) # (batch,seq_len,emb_dim)
        x = self.final_norm(x) # (batch,seq_len,emb_dim)
        logits = self.out_head(x) # (batch,seq_len,vocab_len)
        return logits

In [10]:
torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)

NameError: name 'MultiHeadAttention' is not defined