In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Queries, Keys, and Values: Building the QKV Projection from Scratch

*Part 1 of the Vizuara series on Self-Attention from First Principles*
*Estimated time: 35 minutes*

## 1. Why Does This Matter?

Every modern language model — GPT, BERT, LLaMA, Gemini — is built on a single mechanism: **self-attention**. And at the very core of self-attention are three simple objects: the **Query**, the **Key**, and the **Value**.

Before we can understand how Transformers "think," we need to understand how every word in a sentence gets transformed into these three representations. That is what this notebook is about.

By the end of this notebook, you will have:
- Built the QKV projection from scratch in PyTorch
- Computed Q, K, V vectors by hand and verified them against PyTorch
- Visualized how different weight matrices create different projections
- Understood *why* we need three separate projections instead of one

Let us get started.

## 2. Building Intuition

Before touching any code or equations, let us understand what Q, K, and V actually *mean*.

Think about how you search for something in a library.

- **Query (Q):** This is your search question. You walk into the library and say, "I need a book about marine biology." The Query encodes *what you are looking for*.

- **Key (K):** This is the label on each book's spine. Every book in the library has a label that says, "Here is what I am about — marine biology, quantum physics, French cooking." The Key encodes *what each item offers*.

- **Value (V):** This is the actual content inside the book. Once you find a matching book (the Key matches your Query), you open it and read the Value — the actual information.

Here is the beautiful part: in self-attention, **every word plays all three roles simultaneously**. The word "mat" has its own Query (when it wants to find relevant context), its own Key (so other words can find it), and its own Value (the information it contributes when found).

### Think About This

If every word just used its raw embedding as both the "search question" and the "label," what would go wrong? Why do we need *separate* learned transformations for Q, K, and V?

The answer: a word's role as a *searcher* (Query) is fundamentally different from its role as a *searchable item* (Key). The word "it" searching for its referent is a different operation than the word "it" being found by other words. Separate weight matrices let the model learn these different roles independently.

## 3. The Mathematics

Each word enters the Transformer as an **embedding vector** — a numerical representation of its meaning. We then create Q, K, V by multiplying this embedding by three separate learned weight matrices:

$$Q = X W^Q, \quad K = X W^K, \quad V = X W^V$$

