## Transition: From Recurrent Memory to Relational Sequence Modeling

In the previous notebook, we analyzed LSTMs through a set of analytical lenses and showed that their limitations arise from core architectural assumptions rather than optimization or data constraints. These limitations consistently pointed to the same underlying causes: forced sequential dependency, fixed-capacity memory, premature importance decisions, compressed representations, destructive memory access, and poor scalability to long sequences.

This notebook builds directly on those insights. Instead of modifying recurrence or improving gating mechanisms, we explore an alternative architectural direction that removes the need for timestep-to-timestep dependency altogether. The goal is not to replace recurrence with a specific named model, but to derive a new class of sequence architectures that satisfy the architectural requirements identified by the lens-based analysis.

We begin by constructing a sequence model in which memory scales with input length, relevance is computed dynamically at usage time, interactions between sequence elements are content-based rather than position-based, and computation can be performed in parallel across all positions. Only after deriving these properties from first principles will we connect them to existing architectural realizations.

# Deriving Self-Attention — From the Very First Step

## Step 0: What problem are we *actually* solving?

We are given a sequence:

$
x_1, x_2, x_3, \dots, x_T
$

Example (keep this fixed in your head):

> “The keys to the cabinet are missing”

Our goal is **not**:

* to predict next token yet
* to name an architecture

Our goal is only this:

> **For each position, build a representation that contains the information it actually needs from the rest of the sequence.**

That’s it.

---

## Step 1: What did LSTM do wrong (minimum statement)

LSTM tried to do this:

$
\text{many tokens} ;\rightarrow; \text{one memory vector}
$

That forced:

* compression
* guessing importance early
* sequential dependency

So we impose our **first rule**:

> **We will NOT compress the entire sequence into one vector.**

---

## Step 2: New starting assumption (very basic)

Instead of one memory, we keep **one vector per token**.

So we do the most boring thing possible:

$
x_i ;\rightarrow; h_i \in \mathbb{R}^d
$

Now we have:

$
h_1, h_2, h_3, \dots, h_T
$

At this point:

* no recurrence
* no order logic
* no interaction

Just vectors.

---

## Step 3: Why this is still useless

Right now:

* $ (h_{1}) $ knows nothing about $ (h_{2}) $
* $ (h_{5}) $ knows nothing about $ (h_{2}) $

But language requires this:

> “When deciding something at position (i), I must look at other positions.”

So we ask the **first real question**:

> **How should one token look at other tokens?**

---

## Step 4: The most basic form of “looking”

To “look at” another token means:

> Measure **how relevant** token (j) is to token (i)

So we want a **number**:

$
\text{relevance}(i, j) \in \mathbb{R}
$

What can we use?

The simplest thing in linear algebra:

* dot product

So first attempt:

$
\text{relevance}(i, j) = h_i^\top h_{j}
$

This says:

* if vectors align → relevant
* if not → irrelevant

---

## Step 5: Why raw dot product is not enough

Problem:

* all tokens use the same representation
* token has no way to say **“I’m asking”** vs **“I’m providing”**

So we introduce a **very small idea**:

> A token should behave differently when *asking* than when *answering*.

---

## Step 6: Split roles (this is the key mental shift)

Each token will produce **three different vectors**:

1. One to **ask** questions
2. One to **advertise what it contains**
3. One to **send information**

So for token (i):

$
\begin{aligned}
\text{Query: } & q_i = W_Q h_i \
\text{Key: } & k_i = W_K h_i \
\text{Value: } & v_i = W_V h_i
\end{aligned}
$

This is not fancy.
It’s just **three linear layers**.

---

## Step 7: Now define relevance properly

Token (i) asks:

> “Which other tokens matter to me?”

So relevance becomes:

$
s_{ij} = q_i^\top k_j
$

Interpretation:

* (q_i): what I want
* (k_j): what token (j) offers
* dot product → match quality

---

## Step 8: Turn relevance into importance weights

Raw scores are meaningless alone.

We want:

* higher score → more influence
* total influence = 1

So we normalize:

$
\alpha_{ij} =
\frac{e^{s_{ij}}}{\sum_{m=1}^{T} e^{s_{im}}}
$

