This notebook is written using Llama self attention module as example.

### Original self-attention

Let's denote

- $B$ is batch size
- $S$ is a sequence length
- $D$ is a hidden size

- and we have $X \in \mathbb{R}^{B, S, D}$ input.

Let's denote parameters:
- $W_Q \in \mathbb{R}^{D x (H d_h)}$ - query projection
- $W_K \in \mathbb{R}^{D x (H d_h)}$ - key projection
- $W_V \in \mathbb{R}^{D x (H d_h)}$ - value projection
- $W_O \in \mathbb{R}^{D x (H d_h)}$ - output projection
- where $H$ is head count and $d_h$ is individual head dimension
- also, $H_k$ is a number of key-value heads

Now let's write down query / key / value projections:

- $Q = X W_Q, Q \in \mathbb{R}^{B x S x (H d_h)}$
- $K = X W_K, K \in \mathbb{R}^{B x S x (H d_h)}$
- $V = X W_V, V \in \mathbb{R}^{B x S x (H d_h)}$

Now let's transpose them for multihead attention:
- $Q_{transposed} = Q.view(B, S, H, d_h).transpose(1, 2), Q_{transposed} \in \mathbb{R}^{B x H x S x d_h}$
- $K_{transposed} = K.view(B, S, H, d_h).transpose(1, 2), K_{transposed} \in \mathbb{R}^{B x H x S x d_h}$
- $V_{transposed} = V.view(B, S, H, d_h).transpose(1, 2), V_{transposed} \in \mathbb{R}^{B x H x S x d_h}$

Now let's write down RoPE (rotary position embeddings)

- $rotateHalf(X) = (-X_{:, :, :, d_h/2:}, X_{:, :, :, :d_h/2})$
- $Q_{rotated} = Q_{transposed} cos(\theta) + rotateHalf(Q_{transposed}) sin(\theta)$
- $K_{rotated} = K_{transposed} cos(\theta) + rotateHalf(K_{transposed}) sin(\theta)$

```python
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
```


Assuming $H=H_k$ no repeat occurs, so attention directly use $Q_{rotated}$ and $K_{rotated}$

Now to attention scores:

- $A = { {Q_{rotated} K_{rotated}^T} \over {\sqrt{d_h}} } + M$, where $M$ is a causal mask with large negative values for masked positions
- Now applying softmax and dropout:

  $ A_{postprocessed} = dropout(softmax(A, dim=-1), p)$ where we can do dropout through elemenwise multiplication to random matrix.

  So $ A_{postprocessed} = softmax(A, dim=-1) * (random(S, S) < p)$

Weighted sum:

- $O = A_{postprocessed} V = A_{postprocessed} (X W_V)$
- $O_{reshaped} = O.view((B, S, D))$

Than final output is 

$Y = O_{reshaped} W_O$

### Attention simpliciation idea

But we can (formally incorrectly, I am just checking if the idea will be a good approximation) to think about $A$ as a linear attention, this way replacing:

- $A = { {Q_{rotated} K_{rotated}^T} \over {\sqrt{d_h}}}$
- $A_{postprocessed} = A * (random(S, S) < p)$

