In [None]:
import numpy as np

def softmax_stable(x, axis=-1):
    x = x - np.max(x, axis=axis, keepdims=True)
    expx = np.exp(x)
    return expx / np.sum(expx, axis=axis, keepdims=True)

class ScaledDotProductAttention:
    """
    Core attention: softmax( Q K^T / sqrt(d_k) ) V
    Supports optional mask (e.g., causal or padding).
    """
    def __init__(self, dropout=0.0, seed=42):
        self.dropout = dropout
        self.rng = np.random.default_rng(seed)

    def __call__(self, Q, K, V, mask=None):
        # Q,K,V: (B, H, T, d_k)
        d_k = Q.shape[-1]
        scores = (Q @ np.swapaxes(K, -1, -2)) / np.sqrt(d_k)   # (B,H,T,T)

        if mask is not None:
            # mask: broadcastable to (B,H,T,T); masked positions -> -inf
            scores = np.where(mask, scores, -1e9)

        attn = softmax_stable(scores, axis=-1)                 # (B,H,T,T)

        if self.dropout > 0.0:
            drop = self.rng.random(attn.shape) >= self.dropout
            attn = attn * drop / (1.0 - self.dropout)

        out = attn @ V                                         # (B,H,T,d_k)
        return out, attn


class MultiHeadSelfAttention:
    """
    Multi-Head Self-Attention (from scratch, NumPy).
    Shapes:
      X: (B, T, d_model)
      Output: (B, T, d_model)
    """
    def __init__(self, d_model, num_heads, dropout=0.0, seed=42):
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads."
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.attn = ScaledDotProductAttention(dropout=dropout, seed=seed)
        self.rng = np.random.default_rng(seed)

        # Parameter initialization (Xavier-ish)
        k = np.sqrt(1.0 / d_model)
        self.W_q = self.rng.uniform(-k, k, size=(d_model, d_model))
        self.W_k = self.rng.uniform(-k, k, size=(d_model, d_model))
        self.W_v = self.rng.uniform(-k, k, size=(d_model, d_model))
        self.W_o = self.rng.uniform(-k, k, size=(d_model, d_model))

        # Optional biases (often omitted in attention)
        self.b_q = np.zeros((d_model,))
        self.b_k = np.zeros((d_model,))
        self.b_v = np.zeros((d_model,))
        self.b_o = np.zeros((d_model,))

    # ----- helpers -----
    def _project(self, X, W, b):
        # X: (B,T,d_model) -> (B,T,d_model)
        return X @ W + b

    def _split_heads(self, X):
        # (B,T,d_model) -> (B,H,T,d_k)
        B, T, _ = X.shape
        X = X.reshape(B, T, self.num_heads, self.d_k)
        X = np.transpose(X, (0, 2, 1, 3))
        return X

    def _merge_heads(self, X):
        # (B,H,T,d_k) -> (B,T,d_model)
        B, H, T, d_k = X.shape
        X = np.transpose(X, (0, 2, 1, 3)).reshape(B, T, H * d_k)
        return X

    # ----- forward -----
    def __call__(self, X, mask=None):
        """
        X    : (B, T, d_model)
        mask : None or boolean array broadcastable to (B, H, T, T)
               True = keep, False = mask-out (set to -inf before softmax)
               For causal mask, mask[i,j]=False if j>i.
        """
        B, T, D = X.shape
        assert D == self.d_model, "Bad d_model."

        # 1) Linear projections
        Q = self._project(X, self.W_q, self.b_q)
        K = self._project(X, self.W_k, self.b_k)
        V = self._project(X, self.W_v, self.b_v)

        # 2) Split into heads
        Qh, Kh, Vh = self._split_heads(Q), self._split_heads(K), self._split_heads(V)  # (B,H,T,d_k)

        # 3) Scaled dot-product attention
        out_heads, attn_weights = self.attn(Qh, Kh, Vh, mask=mask)  # (B,H,T,d_k), (B,H,T,T)

        # 4) Merge heads and final projection
        out = self._merge_heads(out_heads)                          # (B,T,d_model)
        out = out @ self.W_o + self.b_o                             # (B,T,d_model)
        return out, attn_weights

    # ----- masks convenience -----
    @staticmethod
    def causal_mask(B, T, H):
        # shape (B,H,T,T): True for allowed positions i>=j
        base = np.tril(np.ones((T, T), dtype=bool))
        return np.broadcast_to(base, (B, H, T, T))

    @staticmethod
    def padding_mask(pad_mask, H):
        """
        pad_mask: (B,T) boolean, True for real tokens, False for pads.
        Returns broadcastable mask (B,H,T,T) that blocks attending to pads.
        """
        B, T = pad_mask.shape
        # allow attending only to valid keys (dimension K=T)
        key_keep = pad_mask[:, None, None, :]        # (B,1,1,T)
        # queries always allowed; combine with key_keep
        return np.broadcast_to(key_keep, (B, H, T, T))

In [None]:
# B=2 batches, T=4 tokens, d_model=8, H=2 heads
np.random.seed(0)
X = np.random.randn(2, 4, 8)

mha = MultiHeadSelfAttention(d_model=8, num_heads=2, dropout=0.0)
# causal example
mask = MultiHeadSelfAttention.causal_mask(B=2, T=4, H=2)
Y, A = mha(X, mask=mask)
print("Output shape:", Y.shape)        # (2,4,8)
print("Attention shape:", A.shape)     # (2,2,4,4)