# Full Transformer Decoder Stack
Implements a full stack of **8 Transformer decoder blocks** using PyTorch.

Each decoder block contains:
- Masked Multi-Head Self-Attention (prevents peeking ahead)
- Encoder–Decoder Attention (cross-attends to encoder output)
- Residual connections + LayerNorm
- Feedforward Network (512 → 2048 → 512)

The decoder input (e.g. "I understand this") is treated as a partially generated sentence, passed along with simulated encoder output.

Output: Final token representations of shape (3, 512) — enriched through 8 rounds of self-attention, cross-attention, and feedforward computation.

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# ------------------------
# Positional Encoding Class
# ------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)                # (max_len, 1)
        i = torch.arange(d_model).unsqueeze(0)                       # (1, d_model)
        angle_rates = 1 / torch.pow(10000, (2 * (i // 2)) / d_model)
        angle_rads = position * angle_rates

        PE = torch.zeros_like(angle_rads)
        PE[:, 0::2] = torch.sin(angle_rads[:, 0::2])
        PE[:, 1::2] = torch.cos(angle_rads[:, 1::2])

        self.register_buffer('PE', PE)  # Not trainable

    def forward(self, x):
        seq_len = x.size(0)
        return x + self.PE[:seq_len]


# ------------------------
# Decoder Block
# ------------------------
class DecoderBlock(nn.Module):
    def __init__(self, d_model=512, num_heads=8, ffn_hidden=2048):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # 1. Masked Self-Attention
        self.W_Q_self = nn.ModuleList([nn.Linear(d_model, self.d_k) for _ in range(num_heads)])
        self.W_K_self = nn.ModuleList([nn.Linear(d_model, self.d_k) for _ in range(num_heads)])
        self.W_V_self = nn.ModuleList([nn.Linear(d_model, self.d_k) for _ in range(num_heads)])

        # 2. Encoder–Decoder Cross Attention
        self.W_Q_encdec = nn.ModuleList([nn.Linear(d_model, self.d_k) for _ in range(num_heads)])
        self.W_K_encdec = nn.ModuleList([nn.Linear(d_model, self.d_k) for _ in range(num_heads)])
        self.W_V_encdec = nn.ModuleList([nn.Linear(d_model, self.d_k) for _ in range(num_heads)])

        self.W_O = nn.Linear(d_model, d_model)

        # 3. Feedforward Network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ffn_hidden),
            nn.ReLU(),
            nn.Linear(ffn_hidden, d_model)
        )

        # 4. LayerNorms
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, encoder_output, look_ahead_mask=None):
        # ------------------
        # 1. Masked Self-Attention
        # ------------------
        heads = []
        for i in range(self.num_heads):
            Q = self.W_Q_self[i](x)
            K = self.W_K_self[i](x)
            V = self.W_V_self[i](x)

            scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
            if look_ahead_mask is not None:
                scores = scores.masked_fill(look_ahead_mask == 0, float('-inf'))

            attn = F.softmax(scores, dim=-1)
            Z = attn @ V
            heads.append(Z)

        concat = torch.cat(heads, dim=-1)
        self_attn_out = self.W_O(concat)
        x = self.norm1(x + self_attn_out)

        # ------------------
        # 2. Encoder–Decoder Attention
        # ------------------
        heads = []
        for i in range(self.num_heads):
            Q = self.W_Q_encdec[i](x)
            K = self.W_K_encdec[i](encoder_output)
            V = self.W_V_encdec[i](encoder_output)

            scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
            attn = F.softmax(scores, dim=-1)
            Z = attn @ V
            heads.append(Z)

        concat = torch.cat(heads, dim=-1)
        encdec_out = self.W_O(concat)
        x = self.norm2(x + encdec_out)

        # ------------------
        # 3. Feedforward
        # ------------------
        ffn_out = self.ffn(x)
        out = self.norm3(x + ffn_out)
        return out


# ------------------------
# Decoder Stack
# ------------------------
class TransformerDecoder(nn.Module):
    def __init__(self, num_layers=8, d_model=512, num_heads=8, ffn_hidden=2048):
        super().__init__()
        self.layers = nn.ModuleList([
            DecoderBlock(d_model, num_heads, ffn_hidden) for _ in range(num_layers)
        ])

    def forward(self, x, encoder_output, look_ahead_mask=None):
        for layer in self.layers:
            x = layer(x, encoder_output, look_ahead_mask)
        return x


In [4]:
# ------------------------
# Example Usage
# ------------------------

import pandas as pd

# Toy sentence (pretend it's the decoder input so far)
decoder_tokens = ["I", "understand", "this"]
seq_len = len(decoder_tokens)
d_model = 512

# Random decoder embeddings
decoder_input = torch.randn(seq_len, d_model)

# Add positional encoding
pe = PositionalEncoding(d_model)
decoder_input = pe(decoder_input)

# Simulated encoder output (normally comes from encoder block)
encoder_output = torch.randn(seq_len, d_model)

# Look-ahead mask (prevents future peeking in decoder self-attn)
look_ahead_mask = torch.tril(torch.ones(seq_len, seq_len)).bool()

# Build decoder stack
decoder = TransformerDecoder(num_layers=8)

# Run decoder
decoder_output = decoder(decoder_input, encoder_output, look_ahead_mask)

# Inspect result
df = pd.DataFrame(decoder_output.detach().numpy(), index=decoder_tokens, columns=[f"dim_{i+1}" for i in range(d_model)])
df.head()

Unnamed: 0,dim_1,dim_2,dim_3,dim_4,dim_5,dim_6,dim_7,dim_8,dim_9,dim_10,...,dim_503,dim_504,dim_505,dim_506,dim_507,dim_508,dim_509,dim_510,dim_511,dim_512
I,-0.700927,0.459993,0.486396,-0.760917,-0.544988,0.681811,-0.704418,-0.278896,1.915847,-1.126197,...,0.038788,2.104369,-1.647549,0.110661,-0.030985,0.55518,-1.075598,0.119225,-0.837795,0.963089
understand,0.454692,-0.059841,0.12134,0.276973,-1.402326,1.036451,-1.022408,0.738934,0.494328,0.364696,...,-0.597994,1.023504,-2.176769,0.979234,-1.116079,1.496523,-2.32908,-1.12726,-0.072432,0.590817
this,-0.536496,0.276218,-0.451541,-1.167974,-0.995088,0.064617,-1.45694,0.229347,1.807538,-0.631978,...,-1.278745,1.378657,-1.698918,0.066068,-0.491405,0.753732,-1.628311,-1.579115,0.41148,0.160179


### * Note on Real Use Case
In a real translation task, here's what will happen:
- encoder_output will come from passing the source sentence through the encoder.
- decoder_input will be the target sentence, but shifted right (starting with a special token indicating the beginning of sentence) — this is called teacher forcing.
- The model is trained to predict the next word at each position.
- Look-ahead masking ensures that the model doesn't cheat by peeking into the future.

Right now we just use random vectors to test the architecture. Later, we’ll hook it to real tokens, embeddings, vocabulary, and training loop.

# Transformer Decoder Block — Shapes, Concepts & Flow

### Decoder Input Embedding (with Positional Encoding)

| Step | Name                | Shape                 | Description                                  |
|------|---------------------|------------------------|----------------------------------------------|
| 1    | Decoder Input `x`   | `(seq_len, d_model)`   | Input embeddings (e.g. previous target tokens) |
| 2    | Positional Encoding | `(seq_len, d_model)`   | Adds timing information using sin/cos waves  |
| 3    | Input + PE          | `(seq_len, d_model)`   | Final input to decoder stack                 |

---

### Masked Multi-Head Self-Attention (per decoder layer)

| Step | Name               | Shape                     | Description                                          |
|------|--------------------|----------------------------|------------------------------------------------------|
| 4    | Linear Q/K/V       | `(seq_len, d_k)`           | Projects input into queries, keys, and values        |
| 5    | Attention Scores   | `(seq_len, seq_len)`       | Dot product of Q and Kᵀ, scaled                     |
| 6    | Look-Ahead Mask    | `(seq_len, seq_len)`       | Masks out future positions during training          |
| 7    | Weighted V         | `(seq_len, d_k)`           | Attention-weighted sum of values                    |
| 8    | Heads              | `num_heads × (seq_len, d_k)` | Separate attention heads                         |
| 9    | Concatenation      | `(seq_len, d_model)`       | Merge all heads                                     |
| 10   | Final Linear       | `(seq_len, d_model)`       | Project back to full model dimension                |
| 11   | Add & Norm1        | `(seq_len, d_model)`       | Residual + LayerNorm                                |

---

### Encoder–Decoder Cross Attention

| Step | Name                 | Shape                     | Description                                             |
|------|----------------------|----------------------------|---------------------------------------------------------|
| 12   | Linear Q (from Dec)  | `(seq_len, d_k)`           | Queries from decoder                                    |
| 13   | Linear K/V (from Enc)| `(seq_len, d_k)`           | Keys and values from encoder output                     |
| 14   | Cross-Attn Scores    | `(seq_len, seq_len)`       | Dot product between decoder queries and encoder keys    |
| 15   | Weighted V (Encoder) | `(seq_len, d_k)`           | Aggregated info from encoder                           |
| 16   | Merge Heads          | `(seq_len, d_model)`       | Combine all heads                                       |
| 17   | Final Linear         | `(seq_len, d_model)`       | Project to model dim                                    |
| 18   | Add & Norm2          | `(seq_len, d_model)`       | Residual + LayerNorm                                    |

---

### Feedforward Network

| Step | Name          | Shape                   | Description                                  |
|------|---------------|--------------------------|----------------------------------------------|
| 19   | Linear 1      | `(seq_len, ffn_hidden)`  | Expands representation (e.g., 512 → 2048)    |
| 20   | ReLU          | `(seq_len, ffn_hidden)`  | Non-linearity                                |
| 21   | Linear 2      | `(seq_len, d_model)`     | Project back to model dimension              |
| 22   | Add & Norm3   | `(seq_len, d_model)`     | Final residual + LayerNorm                   |

---

### Key Concepts Recap

- **Masked Self-Attention:** Allows each position to only attend to past (and current) tokens — essential for autoregressive decoding.
- **Cross-Attention:** Each target token attends to all source (encoder) tokens — brings in source-side context.
- **Multi-Head Attention:** Enables learning from multiple representation subspaces.
- **LayerNorm & Residuals:** Ensure training stability and gradient flow.
- **FFN:** Adds non-linearity and dimensional transformation for token-wise refinement.