In [1]:
# B T C = Batch Size, Sequence Length (in tokens), Embedding Dimension
# End in the same projection as before - 4, 16, 8
# Q K V = (8, 16) (16, 8) (8, 16)

In [2]:
import torch

q = torch.randn(4, 8, 16)
k = torch.randn(4, 16, 8)
v = torch.randn(4, 8, 16)

attention_scores = q @ k
output = attention_scores @ v
attention_scores.shape
output.shape

torch.Size([4, 8, 16])

In [3]:
# from dataclasses import dataclass

# @dataclass
class GPTConfig:
    block_size: int = 256
    vocab_size: int = 65
    n_layer: int = 6
    n_head: int = 4
    n_embd: int = 8

import torch.nn as nn 
from torch.nn import functional as F
import math

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.n_head = config.n_head
        self.n_embd = config.n_embd

        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                             .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        qvk = self.c_attn(x)
        q, k, v = qvk.split(self.n_embd, dim=2)

        """
        Number of tokens stays the same, only embeddings change by the size of number of heads.
        Taking only specific number of embeddings are like applying filters to certain embedding
        dimensions, so we can get various many filters (like kernels in CNN)
        """

        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, T, number of heads, embedding for each head) -> (B, number of heads, T, embedding for each head)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2).transpose(-2, -1)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        attention_scores = (q @ k) * (1.0 / math.sqrt(k.size(-1)))
        attention_scores = attention_scores.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
        attention_scores = F.softmax(attention_scores, dim=-1)

        y = attention_scores @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        
        return y

In [4]:
attention = MultiHeadSelfAttention(GPTConfig)
y = attention(torch.randn(4, 16, 8))
# print(q.shape, k.shape, v.shape)
# print(k.size(-1))
y.shape

torch.Size([4, 16, 8])

In [2]:
from transformers import GPT2LMHeadModel
model_hf = GPT2LMHeadModel.from_pretrained("gpt2")
sd_hf = model_hf.state_dict()

for k, v in sd_hf.items():
    print(k, v.shape)

transformer.wte.weight torch.Size([50257, 768])
transformer.wpe.weight torch.Size([1024, 768])
transformer.h.0.ln_1.weight torch.Size([768])
transformer.h.0.ln_1.bias torch.Size([768])
transformer.h.0.attn.c_attn.weight torch.Size([768, 2304])
transformer.h.0.attn.c_attn.bias torch.Size([2304])
transformer.h.0.attn.c_proj.weight torch.Size([768, 768])
transformer.h.0.attn.c_proj.bias torch.Size([768])
transformer.h.0.ln_2.weight torch.Size([768])
transformer.h.0.ln_2.bias torch.Size([768])
transformer.h.0.mlp.c_fc.weight torch.Size([768, 3072])
transformer.h.0.mlp.c_fc.bias torch.Size([3072])
transformer.h.0.mlp.c_proj.weight torch.Size([3072, 768])
transformer.h.0.mlp.c_proj.bias torch.Size([768])
transformer.h.1.ln_1.weight torch.Size([768])
transformer.h.1.ln_1.bias torch.Size([768])
transformer.h.1.attn.c_attn.weight torch.Size([768, 2304])
transformer.h.1.attn.c_attn.bias torch.Size([2304])
transformer.h.1.attn.c_proj.weight torch.Size([768, 768])
transformer.h.1.attn.c_proj.bias 