In [131]:
import jax
import jax.numpy as jnp
import numpy as np
import math
from flax import linen as nn
from flax.linen.initializers import constant, ones, zeros

Word Embeddings

In [132]:
class Embedding(nn.Module):
    vocab_size: int
    emb_dim: int

    def setup(self):
        self.embed = nn.Embed(self.vocab_size, self.emb_dim)

    def __call__(self, x):
        return self.embed(x)

Positional Encoding



In [133]:
class PositionalEmbedding(nn.Module):
    max_seq_len: int
    embed_dim: int

    def setup(self):
        # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
        pe = np.zeros((self.max_seq_len, self.embed_dim))
        position = np.arange(0, self.max_seq_len, dtype=np.float32)[:,None]
        div_term = np.exp(np.arange(0, self.embed_dim, 2) * (-math.log(10000.0) / self.embed_dim))
        pe[:, 0::2] = np.sin(position * div_term)
        pe[:, 1::2] = np.cos(position * div_term)
        pe = pe[None]
        self.pe = jax.device_put(pe)

    def __call__(self, x):
        x = x + self.pe[:, :x.shape[1]]
        return x

Self Attention

In [134]:
class MultiHeadAttention(nn.Module):
    """Multi-head attention mechanism as described in 'Attention Is All You Need'.

    Allows the model to jointly attend to information from different representation subspaces.
    """
    embed_dim: int = 512  # Total dimension of the model
    num_heads: int = 8    # Number of attention heads

    def setup(self):
        # Ensure the dimension is divisible by the number of heads
        assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
        self.head_dim = self.embed_dim // self.num_heads

        # Linear projections for a query, key, and value
        # All have output dimension = embed_dim, allowing for parallel computation of heads
        self.query = nn.Dense(self.embed_dim, use_bias=False)
        self.key = nn.Dense(self.embed_dim, use_bias=False)
        self.value = nn.Dense(self.embed_dim, use_bias=False)

        # Final output projection
        self.out_proj = nn.Dense(self.embed_dim)

    def __call__(self, key, query, value, mask=None):
        """Forward pass for multi-head attention.

        Args:
            key: Key tensor of shape [batch_size, seq_len_k, embed_dim]
            query: Query tensor of shape [batch_size, seq_len_q, embed_dim]
            value: Value tensor of shape [batch_size, seq_len_k, embed_dim]
            mask: Optional mask tensor for masked attention

        Returns:
            Output tensor of shape [batch_size, seq_len_q, embed_dim]
        """
        batch_size = query.shape[0]
        seq_len_q = query.shape[1]
        seq_len_k = key.shape[1]

        # 1. Linear projections and reshape in one computational block
        q = self.query(query).reshape(batch_size, seq_len_q, self.num_heads, self.head_dim)
        k = self.key(key).reshape(batch_size, seq_len_k, self.num_heads, self.head_dim)
        v = self.value(value).reshape(batch_size, seq_len_k, self.num_heads, self.head_dim)

        # 2. Transpose to [batch_size, num_heads, seq_len, head_dim]
        q = q.transpose(0, 2, 1, 3)  # Shape: [B, H, Lq, D]
        k = k.transpose(0, 2, 1, 3)  # Shape: [B, H, Lk, D]
        v = v.transpose(0, 2, 1, 3)  # Shape: [B, H, Lk, D]

        # 3. Compute scaled dot-product attention
        # Matmul q and k, and scale
        scale = jnp.sqrt(self.head_dim)
        attention_scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) / scale  # [B, H, Lq, Lk]

        # 4. Apply attention mask if provided
        if mask is not None:
            # A large negative value in softmax becomes ~0 after normalization
            attention_scores = jnp.where(mask == 0, jnp.finfo(attention_scores.dtype).min, attention_scores)

        # 5. Apply softmax to get attention weights
        attention_weights = nn.softmax(attention_scores, axis=-1)  # [B, H, Lq, Lk]

        # 6. Apply attention weights to values
        context = jnp.matmul(attention_weights, v)  # [B, H, Lq, D]

        # 7. Reshape back to the original sequence length and embedding dimension
        context = context.transpose(0, 2, 1, 3).reshape(batch_size, seq_len_q, self.embed_dim)

        # 8. Final linear projection
        output = self.out_proj(context)

        return output

DyT