Now:
$
\sum_{j=1}^{T} \alpha_{ij} = 1
$

These are **importance weights**.

---

## Step 9: Finally, build the new representation

Token (i) now **collects information**:

$
z_i = \sum_{j=1}^{T} \alpha_{ij} v_j
$

Meaning:

* look at all tokens
* take what matters
* ignore the rest

This is the **entire mechanism**.

---

## Step 10: Stop and reflect (important)

Notice what we achieved **without naming anything**:

* no compression
* no recurrence
* no guessing early
* no destructive memory
* no long gradient paths

Each (z_i) is:

> “token (i) + exactly the context it needs”

---

## Step 11: Write it compactly (only now)

Let:

$
Q = HW_Q,\quad K = HW_K,\quad V = HW_V
$

Then:

$
Z = \text{softmax}(QK^\top)V
$

This is just **Steps 6–9 written together**.

---

## Step 12: Only now — give it a name

> A mechanism where each element in a sequence builds its representation by dynamically weighting other elements in the same sequence is called **self-attention**.

Not because it’s popular.
Because **nothing else satisfies the constraints you derived from LSTM’s failures**.

---

## One-line intuition (keep this)

> LSTM remembers by overwriting; self-attention remembers by keeping everything and choosing later.


## Step 0: What problem are we *actually* solving?

We are given a sequence:

$
x_1, x_2, x_3, \dots, x_T
$

Example (keep this fixed in your head):

> “The keys to the cabinet are missing”

Our goal is **not**:

* to predict next token yet
* to name an architecture

Our goal is only this:

> **For each position, build a representation that contains the information it actually needs from the rest of the sequence.**

That’s it.

In [1]:
import torch
import torch.nn as nn

text = open("tiny_shakespeare.txt", 'r', encoding='utf-8').read()

chars = sorted(list(set(text)))
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for c, i in stoi.items()}

data = torch.tensor([stoi[c] for c in text], dtype=torch.long)
seq_length = 50  # sequence length

n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split='train', batch_size = 64):
    source = train_data if split == 'train' else val_data
    ix = torch.randint(len(source) - seq_length - 1, (batch_size,))
    X = torch.stack([source[i:i+seq_length] for i in ix])
    Y = torch.stack([source[i+1:i+seq_length+1] for i in ix])
    return X, Y
x_dummy, y_dummy = get_batch(batch_size = 1)

In [7]:
text = ''.join([itos[num.item()] for num in x_dummy[0]])
target = ''.join([itos[num.item()] for num in y_dummy[0]])
print("This is text: ", text)
print("This is target: ", target)

print("-" * 100)
print("Step-0: Now lets build the representation for each of the input char: Embeddings")
n_embedding_dims = 2
embeddings = torch.rand(len(chars), n_embedding_dims)
x_enc = embeddings[x_dummy] # Now each input char is now a 2-D vector representation in 2D Space
x_enc[0][:10]

print(f"> Before Embeddings: Input Sequence looks something like this:\n          {x_dummy[0][:10]}")
print(f"> After Embeddings: Now each Input Sequence is a 2-D vector representation in 2D Space:\n          {x_enc[0][:10]}")
print("-" * 100)
print(f"> Now we have: h1, h2, h3, h4, h5, ..... h_t")
print(f"> At this point: ")
print(f"    * no recurrence")
print(f"    * no order logic")
print(f"    * no interaction")
print(f"NOTE: But this is useless because now h5 dsn't know anything about h2 and so on... ")
print(f"But language requires this:\n > “When deciding something at position (i), I must look at other positions.”")
print(f"\nNow Lets Build a mechanism to get the information from other positions as well")
print("-" * 100)
print(f"\nThis is our 1st 5 Embedding Vectors, {x_enc[0][:5]}\nnow if we want to see wheather token at index 3rd {x_enc[0][:5][3]} is relevent to token at index 2nd {x_enc[0][:5][2]},\nthen we compute this,\n{x_enc[0][:5][3]} @ {x_enc[0][:5][2]} = {(x_enc[0][:5][3] @ x_enc[0][:5][2]):.3f}")
print(f"> Higher relative dot product → higher relevance after normalization, {x_enc[0][:5][3]}")
print(f"    * if vectors align → relevant")
print(f"    * if not → irrelevant")
print("-" * 100)
print("Problem:")
print(f"    * all tokens use the same representation")
print(f"    * token has no way to say “I’m asking” vs “I’m providing”")
print(f"So now we use 3 different Matrix which behaves differently when asking vs giving the information")
print(f"and They are,\n1) Query: asks which tokes are relevent to me\n2) Key: Responds to query and says wheather it is really important to specific query or not")
print(f"3) Value: Which set of values I need to add so that I can update my current embedding representation so that i can use the information from the relevent key")
print("> Now A token should behave differently when asking than when answering.")

