# Full Transformer Encoder Stack

Implements a full stack of **8 Transformer encoder blocks** using PyTorch.

Each block consists of:
- Multi-head self-attention (8 heads, d_k = 64)
- Residual connections + LayerNorm
- Feedforward network (512 → 2048 → 512)

The sentence `"I understand this"` is passed through the stack, starting with random embeddings and sinusoidal positional encoding.

Output: Final token representations of shape `(3, 512)` — enriched by 8 rounds of context-aware attention and feedforward refinement.


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# ------------------------
# Positional Encoding Function (Same as in previous notebooks)
# ------------------------
def add_positional_encoding(X):
    seq_len, d_model = X.shape
    position = torch.arange(seq_len).unsqueeze(1)  # (seq_len, 1)
    i = torch.arange(d_model).unsqueeze(0)         # (1, d_model)
    angle_rates = 1 / torch.pow(10000, (2 * (i // 2)) / d_model)
    angle_rads = position * angle_rates

    PE = torch.zeros_like(angle_rads)
    PE[:, 0::2] = torch.sin(angle_rads[:, 0::2])
    PE[:, 1::2] = torch.cos(angle_rads[:, 1::2])

    return X + PE

# ------------------------
# Encoder Block (Self-Attn + FFN + LayerNorm)
# ------------------------
class EncoderBlock(nn.Module):
    def __init__(self, d_model=512, num_heads=8, ffn_hidden=2048):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_Q = nn.ModuleList([nn.Linear(d_model, self.d_k) for _ in range(num_heads)])
        self.W_K = nn.ModuleList([nn.Linear(d_model, self.d_k) for _ in range(num_heads)])
        self.W_V = nn.ModuleList([nn.Linear(d_model, self.d_k) for _ in range(num_heads)])
        self.W_O = nn.Linear(d_model, d_model)

        self.ffn = nn.Sequential(
            nn.Linear(d_model, ffn_hidden),
            nn.ReLU(),
            nn.Linear(ffn_hidden, d_model)
        )

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        # Multi-head Attention
        heads = []
        for i in range(self.num_heads):
            Q = self.W_Q[i](x)
            K = self.W_K[i](x)
            V = self.W_V[i](x)

            scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
            attn = F.softmax(scores, dim=-1)
            Z = attn @ V
            heads.append(Z)

        concat = torch.cat(heads, dim=-1)
        attn_out = self.W_O(concat)

        # Add & Norm
        x = self.norm1(x + attn_out)

        # Feedforward
        ffn_out = self.ffn(x)
        out = self.norm2(x + ffn_out)
        return out

# ------------------------
# Full Transformer Encoder Stack (8 layers)
# ------------------------
class TransformerEncoder(nn.Module):
    def __init__(self, num_layers=8, d_model=512, num_heads=8, ffn_hidden=2048):
        super().__init__()
        self.layers = nn.ModuleList([
            EncoderBlock(d_model, num_heads, ffn_hidden) for _ in range(num_layers)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# ------------------------
# Example Run
# ------------------------
tokens = ["I", "understand", "this"]
seq_len = len(tokens)
d_model = 512

# Simulate input embeddings
X = torch.randn(seq_len, d_model)
X = add_positional_encoding(X)

# Run through 8-layer encoder
encoder = TransformerEncoder(num_layers=8)
encoder_output = encoder(X)  # (3, 512)

# Visualize result
import pandas as pd
df = pd.DataFrame(encoder_output.detach().numpy(), index=tokens, columns=[f"dim_{i+1}" for i in range(d_model)])
df.head()
