[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RichardJPovinelli/Neural_Networks_Course/blob/main/Attention_Backprop_Demo.ipynb)
# Backpropagation Through Scaled Dot-Product Attention (Demo)

This notebook demonstrates **backpropagation through a tiny self-attention block**.

We will:

1. Define a toy self-attention module (single head) with small, fixed matrices.
2. Run a forward pass and define a simple scalar loss.
3. Use PyTorch autograd to compute gradients.
4. Manually compute the same gradients using the analytical formulas for attention.
5. Compare **autograd vs. manual vs. finite-difference** to verify everything matches.


In [29]:
import torch
import math

PRECISION = 2

torch.set_printoptions(precision=PRECISION, sci_mode=True)
# choose best available device automatically
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"  # Apple Silicon GPU
else:
    device = "cpu"
print("Using PyTorch version:", torch.__version__)


Using PyTorch version: 2.9.0+cpu



## 1. Tiny Self-Attention Setup

We use:
- Sequence length $n = 3$
- Model dimension $d_{model} = d_k = d_v = 4$
- Single-head attention
- Fixed numerical values so every run is reproducible.


In [30]:

# Fixed toy inputs (3 tokens, 4-dim embeddings)
X = torch.tensor([
    [0.5, 0.2, 0.1, -0.1],
    [0.0, 0.3, -0.2, 0.2],
    [0.4, -0.1, 0.0, 0.3]
], dtype=torch.float32)

# Fixed projection matrices
W_Q = torch.tensor([
    [0.2, -0.1, 0.0, 0.3],
    [0.1,  0.0, 0.2, -0.2],
    [-0.1, 0.3, 0.1, 0.0],
    [0.0,  0.2, -0.2, 0.1]
], dtype=torch.float32)

W_K = torch.tensor([
    [0.1,  0.2, 0.0, -0.1],
    [0.0,  0.1, 0.3,  0.0],
    [0.2, -0.2, 0.1,  0.1],
    [-0.1, 0.0, 0.2,  0.2]
], dtype=torch.float32)

W_V = torch.tensor([
    [0.3,  0.1, 0.0, -0.2],
    [0.0,  0.2, 0.1,  0.0],
    [0.1, -0.1, 0.2,  0.1],
    [0.0,  0.1, -0.2, 0.3]
], dtype=torch.float32)

d_k = 4

# Simple target for the loss: same shape as A
T = torch.tensor([
    [0.10, 0.00, 0.05, -0.05],
    [0.00, 0.10, -0.05, 0.05],
    [0.05, -0.05, 0.10, 0.00]
], dtype=torch.float32)

print(f"X = \n{X}\nW_Q.shape = {W_Q.shape}\nW_K.shape = {W_K.shape}\nW_V.shape = {W_V.shape}\nT.shape = {T.shape}")


X = 
tensor([[ 5.00e-01,  2.00e-01,  1.00e-01, -1.00e-01],
        [ 0.00e+00,  3.00e-01, -2.00e-01,  2.00e-01],
        [ 4.00e-01, -1.00e-01,  0.00e+00,  3.00e-01]])
W_Q.shape = torch.Size([4, 4])
W_K.shape = torch.Size([4, 4])
W_V.shape = torch.Size([4, 4])
T.shape = torch.Size([3, 4])


Turn off scaling by setting SCALE to False

In [None]:
SCALE = True #set to False to turn off scaling
scale = 1.0 / math.sqrt(d_k) if SCALE else 1.0


## 2. Forward Pass

We implement scaled dot-product self-attention:


$Q = X W_Q, \quad$
$K = X W_K, \quad$
$V = X W_V$


$S = \frac{QK^\top}{\sqrt{d_k}}, \quad$
$P = \text{softmax}(S) \text{ (row-wise)}, \quad$
$A = P V$

Loss:
$\mathcal{L} = \tfrac12 \|A - T\|^2.$