print("-" * 100)

# Creating 3 Linear Projections vector
n_query_dims = n_embedding_dims # one query per token
print("We want that for each token in embedding_dim there must be a Query, i.e for each Token we want a Query => Q.shape = (n_embedding_dims, n_query_dim)")
print("for Each Query we want a key, so => K.shape = Q.shape  = (n_embedding_dims, n_query_dim)")
n_key_dims = n_query_dims # one key per token
n_values_dims = 3
w_q = torch.rand(n_query_dims, n_query_dims)
w_k = torch.rand(n_query_dims, n_key_dims)
w_v = torch.rand(n_query_dims, n_values_dims)
sentence = ''.join([itos[num.item()] for num in x_dummy[0][:5]])
print("-" * 100)
print(f"Sentence: '{sentence}'")
for i, char in enumerate(sentence):
    print(f"index = {i} | char = {char}")
print("-" * 100)

# Now if let's say token at index  3rd have to check wheather token at index 2nd is relevent or not then,
# 2nd token will expose his information about what i can answer as: w_k @ x_enc[0][:5][2]
print("Now if let's say token at index 3rd have to check wheather token at index 2nd is relevent or not then,")
print(f"Token 3 will raise a Query as, {x_enc[0][:5][3] @ w_q}")
print(f"Token 2 will raise a Key as, {x_enc[0][:5][2] @ w_k}")
print(f"To check Relevency: we compute, {x_enc[0][:5][3] @ w_q} @ {x_enc[0][:5][2] @ w_k} = {((x_enc[0][:5][3] @ w_q) @ (x_enc[0][:5][2] @ w_k)):.3f}")
x_enc[0][:5][3] @ w_q # Here the 3rd token is asking wheather someone is relevent to me or not?
x_enc[0][:5][2] @ w_k
# If someone is looking for relevance and on the other hand if someone if giving its relevense for all tokens then we can get the answer as:
(x_enc[0][:5][3] @ w_q) @ (x_enc[0][:5][2] @ w_k) # This is how much important token at index 2 is, to token at index 3


# If we compute this for every 5 tokens then,
relevance_4_4 = (x_enc[0][:5][4] @ w_q) @ (x_enc[0][:5][4] @ w_k) # How much token at index 4th is relevent to token at index 4 (itself)
relevance_4_3 = (x_enc[0][:5][4] @ w_q) @ (x_enc[0][:5][3] @ w_k) # How much token at index 4th is relevent to token at index 3
relevance_4_2 = (x_enc[0][:5][4] @ w_q) @ (x_enc[0][:5][2] @ w_k) # How much token at index 4th is relevent to token at index 2
relevance_4_1 = (x_enc[0][:5][4] @ w_q) @ (x_enc[0][:5][1] @ w_k) # How much token at index 4th is relevent to token at index 1
relevance_4_0 = (x_enc[0][:5][4] @ w_q) @ (x_enc[0][:5][0] @ w_k) # How much token at index 4th is relevent to token at index 0
# -------------------------------------------------------------------------------------
relevance_3_3 = (x_enc[0][:5][3] @ w_q) @ (x_enc[0][:5][3] @ w_k) # How much token at index 3rd is relevent to token at index 3rd (itself)
relevance_3_2 = (x_enc[0][:5][3] @ w_q) @ (x_enc[0][:5][2] @ w_k) # How much token at index 3rd is relevent to token at index 2nd
relevance_3_1 = (x_enc[0][:5][3] @ w_q) @ (x_enc[0][:5][1] @ w_k) # How much token at index 3rd is relevent to token at index 1st
relevance_3_0 = (x_enc[0][:5][3] @ w_q) @ (x_enc[0][:5][0] @ w_k) # How much token at index 3rd is relevent to token at index 0
# -------------------------------------------------------------------------------------
relevance_2_2 = (x_enc[0][:5][2] @ w_q) @ (x_enc[0][:5][2] @ w_k) # How much token at index 2nd is relevent to token at index 2nd (itself)
relevance_2_1 = (x_enc[0][:5][2] @ w_q) @ (x_enc[0][:5][1] @ w_k) # How much token at index 2nd is relevent to token at index 1st
relevance_2_0 = (x_enc[0][:5][2] @ w_q) @ (x_enc[0][:5][0] @ w_k) # How much token at index 2nd is relevent to token at index 0
# -------------------------------------------------------------------------------------
relevance_1_1 = (x_enc[0][:5][1] @ w_q) @ (x_enc[0][:5][1] @ w_k) # How much token at index 1st is relevent to token at index 1st (itself)
relevance_1_0 = (x_enc[0][:5][1] @ w_q) @ (x_enc[0][:5][0] @ w_k) # How much token at index 1st is relevent to token at index 0
# -------------------------------------------------------------------------------------
relevance_0_0 = (x_enc[0][:5][0] @ w_q) @ (x_enc[0][:5][0] @ w_k) # How much token at index 2nd is relevent to token at index 0 (itself)