In [135]:
class DyT(nn.Module):
    """Dynamic Tanh (DyT) normalization layer.

    This layer performs normalization using a learnable scaled tanh activation
    followed by an affine transformation (weight * x + bias).

    Attributes:
        num_features: Number of features/channels to normalize
        alpha_init: Initial value for the tanh scaling parameter (default: 0.5)
    """
    num_features: int
    alpha_init: float = 0.5

    def setup(self):
        # Learnable scaling parameter for tanh activation
        self.alpha = self.param('alpha', constant(self.alpha_init), ())  # Scalar parameter

        # Affine transformation parameters (similar to BatchNorm)
        self.weight = self.param('weight', ones, (self.num_features,))    # Scale parameter
        self.bias = self.param('bias', zeros, (self.num_features,))       # Shift parameter

    def __call__(self, x):
        # Apply scaled tanh activation
        # constrains values between -1 and 1 while alpha controls steepness
        normalized = nn.tanh(self.alpha * x)

        # Apply affine transformation (channel-wise scaling and shifting)
        return normalized * self.weight + self.bias

Encoder

In [136]:
class TransformerBlock(nn.Module):
    """Transformer block with Dynamic Tanh (DyT) normalization instead of LayerNorm.

    Implements a standard transformer block with multi-head attention followed by
    a feed-forward network, using DyT for normalization and residual connections.
    """
    embed_dim: int
    expansion_factor: int = 4
    n_heads: int = 8

    def setup(self):
        # Multi-head self-attention mechanism
        self.attention = MultiHeadAttention(self.embed_dim, self.n_heads)

        # DyT normalization layers (alternative to LayerNorm)
        self.dyt1 = DyT(self.embed_dim)  # After attention block
        self.dyt2 = DyT(self.embed_dim)  # After feed-forward block

        # Position-wise feed-forward network
        self.feed_forward = nn.Sequential([
                nn.Dense(self.embed_dim * self.expansion_factor),
                nn.relu,
                nn.Dense(self.embed_dim)
        ])

        # Dropout layers for regularization
        self.dropout1 = nn.Dropout(rate=0.2)
        self.dropout2 = nn.Dropout(rate=0.2)

    def __call__(self, key, query, value, deterministic=False):
        # 1. Multi-head attention sublayer
        attention_out = self.attention(key, query, value)
        attention_residual_out = attention_out + query  # Residual connection

        # 2. Normalization and dropout
        dyt1_out = self.dropout1(self.dyt1(attention_residual_out), deterministic=deterministic)

        # 3. Feed-forward sublayer
        feed_forward_out = self.feed_forward(dyt1_out)
        feed_forward_residual_out = feed_forward_out + dyt1_out  # Residual connection

        # 4. Final normalization and dropout
        output = self.dropout2(self.dyt2(feed_forward_residual_out), deterministic=deterministic)

        return output

In [137]:
class TransformerEncoder(nn.Module):
    """Transformer encoder with multiple transformer blocks.

    Processes input sequences through word embeddings, positional encodings,
    and multiple transformer layers for contextual representation.
    """
    seq_len: int         # Maximum sequence length supported
    vocab_size: int      # Size of input vocabulary
    embed_dim: int       # Dimension of embeddings
    num_layers: int = 2  # Number of transformer blocks
    expansion_factor: int = 4  # Expansion factor for feed-forward network
    n_heads: int = 8     # Number of attention heads

    def setup(self):
        # Initialize embedding layers
        self.embedding = Embedding(self.vocab_size, self.embed_dim)
        self.positional_embedding = PositionalEmbedding(self.seq_len, self.embed_dim)

        # Create a stack of transformer blocks
        self.layers = [
                TransformerBlock(
                        embed_dim=self.embed_dim,
                        expansion_factor=self.expansion_factor,
                        n_heads=self.n_heads
                )
                for _ in range(self.num_layers)
        ]

    def __call__(self, x, deterministic=False):
        """Forward pass through the encoder.

        Args:
            x: Input tensor of token indices [batch_size, seq_len]
            deterministic: Whether to run in deterministic mode (no dropout)

        Returns:
            Encoded representation [batch_size, seq_len, embed_dim]
        """
        # Convert tokens to embeddings
        x = self.embedding(x)

        # Add positional information
        x = self.positional_embedding(x)

        # Pass through each transformer block sequentially
        # Using the same tensor for key, query and value (self-attention)
        for layer in self.layers:
            x = layer(key=x, query=x, value=x, deterministic=deterministic)

        return x

Decoder