In [32]:
# Clone parameters with requires_grad for autograd
X_autogradient  = X.clone().detach().requires_grad_(True)
WQ_autogradient = W_Q.clone().detach().requires_grad_(True)
WK_autogradient = W_K.clone().detach().requires_grad_(True)
WV_autogradient = W_V.clone().detach().requires_grad_(True)

def attention_forward(X_, WQ_, WK_, WV_):
    Q = X_ @ WQ_
    K = X_ @ WK_
    V = X_ @ WV_
    S = (Q @ K.transpose(0, 1)) * scale
    P = torch.softmax(S, dim=-1)
    A = P @ V
    return Q, K, V, S, P, A

Q_autogradient, K_autogradient, V_autogradient, S_autogradient, P_autogradient, A_autogradient = attention_forward(X_autogradient, WQ_autogradient, WK_autogradient, WV_autogradient)
L_autogradient = 0.5 * torch.sum((A_autogradient - T)**2)

print("Attention output A:")
print(A_autogradient)
print(f"\nLoss L: {L_autogradient:.{PRECISION}e}")

Attention output A:
tensor([[ 8.67e-02,  7.33e-02, -2.00e-02, -2.34e-02],
        [ 8.69e-02,  7.33e-02, -1.98e-02, -2.36e-02],
        [ 8.68e-02,  7.32e-02, -2.01e-02, -2.33e-02]], grad_fn=<MmBackward0>)

Loss L: 2.86e-02


## 3. Backward Pass (with Autograd)

In [33]:
# Retain gradients for intermediate (non-leaf) tensors so their .grad fields are populated
# when we call backward(). This makes it easy to inspect Q_ag.grad, K_ag.grad, etc.
for t in (Q_autogradient, K_autogradient, V_autogradient, S_autogradient, P_autogradient, A_autogradient):
    # Only call retain_grad() on tensors that require grad and are not leafs
    if isinstance(t, torch.Tensor) and t.requires_grad and not t.is_leaf:
        t.retain_grad()

L_autogradient.backward()

def safe_clone_grad(tensor):
    if tensor is None:
        return None
    g = tensor.grad
    if g is None:
        return None
    return g.detach().clone()

grads_auto = {
    "X": safe_clone_grad(X_autogradient),
    "W_Q": safe_clone_grad(WQ_autogradient),
    "W_K": safe_clone_grad(WK_autogradient),
    "W_V": safe_clone_grad(WV_autogradient),
    "Q": safe_clone_grad(Q_autogradient),
    "K": safe_clone_grad(K_autogradient),
    "V": safe_clone_grad(V_autogradient),
    "S": safe_clone_grad(S_autogradient),
    "P": safe_clone_grad(P_autogradient),
    "A": safe_clone_grad(A_autogradient),
}

print("\nAutograd gradient dL/dW_Q:")
print(grads_auto["W_Q"])
print("\nAutograd gradient dL/dW_K:")
print(grads_auto["W_K"])
print("\nAutograd gradient dL/dW_V:")
print(grads_auto["W_V"])


Autograd gradient dL/dW_Q:
tensor([[-2.54e-04, -6.72e-05,  6.62e-05,  1.79e-04],
        [ 1.59e-04,  3.44e-05, -6.72e-05, -9.17e-05],
        [-1.86e-04, -4.14e-05,  7.43e-05,  1.11e-04],
        [ 1.40e-04,  2.62e-05, -7.27e-05, -7.00e-05]])

Autograd gradient dL/dW_K:
tensor([[-3.46e-05, -8.76e-06, -5.94e-05, -2.93e-04],
        [-4.36e-05,  2.58e-06,  2.59e-05, -3.27e-05],
        [-2.71e-05, -5.15e-06, -3.39e-05, -1.87e-04],
        [ 7.44e-05,  4.38e-06,  2.08e-05,  2.73e-04]])

