This notebook is written using Llama self attention module as example.

### Original self-attention

Let's denote

- $B$ is batch size
- $S$ is a sequence length
- $D$ is a hidden size

- and we have $X \in \mathbb{R}^{B, S, D}$ input.

Let's denote parameters:
- $W_Q \in \mathbb{R}^{D x (H d_h)}$ - query projection
- $W_K \in \mathbb{R}^{D x (H d_h)}$ - key projection
- $W_V \in \mathbb{R}^{D x (H d_h)}$ - value projection
- $W_O \in \mathbb{R}^{D x (H d_h)}$ - output projection
- where $H$ is head count and $d_h$ is individual head dimension
- also, $H_k$ is a number of key-value heads

Now let's write down query / key / value projections:

- $Q = X W_Q, Q \in \mathbb{R}^{B x S x (H d_h)}$
- $K = X W_K, K \in \mathbb{R}^{B x S x (H d_h)}$
- $V = X W_V, V \in \mathbb{R}^{B x S x (H d_h)}$

Now let's transpose them for multihead attention:
- $Q_{transposed} = Q.view(B, S, H, d_h).transpose(1, 2), Q_{transposed} \in \mathbb{R}^{B x H x S x d_h}$
- $K_{transposed} = K.view(B, S, H, d_h).transpose(1, 2), K_{transposed} \in \mathbb{R}^{B x H x S x d_h}$
- $V_{transposed} = V.view(B, S, H, d_h).transpose(1, 2), V_{transposed} \in \mathbb{R}^{B x H x S x d_h}$

Now let's write down RoPE (rotary position embeddings)

- $rotateHalf(X) = (-X_{:, :, :, d_h/2:}, X_{:, :, :, :d_h/2})$
- $Q_{rotated} = Q_{transposed} cos(\theta) + rotateHalf(Q_{transposed}) sin(\theta)$
- $K_{rotated} = K_{transposed} cos(\theta) + rotateHalf(K_{transposed}) sin(\theta)$

```python
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
```


Assuming $H=H_k$ no repeat occurs, so attention directly use $Q_{rotated}$ and $K_{rotated}$

Now to attention scores:

- $A = { {Q_{rotated} K_{rotated}^T} \over {\sqrt{d_h}} } + M$, where $M$ is a causal mask with large negative values for masked positions
- Now applying softmax and dropout:

  $ A_{postprocessed} = dropout(softmax(A, dim=-1), p)$ where we can do dropout through elemenwise multiplication to random matrix.

  So $ A_{postprocessed} = softmax(A, dim=-1) * (random(S, S) < p)$

Weighted sum:

- $O = A_{postprocessed} V = A_{postprocessed} (X W_V)$
- $O_{reshaped} = O.view((B, S, D))$

Than final output is 

$Y = O_{reshaped} W_O$

### Attention simpliciation idea

But we can (formally incorrectly, I am just checking if the idea will be a good approximation) to think about $A$ as a linear attention, this way replacing:

- $A = { {Q_{rotated} K_{rotated}^T} \over {\sqrt{d_h}}}$
- $A_{postprocessed} = A * (random(S, S) < p)$

