This notebook is written using Llama self attention module as example.

### Original self-attention

Let's denote

- $B$ is batch size
- $S$ is a sequence length
- $D$ is a hidden size

- and we have $X \in \mathbb{R}^{B, S, D}$ input.

Let's denote parameters:
- $W_Q \in \mathbb{R}^{D x (H d_h)}$ - query projection
- $W_K \in \mathbb{R}^{D x (H d_h)}$ - key projection
- $W_V \in \mathbb{R}^{D x (H d_h)}$ - value projection
- $W_O \in \mathbb{R}^{D x (H d_h)}$ - output projection
- where $H$ is head count and $d_h$ is individual head dimension
- also, $H_k$ is a number of key-value heads

Now let's write down query / key / value projections:

- $Q = X W_Q, Q \in \mathbb{R}^{B x S x (H d_h)}$
- $K = X W_K, K \in \mathbb{R}^{B x S x (H d_h)}$
- $V = X W_V, V \in \mathbb{R}^{B x S x (H d_h)}$

Now let's transpose them for multihead attention:
- $Q_{transposed} = Q.view(B, S, H, d_h).transpose(1, 2), Q_{transposed} \in \mathbb{R}^{B x H x S x d_h}$
- $K_{transposed} = K.view(B, S, H, d_h).transpose(1, 2), K_{transposed} \in \mathbb{R}^{B x H x S x d_h}$
- $V_{transposed} = V.view(B, S, H, d_h).transpose(1, 2), V_{transposed} \in \mathbb{R}^{B x H x S x d_h}$

Now let's write down RoPE (rotary position embeddings)

- $rotateHalf(X) = (-X_{:, :, :, d_h/2:}, X_{:, :, :, :d_h/2})$
- $Q_{rotated} = Q_{transposed} cos(\theta) + rotateHalf(Q_{transposed}) sin(\theta)$
- $K_{rotated} = K_{transposed} cos(\theta) + rotateHalf(K_{transposed}) sin(\theta)$

```python
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
```


Assuming $H=H_k$ no repeat occurs, so attention directly use $Q_{rotated}$ and $K_{rotated}$

Now to attention scores:

- $A = { {Q_{rotated} K_{rotated}^T} \over {\sqrt{d_h}} } + M$, where $M$ is a causal mask with large negative values for masked positions
- Now applying softmax and dropout:

  $ A_{postprocessed} = dropout(softmax(A, dim=-1), p)$ where we can do dropout through elemenwise multiplication to random matrix.

  So $ A_{postprocessed} = softmax(A, dim=-1) * (random(S, S) < p)$

Weighted sum:

- $O = A_{postprocessed} V = A_{postprocessed} (X W_V)$
- $O_{reshaped} = O.view((B, S, D))$

Than final output is 

$Y = O_{reshaped} W_O$

### Attention simpliciation idea

But we can (formally incorrectly, I am just checking if the idea will be a good approximation) to think about $A$ as a linear attention, this way replacing:

- $A = { {Q_{rotated} K_{rotated}^T} \over {\sqrt{d_h}}}$
- $A_{postprocessed} = A * (random(S, S) < p)$

