# Self Attention Layer
Forward + backward pass as basics

In [3]:
import torch
import torch.nn as nn
import math
torch.manual_seed(0)

class SelfAttention(nn.Module):
    """
    Attention = Softmax(Q@K^T/\sqrt{d_k})@V
    """

    def __init__(self, D):
        super().__init__()

        self.Wq_DD = torch.randn(D, D)
        self.Wk_DD = torch.randn(D, D)
        self.Wv_DD = torch.randn(D, D)

        self.cache = None
        self.grads = {}


    def forward(self, X_BND):
        """
        Minimal self-attention implementation to focus on gradient computation.

        Excludes support for a final projection, residual, layernorm, masking, padding, KV cache.
        """

        B, N, D = X_BND.shape
        
        Q_BND = torch.einsum("bnd,de->bne", X_BND, self.Wq_DD)
        K_BND = torch.einsum("bnd,de->bne", X_BND, self.Wk_DD)
        V_BND = torch.einsum("bnd,de->bne", X_BND, self.Wv_DD)

        logits_BNM = torch.einsum("bnd,bmd->bnm", Q_BND, K_BND)
        logits_BNM = logits_BNM / math.sqrt(D)
        weights_BNM = torch.softmax(logits_BNM, dim=-1)

        Y_BND = torch.einsum("bnm,bmd->bnd", weights_BNM, V_BND)

        self.cache = (weights_BNM, V_BND, K_BND, Q_BND, X_BND)

        return Y_BND

    def backward(self, dY_BND):
        """
        Softmax Derivative Reference: https://medium.com/data-science/derivative-of-the-softmax-function-and-the-categorical-cross-entropy-loss-ffceefc081d1
        softmax: R^N -> R^N = vector function
        Jacobian matrix = matrix of all first-order partial derivatives
        trick: we take the derivative of the log of the output s.t. we can avoid using the quotient rule for derivatives for softmax which is a bit more complex
        Assuming s = softmax(z)
        d/dz_j log(s_i) = 1/s_i * ds_i/dz_j
        ds_i/dz_j = d/dz_j log(s_i) * s_i
        log(s_i) = log(e^z_i / \sum_j e^z_j) = z_i - log(\sum_j e^z_j)
        d/dz_j log(s_i) = dz_i / dz_j - d/dz_j log(\sum_j e^z_j)
                        = I(i == j)   - 1/\sum_j e^z_j * d/dz_j \sum_j e^z_j
                        = I(i == j)   - 1/\sum_l e^z_l * e^z_j
        d/dz_j log(s_i) = I(i == j)   - s_j
        ds_i/dz_j = (I(i == j) - s_j) * s_i = Jacobian matrix of the softmax

        Therefore,
        dL/dz_j = dL/ds * ds/dz_j
                = \sum_i (dL/ds_i * ds_i/dz_j)
                = \sum_i (dL/ds_i * (I(i == j) - s_j) * s_i)
                ... <applying some casework when i = j and when i != j to the sum ...
                = dL/ds_j * s_j (1 - s_j) + \sum_{i != j} - dL/ds_i * s_i * s_j
                = s_j (dL/ds_j (1 - s_j) - \sum_{i != j} dL/ds_i * s_i)
                = s_j (dL/ds_j - \sum_i dL/ds_i * s_i)
        dL/dz_j = s_j (dL/ds_j - dL/ds @ s)



        Note: for cross-entropy loss where L(y_true, y_pred) = - \sum_{i=1}^C y_true_i * log(y_pred_i), the math of dL/dz_j works out s.t. dL/dz = s - y (so a clean derivative where other dependencies cancel out) which makes intuitive sense because the loss is only dependent on the y_true class=1 probability going up, so s - y directly affects that probability, so the gradient computation is quite simple. It's also just clean from the choice of log(s_i) in the loss function where the derivative lets s_i cancel out in the derivation.
        """
        weights_BNM, V_BND, K_BND, Q_BND, X_BND = self.cache
        B, N, D = dY_BND.shape

        dV_BND = torch.einsum("bnd,bnm->bmd", dY_BND, weights_BNM)
        dweights_BNM = torch.einsum("bnd,bmd->bnm", dY_BND, V_BND)

        # derivative of a softmax is dL/dz_j = s_j (dL/ds_j - dL/ds @ s) where s = softmax(z). z = logits_BNM, s = weights_BNM in this case
        dlogits_BNM = weights_BNM * (dweights_BNM - torch.einsum("bnm,bnm->bn", dweights_BNM, weights_BNM).unsqueeze(-1)) # dL/dz = s * (dL/ds - dL/ds @ s) # Jacobian-vector product (JVP) for the softmax function. # short hand: dz = s * (g - dot) where dot = g @ s, g = dL/ds
        dlogits_BNM = dlogits_BNM * (1.0 / math.sqrt(D))
        dQ_BND = torch.einsum("bnm,bmd->bnd", dlogits_BNM, K_BND)
        dK_BND = torch.einsum("bnm,bnd->bmd", dlogits_BNM, Q_BND)

        # compute gradients for projections
        dWv_DD = torch.einsum("bnd,bne->de", X_BND, dV_BND)
        dX_BND_V = torch.einsum("de,bne->bnd", self.Wv_DD, dV_BND)

        dWk_DD = torch.einsum("bnd,bne->de", X_BND, dK_BND)
        dX_BND_K = torch.einsum("de,bne->bnd", self.Wk_DD, dK_BND)

        dWq_DD = torch.einsum("bnd,bne->de", X_BND, dQ_BND)
        dX_BND_Q = torch.einsum("de,bne->bnd", self.Wq_DD, dQ_BND)

        dX_BND = dX_BND_Q + dX_BND_K + dX_BND_V

        self.grads = {
            "dWq_DD": dWq_DD,
            "dWk_DD": dWk_DD,
            "dWv_DD": dWv_DD,
        }

        return dX_BND

    def _test_forward(self):
        B, N, D = 2, 3, 4

        X_BND = torch.zeros(B, N, D)
        Y_BND = self.forward(X_BND)

        assert Y_BND.shape == (B, N, D)
        assert torch.allclose(X_BND, torch.zeros(B, N, D))

    def _test_backward(self):
        B, N, D = 2, 3, 4

        X_BND = torch.zeros(B, N, D)
        Y_BND = self.forward(X_BND)

        assert Y_BND.shape == (B, N, D)
        assert torch.allclose(X_BND, torch.zeros(B, N, D))

        # backward
        dY_BND = torch.zeros(B, N, D)
        dX_BND = self.backward(dY_BND)

        assert dX_BND.shape == (B, N, D)
        assert torch.allclose(dX_BND, torch.zeros(B, N, D))


B, N, D = 2, 3, 4
sa = SelfAttention(D)
sa._test_forward()
sa._test_backward()