# Transformer Architecture: Encoder-Decoder & Positional Encoding

Attention removes recurrence bottlenecks, but the transformer architecture adds structure—positional encodings, stacked encoder and decoder blocks, and masking strategies. This notebook assembles those pieces so you can build both encoder-decoder and decoder-only models.

## Learning Objectives

- Implement sinusoidal and learned positional encodings.
- Build transformer encoder and decoder layers using attention blocks.
- Apply padding and causal masks correctly.
- Construct a small decoder-only model capable of greedy generation.

## Reusing Attention Blocks

The transformer relies on multi-head attention and residual feed-forward networks. We package the implementation here so the notebook is self-contained.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt

def scaled_dot_product_attention(q, k, v, mask=None):
    d_k = q.size(-1)
    scores = q @ k.transpose(-2, -1) / d_k ** 0.5
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    weights = F.softmax(scores, dim=-1)
    output = weights @ v
    return output, weights

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, context=None, mask=None):
        context = x if context is None else context
        bsz = x.size(0)

        def reshape(tensor):
            return tensor.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)

        q = reshape(self.q_proj(x))
        k = reshape(self.k_proj(context))
        v = reshape(self.v_proj(context))

        attn_output, attn_weights = scaled_dot_product_attention(q, k, v, mask)
        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, -1, self.embed_dim)
        return self.out_proj(attn_output), attn_weights

class AttentionBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, cross_attention=False, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(embed_dim, num_heads)
        self.cross_attn = MultiHeadAttention(embed_dim, num_heads) if cross_attention else None
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, context=None, mask=None, context_mask=None):
        residual = x
        attn_out, _ = self.self_attn(self.norm1(x), mask=mask)
        x = residual + self.dropout(attn_out)

        if self.cross_attn is not None and context is not None:
            residual = x
            cross_out, _ = self.cross_attn(self.norm2(x), context=context, mask=context_mask)
            x = residual + self.dropout(cross_out)

        residual = x
        ff_out = self.ff(self.norm3(x))
        x = residual + self.dropout(ff_out)
        return x


## Positional Encodings

Without positional encodings, a transformer cannot distinguish between permutations of the same tokens.

In [None]:
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super().__init__()
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
        pe = torch.zeros(max_len, embed_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, : x.size(1)]

encoding = SinusoidalPositionalEncoding(32)
tokens = torch.zeros(1, 50, 32)
encoded = encoding(tokens)[0].detach()
plt.plot(encoded[:, :4])
plt.title("First four sinusoidal dimensions")
plt.show()


### Learned Positional Embeddings

Learned positional embeddings often help when sequence lengths stay within a known range, such as in language modeling.

In [None]:
class LearnedPositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=512):
        super().__init__()
        self.positional_embedding = nn.Parameter(torch.zeros(1, max_len, embed_dim))
        nn.init.normal_(self.positional_embedding, mean=0.0, std=0.02)

    def forward(self, x):
        length = x.size(1)
        return x + self.positional_embedding[:, :length]

learned_encoding = LearnedPositionalEncoding(32)
print(learned_encoding(torch.zeros(1, 10, 32)).shape)


## Encoder and Decoder Layers

Encoder layers include self-attention and feed-forward sublayers. Decoder layers add masked self-attention plus cross-attention over encoder outputs.

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.block = AttentionBlock(embed_dim, num_heads, ff_dim, cross_attention=False, dropout=dropout)

    def forward(self, src, src_mask=None):
        return self.block(src, mask=src_mask)

class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.block = AttentionBlock(embed_dim, num_heads, ff_dim, cross_attention=True, dropout=dropout)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
        return self.block(tgt, context=memory, mask=tgt_mask, context_mask=memory_mask)

encoder_layer = EncoderLayer(32, 4, 64)
decoder_layer = DecoderLayer(32, 4, 64)
memory = encoder_layer(torch.randn(2, 7, 32))
out = decoder_layer(torch.randn(2, 6, 32), memory)
print(out.shape)


## Transformer Stack

We now combine positional encodings, encoder/decoder layers, and masking into a full encoder-decoder transformer.

