# Self Attention Layer
Forward + backward pass as basics

In [None]:
import torch
import torch.nn as nn
import math
torch.manual_seed(0)

class SelfAttention(nn.Module):
    """
    Attention = Softmax(Q@K^T/\sqrt{d_k})@V
    """

    def __init__(self, D):
        super().__init__()

        self.Wq_DD = torch.randn(D, D)
        self.Wk_DD = torch.randn(D, D)
        self.Wv_DD = torch.randn(D, D)

        self.cache = None
        self.grads = {}


    def forward(self, X_BND):
        """
        Minimal self-attention implementation to focus on gradient computation.

        Excludes support for a final projection, residual, layernorm, masking, padding, KV cache.
        """

        B, N, D = X_BND.shape
        
        Q_BND = torch.einsum("bnd,de->bne", X_BND, self.Wq_DD)
        K_BND = torch.einsum("bnd,de->bne", X_BND, self.Wk_DD)
        V_BND = torch.einsum("bnd,de->bne", X_BND, self.Wv_DD)

        logits_BNM = torch.einsum("bnd,bmd->bnm", Q_BND, K_BND)
        logits_BNM = logits_BNM / math.sqrt(D)
        weights_BNM = torch.softmax(logits_BNM, dim=-1)

        Y_BND = torch.einsum("bnm,bmd->bnd", weights_BNM, V_BND)

        self.cache = (weights_BNM, V_BND)

        return Y_BND

    def backward(self, dY_BND):
        """
        Softmax Derivative Reference: https://medium.com/data-science/derivative-of-the-softmax-function-and-the-categorical-cross-entropy-loss-ffceefc081d1
        softmax: R^N -> R^N = vector function
        Jacobian matrix = matrix of all first-order partial derivatives
        trick: we take the derivative of the log of the output s.t. we can avoid using the quotient rule for derivatives for softmax which is a bit more complex
        Assuming s = softmax(z)
        d/dz_j log(s_i) = 1/s_i * ds_i/dz_j
        ds_i/dz_j = d/dz_j log(s_i) * s_i
        log(s_i) = log(e^z_i / \sum_j e^z_j) = z_i - log(\sum_j e^z_j)
        d/dz_j log(s_i) = dz_i / dz_j - d/dz_j log(\sum_j e^z_j)
                        = I(i == j)   - 1/\sum_j e^z_j * d/dz_j \sum_j e^z_j
                        = I(i == j)   - 1/\sum_l e^z_l * e^z_j
        d/dz_j log(s_i) = I(i == j)   - s_j
        ds_i/dz_j = (I(i == j) - s_j) * s_i = Jacobian matrix of the softmax

        Note: for cross-entropy loss where L(y_true, y_pred) = - \sum_{i=1}^C y_true_i * log(y_pred_i), the math of dL/dz_j works out s.t. dL/dz = s - y (so a clean derivative where other dependencies cancel out) which makes intuitive sense because the loss is only dependent on the y_true class=1 probability going up, so s - y directly affects that probability, so the gradient computation is quite simple. It's also just clean from the choice of log(s_i) in the loss function where the derivative lets s_i cancel out in the derivation.
        """
        weights_BNM, V_BND = self.cache

        dV_BND = torch.einsum("bnd,bnm->bmd", dY_BND, weights_BNM)
        dweights_BNM = torch.einsum("bnd,bmd->bnm", dY_BND, V_BND)

        dlogits_BNM = # derivative of a softmax is...?

    def _test_forward(self):
        B, N, D = 2, 3, 4

        X_BND = torch.zeros(B, N, D)
        Y_BND = self.forward(X_BND)

        assert Y_BND.shape == (B, N, D)
        assert torch.allclose(X_BND, torch.zeros(B, N, D))

    def _test_backward(self):
        pass

B, N, D = 2, 3, 4
sa = SelfAttention(D)
sa._test_forward()