This notebook is written using Llama self attention module as example.

### Original self-attention

Let's denote

- $B$ is batch size
- $S$ is a sequence length
- $D$ is a hidden size

- and we have $X \in \mathbb{R}^{B, S, D}$ input.

Let's denote parameters:
- $W_Q \in \mathbb{R}^{D x (H d_h)}$ - query projection
- $W_K \in \mathbb{R}^{D x (H d_h)}$ - key projection
- $W_V \in \mathbb{R}^{D x (H d_h)}$ - value projection
- $W_O \in \mathbb{R}^{D x (H d_h)}$ - output projection
- where $H$ is head count and $d_h$ is individual head dimension
- also, $H_k$ is a number of key-value heads

Now let's write down query / key / value projections:

- $Q = X W_Q, Q \in \mathbb{R}^{B x S x (H d_h)}$
- $K = X W_K, K \in \mathbb{R}^{B x S x (H d_h)}$
- $V = X W_V, V \in \mathbb{R}^{B x S x (H d_h)}$

Now let's transpose them for multihead attention:
- $Q_{transposed} = Q.view(B, S, H, d_h).transpose(1, 2), Q_{transposed} \in \mathbb{R}^{B x H x S x d_h}$
- $K_{transposed} = K.view(B, S, H, d_h).transpose(1, 2), K_{transposed} \in \mathbb{R}^{B x H x S x d_h}$
- $V_{transposed} = V.view(B, S, H, d_h).transpose(1, 2), V_{transposed} \in \mathbb{R}^{B x H x S x d_h}$

Now let's write down RoPE (rotary position embeddings)

- $rotateHalf(X) = (-X_{:, :, :, d_h/2:}, X_{:, :, :, :d_h/2})$
- $Q_{rotated} = Q_{transposed} cos(\theta) + rotateHalf(Q_{transposed}) sin(\theta)$
- $K_{rotated} = K_{transposed} cos(\theta) + rotateHalf(K_{transposed}) sin(\theta)$

```python
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
```


Assuming $H=H_k$ no repeat occurs, so attention directly use $Q_{rotated}$ and $K_{rotated}$

Now to attention scores:

- $A = { {Q_{rotated} K_{rotated}^T} \over {\sqrt{d_h}} } + M$, where $M$ is a causal mask with large negative values for masked positions
- Now applying softmax and dropout:

  $ A_{postprocessed} = dropout(softmax(A, dim=-1), p)$ where we can do dropout through elemenwise multiplication to random matrix.

  So $ A_{postprocessed} = softmax(A, dim=-1) * (random(S, S) < p)$

Weighted sum:

- $O = A_{postprocessed} V = A_{postprocessed} (X W_V)$
- $O_{reshaped} = O.view((B, S, D))$

Than final output is 

$Y = O_{reshaped} W_O$

### Attention simpliciation idea

But we can (formally incorrectly, I am just checking if the idea will be a good approximation) to think about $A$ as a linear attention, this way replacing:

- $A = { {Q_{rotated} K_{rotated}^T} \over {\sqrt{d_h}}}$
- $A_{postprocessed} = A * (random(S, S) < p)$