Now, assuming that all we know after attention computation is two matrices (but we know query / key / value projections other internal stuff, it is just attention mechanism what we don't want to recompute):
- $A_{postprocessed}$ 
- $O$

Now, assuming we used these replaced attention mechanism during forward pass.

Assuming we have this loss function: $Loss(Y, X) = |Y_{:, :-1, :} - X_{:, 1:, :}|$ (expecting $Y$ and $X$ to be $(B, S, D)$ shape matrixes).

In [1]:
import torch

# Inputs
B, S, D = X.shape
X = torch.randn(1, 2, 2)  # (B, S, D)
Y_target = X[:, 1:, :]     # Loss compares Y[:, :-1] to X[:, 1:]

# Parameters
W_Q = torch.randn(2, 2, requires_grad=True)
W_K = torch.randn(2, 2, requires_grad=True)
W_V = torch.randn(2, 2, requires_grad=True)
W_O = torch.randn(2, 2, requires_grad=True)

# Fixed dropout mask (for reproducibility)
dropout_mask = (torch.rand(2, 2) > 0.5).float()  # Binary mask

In [2]:
def rotate_half(x):
    d = x.shape[-1]
    return torch.cat((-x[..., d//2:], x[..., :d//2]), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    q_rotated = q * cos + rotate_half(q) * sin
    k_rotated = k * cos + rotate_half(k) * sin
    return q_rotated, k_rotated

In [3]:
def get_rotary_embeddings(seq_len: int, dim: int, theta: float = 10000.0):
    position = torch.arange(seq_len, dtype=torch.float32)
    freqs = theta ** (-2 * torch.arange(0, dim, 2).float() / dim)
    angles = position.unsqueeze(1) * freqs.unsqueeze(0)  # (S, dim/2)
    
    cos = torch.cos(angles)  # (S, dim/2)
    sin = torch.sin(angles)  # (S, dim/2)
    
    # Expand to match the shape of Q/K (B, H, S, d_h)
    cos = cos.view(1, 1, seq_len, -1)  # (1, 1, S, d_h/2)
    sin = sin.view(1, 1, seq_len, -1)
    
    # Concatenate to handle even/odd dimensions properly
    cos = torch.cat([cos, cos], dim=-1)
    sin = torch.cat([sin, sin], dim=-1)
    return cos, sin

In [4]:
# Example for S=2, d_h=2
cos, sin = get_rotary_embeddings(seq_len=2, dim=2, theta=0.1)

In [5]:
# Forward pass
Q = torch.matmul(X, W_Q)  # (B, S, H*d_h)
K = torch.matmul(X, W_K)
V = torch.matmul(X, W_V)

Q_transposed = Q.view(1, 2, 1, 2).transpose(1, 2)  # (B, H, S, d_h)
K_transposed = K.view(1, 2, 1, 2).transpose(1, 2)
V_transposed = V.view(1, 2, 1, 2).transpose(1, 2)

In [6]:
# Apply RoPE with precomputed cos/sin
Q_rotated, K_rotated = apply_rotary_pos_emb(Q_transposed, K_transposed, cos, sin)

In [7]:
# Rest of forward pass remains the same...
A = (torch.matmul(Q_rotated, K_rotated.transpose(-1, -2)) / (2**0.5))  # Scaled by sqrt(d_h)
A_postprocessed = A * dropout_mask  # Apply dropout mask

In [8]:
O = torch.matmul(A_postprocessed, V_transposed)
O_reshaped = O.transpose(1, 2).reshape(1, 2, 2)
Y = torch.matmul(O_reshaped, W_O)

In [9]:
loss = torch.abs(Y[:, :-1, :] - X[:, 1:, :]).sum()
loss.backward()

# Autograd gradients
autograd_grad_W_O = W_O.grad.clone()
autograd_grad_W_V = W_V.grad.clone()
autograd_grad_W_K = W_K.grad.clone()
autograd_grad_W_Q = W_Q.grad.clone()

In [10]:
# After forward pass:
Q_rotated = Q_rotated.detach()
K_rotated = K_rotated.detach()
V_transposed = V_transposed.detach()
A_postprocessed = A_postprocessed.detach()
dropout_mask = dropout_mask.detach()
cos = cos.detach()
sin = sin.detach()

In [11]:
dLoss_dY = torch.sign(Y[:, :-1, :] - X[:, 1:, :])  # Gradient of L1 loss
manual_grad_W_O = torch.einsum('bsd,bse->de', O_reshaped[:, :-1, :], dLoss_dY)

In [12]:
print("∇W_O match:", torch.allclose(autograd_grad_W_O, manual_grad_W_O, atol=1e-5))

∇W_O match: True


In [44]:
# Original dLoss_dY (gradient from Y[:, :-1, :])
dLoss_dY = torch.sign(Y[:, :-1, :] - X[:, 1:, :])  # Shape: (1, 1, 2)

# Expand to full sequence length with zeros
dLoss_dY_full = torch.zeros_like(Y)                # Shape: (1, 2, 2)
dLoss_dY_full[:, :-1, :] = dLoss_dY                # Fill first S-1 positions

# Gradient for O_reshaped: dLoss_dY_full @ W_O^T
dLoss_dO_reshaped = torch.matmul(dLoss_dY_full, W_O.T)  # Shape: (1, 2, 2)

In [45]:
# Reshape and transpose to match O’s original shape (B, H, S, d_h) = (1, 1, 2, 2)
dLoss_dO = dLoss_dO_reshaped.view(1, 2, 1, 2).transpose(1, 2)  # Shape: (1, 1, 2, 2)

In [46]:
dLoss_dO

tensor([[[[-0.7708,  0.6129],
          [ 0.0000,  0.0000]]]], grad_fn=<TransposeBackward0>)

In [50]:
# Compute dLoss_dY_full (includes zeros for unused positions)
dLoss_dY = torch.sign(Y[:, :-1, :] - X[:, 1:, :])  # (B, S-1, D) = (1, 1, 2)
dLoss_dY_full = torch.zeros(1, 2, 2)               # (B, S, D) = (1, 2, 2)
dLoss_dY_full[:, :-1, :] = dLoss_dY                # Fill valid positions

# Compute dLoss_dO_reshaped = dLoss_dY_full @ W_O^T
dLoss_dO_reshaped = torch.matmul(dLoss_dY_full, W_O.T)  # (1, 2, 2)

# Reshape/transpose to match O’s shape (B, H, S, d_h) = (1, 1, 2, 2)
dLoss_dO = dLoss_dO_reshaped.view(1, 2, 1, 2).transpose(1, 2)  # (1, 1, 2, 2)

# Now compute dLoss/dV_transposed
dLoss_dV_transposed = torch.matmul(A_postprocessed.transpose(-1, -2), dLoss_dO)  # (1, 1, 2, 2)
dLoss_dV = dLoss_dV_transposed.transpose(1, 2).reshape(1, 2, 2)  # (1, 2, 2)

# Manual gradient for W_V
manual_grad_W_V = torch.einsum('bsd,bse->de', X, dLoss_dV)

print("∇W_V match:", torch.allclose(autograd_grad_W_V, manual_grad_W_V, atol=1e-5))

∇W_V match: True


In [48]:
dLoss_dO

tensor([[[[-0.7708,  0.6129],
          [ 0.0000,  0.0000]]]], grad_fn=<TransposeBackward0>)

In [16]:
def apply_rotary_pos_emb_backward(grad_rotated, cos, sin):
    # Gradient through RoPE for K
    grad_transposed = grad_rotated * cos + rotate_half(grad_rotated) * sin
    return grad_transposed

In [17]:
# Incorrect line (extra division by sqrt(d_h)):
# dLoss_dA = torch.matmul(dLoss_dO, V_transposed.transpose(-1, -2)) * dropout_mask / (2**0.5)

# Corrected line (remove division by sqrt(d_h)):
dLoss_dA = torch.matmul(dLoss_dO, V_transposed.transpose(-1, -2)) * dropout_mask  # Shape: (1, 1, 2, 2)

# Backprop through A = Q_rotated @ K_rotated^T / sqrt(d_h)
dLoss_dQ_rotated = torch.matmul(dLoss_dA, K_rotated) / (2**0.5)  # Apply scaling here

# Backprop through RoPE for Q (correct sign)
dLoss_dQ_transposed = dLoss_dQ_rotated * cos - rotate_half(dLoss_dQ_rotated) * sin

# Reshape and compute dLoss/dW_Q
dLoss_dQ = dLoss_dQ_transposed.transpose(1, 2).reshape(1, 2, 2)
manual_grad_W_Q = torch.einsum('bsd,bse->de', X, dLoss_dQ)

print("∇W_Q match:", torch.allclose(autograd_grad_W_Q, manual_grad_W_Q, atol=1e-5))

∇W_Q match: True


In [54]:
# Backprop through A = Q_rotated @ K_rotated^T / sqrt(d_h)
dLoss_dK_rotated = torch.matmul(Q_rotated.transpose(-1, -2), dLoss_dA) / (2**0.5)  # Apply scaling here

# Backprop through RoPE for K (FIXED SIGN HERE)
dLoss_dK_transposed = dLoss_dK_rotated * cos - rotate_half(dLoss_dK_rotated) * sin  # "-" instead of "+"
#dLoss_dK_transposed = apply_rotary_pos_emb_backward(dLoss_dK_rotated, cos, sin)

# Reshape and compute dLoss/dW_K
dLoss_dK = dLoss_dK_transposed.transpose(1, 2).reshape(1, 2, 2)
manual_grad_W_K = torch.einsum('bsd,bse->de', X, dLoss_dK)

print("∇W_K match:", torch.allclose(autograd_grad_W_K, manual_grad_W_K, atol=1e-5))

∇W_K match: False


In [55]:
autograd_grad_W_K

tensor([[ 0.0192,  0.0048],
        [-0.0047, -0.0012]])

In [56]:
manual_grad_W_K

tensor([[ 0.0158, -0.0219],
        [-0.0039,  0.0152]], grad_fn=<ViewBackward0>)

In [21]:
B, S, D = X.shape
H = 1
d_h = 2
# Step 1: Gradient from Loss to Y
dLoss_dY = torch.sign(Y[:, :-1, :] - X[:, 1:, :])
dLoss_dY_full = torch.zeros(B, S, D)
dLoss_dY_full[:, :-1, :] = dLoss_dY

# Step 2: Gradient from Y to O
dLoss_dO_reshaped = torch.matmul(dLoss_dY_full, W_O.T)
dLoss_dO = dLoss_dO_reshaped.view(B, S, H, d_h).transpose(1, 2)

# Step 3: Gradient from O to A_postprocessed
dLoss_dA = torch.matmul(dLoss_dO, V_transposed.transpose(-1, -2)) * dropout_mask

# Step 4: Gradient from A to K_rotated
dLoss_dK_rotated = torch.matmul(Q_rotated.transpose(-1, -2), dLoss_dA) / (d_h**0.5)

# Step 5: Gradient through RoPE
dLoss_dK_transposed = dLoss_dK_rotated * cos - rotate_half(dLoss_dK_rotated) * sin

# Step 6: Gradient from K_transposed to K
dLoss_dK = dLoss_dK_transposed.transpose(1, 2).reshape(B, S, H * d_h)

# Step 7: Gradient from K to W_K
manual_grad_W_K = torch.einsum('bsd,bse->de', X, dLoss_dK)

# Verify
print("∇W_K match:", torch.allclose(W_K.grad, manual_grad_W_K, atol=1e-5))

∇W_K match: False


In [22]:
manual_grad_W_K

tensor([[ 0.0158, -0.0219],
        [-0.0039,  0.0152]], grad_fn=<ViewBackward0>)

In [23]:
W_K.grad

tensor([[ 0.0192,  0.0048],
        [-0.0047, -0.0012]])

In [24]:
(manual_grad_W_K - W_K.grad).abs().mean()

tensor(0.0118, grad_fn=<MeanBackward0>)

In [25]:
# Step 1: Gradient from Loss to Y
dLoss_dY = torch.sign(Y[:, :-1, :] - X[:, 1:, :])
dLoss_dY_full = torch.zeros_like(Y)
dLoss_dY_full[:, :-1, :] = dLoss_dY

# Step 2: Gradient from Y to O
dLoss_dO_reshaped = torch.matmul(dLoss_dY_full, W_O.T.detach())
dLoss_dO = dLoss_dO_reshaped.view(B, S, H, d_h).transpose(1, 2)

# Step 3: Gradient from O to A_postprocessed
dLoss_dA = torch.matmul(dLoss_dO, V_transposed.transpose(-1, -2)) * dropout_mask

# Step 4: Gradient from A_postprocessed to K_rotated
dLoss_dK_rotated = torch.matmul(Q_rotated.transpose(-1, -2), dLoss_dA) / (d_h**0.5)

# Step 5: Gradient through RoPE
dLoss_dK_transposed = dLoss_dK_rotated * cos - rotate_half(dLoss_dK_rotated) * sin

# Step 6: Reshape and compute gradient for W_K
dLoss_dK = dLoss_dK_transposed.transpose(1, 2).reshape(B, S, H * d_h)
manual_grad_W_K = torch.einsum('bsd,bse->de', X.detach(), dLoss_dK)

In [26]:
W_K.grad

tensor([[ 0.0192,  0.0048],
        [-0.0047, -0.0012]])

In [27]:
manual_grad_W_K

tensor([[ 0.0158, -0.0219],
        [-0.0039,  0.0152]], grad_fn=<ViewBackward0>)

In [28]:
(manual_grad_W_K - W_K.grad).abs().mean()

tensor(0.0118, grad_fn=<MeanBackward0>)

In [33]:
W_O

tensor([[-0.3479,  0.4229],
        [ 0.3712, -0.2417]], requires_grad=True)

In [34]:
Q

tensor([[[-0.3951, -1.1592],
         [ 0.0997,  0.1972]]], grad_fn=<UnsafeViewBackward0>)

Expected
```
Q = tensor([[[-0.3950, -1.1592],
            [ 0.0996,  0.1973]]])
```

In [35]:
K

tensor([[[-1.0076, -0.1648],
         [ 0.1887,  0.0459]]], grad_fn=<UnsafeViewBackward0>)

```
K = tensor([[[-1.0079, -0.1645],
            [ 0.1888,  0.0460]]])
```

In [36]:
K_transposed

tensor([[[[-1.0076, -0.1648],
          [ 0.1887,  0.0459]]]], grad_fn=<TransposeBackward0>)

Expected
```
tensor([[[[-1.0079, -0.1645],
          [ 0.1888,  0.0460]]]])
```

In [39]:
sin

tensor([[[[0.0000, 0.0000],
          [0.8415, 0.8415]]]])

In [40]:
dLoss_dO 

tensor([[[[-0.7708,  0.6129],
          [ 0.0000,  0.0000]]]], grad_fn=<TransposeBackward0>)

Expected
```
tensor([[[[ 0.7708, -0.6129],
          [ 0.0000,  0.0000]]]])
```

In [41]:
V_transposed 

tensor([[[[ 0.2880, -1.0236],
          [-0.0507,  0.2050]]]])

Expected
```
tensor([[[[ 0.2881, -1.0241],
          [-0.0508,  0.2050]]]])
```

In [42]:
dLoss_dA_before_mask

NameError: name 'dLoss_dA_before_mask' is not defined