# Transformer Encoder

Building the Encoder Block and the full Encoder Stack.


In [None]:
import torch
import torch.nn as nn
from typing import Optional, List, Tuple

# Note: Imports are simplified for standalone generation


In [None]:
class EncoderBlock(nn.Module):
    """
    Single Transformer Encoder Block.
    
    Implements the standard encoder layer with:
    - Multi-head self-attention with residual connection
    - Feed-forward network with residual connection
    - Layer normalization (pre-norm or post-norm)
    
    Args:
        d_model: Dimension of the model
        n_heads: Number of attention heads
        d_ff: Dimension of the feed-forward hidden layer
        dropout: Dropout probability
        activation: Activation function for FFN ("gelu" or "relu")
        pre_norm: Whether to use pre-layer normalization (more stable)
        
    Example:
        >>> block = EncoderBlock(d_model=512, n_heads=8, d_ff=2048)
        >>> x = torch.randn(2, 10, 512)  # (batch, seq, d_model)
        >>> output, attn = block(x)
    """
    
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        d_ff: int,
        dropout: float = 0.1,
        activation: str = "gelu",
        pre_norm: bool = True,
        layer_norm_eps: float = 1e-6
    ):
        super().__init__()
        
        self.pre_norm = pre_norm
        
        # Multi-head self-attention
        self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
        
        # Feed-forward network
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout, activation)
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        
        # Dropout for residual connections
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        return_attention: bool = False
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Forward pass through the encoder block.
        
        Args:
            x: Input tensor of shape (batch, seq, d_model)
            mask: Optional attention mask for padding
            return_attention: Whether to return attention weights
            
        Returns:
            Tuple of:
                - Output tensor of shape (batch, seq, d_model)
                - Attention weights if return_attention=True, else None
        """
        if self.pre_norm:
            # Pre-Layer Normalization (more stable training)
            # LN -> Attention -> Residual
            attn_output, attn_weights = self._self_attention_block(
                self.norm1(x), mask, return_attention
            )
            x = x + attn_output
            
            # LN -> FFN -> Residual
            x = x + self._ff_block(self.norm2(x))
        else:
            # Post-Layer Normalization (original Transformer)
            # Attention -> Residual -> LN
            attn_output, attn_weights = self._self_attention_block(
                x, mask, return_attention
            )
            x = self.norm1(x + attn_output)
            
            # FFN -> Residual -> LN
            x = self.norm2(x + self._ff_block(x))
            
        return x, attn_weights
    
    def _self_attention_block(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor],
        return_attention: bool
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """Apply self-attention with dropout."""
        attn_output, attn_weights = self.self_attention(
            x, x, x, mask=mask, return_attention=return_attention
        )
        return self.dropout(attn_output), attn_weights
    
    def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
        """Apply feed-forward network with dropout."""
        return self.dropout(self.feed_forward(x))


In [None]:
class Encoder(nn.Module):
    """
    Transformer Encoder Stack.
    
    Combines embedding layer with multiple encoder blocks to create
    the complete encoder for processing input sequences.
    
    Args:
        vocab_size: Size of the vocabulary
        d_model: Dimension of the model
        n_heads: Number of attention heads
        n_layers: Number of encoder blocks
        d_ff: Dimension of the feed-forward hidden layer
        max_seq_len: Maximum sequence length
        dropout: Dropout probability
        activation: Activation function for FFN
        pre_norm: Whether to use pre-layer normalization
        
    Example:
        >>> encoder = Encoder(vocab_size=10000, d_model=512, n_heads=8, n_layers=6, d_ff=2048)
        >>> tokens = torch.randint(0, 10000, (2, 50))  # (batch, seq)
        >>> output, attentions = encoder(tokens, return_attention=True)
        >>> print(output.shape)  # (2, 50, 512)
    """
    
    def __init__(
        self,
        vocab_size: int,
        d_model: int,
        n_heads: int,
        n_layers: int,
        d_ff: int,
        max_seq_len: int = 512,
        dropout: float = 0.1,
        activation: str = "gelu",
        pre_norm: bool = True,
        layer_norm_eps: float = 1e-6
    ):
        super().__init__()
        
        self.d_model = d_model
        self.n_layers = n_layers
        
        # Embedding layer (token + positional)
        self.embedding = TransformerEmbedding(
            vocab_size=vocab_size,
            d_model=d_model,
            max_seq_len=max_seq_len,
            dropout=dropout
        )
        
        # Stack of encoder blocks
        self.layers = nn.ModuleList([
            EncoderBlock(
                d_model=d_model,
                n_heads=n_heads,
                d_ff=d_ff,
                dropout=dropout,
                activation=activation,
                pre_norm=pre_norm,
                layer_norm_eps=layer_norm_eps
            )
            for _ in range(n_layers)
        ])
        
        # Final layer normalization (only for pre-norm)
        self.final_norm = nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_norm else None
        
    def forward(
        self,
        src: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        return_attention: bool = False
    ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
        """
        Encode input sequence.
        
        Args:
            src: Source token indices of shape (batch, seq)
            mask: Optional padding mask of shape (batch, 1, 1, seq)
                 True values indicate positions to mask (not attend to)
            return_attention: Whether to return attention weights from all layers
            
        Returns:
            Tuple of:
                - Encoded representation of shape (batch, seq, d_model)
                - List of attention weights from each layer (if return_attention=True)
        """
        attention_weights = [] if return_attention else None
        
        # Apply embeddings
        x = self.embedding(src)
        
        # Pass through encoder blocks
        for layer in self.layers:
            x, attn = layer(x, mask=mask, return_attention=return_attention)
            if return_attention and attn is not None:
                attention_weights.append(attn)
        
        # Apply final layer normalization
        if self.final_norm is not None:
            x = self.final_norm(x)
            
        return x, attention_weights
    
    def get_embedding(self, src: torch.Tensor) -> torch.Tensor:
        """Get just the embeddings without passing through layers."""
        return self.embedding(src)