Now, assuming that all we know after attention computation is two matrices (but we know query / key / value projections other internal stuff, it is just attention mechanism what we don't want to recompute):
- $A_{postprocessed}$ 
- $O$

Now, assuming we used these replaced attention mechanism during forward pass.

Assuming we have this loss function: $Loss(Y, X) = |Y_{:, :-1, :} - X_{:, 1:, :}|$ (expecting $Y$ and $X$ to be $(B, S, D)$ shape matrixes).

We need weight ($W_K$ , $W_Q$ , $W_V$ , $W_O$) gradients to be computed reusing inputs / intermediates / $A_{postprocessed} / $O$ / outputs.

In [1]:
import torch
import dataclasses

In [2]:
def rotate_half(x):
    d = x.shape[-1]
    return torch.cat((-x[..., d//2:], x[..., :d//2]), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    q_rotated = q * cos + rotate_half(q) * sin
    k_rotated = k * cos + rotate_half(k) * sin
    return q_rotated, k_rotated

def get_rotary_embeddings(seq_len: int, dim: int, theta: float = 10000.0):
    position = torch.arange(seq_len, dtype=torch.float32)
    freqs = theta ** (-2 * torch.arange(0, dim, 2).float() / dim)
    angles = position.unsqueeze(1) * freqs.unsqueeze(0)  # (S, dim/2)
    
    cos = torch.cos(angles)  # (S, dim/2)
    sin = torch.sin(angles)  # (S, dim/2)
    
    # Expand to match the shape of Q/K (B, H, S, d_h)
    cos = cos.view(1, 1, seq_len, -1)  # (1, 1, S, d_h/2)
    sin = sin.view(1, 1, seq_len, -1)
    
    # Concatenate to handle even/odd dimensions properly
    cos = torch.cat([cos, cos], dim=-1)
    sin = torch.cat([sin, sin], dim=-1)
    return cos, sin

In [3]:
def apply_rotary_pos_emb_backward(grad_rotated, cos, sin):
    # Gradient through RoPE for K
    grad_transposed = grad_rotated * cos + rotate_half(grad_rotated) * sin
    return grad_transposed

In [4]:
@dataclasses.dataclass
class LlamaLikeLinearAttentionState:
    Y: torch.Tensor
    O_reshaped: torch.Tensor
    A_postprocessed: torch.Tensor
    V_transposed: torch.Tensor
    K_rotated: torch.Tensor
    Q_rotated: torch.Tensor
    cos: torch.Tensor
    sin: torch.Tensor
    W_O: torch.Tensor
    W_Q: torch.Tensor
    W_K: torch.Tensor
    W_V: torch.Tensor
    dropout_mask: torch.Tensor


def llamalike_linear_attention_forward(X, W_Q, W_K, W_V, W_O, H, cos, sin, dropout_mask):
    B, S, D = X.shape
    d_h = D // H
    assert D == H * d_h, f"Hidden size {D} must be divisible by {H} heads"

    # Projections
    Q = X @ W_Q  # (B, S, D)
    K = X @ W_K
    V = X @ W_V

    # Reshape for multi-head attention
    Q_transposed = Q.view(B, S, H, d_h).transpose(1, 2)  # (B, H, S, d_h)
    K_transposed = K.view(B, S, H, d_h).transpose(1, 2)
    V_transposed = V.view(B, S, H, d_h).transpose(1, 2)

    # Apply RoPE
    Q_rotated, K_rotated = apply_rotary_pos_emb(Q_transposed, K_transposed, cos, sin)

    # Attention scores
    A = (Q_rotated @ K_rotated.transpose(-1, -2)) / (d_h**0.5)
    A_postprocessed = A * dropout_mask  # Apply dropout

    # Output computation
    O = A_postprocessed @ V_transposed  # (B, H, S, d_h)
    O_reshaped = O.transpose(1, 2).reshape(B, S, D)  # Dynamic reshape
    Y = O_reshaped @ W_O

    return LlamaLikeLinearAttentionState(
        Y=Y,
        O_reshaped=O_reshaped,
        A_postprocessed=A_postprocessed,
        V_transposed=V_transposed,
        K_rotated=K_rotated,
        Q_rotated=Q_rotated,
        cos=cos,
        sin=sin,
        W_O=W_O,
        W_Q=W_Q,
        W_K=W_K,
        W_V=W_V,
        dropout_mask=dropout_mask
    )

In [5]:
def llamalike_linear_attention_gradients(state, X, d_h):
    B, S, D = X.shape
    H = D // d_h
    
    # Gradient through output projection
    dLoss_dY = torch.sign(state.Y[:, :-1, :] - X[:, 1:, :])
    dLoss_dO_reshaped = torch.zeros_like(state.Y)
    dLoss_dO_reshaped[:, :-1, :] = dLoss_dY @ state.W_O.T
    
    # Gradient through attention output
    dLoss_dO = dLoss_dO_reshaped.view(B, S, H, d_h).transpose(1, 2)
    dLoss_dV_transposed = state.A_postprocessed.transpose(-1, -2) @ dLoss_dO
    dLoss_dV = dLoss_dV_transposed.transpose(1, 2).reshape(B, S, D)
    
    # Gradients for value projection (sum over batch)
    manual_grad_W_V = (X.transpose(1, 2) @ dLoss_dV).sum(dim=0)  # Sum batch

    # Backprop through attention scores
    dLoss_dA = (dLoss_dO @ state.V_transposed.transpose(-1, -2)) * state.dropout_mask
    
    # Gradients for query and key projections
    dLoss_dQ_rotated = (dLoss_dA @ state.K_rotated) / (d_h**0.5)
    dLoss_dK_rotated = (dLoss_dA.transpose(-1, -2) @ state.Q_rotated) / (d_h**0.5)
    
    # Backprop through RoPE
    dLoss_dQ_transposed = dLoss_dQ_rotated * state.cos - rotate_half(dLoss_dQ_rotated) * state.sin
    dLoss_dK_transposed = dLoss_dK_rotated * state.cos + rotate_half(dLoss_dK_rotated) * state.sin
    
    # Reshape gradients
    dLoss_dQ = dLoss_dQ_transposed.transpose(1, 2).reshape(B, S, D)
    dLoss_dK = dLoss_dK_transposed.transpose(1, 2).reshape(B, S, D)
    
    # Sum gradients over batch
    manual_grad_W_Q = (X.transpose(1, 2) @ dLoss_dQ).sum(dim=0)  # Sum batch
    manual_grad_W_K = (X.transpose(1, 2) @ dLoss_dK).sum(dim=0)  # Sum batch
    manual_grad_W_O = (state.O_reshaped[:, :-1].transpose(1, 2) @ dLoss_dY).sum(dim=0)  # Sum batch

    return manual_grad_W_Q, manual_grad_W_K, manual_grad_W_V, manual_grad_W_O

In [6]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

# Inputs
B = 1
S = 2
D = 8  # Corrected to D=8
H = 2
d_h = 4  # Now valid: D = H * d_h → 8 = 2*4
dropout_p = 0.5

X = torch.randn(B, S, D)
Y_target = X[:, 1:, :]

# Parameters (D=8)
W_Q = torch.randn(D, D, requires_grad=True)
W_K = torch.randn(D, D, requires_grad=True)
W_V = torch.randn(D, D, requires_grad=True)
W_O = torch.randn(D, D, requires_grad=True)

# Rotary embeddings for per-head dimension
cos, sin = get_rotary_embeddings(seq_len=S, dim=d_h, theta=0.1)

# Correct dropout mask shape: (B, H, S, S)
dropout_mask = (torch.rand(B, H, S, S) > dropout_p).float()

In [7]:
state = llamalike_linear_attention_forward(X, W_Q, W_K, W_V, W_O, H, cos, sin, dropout_mask)

In [8]:
grads = llamalike_linear_attention_gradients(state, X, d_h)

In [9]:
manual_grad_W_Q, manual_grad_W_K, manual_grad_W_V, manual_grad_W_O = grads

In [10]:
loss = torch.abs(state.Y[:, :-1, :] - X[:, 1:, :]).sum()
loss.backward()

In [11]:
torch.allclose(manual_grad_W_O, W_O.grad, atol=1e-3)

True

In [12]:
torch.allclose(manual_grad_W_Q, W_Q.grad, atol=1e-3)

True

In [13]:
torch.allclose(manual_grad_W_V, W_V.grad, atol=1e-3)

True

In [14]:
torch.allclose(manual_grad_W_K, W_K.grad, atol=1e-3)

False

In [19]:
W_K.grad

tensor([[-270.9939,   82.4250,  -40.0187, -132.5305,    8.4649,    0.6771,
            5.4119,   -5.7520],
        [-186.2232,   62.7531,  -55.4063, -122.5704,  -18.5556,   -1.4843,
          -11.8633,   12.6088],
        [-127.0943,   38.5447,  -18.2565,  -61.5780,    4.4172,    0.3533,
            2.8240,   -3.0015],
        [ 273.2087,  -89.2001,   68.2043,  165.0573,   15.7969,    1.2636,
           10.0995,  -10.7342],
        [-100.1759,   29.1998,   -8.9966,  -42.4486,    8.1920,    0.6553,
            5.2374,   -5.5665],
        [ 162.9048,  -52.4037,   37.0920,   94.3819,    6.2961,    0.5036,
            4.0253,   -4.2783],
        [  -2.0288,   -1.5369,    9.5352,   10.1083,    8.6529,    0.6921,
            5.5321,   -5.8798],
        [ 227.0748,  -68.6936,   31.8294,  109.1288,   -8.5810,   -0.6864,
           -5.4861,    5.8309]])

In [15]:
manual_grad_W_K

tensor([[-267.5538,   76.0062,  -58.7720, -136.3131,   -8.4437,    5.5276,
            5.4450,   -1.7291],
        [-193.7640,   76.8235,  -14.2982, -114.2786,   18.5091,  -12.1168,
          -11.9357,    3.7904],
        [-125.2992,   35.1952,  -28.0423,  -63.5519,   -4.4061,    2.8844,
            2.8413,   -0.9023],
        [ 279.6284, -101.1786,   33.2077,  157.9982,  -15.7573,   10.3154,
           10.1612,   -3.2269],
        [ -96.8468,   22.9879,  -27.1451,  -46.1092,   -8.1714,    5.3494,
            5.2694,   -1.6734],
        [ 165.4635,  -57.1780,   23.1436,   91.5684,   -6.2803,    4.1114,
            4.0499,   -1.2861],
        [   1.4877,   -8.0983,   -9.6345,    6.2417,   -8.6312,    5.6504,
            5.5659,   -1.7675],
        [ 223.5875,  -62.1868,   50.8397,  112.9633,    8.5594,   -5.6034,
           -5.5196,    1.7528]], grad_fn=<SumBackward1>)

In [16]:
W_K.grad

tensor([[-270.9939,   82.4250,  -40.0187, -132.5305,    8.4649,    0.6771,
            5.4119,   -5.7520],
        [-186.2232,   62.7531,  -55.4063, -122.5704,  -18.5556,   -1.4843,
          -11.8633,   12.6088],
        [-127.0943,   38.5447,  -18.2565,  -61.5780,    4.4172,    0.3533,
            2.8240,   -3.0015],
        [ 273.2087,  -89.2001,   68.2043,  165.0573,   15.7969,    1.2636,
           10.0995,  -10.7342],
        [-100.1759,   29.1998,   -8.9966,  -42.4486,    8.1920,    0.6553,
            5.2374,   -5.5665],
        [ 162.9048,  -52.4037,   37.0920,   94.3819,    6.2961,    0.5036,
            4.0253,   -4.2783],
        [  -2.0288,   -1.5369,    9.5352,   10.1083,    8.6529,    0.6921,
            5.5321,   -5.8798],
        [ 227.0748,  -68.6936,   31.8294,  109.1288,   -8.5810,   -0.6864,
           -5.4861,    5.8309]])

In [17]:
(W_K.grad - manual_grad_W_K).abs()

tensor([[3.4401e+00, 6.4188e+00, 1.8753e+01, 3.7827e+00, 1.6909e+01, 4.8505e+00,
         3.3049e-02, 4.0229e+00],
        [7.5408e+00, 1.4070e+01, 4.1108e+01, 8.2918e+00, 3.7065e+01, 1.0633e+01,
         7.2445e-02, 8.8184e+00],
        [1.7951e+00, 3.3495e+00, 9.7858e+00, 1.9739e+00, 8.8232e+00, 2.5311e+00,
         1.7246e-02, 2.0992e+00],
        [6.4197e+00, 1.1979e+01, 3.4997e+01, 7.0590e+00, 3.1554e+01, 9.0518e+00,
         6.1675e-02, 7.5074e+00],
        [3.3291e+00, 6.2118e+00, 1.8149e+01, 3.6607e+00, 1.6363e+01, 4.6941e+00,
         3.1983e-02, 3.8932e+00],
        [2.5587e+00, 4.7742e+00, 1.3948e+01, 2.8135e+00, 1.2576e+01, 3.6077e+00,
         2.4581e-02, 2.9922e+00],
        [3.5165e+00, 6.5614e+00, 1.9170e+01, 3.8667e+00, 1.7284e+01, 4.9582e+00,
         3.3782e-02, 4.1122e+00],
        [3.4872e+00, 6.5068e+00, 1.9010e+01, 3.8345e+00, 1.7140e+01, 4.9170e+00,
         3.3502e-02, 4.0780e+00]], grad_fn=<AbsBackward0>)

In [18]:
dropout_mask

tensor([[[[1., 1.],
          [1., 0.]],

         [[0., 1.],
          [1., 0.]]]])