In [None]:
# Multi-Head Latent Attention (MLA)

In [4]:
# Example toy embeddings (in reality they are much larger, >1000 dimensions):
# 1. "The"     : [0.3, 0.2, 0.1]
# 2. "dog"     : [0.5, 0.1, 0.3]
# 3. "ran"     : [0.8, 0.4, 0.2]
# 4. "quickly" : [0.2, 0.7, 0.1]
# 5. "around"  : [0.2, 0.7, 0.1]
# 6. "the"     : [0.3, 0.2, 0.1]
#
# Notes:
# - In reality, embeddings have very large dimensions (hundreds or thousands).
# - We don’t know exactly what each dimension represents.
# - We just decide the size of the vector (embedding dimension).
#
# Key idea:
# - Each embedding vector initially represents some notion of the word itself.
#   Example: "dog" → [0.5, 0.1, 0.3]
# - But we want these vectors to capture not only the word,
#   but also the CONTEXT in which the word appears.
#
# Example with context:
# - The embedding for "ran" should encode that "ran" refers to the "dog".
# - Importantly, we should NOT include information about future words
#   (because the model’s job is to predict them).
#
# Prediction process:
# - The model does not predict the next word directly.
# - Instead, it outputs a probability distribution over the entire vocabulary.
#   (e.g., "around" gets high probability, "airplane" gets low probability).
#
# Model design:
# 1. The transformer modifies each word’s vector so it carries context.
#    ("ran" knows it refers to "dog").
# 2. The output head (the last layer) uses this contextualized vector
#    to produce probabilities over the vocabulary.
# 3. If the model is trained well:
#    - Words that make sense in context get high probability (e.g., "around").
#    - Words that don’t fit the context get low probability (e.g., "airplane").


# The embedding vector represents the semantic meaning of a word.
#
# The key vector encodes what contextual information a word can provide to others.
# Example: the key vector of "dog" contains information that could be useful
# for words like "quickly" or "ran".
#
# The query vector encodes what type of context a word is looking for.
# Example: the query of "quickly" asks questions like:
#   - "Who is acting?" 
#   - "What action is happening?"
#
# Attention works by comparing queries against keys:
# - Query("quickly") is compared to Key("dog"), Key("ran"), etc.
# - If the model is well trained, the query of "quickly" will strongly match
#   with the keys of "ran" and "dog", since those provide the most relevant context.
#
# In summary:
# - Embedding → meaning of the word itself.
# - Key       → what context a word can provide.
# - Query     → what context a word is seeking.

