In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass
from math import sqrt


@dataclass
class ModelConfig:
    d_model: int = 1024
    n_heads: int = 16
    d_ff: int = 2816
    vocab_size: int = 32000
    num_encoder_layers: int = 6
    num_decoder_layers: int = 3
    rope_theta: float = 10000.0
    dropout: float = 0.0
    kv_heads: int = 4        



In [4]:
class QKVProjection(nn.Module):
    """
    Projects input x into Q, K, and V for Multi-Head Attention.
    Supports GQA (Grouped Query Attention).
    """

    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.cfg = cfg
        d_model = cfg.d_model

        # head dimension based on full heads
        self.head_dim = d_model // cfg.n_heads

        # Query always has full heads
        self.W_q = nn.Linear(d_model, cfg.n_heads * self.head_dim, bias=False)

        # K/V have fewer heads but same head_dim
        kv_dim = cfg.kv_heads * self.head_dim
        self.W_k = nn.Linear(d_model, kv_dim, bias=False)
        self.W_v = nn.Linear(d_model, kv_dim, bias=False)

    def forward(self, x):
        batch, seq, _ = x.shape
        cfg = self.cfg
        hdim = self.head_dim

        # projections
        Q = self.W_q(x)                                # (B, S, n_heads * hdim)
        K = self.W_k(x)                                # (B, S, kv_heads * hdim)
        V = self.W_v(x)

        # reshape into heads
        Q = Q.view(batch, seq, cfg.n_heads, hdim)      # (B, S, n_heads, hdim)
        K = K.view(batch, seq, cfg.kv_heads, hdim)     # (B, S, kv_heads, hdim)
        V = V.view(batch, seq, cfg.kv_heads, hdim)

        return Q, K, V


In [5]:
cfg = ModelConfig()
proj = QKVProjection(cfg)

x = torch.randn(2, 5, cfg.d_model)

Q, K, V = proj(x)

print("Q shape:", Q.shape)
print("K shape:", K.shape)
print("V shape:", V.shape)


Q shape: torch.Size([2, 5, 16, 64])
K shape: torch.Size([2, 5, 4, 64])
V shape: torch.Size([2, 5, 4, 64])
