In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [7]:
class FeedForward(nn.Module):
  def __init__(self, d_model):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(d_model, 4*d_model),
        nn.GELU(),
        nn.Linear(4*d_model, d_model)
    )
  def forward(self, x):
    return self.net(x)


In [8]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x):


        seq_len, d_model = x.shape

        qkv = self.qkv(x)
        Q, K, V = qkv.chunk(3, dim=-1)

        # split heads
        Q = Q.view(seq_len, self.num_heads, self.d_head).transpose(0, 1)
        K = K.view(seq_len, self.num_heads, self.d_head).transpose(0, 1)
        V = V.view(seq_len, self.num_heads, self.d_head).transpose(0, 1)

        # attention
        scores = (Q @ K.transpose(-2, -1)) / (self.d_head ** 0.5)

        # causal mask
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
        scores = scores.masked_fill(mask == 1, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        out = attn @ V
        out = out.transpose(0, 1).contiguous()
        out = out.view(seq_len, d_model)

        return self.out_proj(out)


In [9]:
class GPTDecoderBlock(nn.Module):
  def __init__(self, d_model, num_heads):
    super().__init__()
    self.ln1 = nn.LayerNorm(d_model)
    self.attn = MultiHeadSelfAttention(d_model, num_heads)
    self.ln2 = nn.LayerNorm(d_model)
    self.mlp = FeedForward(d_model)

  def forward(self, x):
    x = x + self.attn(self.ln1(x))

    x = x + self.mlp(self.ln2(x))

    return x

In [10]:
X = torch.randn(6, 32)

block = GPTDecoderBlock(d_model=32, num_heads=4)

out = block(X)
print(out.shape)

torch.Size([6, 32])