Now, assuming that all we know after attention computation is two matrices (but we know query / key / value projections other internal stuff, it is just attention mechanism what we don't want to recompute):
- $A_{postprocessed}$ 
- $O$

Now, assuming we used these replaced attention mechanism during forward pass.

Assuming we have this loss function: $Loss(Y, X) = |Y_{:, :-1, :} - X_{:, 1:, :}|$ (expecting $Y$ and $X$ to be $(B, S, D)$ shape matrixes).

We need weight ($W_K$ , $W_Q$ , $W_V$ , $W_O$) gradients to be computed reusing inputs / intermediates / $A_{postprocessed} / $O$ / outputs.

In [1]:
import torch
import dataclasses


def rotate_half(x):
    d = x.shape[-1]
    return torch.cat((-x[..., d//2:], x[..., :d//2]), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    q_rotated = q * cos + rotate_half(q) * sin
    k_rotated = k * cos + rotate_half(k) * sin
    return q_rotated, k_rotated

def get_rotary_embeddings(seq_len: int, dim: int, theta: float = 10000.0):
    position = torch.arange(seq_len, dtype=torch.float32)
    freqs = theta ** (-2 * torch.arange(0, dim, 2).float() / dim)
    angles = position.unsqueeze(1) * freqs.unsqueeze(0)  # (S, dim/2)
    
    cos = torch.cos(angles)  # (S, dim/2)
    sin = torch.sin(angles)  # (S, dim/2)
    
    # Expand to match the shape of Q/K (B, H, S, d_h)
    cos = cos.view(1, 1, seq_len, -1)  # (1, 1, S, d_h/2)
    sin = sin.view(1, 1, seq_len, -1)
    
    # Concatenate to handle even/odd dimensions properly
    cos = torch.cat([cos, cos], dim=-1)
    sin = torch.cat([sin, sin], dim=-1)
    return cos, sin


def apply_rotary_pos_emb_backward(grad_rotated, cos, sin):
    # Gradient through RoPE for K
    grad_transposed = grad_rotated * cos + rotate_half(grad_rotated) * sin
    return grad_transposed

@dataclasses.dataclass
class LlamaLikeLinearAttentionState:
    Y: torch.Tensor
    O_reshaped: torch.Tensor
    A_postprocessed: torch.Tensor
    V_transposed: torch.Tensor
    K_rotated: torch.Tensor
    Q_rotated: torch.Tensor
    cos: torch.Tensor
    sin: torch.Tensor
    W_O: torch.Tensor
    W_Q: torch.Tensor
    W_K: torch.Tensor
    W_V: torch.Tensor
    dropout_mask: torch.Tensor


def llamalike_linear_attention_forward(X, W_Q, W_K, W_V, W_O, H, cos, sin, dropout_mask):
    B, S, D = X.shape
    d_h = D // H
    assert D == H * d_h, f"Hidden size {D} must be divisible by {H} heads"

    # Projections
    Q = X @ W_Q  # (B, S, D)
    K = X @ W_K
    V = X @ W_V

    # Reshape for multi-head attention
    Q_transposed = Q.view(B, S, H, d_h).transpose(1, 2)  # (B, H, S, d_h)
    K_transposed = K.view(B, S, H, d_h).transpose(1, 2)
    V_transposed = V.view(B, S, H, d_h).transpose(1, 2)

    # Apply RoPE
    Q_rotated, K_rotated = apply_rotary_pos_emb(Q_transposed, K_transposed, cos, sin)

    # Attention scores
    A = (Q_rotated @ K_rotated.transpose(-1, -2)) / (d_h**0.5)
    A_postprocessed = A * dropout_mask  # Apply dropout

    # Output computation
    O = A_postprocessed @ V_transposed  # (B, H, S, d_h)
    O_reshaped = O.transpose(1, 2).reshape(B, S, D)  # Dynamic reshape
    Y = O_reshaped @ W_O

    return LlamaLikeLinearAttentionState(
        Y=Y,
        O_reshaped=O_reshaped,
        A_postprocessed=A_postprocessed,
        V_transposed=V_transposed,
        K_rotated=K_rotated,
        Q_rotated=Q_rotated,
        cos=cos,
        sin=sin,
        W_O=W_O,
        W_Q=W_Q,
        W_K=W_K,
        W_V=W_V,
        dropout_mask=dropout_mask
    )


def llamalike_linear_attention_gradients(state, X, d_h):
    B, S, D = X.shape
    H = D // d_h
    
    # Gradient through output projection
    dLoss_dY = torch.sign(state.Y[:, :-1, :] - X[:, 1:, :])
    dLoss_dO_reshaped = torch.zeros_like(state.Y)
    dLoss_dO_reshaped[:, :-1, :] = dLoss_dY @ state.W_O.T
    
    # Gradient through attention output
    dLoss_dO = dLoss_dO_reshaped.view(B, S, H, d_h).transpose(1, 2)
    dLoss_dV_transposed = state.A_postprocessed.transpose(-1, -2) @ dLoss_dO
    dLoss_dV = dLoss_dV_transposed.transpose(1, 2).reshape(B, S, D)
    
    # Gradients for value projection (sum over batch)
    manual_grad_W_V = (X.transpose(1, 2) @ dLoss_dV).sum(dim=0)  # Sum batch

    # Backprop through attention scores
    dLoss_dA = (dLoss_dO @ state.V_transposed.transpose(-1, -2)) * state.dropout_mask
    
    # Gradients for query and key projections
    dLoss_dQ_rotated = (dLoss_dA @ state.K_rotated) / (d_h**0.5)
    # Corrected line: transpose the last two dimensions of dLoss_dA
    dLoss_dK_rotated = (dLoss_dA.transpose(-1, -2) @ state.Q_rotated) / (d_h**0.5)
    
    # Backprop through RoPE
    dLoss_dQ_transposed = dLoss_dQ_rotated * state.cos - rotate_half(dLoss_dQ_rotated) * state.sin
    dLoss_dK_transposed = dLoss_dK_rotated * state.cos - rotate_half(dLoss_dK_rotated) * state.sin
    
    # Reshape gradients
    dLoss_dQ = dLoss_dQ_transposed.transpose(1, 2).reshape(B, S, D)
    dLoss_dK = dLoss_dK_transposed.transpose(1, 2).reshape(B, S, D)
    
    # Sum gradients over batch using einsum for clarity
    manual_grad_W_Q = torch.einsum('bsd,bse->de', X, dLoss_dQ)
    manual_grad_W_K = torch.einsum('bsd,bse->de', X, dLoss_dK)
    manual_grad_W_O = (state.O_reshaped[:, :-1].transpose(1, 2) @ dLoss_dY).sum(dim=0)
    
    return manual_grad_W_Q, manual_grad_W_K, manual_grad_W_V, manual_grad_W_O

In [2]:
def compare_normed_grads(manual_grad, torch_grad, max_norm=1.0, atol=1e-3):
    # Normalize the manual gradient if it's not zero
    manual_grad_norm = torch.norm(manual_grad)
    if manual_grad_norm > 0:
        manual_grad = manual_grad / manual_grad_norm
    
    # Normalize the torch gradient if it's not zero
    torch_grad_norm = torch.norm(torch_grad) 
    if torch_grad_norm > 0:
        torch_grad = torch_grad / torch_grad_norm
    # Compare normalized gradients
    return torch.allclose(manual_grad, torch_grad, atol=atol)


In [3]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

# Inputs
for S in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
    for B in [1, 2, 4]:
        for D, d_h in [(8, 4), (16, 4), (32, 8), (64, 8), (128, 8), (256, 16)]:
            for dropout_p in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]:
                H = D // d_h
                X = torch.randn(B, S, D)
                Y_target = X[:, 1:, :]

                # Parameters (D=8)
                W_Q = torch.randn(D, D, requires_grad=True)
                W_K = torch.randn(D, D, requires_grad=True)
                W_V = torch.randn(D, D, requires_grad=True)
                W_O = torch.randn(D, D, requires_grad=True)

                # Rotary embeddings for per-head dimension
                cos, sin = get_rotary_embeddings(seq_len=S, dim=d_h, theta=0.1)

                # Correct dropout mask shape: (B, H, S, S)
                dropout_mask = (torch.rand(B, H, S, S) > dropout_p).float()

                state = llamalike_linear_attention_forward(X, W_Q, W_K, W_V, W_O, H, cos, sin, dropout_mask)
                grads = llamalike_linear_attention_gradients(state, X, d_h)
                manual_grad_W_Q, manual_grad_W_K, manual_grad_W_V, manual_grad_W_O = grads

                loss = torch.abs(state.Y[:, :-1, :] - X[:, 1:, :]).sum()
                loss.backward()

                print("B, S, D", B, S, D)
                print("d_h", d_h)
                print("dropout_p", dropout_p)
                print("W_O", compare_normed_grads(manual_grad_W_O, W_O.grad, max_norm=1.0, atol=1e-3), (manual_grad_W_O - W_O.grad).abs().mean())
                print("W_Q", compare_normed_grads(manual_grad_W_Q, W_Q.grad, max_norm=1.0, atol=1e-3), (manual_grad_W_Q - W_Q.grad).abs().mean())
                print("W_V", compare_normed_grads(manual_grad_W_V, W_V.grad, max_norm=1.0, atol=1e-3), (manual_grad_W_V - W_V.grad).abs().mean())
                print("W_K", compare_normed_grads(manual_grad_W_K, W_K.grad, max_norm=1.0, atol=1e-3), (manual_grad_W_K - W_K.grad).abs().mean())

B, S, D 1 1 8
d_h 4
dropout_p 0.0
W_O True tensor(0., grad_fn=<MeanBackward0>)
W_Q True tensor(0., grad_fn=<MeanBackward0>)
W_V True tensor(0., grad_fn=<MeanBackward0>)
W_K True tensor(0., grad_fn=<MeanBackward0>)
B, S, D 1 1 8
d_h 4
dropout_p 0.1
W_O True tensor(0., grad_fn=<MeanBackward0>)
W_Q True tensor(0., grad_fn=<MeanBackward0>)
W_V True tensor(0., grad_fn=<MeanBackward0>)
W_K True tensor(0., grad_fn=<MeanBackward0>)
B, S, D 1 1 8
d_h 4
dropout_p 0.2
W_O True tensor(0., grad_fn=<MeanBackward0>)
W_Q True tensor(0., grad_fn=<MeanBackward0>)
W_V True tensor(0., grad_fn=<MeanBackward0>)
W_K True tensor(0., grad_fn=<MeanBackward0>)
B, S, D 1 1 8
d_h 4
dropout_p 0.3
W_O True tensor(0., grad_fn=<MeanBackward0>)
W_Q True tensor(0., grad_fn=<MeanBackward0>)
W_V True tensor(0., grad_fn=<MeanBackward0>)
W_K True tensor(0., grad_fn=<MeanBackward0>)
B, S, D 1 1 8
d_h 4
dropout_p 0.4
W_O True tensor(0., grad_fn=<MeanBackward0>)
W_Q True tensor(0., grad_fn=<MeanBackward0>)
W_V True tensor(0., 