# Attention mechanism: step by step
#
# 1) Compute raw scores:
#    - For each query, take the dot product with every key.
#    - This produces the *attention score matrix* (unnormalized).
#
# Example sentence: "life is short eat dessert first"
#
# Raw attention scores (simplified, arbitrary numbers for illustration):
#
#           life   is   short   eat   dessert   first
# life      2.0   1.5    1.2    0.8     0.3      0.1
# is        1.0   2.5    1.7    0.6     0.4      0.2
# short     0.9   1.1    2.8    1.0     0.7      0.3
# eat       0.4   0.5    0.9    2.2     1.5      0.8
# dessert   0.3   0.4    0.6    1.0     2.6      1.9
# first     0.2   0.3    0.5    0.8     1.2      2.7
#
# 2) Scale scores:
#    - Divide each dot product by √d (d = dimension of keys/queries).
#    - This prevents very large values when d is large.
#
# Scaled attention scores (still before masking):
#
#           life   is   short   eat   dessert   first
# life      0.5   0.4    0.3    0.2     0.1      0.0
# is        0.2   0.6    0.4    0.1     0.1      0.0
# short     0.2   0.3    0.7    0.2     0.2      0.1
# eat       0.1   0.1    0.2    0.6     0.4      0.2
# dessert   0.1   0.1    0.2    0.2     0.7      0.5
# first     0.0   0.1    0.1    0.2     0.3      0.7
#
# 3) Apply causal masking:
#    - For each query word, block attention to *future* tokens by setting those
#      scores to -∞ (which softmax will later turn into 0).
#
# Example after masking (values to the right of the diagonal are set to -∞):
#
#           life    is   short   eat   dessert   first
# life      0.5   -∞     -∞      -∞      -∞       -∞
# is        0.2    0.6   -∞      -∞      -∞       -∞
# short     0.2    0.3    0.7    -∞      -∞       -∞
# eat       0.1    0.1    0.2     0.6    -∞       -∞
# dessert   0.1    0.1    0.2     0.2     0.7     -∞
# first     0.0    0.1    0.1     0.2     0.3      0.7
#
# 4) Softmax:
#    - Apply softmax row by row to turn scores into probabilities.
#    - Probabilities sum to 1 for each query.
#    - Future tokens get exact probability 0 (thanks to masking).
#
# Final attention matrix (probabilities):
#
#           life    is   short   eat   dessert   first
# life      1.0   0.0    0.0     0.0     0.0      0.0
# is        0.3   0.7    0.0     0.0     0.0      0.0
# short     0.3   0.1    0.6     0.0     0.0      0.0
# eat       0.2   0.2    0.2     0.4     0.0      0.0
# dessert   0.1   0.1    0.2     0.2     0.4      0.0
# first     0.1   0.1    0.2     0.2     0.2      0.2
#
# Interpretation:
# - Row = query word asking for context.
# - Column = key word providing context.
# - Values = probability of how much attention is given.
# - Example: "is" attends 30% to "life" and 70% to itself.


# and now we create a new vector is called value its what we gonna add 
# the ittuition of it what you should get from me to encode my context 
# value (dog) it encoded the context of dog it give 
# key (dog) [0.54, 0.36, 0.74]
# value (dog) [0.12, 0.84, 0.51]
# query (quickly) [0.22, 0.64, 0.73]
# value (quickly) [0.62 ,0.24, 0.33]

In [6]:
# Query vs Key vs Value
# 1. Query (Q)
# Question: "What context do I need?"
# Example: "quickly" → its query says:
#   "I need to know who performed the action and what action happened."
#
# 2. Key (K)
# Answer: "What context can I provide?"
# Example: "dog" → its key says:
#   "I can provide information that I am the subject."
# Example: "ran" → its key says:
#   "I can provide information that I am the main action."
#
# Queries are compared against keys to decide which tokens to attend to.
#
# 3. Value (V)
# Once you know which token to look at (via keys), what you actually receive is the value.
# The value is the real content that a word contributes to the others.
#
# Examples:
# - "dog": its value contains "animal, subject, singular".
# - "ran": its value contains "past action, verb of movement".
#
# Intuition:
# - Key = the doorbell (identity: "I’m here").
# - Value = the package you receive when you open the door.
# The query rings many doorbells (keys), and based on attention weights,
# it collects a weighted mix of the values.
#
# Quick summary:
# - Query → what I am looking for.
# - Key   → how I identify myself so others can find me.
# - Value → what I give if someone attends to me.
#
# --- Mini dialogue example with the sentence: "The dog ran quickly" ---
#
# "quickly" (Query):
#   → "I need context… who did the action and what action was it?"
#
# "dog" (Key):
#   → "I can tell you that I am the subject."
#
# "dog" (Value):
#   → "Here’s my content: subject = dog, singular, actor of the action."
#
# "ran" (Key):
#   → "I can tell you I’m the verb of movement, in past tense."
#
# "ran" (Value):
#   → "Here’s my content: action = run, tense = past."
#
# "quickly" (receiving Values):
#   → "Okay, with the Values I got from ‘dog’ and ‘ran’, now I know this means:
#       ‘the dog ran quickly’."
#
# Final intuition:
# - Query → question ("what context do I need?")
# - Key   → identity ("what context can I provide?")
# - Value → package ("the actual information I contribute").