**What this equation says computationally:** Take the input embedding matrix $X$ (where each row is one word's embedding), and multiply it by three separate weight matrices. The result is three new matrices — Q, K, and V — each with the same number of rows (one per word) but potentially different column dimensions.

If $X$ has shape $(n, d_{\text{model}})$ where $n$ is the number of words and $d_{\text{model}}$ is the embedding dimension, and each weight matrix has shape $(d_{\text{model}}, d_k)$, then Q, K, and V each have shape $(n, d_k)$.

The weight matrices $W^Q$, $W^K$, and $W^V$ are **learned** during training. The model figures out the best way to create Queries, Keys, and Values by seeing millions of sentences and adjusting these weights through backpropagation.

## 4. Let's Build It — Component by Component

### 4.1 Word Embeddings

Let us start with the very first step: representing words as vectors. In a real Transformer, embeddings are learned during training. For our educational example, we will create simple embeddings by hand.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import numpy as np

# Our tiny vocabulary
sentence = ["The", "cat", "sat", "on", "the", "mat"]

# Create simple 4-dimensional embeddings for each word
# In a real model, these would be learned; here we define them manually
embeddings = {
    "The": [1.0, 0.0, 0.5, 0.2],
    "cat": [0.3, 1.0, 0.1, 0.7],
    "sat": [0.5, 0.2, 1.0, 0.3],
    "on":  [0.1, 0.4, 0.3, 0.9],
    "the": [0.9, 0.1, 0.4, 0.2],
    "mat": [1.0, 0.5, -0.3, 0.8],
}

# Build the input matrix X: shape (6, 4)
X = torch.tensor([embeddings[word] for word in sentence], dtype=torch.float32)

print("Input matrix X (6 words, 4 dimensions):")
print(X)
print(f"\nShape: {X.shape}")
print(f"Each row is one word's embedding")

Now let us see what these embeddings look like visually.

In [None]:
# Visualize the embedding matrix
fig, ax = plt.subplots(figsize=(8, 5))
im = ax.imshow(X.numpy(), cmap='RdBu_r', aspect='auto', vmin=-1, vmax=1)
ax.set_xticks(range(4))
ax.set_xticklabels([f'Dim {i}' for i in range(4)])
ax.set_yticks(range(6))
ax.set_yticklabels(sentence)
ax.set_title('Input Embeddings X', fontsize=14, fontweight='bold')
ax.set_xlabel('Embedding Dimension')
ax.set_ylabel('Word')

# Add text annotations
for i in range(6):
    for j in range(4):
        ax.text(j, i, f'{X[i, j].item():.1f}', ha='center', va='center',
                color='white' if abs(X[i, j].item()) > 0.5 else 'black', fontsize=10)

plt.colorbar(im, ax=ax, label='Value')
plt.tight_layout()
plt.show()

### 4.2 The Weight Matrices

Now let us create the three weight matrices: $W^Q$, $W^K$, and $W^V$. These project our 4-dimensional embeddings down to 2-dimensional Q, K, V vectors (i.e., $d_k = 2$).

In [None]:
# Define weight matrices manually
# W^Q: projects embeddings into "query space"
W_Q = torch.tensor([
    [0.2, 0.1],
    [0.4, -0.3],
    [0.1, 0.5],
    [-0.2, 0.3]
], dtype=torch.float32)

# W^K: projects embeddings into "key space"
W_K = torch.tensor([
    [0.3, -0.1],
    [0.1, 0.4],
    [-0.2, 0.2],
    [0.5, 0.1]
], dtype=torch.float32)

# W^V: projects embeddings into "value space"
W_V = torch.tensor([
    [0.1, 0.3],
    [-0.2, 0.5],
    [0.4, -0.1],
    [0.2, 0.2]
], dtype=torch.float32)

print(f"W_Q shape: {W_Q.shape}  (d_model=4 -> d_k=2)")
print(f"W_K shape: {W_K.shape}")
print(f"W_V shape: {W_V.shape}")

### 4.3 Computing Q, K, V

This is the core step. We multiply the input embeddings by each weight matrix.

In [None]:
# Compute Q, K, V
Q = X @ W_Q  # Matrix multiplication: (6, 4) @ (4, 2) = (6, 2)
K = X @ W_K
V = X @ W_V

print("Queries Q:")
print(Q)
print(f"\nKeys K:")
print(K)
print(f"\nValues V:")
print(V)

print(f"\nAll three have shape: {Q.shape} — 6 words, each with a 2D vector")

Let us verify one computation by hand — the Query for "mat" (the last word):

In [None]:
# Manual computation for "mat"
x_mat = X[5]  # [1.0, 0.5, -0.3, 0.8]
q_mat_manual = torch.tensor([
    x_mat[0]*W_Q[0,0] + x_mat[1]*W_Q[1,0] + x_mat[2]*W_Q[2,0] + x_mat[3]*W_Q[3,0],
    x_mat[0]*W_Q[0,1] + x_mat[1]*W_Q[1,1] + x_mat[2]*W_Q[2,1] + x_mat[3]*W_Q[3,1],
])

print(f"x_mat = {x_mat.tolist()}")
print(f"\nManual computation:")
print(f"  q[0] = (1.0)(0.2) + (0.5)(0.4) + (-0.3)(0.1) + (0.8)(-0.2)")
print(f"       = 0.2 + 0.2 - 0.03 - 0.16 = {q_mat_manual[0].item():.4f}")
print(f"  q[1] = (1.0)(0.1) + (0.5)(-0.3) + (-0.3)(0.5) + (0.8)(0.3)")
print(f"       = 0.1 - 0.15 - 0.15 + 0.24 = {q_mat_manual[1].item():.4f}")
print(f"\nPyTorch result: {Q[5].tolist()}")
print(f"Manual result:  {q_mat_manual.tolist()}")
print(f"Match: {torch.allclose(Q[5], q_mat_manual)}")

In [None]:
# Visualize Q, K, V as 2D scatter plots
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for ax, matrix, title, color in zip(axes,
                                      [Q, K, V],
                                      ['Queries (Q)', 'Keys (K)', 'Values (V)'],
                                      ['#e74c3c', '#2ecc71', '#3498db']):
    data = matrix.detach().numpy()
    ax.scatter(data[:, 0], data[:, 1], c=color, s=100, edgecolors='black', linewidth=1, zorder=5)
    for i, word in enumerate(sentence):
        ax.annotate(word, (data[i, 0], data[i, 1]),
                    textcoords="offset points", xytext=(8, 8),
                    fontsize=11, fontweight='bold')
    ax.set_xlabel('Dimension 0', fontsize=11)
    ax.set_ylabel('Dimension 1', fontsize=11)
    ax.set_title(title, fontsize=13, fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.axhline(y=0, color='gray', linewidth=0.5)
    ax.axvline(x=0, color='gray', linewidth=0.5)

plt.tight_layout()
plt.suptitle('QKV Projections of "The cat sat on the mat"', y=1.02, fontsize=14, fontweight='bold')
plt.show()

Notice how the same words end up in different positions in Q-space, K-space, and V-space. The weight matrices have learned (well, we defined them) different projections — each one highlights different aspects of each word.

## 5. Your Turn

### TODO 1: Implement the QKV Projection as a PyTorch Module

In [None]:
class QKVProjection(nn.Module):
    """
    Compute Query, Key, and Value projections from input embeddings.

    Args:
        d_model: Dimension of input embeddings
        d_k: Dimension of Q, K, V vectors (projection dimension)
    """
    def __init__(self, d_model, d_k):
        super().__init__()
        # ============ TODO ============
        # Create three nn.Linear layers (without bias) for W_Q, W_K, W_V
        # Each should project from d_model to d_k
        # Hint: nn.Linear(in_features, out_features, bias=False)
        # ==============================

        self.W_q = ???  # YOUR CODE HERE
        self.W_k = ???  # YOUR CODE HERE
        self.W_v = ???  # YOUR CODE HERE

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, d_model)

        Returns:
            Q, K, V: Each of shape (batch_size, seq_len, d_k)
        """
        # ============ TODO ============
        # Apply each linear layer to x to get Q, K, V
        # ==============================

        Q = ???  # YOUR CODE HERE
        K = ???  # YOUR CODE HERE
        V = ???  # YOUR CODE HERE

        return Q, K, V

In [None]:
# Verification
torch.manual_seed(42)
qkv = QKVProjection(d_model=64, d_k=16)
test_input = torch.randn(2, 10, 64)  # batch=2, seq_len=10, d_model=64
Q_test, K_test, V_test = qkv(test_input)

assert Q_test.shape == (2, 10, 16), f"Expected Q shape (2, 10, 16), got {Q_test.shape}"
assert K_test.shape == (2, 10, 16), f"Expected K shape (2, 10, 16), got {K_test.shape}"
assert V_test.shape == (2, 10, 16), f"Expected V shape (2, 10, 16), got {V_test.shape}"

# Q, K, V should be different (different weight matrices)
assert not torch.allclose(Q_test, K_test), "Q and K should differ (different projections)"
print("All assertions passed!")

### TODO 2: Experiment with Different Projection Dimensions

In [None]:
def compare_projections(X, d_model, d_k_values):
    """
    Visualize how different projection dimensions affect the QKV vectors.

    Args:
        X: Input embeddings of shape (seq_len, d_model)
        d_model: Input dimension
        d_k_values: List of projection dimensions to compare
    """
    # ============ TODO ============
    # For each d_k in d_k_values:
    #   1. Create random weight matrices W_Q of shape (d_model, d_k)
    #   2. Compute Q = X @ W_Q
    #   3. Compute the pairwise cosine similarity between all Q vectors
    #   4. Store the similarity matrix
    #
    # Then plot the similarity matrices side by side
    #
    # Hints:
    #   - Use torch.randn(d_model, d_k) for random weights
    #   - Use F.cosine_similarity with appropriate broadcasting,
    #     or compute manually: sim = Q @ Q.T / (norms * norms.T)
    # ==============================

    fig, axes = plt.subplots(1, len(d_k_values), figsize=(5*len(d_k_values), 4))
    if len(d_k_values) == 1:
        axes = [axes]

    for ax, d_k in zip(axes, d_k_values):
        # YOUR CODE HERE: compute Q and similarity matrix
        W_Q = torch.randn(d_model, d_k)
        Q = X @ W_Q
        norms = Q.norm(dim=1, keepdim=True)
        sim = (Q @ Q.T) / (norms @ norms.T + 1e-8)

        im = ax.imshow(sim.detach().numpy(), cmap='coolwarm', vmin=-1, vmax=1)
        ax.set_title(f'd_k = {d_k}', fontsize=12, fontweight='bold')
        ax.set_xticks(range(len(sentence)))
        ax.set_xticklabels(sentence, rotation=45)
        ax.set_yticks(range(len(sentence)))
        ax.set_yticklabels(sentence)
        plt.colorbar(im, ax=ax, shrink=0.8)

    plt.suptitle('Query Similarity for Different Projection Dimensions', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Run the comparison
compare_projections(X, d_model=4, d_k_values=[1, 2, 4])

## 6. Putting It All Together

Let us now build a complete QKV projection that uses `nn.Linear` (the PyTorch way) and verify it matches our manual computation.

In [None]:
class QKVProjectionComplete(nn.Module):
    """Complete QKV projection with manual weight initialization."""

    def __init__(self, d_model, d_k):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_k, bias=False)
        self.W_k = nn.Linear(d_model, d_k, bias=False)
        self.W_v = nn.Linear(d_model, d_k, bias=False)

    def forward(self, x):
        return self.W_q(x), self.W_k(x), self.W_v(x)

# Create the module and set weights manually to match our earlier example
qkv_module = QKVProjectionComplete(d_model=4, d_k=2)

# nn.Linear stores weights transposed: weight shape is (d_k, d_model)
with torch.no_grad():
    qkv_module.W_q.weight.copy_(W_Q.T)
    qkv_module.W_k.weight.copy_(W_K.T)
    qkv_module.W_v.weight.copy_(W_V.T)

# Compute Q, K, V using the module
Q_mod, K_mod, V_mod = qkv_module(X)

print("Module Q matches manual Q:", torch.allclose(Q_mod, Q, atol=1e-6))
print("Module K matches manual K:", torch.allclose(K_mod, K, atol=1e-6))
print("Module V matches manual V:", torch.allclose(V_mod, V, atol=1e-6))

print("\nQ from module:")
print(Q_mod)
print("\nQ from manual:")
print(Q)

## 7. Training and Results

Let us see what happens when we *learn* the QKV projections through training. We will set up a simple task: given a pair of words, predict whether they should attend to each other (1) or not (0). We will let the model learn W_Q and W_K to make this prediction.

In [None]:
# Simple training task: learn which word pairs should attend to each other
# Ground truth: "cat" and "sat" should attend (1), random pairs should not (0)

torch.manual_seed(42)

d_model = 4
d_k = 2

# Learnable Q and K projections
W_q_learn = nn.Linear(d_model, d_k, bias=False)
W_k_learn = nn.Linear(d_model, d_k, bias=False)

optimizer = torch.optim.Adam(list(W_q_learn.parameters()) + list(W_k_learn.parameters()), lr=0.01)

# Training data: (word_i, word_j, should_attend)
pairs = [
    (1, 2, 1.0),  # cat-sat: should attend
    (5, 1, 1.0),  # mat-cat: should attend
    (0, 3, 0.0),  # The-on: should not attend
    (3, 2, 0.0),  # on-sat: should not attend
    (1, 5, 1.0),  # cat-mat: should attend
    (0, 4, 0.0),  # The-the: should not attend
]

losses = []

for epoch in range(200):
    total_loss = 0
    for i, j, target in pairs:
        q_i = W_q_learn(X[i].unsqueeze(0))  # Query for word i
        k_j = W_k_learn(X[j].unsqueeze(0))  # Key for word j

        # Attention score = dot product, passed through sigmoid
        score = torch.sigmoid((q_i * k_j).sum() / math.sqrt(d_k))
        loss = F.binary_cross_entropy(score, torch.tensor(target))
        total_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    losses.append(total_loss / len(pairs))

print(f"Final loss: {losses[-1]:.4f}")

In [None]:
# Plot the training curve
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(losses, color='#3498db', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('Learning QKV Projections', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Check what the model learned
print("Learned attention predictions:")
for i, j, target in pairs:
    q_i = W_q_learn(X[i].unsqueeze(0))
    k_j = W_k_learn(X[j].unsqueeze(0))
    score = torch.sigmoid((q_i * k_j).sum() / math.sqrt(d_k))
    pred = "ATTEND" if score.item() > 0.5 else "IGNORE"
    correct = "correct" if (score.item() > 0.5) == (target > 0.5) else "WRONG"
    print(f"  {sentence[i]:4s} -> {sentence[j]:4s}: score={score.item():.3f}  "
          f"predict={pred:6s}  target={'ATTEND' if target > 0.5 else 'IGNORE':6s}  [{correct}]")

## 8. Final Output

In [None]:
# Final comprehensive visualization: Q and K projections in learned space
with torch.no_grad():
    Q_learned = W_q_learn(X)
    K_learned = W_k_learn(X)

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

for ax, data, title, color in zip(axes,
                                    [Q_learned, K_learned],
                                    ['Learned Queries', 'Learned Keys'],
                                    ['#e74c3c', '#2ecc71']):
    pts = data.numpy()
    ax.scatter(pts[:, 0], pts[:, 1], c=color, s=150, edgecolors='black', linewidth=1.5, zorder=5)
    for i, word in enumerate(sentence):
        ax.annotate(word, (pts[i, 0], pts[i, 1]),
                    textcoords="offset points", xytext=(10, 10),
                    fontsize=12, fontweight='bold',
                    arrowprops=dict(arrowstyle='->', color='gray', lw=0.5))
    ax.set_xlabel('Dimension 0', fontsize=12)
    ax.set_ylabel('Dimension 1', fontsize=12)
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3)

plt.suptitle('Learned QKV Projections — Words That Should Attend Are Nearby',
             fontsize=15, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

print("\nCongratulations! You have built QKV projections from scratch!")
print("Next up: using these Q, K, V vectors to compute attention scores.")

## 9. Reflection and Next Steps

### Reflection Questions
1. Why do we need three separate weight matrices instead of reusing the same one for Q, K, and V? What would happen if $W^Q = W^K$?
2. How does the choice of $d_k$ (the projection dimension) affect the model's ability to distinguish between words?
3. In our training example, the model learned to make "cat" and "mat" attend to each other. What does the learned weight matrix $W^Q$ actually encode about these words?

### Optional Challenges
1. Add a **bias term** to the projections (use `nn.Linear(d_model, d_k, bias=True)`) and see how it affects the learned representations.
2. Try different embedding dimensions and projection dimensions. At what point does the model fail to learn the attention patterns?