In [None]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, embed_dim=32, num_heads=4, ff_dim=64, num_layers=2, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_encoding = SinusoidalPositionalEncoding(embed_dim)
        self.encoders = nn.ModuleList([
            EncoderLayer(embed_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)
        ])
        self.decoders = nn.ModuleList([
            DecoderLayer(embed_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.output = nn.Linear(embed_dim, vocab_size)

    def make_causal_mask(self, size, device):
        return torch.tril(torch.ones(1, 1, size, size, device=device, dtype=torch.bool))

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        device = src.device
        src_emb = self.positional_encoding(self.embedding(src))
        memory = src_emb
        for layer in self.encoders:
            memory = layer(memory, src_mask)

        tgt_emb = self.positional_encoding(self.embedding(tgt))
        if tgt_mask is None:
            tgt_mask = self.make_causal_mask(tgt.size(1), device)
        output = tgt_emb
        for layer in self.decoders:
            output = layer(output, memory, tgt_mask, src_mask)

        return self.output(self.norm(output))

model = Transformer(vocab_size=100)
src = torch.randint(0, 100, (2, 7))
tgt = torch.randint(0, 100, (2, 6))
print(model(src, tgt).shape)


## Mini Task – Label Smoothing

Label smoothing improves generalization by preventing the model from becoming overconfident. Implement a label-smoothed cross-entropy loss.

In [None]:
def label_smoothed_nll_loss(logits, targets, smoothing=0.1):
    # TODO: implement log-softmax, smooth target distribution, and compute cross-entropy
    raise NotImplementedError


In [None]:
def label_smoothed_nll_loss(logits, targets, smoothing=0.1):
    num_classes = logits.size(-1)
    log_probs = torch.log_softmax(logits, dim=-1)
    with torch.no_grad():
        true_dist = torch.zeros_like(log_probs)
        true_dist.fill_(smoothing / (num_classes - 1))
        true_dist.scatter_(2, targets.unsqueeze(-1), 1.0 - smoothing)
    loss = (-true_dist * log_probs).sum(dim=-1).mean()
    return loss

logits = torch.randn(2, 6, 50)
targets = torch.randint(0, 50, (2, 6))
print(label_smoothed_nll_loss(logits, targets))


## Comprehensive Exercise – Mini Decoder-Only Transformer

Implement a decoder-only model (GPT-style) with token embeddings, learned positional encodings, stacked decoder layers, weight tying between embeddings and output projection, and greedy generation.

In [None]:
class MiniGPT(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, num_heads=4, ff_dim=128, num_layers=3, max_len=64, dropout=0.1):
        super().__init__()
        # TODO: define embeddings, positional encodings, decoder layers, tie weights

    def forward(self, tokens):
        # TODO: apply causal mask and return logits
        raise NotImplementedError

    def generate(self, prompt, max_new_tokens=20):
        # TODO: greedy generation loop
        raise NotImplementedError


In [None]:
class MiniGPT(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, num_heads=4, ff_dim=128, num_layers=3, max_len=64, dropout=0.1):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.position_embed = nn.Parameter(torch.zeros(1, max_len, embed_dim))
        nn.init.normal_(self.position_embed, mean=0.0, std=0.02)
        self.decoders = nn.ModuleList([
            DecoderLayer(embed_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.output = nn.Linear(embed_dim, vocab_size, bias=False)
        self.output.weight = self.token_embed.weight  # weight tying
        self.dropout = nn.Dropout(dropout)

    def forward(self, tokens):
        bsz, seq_len = tokens.size()
        device = tokens.device
        positions = self.position_embed[:, :seq_len]
        x = self.token_embed(tokens) + positions
        x = self.dropout(x)
        mask = torch.tril(torch.ones(1, 1, seq_len, seq_len, device=device, dtype=torch.bool))
        for layer in self.decoders:
            x = layer(x, x, tgt_mask=mask, memory_mask=None)
        x = self.norm(x)
        return self.output(x)

    @torch.no_grad()
    def generate(self, prompt, max_new_tokens=20):
        self.eval()
        generated = prompt.clone()
        for _ in range(max_new_tokens):
            logits = self.forward(generated)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            generated = torch.cat([generated, next_token], dim=1)
        return generated

model = MiniGPT(vocab_size=200)
sample = torch.randint(0, 200, (2, 12))
print(model(sample).shape)


## Further Reading

- Vaswani et al. (2017) “Attention Is All You Need”
- Annotated Transformer (Harvard NLP)
- GPT-2 Technical Report for decoder-only training heuristics
- PyTorch `nn.Transformer` tutorial for reference implementations