<a href="https://colab.research.google.com/github/IvaroEkel/AI-Spielplatz/blob/main/attention_QKV_cat_jumped_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Understanding the Attention Mechanism

This notebook explains the attention mechanism step-by-step using a simplified version of the sentence:

> **"The cat, which had been sleeping on the couch, suddenly jumped."**

We'll focus on how the word **"jumped"** attends to other words using the Query-Key-Value mechanism.

In [None]:
import numpy as np
import matplotlib.pyplot as plt


In [None]:
# Simplified example with toy vectors
words = ["cat", "sleeping", "couch", "suddenly", "jumped"]
vectors = {
    "cat": np.array([1.0, 0.0]),
    "sleeping": np.array([0.8, 0.3]),
    "couch": np.array([0.4, 0.6]),
    "suddenly": np.array([0.1, 0.9]),
    "jumped": np.array([0.9, 0.1])
}

# Use the vector for 'jumped' as the Query
q = vectors["jumped"]

# Compute dot products between Query and each Key: Q*K^T
scores = np.array([np.dot(q, vectors[w])/np.sqrt(2) for w in words])
# scores = np.array([np.dot(q, vectors[w]) for w in words])

def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

# Compute attention weights: softmax(Q/)
weights = softmax(scores)

# Compute the attended output vector
attended_output = sum(weights[i] * vectors[words[i]] for i in range(len(words)))
# print the attended output
print(f"Attended output vector: {attended_output}")




In [None]:
# plot the vectors
fig, ax = plt.subplots()
for word, vec in vectors.items():
    ax.arrow(0, 0, vec[0], vec[1], head_width=0.05, head_length=0.1, fc='k', ec='k')
    ax.text(vec[0]*1.15, vec[1]*1.15, word)
    ax.axis('equal')
plt.show()


In [None]:
# Plot attention weights
plt.figure(figsize=(8, 4))
plt.bar(words, weights, color='skyblue')
plt.title("Attention Weights from 'jumped'")
plt.ylabel("Weight")
plt.grid(axis='y')
plt.show()


# Show weights
print(f"Query = '{words[-1]}'")
for word, w in zip(words, weights):
    print(f"{word:>10}: {w:.3f}")