# In self-attention:
# - For each token, the query is compared with the keys of other tokens.
# - This produces attention weights (probabilities).
# - Using these weights, a weighted combination of the values from the attended tokens is created.
# - The resulting context vector is then combined with the token’s original embedding
#   (via residual connection and normalization) to form the new representation of that token.

# After computing Q, K, and V for each word, we pass their outputs through a
# separate neural network called a feed-forward layer.
# - This layer can be a simple MLP or something more advanced like a mixture of experts.
# - It applies a nonlinear activation function, which reshapes the information.
# Intuition:
# Attention may mix in some irrelevant context. The feed-forward layer acts like
# a filter: it strengthens important features and reduces noise or less useful context,
# making the representation of each word more meaningful. 

In [7]:
# Let see how Q, K, and V are created 
# For each token embedding "e", we create three new vectors:
#   q = e @ Wq
#   k = e @ Wk
#   v = e @ Wv
# Here Wq, Wk, Wv are weight matrices learned during training.
#
# 3) Are Wq, Wk, Wv always the same?
# - They are the same for all tokens within one attention head of one layer
#   (shared weights across tokens).
# - They are different for each head and each transformer layer.
# - During training, their values change as they are optimized. 
#   After training, they stay fixed.
# Important:
# - The "dot product" between Q and K happens later, when calculating attention scores.
# - The creation of Q, K, V is just a linear projection using the learned matrices.


# --- Understanding matrix multiplication vs. dot product in Attention ---
#
# 1) Creating Q, K, V:
# - Each token has its embedding vector, e.g. dog = [0.5, 0.1, 0.3].
# - To get Q, K, V we multiply the embedding with Wq, Wk, Wv:
#       Q = E @ Wq
#       K = E @ Wk
#       V = E @ Wv
# - Here E is the matrix of embeddings for all tokens (tokens are rows).
# - Matrix multiplication is row × column, so no transpose is needed.
# - The result: Q, K, V are also matrices where each row corresponds to one token.
# - All tokens are projected at the same time (batch style).
#
# 2) Computing attention scores (Q @ K^T):
# - Now we want to compare every query with every key.
# - If we try Q @ K (without transpose), dimensions won’t match and the math doesn’t make sense.
# - We need K^T so that each query row can take a dot product with each key row.
#   Example: row i of Q compares with row j of K → scalar score.
# - This gives us the attention score matrix (tokens × tokens).

In [None]:
# --- Masking the attention scores (affinity matrix) ---
#
# After computing the raw attention scores = Q @ K^T, 
# we get an affinity matrix (one row per query token, one column per key token).
#
# Example with 4 tokens: "life is short eat"
#
# Raw scores (toy values):
#        life   is   short   eat
# life    2.0   1.5   1.2    0.8
# is      1.0   2.5   1.7    0.6
# short   0.9   1.1   2.8    1.0
# eat     0.4   0.5   0.9    2.2
#
# --- Why masking? ---
# In causal/decoder attention, a token should NOT look into the future.
# Example: "life" cannot attend to "is", "short", "eat".
#
# So we set all positions to the right of the diagonal = -∞ (very negative).
#
# Masked scores:
#        life    is   short   eat
# life    2.0   -∞     -∞     -∞
# is      1.0    2.5   -∞     -∞
# short   0.9    1.1   2.8    -∞
# eat     0.4    0.5   0.9    2.2
#
# --- Next step ---
# After masking, we apply softmax row by row.
# -∞ turns into probability 0.
# This ensures each token only attends to past or current tokens.

# Attention row for "quickly" = [0.1, 0.7, 0.2]
# Values = [V(dog), V(ran), V(quickly)]
#
# New "quickly" = 0.1*V(dog) + 0.7*V(ran) + 0.2*V(quickly)