Now, assuming that all we know after attention computation is two matrices (but we know query / key / value projections other internal stuff, it is just attention mechanism what we don't want to recompute):
- $A_{postprocessed}$ 
- $O$

Now, assuming we used these replaced attention mechanism during forward pass.

Assuming we have this loss function: $Loss(Y, X) = |Y_{:, :-1, :} - X_{:, 1:, :}|$ (expecting $Y$ and $X$ to be $(B, S, D)$ shape matrixes).

We need weight ($W_K$ , $W_Q$ , $W_V$ , $W_O$) gradients to be computed reusing inputs / intermediates / $A_{postprocessed} / $O$ / outputs.

In [1]:
import torch
import dataclasses

In [2]:
def rotate_half(x):
    d = x.shape[-1]
    return torch.cat((-x[..., d//2:], x[..., :d//2]), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    q_rotated = q * cos + rotate_half(q) * sin
    k_rotated = k * cos + rotate_half(k) * sin
    return q_rotated, k_rotated

def get_rotary_embeddings(seq_len: int, dim: int, theta: float = 10000.0):
    position = torch.arange(seq_len, dtype=torch.float32)
    freqs = theta ** (-2 * torch.arange(0, dim, 2).float() / dim)
    angles = position.unsqueeze(1) * freqs.unsqueeze(0)  # (S, dim/2)
    
    cos = torch.cos(angles)  # (S, dim/2)
    sin = torch.sin(angles)  # (S, dim/2)
    
    # Expand to match the shape of Q/K (B, H, S, d_h)
    cos = cos.view(1, 1, seq_len, -1)  # (1, 1, S, d_h/2)
    sin = sin.view(1, 1, seq_len, -1)
    
    # Concatenate to handle even/odd dimensions properly
    cos = torch.cat([cos, cos], dim=-1)
    sin = torch.cat([sin, sin], dim=-1)
    return cos, sin

In [3]:
def apply_rotary_pos_emb_backward(grad_rotated, cos, sin):
    # Gradient through RoPE for K
    grad_transposed = grad_rotated * cos + rotate_half(grad_rotated) * sin
    return grad_transposed

In [4]:
@dataclasses.dataclass
class LlamaLikeLinearAttentionState:
    Y: torch.Tensor
    O_reshaped: torch.Tensor
    A_postprocessed: torch.Tensor
    V_transposed: torch.Tensor
    K_rotated: torch.Tensor
    Q_rotated: torch.Tensor
    cos: torch.Tensor
    sin: torch.Tensor


def llamalike_linear_attention_forward(X, W_Q, W_K, W_V, W_O, H, cos, sin, dropout_mask):
    B, S, D = X.shape
    d_h = D // H

    # Forward pass
    Q = torch.matmul(X, W_Q)  # (B, S, H*d_h)
    K = torch.matmul(X, W_K)
    V = torch.matmul(X, W_V)

    Q_transposed = Q.view(B, S, H, d_h).transpose(1, 2)  # (B, H, S, d_h)
    K_transposed = K.view(B, S, H, d_h).transpose(1, 2)
    V_transposed = V.view(B, S, H, d_h).transpose(1, 2)

    # Apply RoPE with precomputed cos/sin
    Q_rotated, K_rotated = apply_rotary_pos_emb(Q_transposed, K_transposed, cos, sin)

    # Rest of forward pass remains the same...
    A = (torch.matmul(Q_rotated, K_rotated.transpose(-1, -2)) / (d_h**0.5))  # Scaled by sqrt(d_h)
                                                                             # (B, H, S, d_h) * (B, H, d_h, S) -> (B, H, S, S)
    A_postprocessed = A * dropout_mask  # Apply dropout mask

    O = torch.matmul(A_postprocessed, V_transposed) # (B, H, S, S) * (B, H, S, d_h) -> (B, H, S, d_h)
    O_reshaped = O.transpose(1, 2).reshape(1, 2, 2) # (B, H, S, d_h) -> (B, S, H, d_h) -> (B, S, D)
    Y = torch.matmul(O_reshaped, W_O) # (B, S, D) * (D, D) -> (B, S, D)

    return LlamaLikeLinearAttentionState(
        Y=Y,
        O_reshaped=O_reshaped,
        A_postprocessed=A_postprocessed,
        V_transposed=V_transposed,
        K_rotated=K_rotated,
        Q_rotated=Q_rotated,
        cos=cos,
        sin=sin
    )

In [5]:
def llamalike_linear_attention_gradients(state, X, d_h):
    B, S, D = X.shape
    H = D // d_h

    dLoss_dY = torch.sign(state.Y[:, :-1, :] - X[:, 1:, :])  # Gradient of L1 loss
    manual_grad_W_O = torch.einsum('bsd,bse->de', state.O_reshaped[:, :-1, :], dLoss_dY)

    # Expand to full sequence length with zeros
    dLoss_dY_full = torch.zeros_like(state.Y)          # Shape: (B, S, D)
    dLoss_dY_full[:, :-1, :] = dLoss_dY                # Fill first S-1 positions
    
    # Gradient for O_reshaped: dLoss_dY_full @ W_O^T
    dLoss_dO_reshaped = torch.matmul(dLoss_dY_full, W_O.T)  # (B, S, D) * (D, D) -> (B, S, D)
    # Reshape and transpose to match Oâ€™s original shape (B, H, S, d_h)
    dLoss_dO = dLoss_dO_reshaped.view(B, S, H, d_h).transpose(1, 2)  # (B, H, S, d_h)
    
    # Now compute dLoss/dV_transposed
    dLoss_dV_transposed = torch.matmul(state.A_postprocessed.transpose(-1, -2), dLoss_dO)  # (B, H, S, S) * (B, H, S, d_h) -> (B, H, S, d_h)
    dLoss_dV = dLoss_dV_transposed.transpose(1, 2).reshape(B, S, D)  # (B, H, S, d_h) -> (B, S, D)

    # Manual gradient for W_V
    manual_grad_W_V = torch.einsum('bsd,bse->de', X, dLoss_dV)

    dLoss_dA = torch.matmul(dLoss_dO, state.V_transposed.transpose(-1, -2)) * dropout_mask  # (B, H, S, d_h) * (B, H, d_h, S) -> (B, H, S, S)

    # Backprop through A = Q_rotated @ K_rotated^T / sqrt(d_h)
    dLoss_dQ_rotated = torch.matmul(dLoss_dA, state.K_rotated) / (d_h**0.5)  # Apply scaling here
    # Backprop through RoPE for Q (correct sign)
    dLoss_dQ_transposed = dLoss_dQ_rotated * state.cos - rotate_half(dLoss_dQ_rotated) * state.sin

    # Reshape and compute dLoss/dW_Q
    dLoss_dQ = dLoss_dQ_transposed.transpose(1, 2).reshape(1, 2, 2)
    manual_grad_W_Q = torch.einsum('bsd,bse->de', X, dLoss_dQ)

    # Backprop through A = Q_rotated @ K_rotated^T / sqrt(d_h)
    # Corrected transpose after matmul and sign in RoPE backward
    dLoss_dK_rotated = torch.matmul(state.Q_rotated.transpose(-1, -2), dLoss_dA).transpose(-1, -2) / (d_h**0.5)
    # Backprop through RoPE for K (FIXED SIGN HERE)
    dLoss_dK_transposed = dLoss_dK_rotated * state.cos + rotate_half(dLoss_dK_rotated) * state.sin  # Corrected to "+"
    
    # Reshape and compute dLoss/dW_K
    dLoss_dK = dLoss_dK_transposed.transpose(1, 2).reshape(B, S, D)
    manual_grad_W_K = torch.einsum('bsd,bse->de', X, dLoss_dK)
    
    return manual_grad_W_Q, manual_grad_W_K, manual_grad_W_V, manual_grad_W_O

In [6]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

# Inputs
B = 1
S = 2
D = 2
H = 1
d_h = 2
dropout_p = 0.5

X = torch.randn(B, S, D)  # (B, S, D)
Y_target = X[:, 1:, :]    # Loss compares Y[:, :-1] to X[:, 1:]

# Parameters
W_Q = torch.randn(D, D, requires_grad=True)
W_K = torch.randn(D, D, requires_grad=True)
W_V = torch.randn(D, D, requires_grad=True)
W_O = torch.randn(D, D, requires_grad=True)

# Fixed dropout mask (for reproducibility)
dropout_mask = (torch.rand(S, D) > dropout_p).float()  # Binary mask

# Example for S=2, d_h=2
cos, sin = get_rotary_embeddings(seq_len=S, dim=D, theta=0.1)

In [7]:
state = llamalike_linear_attention_forward(X, W_Q, W_K, W_V, W_O, H, cos, sin, dropout_mask)

In [8]:
grads = llamalike_linear_attention_gradients(state, X, d_h)

In [9]:
manual_grad_W_Q, manual_grad_W_K, manual_grad_W_V, manual_grad_W_O = grads

In [10]:
loss = torch.abs(state.Y[:, :-1, :] - X[:, 1:, :]).sum()
loss.backward()

In [11]:
torch.allclose(manual_grad_W_O, W_O.grad, atol=1e-3)

True

In [12]:
torch.allclose(manual_grad_W_Q, W_Q.grad, atol=1e-3)

True

In [13]:
torch.allclose(manual_grad_W_V, W_V.grad, atol=1e-3)

True

In [14]:
torch.allclose(manual_grad_W_K, W_K.grad, atol=1e-3)

True