In [138]:
class DecoderBlock(nn.Module):
    """Transformer decoder block using DyT normalization.

    Performs masked self-attention on decoder inputs, followed by
    cross-attention with encoder outputs and a feed-forward network.
    """
    embed_dim: int         # Dimension of embeddings
    expansion_factor: int = 4  # Expansion factor for feed-forward network
    n_heads: int = 8      # Number of attention heads

    def setup(self):
        # First attention sub-layer (masked self-attention)
        self.self_attention = MultiHeadAttention(self.embed_dim, self.n_heads)
        self.dyt1 = DyT(self.embed_dim)
        self.dropout1 = nn.Dropout(rate=0.2)

        # Second sublayer (cross-attention + feed-forward network)
        # Reusing TransformerBlock for cross-attention and feed-forward
        self.cross_attention_block = TransformerBlock(
                self.embed_dim,
                self.expansion_factor,
                self.n_heads
        )

    def __call__(self, encoder_output, x, mask=None, deterministic=False):
        """Forward pass through decoder block.

        Args:
            encoder_output: Output from encoder [batch_size, seq_len, embed_dim]
            x: Decoder input [batch_size, seq_len, embed_dim]
            mask: Optional causal mask for self-attention
            deterministic: Whether to disable dropout

        Returns:
            Decoder block output [batch_size, seq_len, embed_dim]
        """
        # 1. Self-attention with a causal mask
        attn_output = self.self_attention(x, x, x, mask=mask)

        # 2. Residual connection + DyT normalization + dropout
        norm_output = self.dropout1(
                self.dyt1(attn_output + x),  # Add & normalize
                deterministic=deterministic
        )

        # 3. Cross-attention with encoder output + feed-forward network
        # TransformerBlock handles cross-attention, normalization, and feed-forward
        output = self.cross_attention_block(
                key=encoder_output,    # Key from encoder
                query=encoder_output,  # Query from encoder
                value=norm_output,     # Value from the first sublayer
                deterministic=deterministic
        )

        return output

In [139]:
class TransformerDecoder(nn.Module):
    """Transformer decoder that processes target sequences with the help of encoder output.

    Takes encoded source sequence representations and target tokens to produce
    probability distributions over the target vocabulary.
    """
    target_vocab_size: int  # Size of target vocabulary
    embed_dim: int          # Dimension of embeddings
    seq_len: int            # Maximum sequence length
    num_layers: int = 2     # Number of decoder layers
    expansion_factor: int = 4  # Expansion factor for feed-forward network
    n_heads: int = 8        # Number of attention heads

    def setup(self):
        # Token embedding layer converts token IDs to vectors
        self.word_embedding = nn.Embed(self.target_vocab_size, self.embed_dim)

        # Add positional information to embeddings
        self.positional_embedding = PositionalEmbedding(self.seq_len, self.embed_dim)

        # Stack of decoder blocks for iterative refinement
        self.layers = [
                DecoderBlock(
                        self.embed_dim,
                        self.expansion_factor,
                        self.n_heads
                ) for _ in range(self.num_layers)
        ]

        # Final projection to vocabulary size
        self.fc_out = nn.Dense(self.target_vocab_size)

        # Dropout for regularization
        self.dropout = nn.Dropout(rate=0.2)

    def __call__(self, x, enc_out, mask=None, deterministic=False):
        """Forward pass through the decoder.

        Args:
            x: Target token IDs [batch_size, seq_len]
            enc_out: Encoder output [batch_size, src_seq_len, embed_dim]
            mask: Optional causal mask for self-attention
            deterministic: Whether to disable dropout

        Returns:
            Token probability distributions [batch_size, seq_len, target_vocab_size]
        """
        # Extract current batch and sequence dimensions
        batch_size, seq_len = x.shape

        # Convert token IDs to embeddings and add positional information
        x = self.word_embedding(x)
        x = self.positional_embedding(x)
        x = self.dropout(x, deterministic=deterministic)

        # Process through each decoder layer
        for layer in self.layers:
            # Adapt mask to the current sequence length to prevent shape errors
            current_mask = None
            if mask is not None:
                current_mask = mask[:, :, :seq_len, :seq_len]

            # Pass through the decoder layer with cross-attention to the encoder output
            x = layer(enc_out, x, mask=current_mask, deterministic=deterministic)

        # Project to vocabulary size and apply softmax for probabilities
        return nn.softmax(self.fc_out(x), axis=-1)