print(f"For Token {sentence[4]}: ")
print(f"> How much token at index = 4 ({sentence[4]}) is relevent to token at index = 4 ({sentence[4]}) = {relevance_4_4:3f}")
print(f"> How much token at index = 4 ({sentence[4]}) is relevent to token at index = 3 ({sentence[3]}) = {relevance_4_3:3f}")
print(f"> How much token at index = 4 ({sentence[4]}) is relevent to token at index = 2 ({sentence[2]}) = {relevance_4_2:3f}")
print(f"> How much token at index = 4 ({sentence[4]}) is relevent to token at index = 1 ({sentence[1]}) = {relevance_4_1:3f}")
print(f"> How much token at index = 4 ({sentence[4]}) is relevent to token at index = 0 ({sentence[0]}) = {relevance_4_0:3f}")
print("-" * 100)

print(f"For Token {sentence[3]}: ")
print(f"> How much token at index = 3 ({sentence[3]}) is relevent to token at index = 3 ({sentence[3]}) = {relevance_3_3:3f}")
print(f"> How much token at index = 3 ({sentence[3]}) is relevent to token at index = 2 ({sentence[2]}) = {relevance_3_2:3f}")
print(f"> How much token at index = 3 ({sentence[3]}) is relevent to token at index = 1 ({sentence[1]}) = {relevance_3_1:3f}")
print(f"> How much token at index = 3 ({sentence[3]}) is relevent to token at index = 0 ({sentence[0]}) = {relevance_3_0:3f}")
print("-" * 100)

print(f"For Token {sentence[2]}: ")
print(f"> How much token at index = 2 ({sentence[2]}) is relevent to token at index = 2 ({sentence[2]}) = {relevance_2_2:3f}")
print(f"> How much token at index = 2 ({sentence[2]}) is relevent to token at index = 1 ({sentence[1]}) = {relevance_2_1:3f}")
print(f"> How much token at index = 2 ({sentence[2]}) is relevent to token at index = 0 ({sentence[0]}) = {relevance_2_0:3f}")
print("-" * 100)

print(f"For Token {sentence[1]}: ")
print(f"> How much token at index = 1 ({sentence[1]}) is relevent to token at index = 1 ({sentence[1]}) = {relevance_1_1:3f}")
print(f"> How much token at index = 1 ({sentence[1]}) is relevent to token at index = 0 ({sentence[0]}) = {relevance_1_0:3f}")
print("-" * 100)

print(f"For Token {sentence[0]}: ")
print(f"> How much token at index = 2 ({sentence[0]}) is relevent to token at index = 0 ({sentence[0]}) = {relevance_0_0:3f}")
print("-" * 100)