Autograd gradient dL/dW_V:
tensor([[ 3.32e-02,  5.10e-02, -4.80e-02, -2.11e-02],
        [ 1.47e-02,  2.25e-02, -2.12e-02, -9.34e-03],
        [-3.64e-03, -5.64e-03,  5.31e-03,  2.31e-03],
        [ 1.47e-02,  2.27e-02, -2.14e-02, -9.34e-03]])


## 4. Manual Backpropagation

We recompute the forward pass without autograd and apply the analytical formulas.

Start with:
$\frac{\partial \mathcal{L}}{\partial A} = A - T$

Then:

1. Through $A = P V$:
$\frac{\partial \mathcal{L}}{\partial V} = P^\top \frac{\partial \mathcal{L}}{\partial A}, \quad$
$\frac{\partial \mathcal{L}}{\partial P} = \frac{\partial \mathcal{L}}{\partial A} V^\top$

2. Softmax (row-wise for each query $i$):
$
\frac{\partial \mathcal{L}}{\partial s_i} = J_{\text{softmax}}(p_i)^{\top}
\frac{\partial \mathcal{L}}{\partial p_i},
\quad
J_{\text{softmax}}(p_i) = \mathrm{diag}(p_i) - p_i p_i^\top
$

3. Through $S = QK^\top / \sqrt{d_k}$:
$
\frac{\partial \mathcal{L}}{\partial Q}
= \frac{1}{\sqrt{d_k}} \frac{\partial \mathcal{L}}{\partial S} K,
\quad
\frac{\partial \mathcal{L}}{\partial K}
= \frac{1}{\sqrt{d_k}}
\left(\frac{\partial \mathcal{L}}{\partial S}\right)^{\top} Q
$

4. Through projections:
   
$
\frac{\partial \mathcal{L}}{\partial W_Q} = X^\top \frac{\partial \mathcal{L}}{\partial Q},\;
\frac{\partial \mathcal{L}}{\partial W_K} = X^\top \frac{\partial \mathcal{L}}{\partial K},\;
\frac{\partial \mathcal{L}}{\partial W_V} = X^\top \frac{\partial \mathcal{L}}{\partial V}
$

$ \frac{\partial \mathcal{L}}{\partial X} = \frac{\partial \mathcal{L}}{\partial Q} W_Q^\top + \frac{\partial \mathcal{L}}{\partial K} W_K^\top + \frac{\partial \mathcal{L}}{\partial V} W_V^\top $

In [34]:

# Forward without autograd
with torch.no_grad():
    Q = X @ W_Q
    K = X @ W_K
    V = X @ W_V
    S = (Q @ K.t()) * scale
    P = torch.softmax(S, dim=-1)
    A = P @ V
    L = 0.5 * torch.sum((A - T)**2)

# 1) dL/dA
dA = (A - T)

# 2) Through A = P V
dV = P.t() @ dA
dP = dA @ V.t()

# 3) Softmax Jacobian row-wise
def softmax_grad_row(p_row, dp_row):
    # p_row, dp_row: 1D tensors of length n
    p = p_row.view(-1, 1)
    J = torch.diagflat(p) - p @ p.t()
    return (J.t() @ dp_row.view(-1, 1)).view(-1)

n = P.shape[0]
dS = torch.zeros_like(S)
for i in range(n):
    dS[i] = softmax_grad_row(P[i], dP[i])

# 4) Through S = (QK^T)/sqrt(d_k)
dQ = (dS @ K) * scale
dK = (dS.t() @ Q) * scale

# 5) Through projections
dW_Q = X.t() @ dQ
dW_K = X.t() @ dK
dW_V = X.t() @ dV

dX = dQ @ W_Q.t() + dK @ W_K.t() + dV @ W_V.t()

grads_manual = {
    "X": dX,
    "W_Q": dW_Q,
    "W_K": dW_K,
    "W_V": dW_V,
    "Q": dQ,
    "K": dK,
    "V": dV,
    "S": dS,
    "P": dP,
    "A": dA,
}