In [None]:
# --- From embeddings to Q, K, V (per head) ---
# Let E be the embedding matrix of the sequence, with tokens as rows:
#   E ∈ [num_tokens x d_model]
# For one attention head with head size d_head, we use learned weight matrices:
#   Wq ∈ [d_model x d_head], Wk ∈ [d_model x d_head], Wv ∈ [d_model x d_head]
# We obtain Q, K, V by matrix multiplication (no transpose needed here):
#   Q = E @ Wq    # Q ∈ [num_tokens x d_head]
#   K = E @ Wk    # K ∈ [num_tokens x d_head]
#   V = E @ Wv    # V ∈ [num_tokens x d_head]
# Each row of Q/K/V corresponds to one token (the projection of that token’s embedding).

# --- What is out_head for a token? ---
# For one attention head and one token (e.g. "quickly"):
#   out_head(quickly) = [attention row of "quickly"] @ V
#
# Example:
#   Attention row for "quickly" = [0.1, 0.7, 0.2]
#   V = [V(dog), V(ran), V(quickly)]
#   => out_head(quickly) = 0.1*V(dog) + 0.7*V(ran) + 0.2*V(quickly)

# --- Where does the attention row come from? ---
# Attention matrix = softmax( (Q @ K^T) / sqrt(d_head) )
# - Each row = probabilities of how much one token attends to all others.
# - Row i corresponds to token i (e.g., "quickly").

# --- Full head output (all tokens at once) ---
# Out_head = Attention @ V
# Shapes:
#   Attention: [num_tokens x num_tokens]
#   V:         [num_tokens x d_head]
#   Out_head:  [num_tokens x d_head]

# --- Multi-head attention ---
# - Each head produces its own Out_head (a different "view" of context).
# - Concatenate along the feature dimension:
#     concat(out_head_1, ..., out_head_H) → [num_tokens x (H * d_head)] = [num_tokens x d_model]
# - Apply final linear projection:
#     Out = concat(...) @ W_o  → [num_tokens x d_model]

# --- Important ---
# - Q, K, V are intermediate projections used to compute Attention.
# - They are NOT concatenated; what we concatenate are the per-head outputs (Out_head).

In [None]:
# --- After multi-head: concatenate and project with output matrix ---
#
# Important correction:
# - We do NOT concatenate the Value matrices themselves.
# - We concatenate the per-head OUTPUTS (the context vectors), i.e., Out_head from each head.
#
# Steps:
# 1) For each head h:
#      Attention_h = softmax(Q_h @ K_h.T / sqrt(d_head))
#      Out_head_h  = Attention_h @ V_h        # shape: [T x d_head]
#
# 2) Concatenate all head outputs along the feature dimension:
#      Out_concat = concat(Out_head_1, ..., Out_head_H)  # shape: [T x (H * d_head)] = [T x d_model]
#
# 3) Apply the output projection (the "output conversion" matrix W_o):
#      Final = Out_concat @ W_o   # W_o ∈ [d_model x d_model], so Final ∈ [T x d_model]
#
# Intuition:
# - Each head gives a different slice of context.
# - Concatenation gathers all slices.
# - W_o mixes them back into the model’s main representation size.


In [None]:
# --- Output projection (Wo) after multi-head attention ---
#
# 1) After multi-head attention we concatenate all head outputs:
#      Out_concat ∈ [num_tokens x (H * d_head)] = [num_tokens x d_model]
#    - Each head gives part of the context.
#    - Concatenation just places them side by side.
#
# 2) To mix these pieces together, we apply the output projection matrix Wo:
#      Final = Out_concat @ Wo
#    - Wo ∈ [d_model x d_model]
#    - The result has the same size as the original embeddings: [num_tokens x d_model]
#
# 3) Where does Wo come from?
#    - Wo is a learned weight matrix, just like Wq, Wk, Wv.
#    - It is trained with backpropagation together with the rest of the model.
#
# Intuition:
# - Concatenation = collect context from all heads.
# - Wo = learn how to combine them into one coherent vector per token,
#   so the model can continue with the next layers.