Implementation

In [7]:
import numpy as np

def softmax(x, axis=-1):  # By default, last axis;
    exps = np.exp(x - np.max(x, axis, keepdims=True))
    return exps / np.sum(exps, axis=axis, keepdims=True)

# J = diag(s) − s sᵀ ; dL_dx = J ⋅ dL_dy
# dL_dx = J * dL_dy
# When i == j, derivative is the same as single class: s(1 - s)
# When i != j, derivative = -s_i * s_j. Why? Because of the quotient rule of differentiation.
# J -> if i == j: s_i - s_i² = s_i(1 - s_i); else: 0 - s_i * s_j

def dsoftmax(softmax_out, grad_out):
    s = softmax_out
    return s * (grad_out - np.sum(grad_out * s, axis=-1, keepdims=True))
    # dL_dx = s_i * (dL_dy - sum dL_dy * s_j)
    # Dot product is scalar
    # Each component of s gets scaled by the same factor. The length of that vector depends on how grad aligns with s.
    # s sᵀ is the outer product

class SelfAttention:
    def __init__(self, embed_dim):
        self.embed_dim = embed_dim
        # Weight matrices
        self.W_q = np.random.randn(embed_dim, embed_dim) * 0.1
        self.W_k = np.random.randn(embed_dim, embed_dim) * 0.1
        self.W_v = np.random.randn(embed_dim, embed_dim) * 0.1
        self.dW_q = np.zeros_like(self.W_q)
        self.dW_k = np.zeros_like(self.W_k)
        self.dW_v = np.zeros_like(self.W_v)

    def forward(self, x):  # x has shape (seq_len, embed_dim)
        self.x = x
        self.Q = x @ self.W_q  # Same shape as x
        self.K = x @ self.W_k
        self.V = x @ self.W_v
        scores = self.Q @ self.K.T / np.sqrt(self.embed_dim)  # Shape: (seq_len, seq_len), how much attention token i should pay to token j
        # Divide by sqrt(embed_dim) for scaling to stabilize gradients
        self.scale = np.sqrt(self.embed_dim)
        self.weights = softmax(scores, axis=1)
        output = self.weights @ self.V  # (seq_len, embed_dim), same shape as x
        return output

    def backward(self, d_out):
        T, D = self.x.shape  # (seq_len, embed_dim)
        d_weights = d_out @ self.V.T  # (seq_len, seq_len)
        dV = self.weights.T @ d_out
        dscores = dsoftmax(self.weights, d_weights) / self.scale  # As per dsoftmax function
        dQ = dscores @ self.K  # shape (T, D)
        dK = dscores.T @ self.Q  # shape (T, D)
        self.dW_q += self.x.T @ dQ
        self.dW_k += self.x.T @ dK
        self.dW_v += self.x.T @ dV
        dx_q = dQ @ self.W_q.T
        dx_k = dK @ self.W_k.T
        dx_v = dV @ self.W_v.T
        dx = dx_q + dx_k + dx_v
        return dx

    def zero_grad(self):
        self.dW_q.fill(0)
        self.dW_k.fill(0)
        self.dW_v.fill(0)


Testing Gradient flow

In [8]:
np.random.seed(0)
x = np.random.randn(4, 8)
attn = SelfAttention(embed_dim=8)

for epoch in range(100):
    out = attn.forward(x)
    loss = np.sum(out)  # dummy loss
    dout = np.ones_like(out)

    attn.zero_grad()
    dx = attn.backward(dout)

    # SGD
    attn.W_q -= 0.01 * attn.dW_q
    attn.W_k -= 0.01 * attn.dW_k
    attn.W_v -= 0.01 * attn.dW_v

    if epoch % 10 == 0:
        print(f"Epoch {epoch} | Loss: {loss:.4f}")


Epoch 0 | Loss: -1.3954
Epoch 10 | Loss: -41.1197
Epoch 20 | Loss: -156.6300
Epoch 30 | Loss: -347.9961
Epoch 40 | Loss: -535.4889
Epoch 50 | Loss: -722.7095
Epoch 60 | Loss: -909.8796
Epoch 70 | Loss: -1097.0331
Epoch 80 | Loss: -1284.1794
Epoch 90 | Loss: -1471.3220