print("OR we can implicitly calculate this as, ")
print(f"query = (x_enc[0][:5] @ w_q)")
print(f"key = (x_enc[0][:5] @ w_k)")
print(f"scores = query @ key.T")
query = (x_enc[0][:5] @ w_q)
key = (x_enc[0][:5] @ w_k)
value = (x_enc[0][:5] @ w_v)
scores = query @ key.T
exp_scores = torch.exp(scores)
normalized_score = exp_scores / exp_scores.sum(dim = 1, keepdim=True)
Z = normalized_score @ value
print("Now scores[i][j] denotes how much j is relevent to i")
print(f"> How much token at index = 4 ({sentence[4]}) is relevent to token at index = 3 ({sentence[3]}) = {scores[4][3]:3f}")

normalized_score[3]
arg_max = torch.argmax(normalized_score[3]).item()
print(f"> Sentence = `{sentence}`")
print(f"> Token 3 = `{sentence[3]}`")
print(f"> Token {arg_max} = `{sentence[arg_max]}`")
print(f"normalized_score[3] = {normalized_score[3]}")
print(f"> This means that for Token at index 3, This shows that how much token 3 cares about every other token")
print(f"> How much token 3 cares about token 1? = normalized_score[3][1] = {normalized_score[3][1]:.4f}")
print(f"> This means that token 3 cares mostly about Token {arg_max},")
print(f"> Which essentially means that, Token 3 will use most of the information from Token {arg_max},\n   but also uses information from rest of the Tokens, nothing is discarded completely (unless weight ≈ 0)")
print(f"> How to use information from other tokens?,\nwe use Value Matrix for that,\nZ = normalized_score @ value ")
print(f"`normalized_score @ value` makes every token rewrite itself using information from all other tokens, weighted by relevance ")
print(f"> For Updating the representation of token 3, we would do this,")
print(f"normalized_score[3] = {normalized_score[3]}")
print(f"value = {value}")
z_3 = normalized_score[3] @ value
print(f"We will do, normalized_score[3] @ value = {normalized_score[3] @ value}, what does this means?")
print(f"This is the new updated context aware representation of Token 3")
print(f"> New Context Aware Representation for Token 3 (context-enriched embedding) = {z_3}")
print(f"> Old Representation for Token 3 (original embedding (no context)) = {x_enc[0][3]}")
print(f"NOTE: Notice how the new context-rich embedding lives in new 3d space? and Original embedding was in 2d space?")
print(f"> The reason is: Old Representation for Token 3 lives in embedding space and n_embedding_dims = 2, ")
print(f"and New Context Aware Representation for Token 3 lives in Value space and n_values_dims = 3")
print("> Notice how, n_embedding_dims != n_values_dims? This is Architectural Design")
print(f"> We can optionally project this 3d --> 2d or Value space --> embedding space, by just another linear Projection Layer")
print(f"> If we want to use Stacked Layers then, the attention output must be mapped back to the embedding space. This is NOT OPTONAL ANYMORE")
print()
print(" ------------------------------------------------------------ Multi-head Attention --------------------------------------------------------------------------------- ")
print("> But wait, in Language like 'English', etc there are many relationships we want to discover,")
print("instead of letting a single head to learn all possible relationship in 'n_embedding_dims' vector-space")
print("we project this input sequence into 'h' different sub-space of each 'n_embedding_dims/h' vector sub-spaces and in each of the 'h' sub-spaces")
print("> We allow the model to learn the relationships, learnable parameters and discoveries independently and they all operate on the same input sequence and then we concatenate them all")
print("> Benefit? Number of Parameters stays the same because after concatenation of 'h' different sub-spaces of each dim = n_embedding_dims/h")
print("we get 'n_embedding_dims/h * h' = 'n_embedding_dims' again, so number of parameters are unchanged and each of the different 'h' learned differet patterns and we concatenate them to preserve all learnable patterns")
print("> Conclusion: Multi-head attention works because it allows multiple independent relational views of the same sequence to coexist without interference, and preserves all of them through concatenation.")
print()
print("------------------------------------------------------------ Positional Embeddings ------------------------------------------------------------")
print("But wait, even after learning all token-to-token dependencies, we ignored a fundamental property of language: position.")
print("Self-attention can learn relationships between words, but without positional information it has no notion of order.")
print("For example:")
print("S-1: The dog bit the man.")
print("S-2: The man bit the dog.")
print("Both sentences contain the same words, but their meanings are completely different due to word order.")
print("Without positional encodings, a self-attention model treats the input as a set, not a sequence, and cannot distinguish these cases.")
print("> Therefore, we must explicitly inject positional information so the model can learn relationships with respect to position.")
print("> Where injection happens")
print("> Given token embeddings: ")
print(">>> X ∈ R^{T × n_embedding_dims}")
print("> and positional encodings:")
print(">>> P ∈ R^{T × d_model}")
print("we form, ")
print("X_pos = X + P")
print("> All attention (Q, K, V) is computed from X_pos, not X.")
print("Q = X_pos @ W_q")
print("K = X_pos @ W_k")
print("V = X_pos @ W_v")
print("\n> There are two common ways to do this:")
print("Method 1: Learned Positional Embeddings")
print("> Create an embedding table for positions")
print("> One vector per position index")
print(">>> pos_embedding = nn.Embedding(max_len, n_embedding_dims)")
print(">>> For a sequence of length T:")
print(">>> positions = [0, 1, 2, ..., T-1]")
print(">>> P = pos_embedding(positions)")
print(">>> X_pos = X + P")
print("Method 2: Fixed (Sinusoidal) Positional Encodings")
print("Here, positions are not learned.")
print("For position pos and dimension i:")
print("P[pos, 2i]   = sin(pos / 10000^(2i / d_model))")
print("P[pos, 2i+1] = cos(pos / 10000^(2i / d_model))")
print("Then:\nX_pos = X + P")
print("Now,\n * Each position has a unique pattern\n * Relative positions can be inferred via linear operations\n * No extra learned parameters")
print("> Because sine and cosine allow linear layers to infer relative offsets between positions, even though the encoding itself is absolute.")
print("Example: ")
# Method-1: Learned Positional Embeddings
# Prepare the learned positional embeddings
pos_embeddings = torch.rand(len(chars), n_embedding_dims) # each token will get a learned positional embeddings of 'n_embedding_dims' dimensions
pos_idx = torch.arange(seq_length) # we encode the positions
print(f"Embeddings Shape: {embeddings.shape}")
print(f"Positional Embeddings shape: {pos_embeddings.shape}")
P = pos_embeddings[pos_idx] # now this P contains the positional informations only
x_enc = x_enc.squeeze(dim = 0) # remove the extra batch-dim or add the extra batch dim to 'P'
x_pos = x_enc + P # Now this x_pos contains the positional information and learned encoding infromation as well
print("we define: 'x_pos = x_enc + P' which contains the positional information as well")
print(f"First 2 values of x_enc = {x_enc[:2]}")
print(f"First 2 values of x_pos = {x_pos[:2]}")
print("> NOTE: Notice how `x_pos` is enriched with positional infromation as well?")
print("Note: Sinusoidal positional encodings are useful because shifting a position by Δ,")
print("      corresponds to a fixed linear transformation of its encoding, allowing attention layers to infer relative positions using only linear operations.")

This is text:  ty to my elders.

KATHARINA:
Of all thy suitors, h
This is target:  y to my elders.

KATHARINA:
Of all thy suitors, he
----------------------------------------------------------------------------------------------------
Step-0: Now lets build the representation for each of the input char: Embeddings
> Before Embeddings: Input Sequence looks something like this:
          tensor([58, 63,  1, 58, 53,  1, 51, 63,  1, 43])
> After Embeddings: Now each Input Sequence is a 2-D vector representation in 2D Space:
          tensor([[0.6645, 0.2370],
        [0.7891, 0.9039],
        [0.0207, 0.0243],
        [0.6645, 0.2370],
        [0.4648, 0.0068],
        [0.0207, 0.0243],
        [0.5886, 0.1974],
        [0.7891, 0.9039],
        [0.0207, 0.0243],
        [0.5165, 0.7936]])
----------------------------------------------------------------------------------------------------
> Now we have: h1, h2, h3, h4, h5, ..... h_t
> At this point: 
    * no recurrence
    * no order logi