In [140]:
class Transformer(nn.Module):
    """Complete Transformer model with encoder and decoder stacks.

    Implements the full Transformer architecture as described in
    "Attention Is All You Need" with DyT normalization instead of LayerNorm.
    """
    embed_dim: int        # Dimension of token embeddings
    src_vocab_size: int   # Size of source vocabulary
    tgt_vocab_size: int   # Size of target vocabulary
    seq_len: int          # Maximum sequence length
    num_layers: int = 2   # Number of encoder/decoder layers
    expansion_factor: int = 4  # Expansion factor in feed-forward networks
    n_heads: int = 8      # Number of attention heads

    def setup(self):
        # Initialize encoder and decoder components
        self.encoder = TransformerEncoder(
                self.seq_len,
                self.src_vocab_size,
                self.embed_dim,
                self.num_layers,
                self.expansion_factor,
                self.n_heads
        )

        self.decoder = TransformerDecoder(
                self.tgt_vocab_size,
                self.embed_dim,
                self.seq_len,
                self.num_layers,
                self.expansion_factor,
                self.n_heads
        )

    def make_trg_mask(self, trg):
        """Create a causal attention mask for the decoder.

        Ensures that predictions at position I can only attend to known
        outputs at positions less than i.

        Args:
            trg: Target sequence tensor of shape [batch_size, seq_len]

        Returns:
            Mask tensor of shape [batch_size, 1, seq_len, seq_len]
        """
        batch_size, trg_len = trg.shape
        # Lower triangular matrix (future tokens are masked)
        trg_mask = jnp.tril(jnp.ones((trg_len, trg_len)))
        # Add batch and head dimensions [batch_size, 1, trg_len, trg_len]
        return jnp.broadcast_to(trg_mask[None, None, :, :], (batch_size, 1, trg_len, trg_len))

    def decode(self, src, trg, deterministic=False):
        """Autoregressive decoding for inference.

        Takes source sequence and initial target tokens, then generates
        output sequence one token at a time.

        Args:
            src: Source sequence [batch_size, seq_len]
            trg: Initial target sequence [batch_size, initial_len]
            deterministic: Whether to disable dropout

        Returns:
            List of output token IDs for each position
        """
        # Create an attention mask and encode a source sequence (done once)
        trg_mask = self.make_trg_mask(trg)
        enc_out = self.encoder(src, deterministic=deterministic)

        batch_size, seq_len = src.shape
        out_labels = []
        out = trg  # Start with the provided target tokens

        # Generate tokens sequentially
        for _ in range(seq_len):
            # Get probability distribution over vocabulary
            decoder_output = self.decoder(out, enc_out, mask=trg_mask, deterministic=deterministic)

            # Extract prediction for the last position
            last_token_logits = decoder_output[:, -1, :]
            predicted_token = jnp.argmax(last_token_logits, axis=-1)
            out_labels.append(predicted_token)

            # Append prediction to sequence for the next iteration
            predicted_token = predicted_token.reshape((batch_size, 1))
            out = jnp.concatenate([out, predicted_token], axis=1)

            # Update mask for the new sequence length
            trg_mask = self.make_trg_mask(out)

        return out_labels

    def __call__(self, src, trg, deterministic=False):
        """Forward pass for training.

        Process source and target sequences in parallel through encoder and decoder.

        Args:
            src: Source sequence [batch_size, seq_len]
            trg: Target sequence [batch_size, seq_len]
            deterministic: Whether to disable dropout

        Returns:
            Output logits [batch_size, seq_len, tgt_vocab_size]
        """
        trg_mask = self.make_trg_mask(trg)
        enc_out = self.encoder(src, deterministic=deterministic)
        return self.decoder(trg, enc_out, mask=trg_mask, deterministic=deterministic)

In [145]:
# Test Transformer

src_vocab_size = 11
target_vocab_size = 11
num_layers = 6
seq_length= 12
embed_dim = 512
expansion_factor = 4
n_heads = 8
transformer = Transformer(embed_dim, src_vocab_size, target_vocab_size, seq_length, num_layers, expansion_factor, n_heads)
params = transformer.init(jax.random.PRNGKey(0), jnp.ones((1, seq_length), dtype=jnp.int32), jnp.ones((1, seq_length), dtype=jnp.int32), deterministic=True)
out = transformer.apply(params, jnp.ones((1, seq_length), dtype=jnp.int32), jnp.ones((1, seq_length), dtype=jnp.int32), deterministic=True)

In [142]:
src = jnp.array([[0, 2, 5, 6, 4, 3, 9, 5, 2, 9, 10, 1],
                 [0, 2, 8, 7, 3, 4, 5, 6, 7, 2, 10, 1]])
target = jnp.array([[0, 1, 7, 4, 3, 5, 9, 2, 8, 10, 9, 1],
                    [0, 1, 5, 6, 2, 4, 7, 6, 2, 8, 10, 1]])

In [143]:
out = transformer.apply(params, src, target, deterministic=True)

In [146]:
out.shape

(1, 12, 11)