print(f"Loss from manual forward: {L:.{PRECISION}e}")


Loss from manual forward: 2.86e-02


## 5. Compare Manual vs Autograd Gradients

In [35]:

def max_abs_diff(a, b):
    return float(torch.max(torch.abs(a - b)))

for name in ["W_Q", "W_K", "W_V", "X"]:
    diff = max_abs_diff(grads_manual[name], grads_auto[name])
    print(f"Max |manual - autograd| for {name}: {diff:.{PRECISION}e}")

print("\nAutograd dL/dW_Q:\n", grads_auto["W_Q"])
print("\nManual   dL/dW_Q:\n", grads_manual["W_Q"])


Max |manual - autograd| for W_Q: 7.28e-12
Max |manual - autograd| for W_K: 2.91e-11
Max |manual - autograd| for W_V: 0.00e+00
Max |manual - autograd| for X: 1.86e-09

Autograd dL/dW_Q:
 tensor([[-2.54e-04, -6.72e-05,  6.62e-05,  1.79e-04],
        [ 1.59e-04,  3.44e-05, -6.72e-05, -9.17e-05],
        [-1.86e-04, -4.14e-05,  7.43e-05,  1.11e-04],
        [ 1.40e-04,  2.62e-05, -7.27e-05, -7.00e-05]])

Manual   dL/dW_Q:
 tensor([[-2.54e-04, -6.72e-05,  6.62e-05,  1.79e-04],
        [ 1.59e-04,  3.44e-05, -6.72e-05, -9.17e-05],
        [-1.86e-04, -4.14e-05,  7.43e-05,  1.11e-04],
        [ 1.40e-04,  2.62e-05, -7.27e-05, -7.00e-05]])


## 6. Finite-Difference Check

Pick one element of $W_Q$ and approximate its gradient numerically.


In [36]:
def compute_loss_with_WQ(WQ_new):
    Q = X @ WQ_new
    K = X @ W_K
    V = X @ W_V
    S = (Q @ K.t()) * scale
    P = torch.softmax(S, dim=-1)
    A = P @ V
    return 0.5 * torch.sum((A - T)**2)

i, j = 0, 0  # test this entry
eps = 1e-4

with torch.no_grad():
    WQ_plus = W_Q.clone();  WQ_plus[i, j] += eps
    WQ_minus = W_Q.clone(); WQ_minus[i, j] -= eps
    L_plus = compute_loss_with_WQ(WQ_plus)
    L_minus = compute_loss_with_WQ(WQ_minus)
    num_grad = (L_plus - L_minus) / (2 * eps)

print(f"Finite-diff grad for W_Q[{i},{j}]: {float(num_grad):.{PRECISION}e}")
print(f"Autograd grad: {float(grads_auto['W_Q'][i,j]):.{PRECISION}e}")
print(f"Manual grad: {float(grads_manual['W_Q'][i,j]):.{PRECISION}e}")


Finite-diff grad for W_Q[0,0]: -2.51e-04
Autograd grad: -2.54e-04
Manual grad: -2.54e-04


## 7. What to Notice

- Manual gradients match autograd (up to numerical precision).
- Each query position's loss sends signal to **all** keys/values it attends to.
- The softmax Jacobian step is where this coupling happens.
- The $1/\sqrt{d_k}$ factor keeps scores and gradients in a stable range.

## 8. Do the following
- Change the target $T$ or matrices and re-run cells. Predict qualitatively how attention weights and gradients will shift.
- Toggle scaling and compare graident norms, i.e. $S = \frac{QK^\top}{\sqrt{d_k}}$ vs. $S = QK^\top$
- Plot heatmaps for $P$, $\frac{\partial {L}}{\partial {P}}$, and $\frac{\partial {L}}{\partial {P}}$
