# 11 transformers attention keras
**Location: TensorVerseHub/notebooks/04_natural_language_processing/11_transformers_attention_keras.ipynb**

In [None]:
import tensorflow as tf
import numpy as np
print(f"TensorFlow version: {tf.__version__}")

# Transformers & Attention with tf.keras

**File Location:** `notebooks/04_natural_language_processing/11_transformers_attention_keras.ipynb`

Master Transformer architectures using tf.keras.layers.MultiHeadAttention, implement self-attention mechanisms, positional encoding, and build BERT-style models. Create state-of-the-art NLP models for text classification, language modeling, and sequence tasks.

## Learning Objectives
- Implement Transformer architecture with tf.keras.layers.MultiHeadAttention
- Build positional encoding and attention mechanisms from scratch
- Create encoder-only, decoder-only, and encoder-decoder Transformers
- Apply attention to various NLP tasks (classification, generation)
- Optimize Transformer models for production deployment
- Handle long sequences and memory-efficient attention

---

## 1. Foundation: Understanding Attention Mechanisms

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import math
import warnings
warnings.filterwarnings('ignore')

print(f"TensorFlow version: {tf.__version__}")
tf.random.set_seed(42)
np.random.seed(42)

# Understanding the core attention mechanism
def scaled_dot_product_attention(queries, keys, values, mask=None):
    """
    Scaled dot-product attention mechanism - the heart of transformers
    
    Args:
        queries: Query matrix (batch_size, seq_len_q, d_k)
        keys: Key matrix (batch_size, seq_len_k, d_k)
        values: Value matrix (batch_size, seq_len_v, d_v)
        mask: Optional mask to prevent attention to certain positions
    
    Returns:
        output: Attention output (batch_size, seq_len_q, d_v)
        attention_weights: Attention weights for visualization
    """
    
    # Step 1: Calculate attention scores (Q * K^T)
    scores = tf.matmul(queries, keys, transpose_b=True)
    
    # Step 2: Scale by sqrt(d_k) to prevent softmax saturation
    dk = tf.cast(tf.shape(keys)[-1], tf.float32)
    scaled_scores = scores / tf.math.sqrt(dk)
    
    # Step 3: Apply mask if provided (set masked positions to large negative value)
    if mask is not None:
        scaled_scores += (mask * -1e9)
    
    # Step 4: Apply softmax to get attention weights
    attention_weights = tf.nn.softmax(scaled_scores, axis=-1)
    
    # Step 5: Apply attention weights to values
    output = tf.matmul(attention_weights, values)
    
    return output, attention_weights

# Demonstrate attention with a simple example
print("=== Attention Mechanism Demonstration ===")

# Create sample sequences
batch_size, seq_len, d_model = 2, 5, 8
queries = tf.random.normal((batch_size, seq_len, d_model))
keys = tf.random.normal((batch_size, seq_len, d_model))
values = tf.random.normal((batch_size, seq_len, d_model))

# Apply attention
attention_output, attention_weights = scaled_dot_product_attention(queries, keys, values)

print(f"Input shapes - Q: {queries.shape}, K: {keys.shape}, V: {values.shape}")
print(f"Attention output shape: {attention_output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")

# Visualize attention pattern
def plot_attention_weights(attention_weights, title="Attention Weights"):
    """Visualize attention weights as a heatmap"""
    
    # Take first sample, average across heads if multi-head
    if len(attention_weights.shape) == 4:
        weights = attention_weights[0, 0].numpy()  # First sample, first head
    else:
        weights = attention_weights[0].numpy()  # First sample
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(weights, annot=True, fmt='.2f', cmap='Blues', 
                xticklabels=[f'Key_{i}' for i in range(weights.shape[1])],
                yticklabels=[f'Query_{i}' for i in range(weights.shape[0])])
    plt.title(title)
    plt.xlabel('Key Positions')
    plt.ylabel('Query Positions')
    plt.tight_layout()
    plt.show()

plot_attention_weights(attention_weights, "Basic Attention Pattern")

## 2. Multi-Head Attention Implementation

In [None]:
# Custom Multi-Head Attention Layer
class CustomMultiHeadAttention(tf.keras.layers.Layer):
    """
    Custom implementation of Multi-Head Attention to understand the internals
    """
    
    def __init__(self, embed_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        
        self.depth = embed_dim // num_heads
        
        # Linear transformations for Q, K, V
        self.query_dense = tf.keras.layers.Dense(embed_dim, name='query_projection')
        self.key_dense = tf.keras.layers.Dense(embed_dim, name='key_projection')
        self.value_dense = tf.keras.layers.Dense(embed_dim, name='value_projection')
        
        # Output projection
        self.output_dense = tf.keras.layers.Dense(embed_dim, name='output_projection')
        
    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth)"""
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    def call(self, inputs, attention_mask=None, training=None):
        batch_size = tf.shape(inputs)[0]
        
        # Linear transformations and head splitting
        queries = self.query_dense(inputs)
        keys = self.key_dense(inputs)
        values = self.value_dense(inputs)
        
        queries = self.split_heads(queries, batch_size)  # (batch, heads, seq_len, depth)
        keys = self.split_heads(keys, batch_size)
        values = self.split_heads(values, batch_size)
        
        # Scaled dot-product attention
        attention_output, attention_weights = scaled_dot_product_attention(
            queries, keys, values, attention_mask
        )
        
        # Concatenate heads
        attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(attention_output, 
                                    (batch_size, -1, self.embed_dim))
        
        # Final linear transformation
        output = self.output_dense(concat_attention)
        
        return output, attention_weights
    
    def get_config(self):
        config = super().get_config()
        config.update({
            'embed_dim': self.embed_dim,
            'num_heads': self.num_heads
        })
        return config

# Compare custom implementation with tf.keras version
print("=== Multi-Head Attention Comparison ===")

embed_dim, num_heads = 64, 8
sample_input = tf.random.normal((2, 10, embed_dim))

# Custom implementation
custom_mha = CustomMultiHeadAttention(embed_dim, num_heads)
custom_output, custom_weights = custom_mha(sample_input)

# tf.keras implementation
keras_mha = tf.keras.layers.MultiHeadAttention(
    num_heads=num_heads, 
    key_dim=embed_dim//num_heads
)
keras_output = keras_mha(sample_input, sample_input)

print(f"Custom MHA output shape: {custom_output.shape}")
print(f"Keras MHA output shape: {keras_output.shape}")
print(f"Custom MHA parameters: {custom_mha.count_params():,}")
print(f"Keras MHA parameters: {keras_mha.count_params():,}")

# Visualize multi-head attention patterns
if custom_weights.shape[1] >= 4:  # If we have at least 4 heads
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    axes = axes.flatten()
    
    for head in range(4):
        weights = custom_weights[0, head].numpy()  # First sample, specific head
        ax = axes[head]
        sns.heatmap(weights, annot=False, cmap='Blues', ax=ax, cbar=True)
        ax.set_title(f'Head {head + 1}')
        ax.set_xlabel('Key Position')
        ax.set_ylabel('Query Position')
    
    plt.suptitle('Multi-Head Attention Patterns', fontsize=16)
    plt.tight_layout()
    plt.show()

## 3. Positional Encoding

In [None]:
# Positional Encoding Implementation
class PositionalEncoding(tf.keras.layers.Layer):
    """
    Sinusoidal positional encoding for transformers
    Since transformers have no inherent notion of position, we add position information
    """
    
    def __init__(self, max_position=10000, **kwargs):
        super().__init__(**kwargs)
        self.max_position = max_position
        
    def build(self, input_shape):
        self.seq_len = input_shape[1]
        self.embed_dim = input_shape[2]
        
        # Create position encodings
        position = tf.range(self.seq_len, dtype=tf.float32)[:, tf.newaxis]
        div_term = tf.exp(tf.range(0, self.embed_dim, 2, dtype=tf.float32) * 
                         -(tf.math.log(10000.0) / self.embed_dim))
        
        # Apply sin to even indices, cos to odd indices
        pos_encoding = tf.zeros((self.seq_len, self.embed_dim))
        
        # Create the encoding matrix
        even_indices = tf.range(0, self.embed_dim, 2)
        odd_indices = tf.range(1, self.embed_dim, 2)
        
        sin_encoding = tf.sin(position * div_term)
        cos_encoding = tf.cos(position * div_term)
        
        # Interleave sin and cos
        encoding_matrix = tf.zeros((self.seq_len, self.embed_dim))
        
        # Use scatter_nd to place sin values at even indices
        sin_indices = tf.stack([
            tf.repeat(tf.range(self.seq_len), len(even_indices)),
            tf.tile(even_indices, [self.seq_len])
        ], axis=1)
        sin_values = tf.reshape(sin_encoding, [-1])
        
        encoding_matrix = tf.tensor_scatter_nd_add(encoding_matrix, sin_indices, sin_values)
        
        # Use scatter_nd to place cos values at odd indices
        if len(odd_indices) > 0:
            cos_indices = tf.stack([
                tf.repeat(tf.range(self.seq_len), len(odd_indices)),
                tf.tile(odd_indices, [self.seq_len])
            ], axis=1)
            cos_values = tf.reshape(cos_encoding[:, :len(odd_indices)], [-1])
            
            encoding_matrix = tf.tensor_scatter_nd_add(encoding_matrix, cos_indices, cos_values)
        
        self.pos_encoding = tf.Variable(
            encoding_matrix[tf.newaxis, :, :], 
            trainable=False, 
            name="positional_encoding"
        )
        
        super().build(input_shape)
    
    def call(self, inputs):
        seq_len = tf.shape(inputs)[1]
        return inputs + self.pos_encoding[:, :seq_len, :]
    
    def get_config(self):
        config = super().get_config()
        config.update({'max_position': self.max_position})
        return config

# Visualize positional encodings
print("=== Positional Encoding Visualization ===")

# Create sample embeddings
sample_seq_len, sample_embed_dim = 50, 128
sample_embeddings = tf.zeros((1, sample_seq_len, sample_embed_dim))

# Apply positional encoding
pos_encoder = PositionalEncoding()
pos_encoded = pos_encoder(sample_embeddings)

# Extract the positional encoding for visualization
pos_encoding_matrix = pos_encoder.pos_encoding[0].numpy()

# Plot positional encoding patterns
plt.figure(figsize=(15, 8))

# Plot 1: Heatmap of positional encodings
plt.subplot(2, 2, 1)
sns.heatmap(pos_encoding_matrix[:30, :64].T, cmap='RdBu', center=0, 
            xticklabels=False, yticklabels=False)
plt.title('Positional Encoding Heatmap\n(First 30 positions, 64 dimensions)')
plt.xlabel('Position')
plt.ylabel('Encoding Dimension')

# Plot 2: Sinusoidal patterns for specific dimensions
plt.subplot(2, 2, 2)
dims_to_plot = [0, 1, 4, 8, 16, 32]
for dim in dims_to_plot:
    plt.plot(pos_encoding_matrix[:, dim], label=f'Dim {dim}')
plt.title('Positional Encoding Patterns')
plt.xlabel('Position')
plt.ylabel('Encoding Value')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 3: Frequency analysis
plt.subplot(2, 2, 3)
frequencies = []
for i in range(0, sample_embed_dim, 2):
    freq = 1 / (10000 ** (i / sample_embed_dim))
    frequencies.append(freq)

plt.semilogy(range(0, sample_embed_dim, 2), frequencies)
plt.title('Frequency Spectrum of Positional Encoding')
plt.xlabel('Dimension Index (Even)')
plt.ylabel('Frequency (log scale)')
plt.grid(True, alpha=0.3)

# Plot 4: Position similarity (dot product)
plt.subplot(2, 2, 4)
pos_similarities = np.dot(pos_encoding_matrix, pos_encoding_matrix.T)
sns.heatmap(pos_similarities[:20, :20], annot=False, cmap='coolwarm', center=0)
plt.title('Position Similarity Matrix\n(First 20 positions)')
plt.xlabel('Position')
plt.ylabel('Position')

plt.tight_layout()
plt.show()

## 4. Complete Transformer Building Blocks

In [None]:
# Transformer Encoder Block
class TransformerEncoderBlock(tf.keras.layers.Layer):
    """
    Complete transformer encoder block with:
    - Multi-head self-attention
    - Feed-forward network
    - Residual connections
    - Layer normalization
    """
    
    def __init__(self, embed_dim, num_heads, ff_dim, dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ff_dim = ff_dim
        self.dropout_rate = dropout_rate
        
        # Multi-head attention layer
        self.attention = tf.keras.layers.MultiHeadAttention(
            num_heads=num_heads, 
            key_dim=embed_dim // num_heads,
            dropout=dropout_rate,
            name='multi_head_attention'
        )
        
        # Feed-forward network
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(ff_dim, activation='relu', name='ffn_dense1'),
            tf.keras.layers.Dropout(dropout_rate, name='ffn_dropout'),
            tf.keras.layers.Dense(embed_dim, name='ffn_dense2')
        ], name='feed_forward')
        
        # Layer normalization layers
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6, name='layernorm1')
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6, name='layernorm2')
        
        # Dropout layers
        self.dropout1 = tf.keras.layers.Dropout(dropout_rate, name='dropout1')
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate, name='dropout2')
    
    def call(self, inputs, training=None, mask=None):
        # Multi-head attention with residual connection
        attention_output = self.attention(
            inputs, inputs, attention_mask=mask, training=training
        )
        attention_output = self.dropout1(attention_output, training=training)
        out1 = self.layernorm1(inputs + attention_output)  # Residual connection
        
        # Feed-forward network with residual connection
        ffn_output = self.ffn(out1, training=training)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output)  # Residual connection
        
        return out2
    
    def get_config(self):
        config = super().get_config()
        config.update({
            'embed_dim': self.embed_dim,
            'num_heads': self.num_heads,
            'ff_dim': self.ff_dim,
            'dropout_rate': self.dropout_rate
        })
        return config

# Transformer Decoder Block
class TransformerDecoderBlock(tf.keras.layers.Layer):
    """
    Complete transformer decoder block with:
    - Masked multi-head self-attention
    - Cross-attention (encoder-decoder attention)
    - Feed-forward network
    - Residual connections and layer normalization
    """
    
    def __init__(self, embed_dim, num_heads, ff_dim, dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ff_dim = ff_dim
        self.dropout_rate = dropout_rate
        
        # Self-attention (with look-ahead mask)
        self.self_attention = tf.keras.layers.MultiHeadAttention(
            num_heads=num_heads, 
            key_dim=embed_dim // num_heads,
            dropout=dropout_rate,
            name='self_attention'
        )
        
        # Cross-attention (encoder-decoder attention)
        self.cross_attention = tf.keras.layers.MultiHeadAttention(
            num_heads=num_heads,
            key_dim=embed_dim // num_heads,
            dropout=dropout_rate,
            name='cross_attention'
        )
        
        # Feed-forward network
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(ff_dim, activation='relu'),
            tf.keras.layers.Dropout(dropout_rate),
            tf.keras.layers.Dense(embed_dim)
        ], name='feed_forward')
        
        # Layer normalization layers
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        
        # Dropout layers
        self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout3 = tf.keras.layers.Dropout(dropout_rate)
    
    def call(self, inputs, encoder_outputs, training=None, 
             look_ahead_mask=None, padding_mask=None):
        
        # Masked self-attention
        self_attn_output = self.self_attention(
            inputs, inputs, attention_mask=look_ahead_mask, training=training
        )
        self_attn_output = self.dropout1(self_attn_output, training=training)
        out1 = self.layernorm1(inputs + self_attn_output)
        
        # Cross-attention
        cross_attn_output = self.cross_attention(
            out1, encoder_outputs, attention_mask=padding_mask, training=training
        )
        cross_attn_output = self.dropout2(cross_attn_output, training=training)
        out2 = self.layernorm2(out1 + cross_attn_output)
        
        # Feed-forward network
        ffn_output = self.ffn(out2, training=training)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layernorm3(out2 + ffn_output)
        
        return out3
    
    def get_config(self):
        config = super().get_config()
        config.update({
            'embed_dim': self.embed_dim,
            'num_heads': self.num_heads,
            'ff_dim': self.ff_dim,
            'dropout_rate': self.dropout_rate
        })
        return config

# Attention mask utilities
def create_look_ahead_mask(size):
    """Create look-ahead mask for decoder self-attention"""
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask

def create_padding_mask(seq, pad_token=0):
    """Create padding mask for attention"""
    seq = tf.cast(tf.math.equal(seq, pad_token), tf.float32)
    return seq[:, tf.newaxis, tf.newaxis, :]

# Test transformer building blocks
print("=== Testing Transformer Building Blocks ===")

# Create sample data
batch_size, seq_len, embed_dim = 2, 10, 128
num_heads, ff_dim = 8, 256

sample_input = tf.random.normal((batch_size, seq_len, embed_dim))

# Test encoder block
encoder_block = TransformerEncoderBlock(embed_dim, num_heads, ff_dim, dropout_rate=0.1)
encoder_output = encoder_block(sample_input, training=True)

print(f"Encoder Block:")
print(f"  Input shape: {sample_input.shape}")
print(f"  Output shape: {encoder_output.shape}")
print(f"  Parameters: {encoder_block.count_params():,}")

# Test decoder block
decoder_block = TransformerDecoderBlock(embed_dim, num_heads, ff_dim, dropout_rate=0.1)

# Create masks for decoder
look_ahead_mask = create_look_ahead_mask(seq_len)
sample_encoder_output = tf.random.normal((batch_size, seq_len, embed_dim))

decoder_output = decoder_block(
    sample_input, sample_encoder_output,
    look_ahead_mask=look_ahead_mask,
    training=True
)

print(f"\nDecoder Block:")
print(f"  Input shape: {sample_input.shape}")
print(f"  Encoder output shape: {sample_encoder_output.shape}")
print(f"  Decoder output shape: {decoder_output.shape}")
print(f"  Parameters: {decoder_block.count_params():,}")

# Visualize look-ahead mask
plt.figure(figsize=(8, 6))
sns.heatmap(look_ahead_mask.numpy(), annot=True, fmt='.0f', cmap='Reds',
            xticklabels=[f'Pos {i}' for i in range(seq_len)],
            yticklabels=[f'Pos {i}' for i in range(seq_len)])
plt.title('Look-Ahead Mask for Decoder\n(1 = masked, 0 = allowed)')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.show()

## 5. Complete Transformer Models

In [None]:
# BERT-style Encoder-only Transformer
class TransformerClassifier(tf.keras.Model):
    """
    BERT-style transformer for text classification
    Uses only encoder blocks with [CLS] token classification
    """
    
    def __init__(self, vocab_size, embed_dim=128, num_heads=8, ff_dim=256,
                 num_layers=4, num_classes=3, max_length=512, dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.max_length = max_length
        
        # Embedding layers
        self.embedding = tf.keras.layers.Embedding(
            vocab_size, embed_dim, name='token_embedding'
        )
        self.pos_encoding = PositionalEncoding(max_position=max_length)
        
        # Transformer encoder layers
        self.encoder_layers = [
            TransformerEncoderBlock(embed_dim, num_heads, ff_dim, dropout_rate, name=f'encoder_{i}')
            for i in range(num_layers)
        ]
        
        # Classification head
        self.global_pool = tf.keras.layers.GlobalAveragePooling1D(name='global_pool')
        self.dropout = tf.keras.layers.Dropout(dropout_rate, name='final_dropout')
        self.classifier = tf.keras.layers.Dense(
            num_classes, activation='softmax', name='classifier'
        )
        
    def call(self, inputs, training=None, mask=None):
        # Token embeddings
        x = self.embedding(inputs)
        
        # Scale embeddings (common in transformer literature)
        x *= tf.math.sqrt(tf.cast(self.embed_dim, tf.float32))
        
        # Add positional encoding
        x = self.pos_encoding(x)
        
        # Create attention mask if not provided
        if mask is None:
            mask = create_padding_mask(inputs)
        
        # Pass through encoder layers
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x, training=training, mask=mask)
        
        # Classification
        # Use mask for global pooling to ignore padded tokens
        if mask is not None:
            mask_for_pool = tf.squeeze(mask, axis=[1, 2])  # Remove extra dimensions
            pooled = self.global_pool(x, mask=mask_for_pool)
        else:
            pooled = self.global_pool(x)
            
        pooled = self.dropout(pooled, training=training)
        return self.classifier(pooled)

# GPT-style Decoder-only Transformer
class GPTStyleTransformer(tf.keras.Model):
    """
    GPT-style decoder-only transformer for text generation
    Uses causal (look-ahead) masking
    """
    
    def __init__(self, vocab_size, embed_dim=128, num_heads=8, ff_dim=256,
                 num_layers=6, max_length=512, dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.max_length = max_length
        
        # Embedding layers
        self.token_embedding = tf.keras.layers.Embedding(vocab_size, embed_dim)
        self.position_embedding = tf.keras.layers.Embedding(max_length, embed_dim)
        
        # Transformer decoder layers (self-attention only, no cross-attention)
        self.decoder_layers = []
        for i in range(num_layers):
            # Use encoder blocks with causal masking instead of decoder blocks
            self.decoder_layers.append(
                TransformerEncoderBlock(embed_dim, num_heads, ff_dim, dropout_rate, 
                                      name=f'decoder_{i}')
            )
        
        # Output layers
        self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5)
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        self.output_projection = tf.keras.layers.Dense(vocab_size, name='lm_head')
        
    def call(self, inputs, training=None):
        seq_len = tf.shape(inputs)[1]
        
        # Token embeddings
        token_emb = self.token_embedding(inputs)
        
        # Position embeddings
        positions = tf.range(start=0, limit=seq_len, delta=1)
        position_emb = self.position_embedding(positions)
        
        # Combine embeddings
        x = token_emb + position_emb
        x = self.dropout(x, training=training)
        
        # Create causal mask
        causal_mask = create_look_ahead_mask(seq_len)
        causal_mask = causal_mask[tf.newaxis, tf.newaxis, :, :]
        
        # Pass through decoder layers
        for decoder_layer in self.decoder_layers:
            x = decoder_layer(x, training=training, mask=causal_mask)
        
        # Final layer norm and projection
        x = self.layer_norm(x)
        logits = self.output_projection(x)
        
        return logits

# Encoder-Decoder Transformer (like original Transformer paper)
class TransformerSeq2Seq(tf.keras.Model):
    """
    Complete encoder-decoder transformer for sequence-to-sequence tasks
    """
    
    def __init__(self, vocab_size, embed_dim=128, num_heads=8, ff_dim=256,
                 num_encoder_layers=4, num_decoder_layers=4, max_length=512,
                 dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.max_length = max_length
        
        # Shared embedding layer
        self.embedding = tf.keras.layers.Embedding(vocab_size, embed_dim)
        self.pos_encoding = PositionalEncoding(max_position=max_length)
        
        # Encoder layers
        self.encoder_layers = [
            TransformerEncoderBlock(embed_dim, num_heads, ff_dim, dropout_rate)
            for _ in range(num_encoder_layers)
        ]
        
        # Decoder layers
        self.decoder_layers = [
            TransformerDecoderBlock(embed_dim, num_heads, ff_dim, dropout_rate)
            for _ in range(num_decoder_layers)
        ]
        
        # Output projection
        self.final_layer = tf.keras.layers.Dense(vocab_size)
        
    def call(self, inputs, training=None):
        encoder_input, decoder_input = inputs
        
        # Create masks
        enc_padding_mask = create_padding_mask(encoder_input)
        dec_padding_mask = create_padding_mask(encoder_input)  # For cross-attention
        
        look_ahead_mask = create_look_ahead_mask(tf.shape(decoder_input)[1])
        dec_target_padding_mask = create_padding_mask(decoder_input)
        combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
        
        # Encoder
        enc_output = self.embedding(encoder_input)
        enc_output *= tf.math.sqrt(tf.cast(self.embed_dim, tf.float32))
        enc_output = self.pos_encoding(enc_output)
        
        for encoder_layer in self.encoder_layers:
            enc_output = encoder_layer(enc_output, training=training, mask=enc_padding_mask)
        
        # Decoder
        dec_output = self.embedding(decoder_input)
        dec_output *= tf.math.sqrt(tf.cast(self.embed_dim, tf.float32))
        dec_output = self.pos_encoding(dec_output)
        
        for decoder_layer in self.decoder_layers:
            dec_output = decoder_layer(
                dec_output, enc_output, training=training,
                look_ahead_mask=combined_mask, padding_mask=dec_padding_mask
            )
        
        # Output projection
        final_output = self.final_layer(dec_output)
        return final_output

# Test complete transformer models
print("=== Testing Complete Transformer Models ===")

# Model parameters
vocab_size = 5000
max_length = 64  # Smaller for testing
embed_dim = 128
num_heads = 8
ff_dim = 256

# Test BERT-style classifier
print("1. BERT-style Classifier:")
bert_classifier = TransformerClassifier(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_heads=num_heads,
    ff_dim=ff_dim,
    num_layers=3,
    num_classes=3,
    max_length=max_length,
    dropout_rate=0.1
)

sample_input = tf.random.uniform((2, max_length), maxval=vocab_size, dtype=tf.int32)
bert_output = bert_classifier(sample_input)
print(f"  Input shape: {sample_input.shape}")
print(f"  Output shape: {bert_output.shape}")
print(f"  Parameters: {bert_classifier.count_params():,}")

# Test GPT-style model
print("\n2. GPT-style Transformer:")
gpt_model = GPTStyleTransformer(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_heads=num_heads,
    ff_dim=ff_dim,
    num_layers=3,
    max_length=max_length,
    dropout_rate=0.1
)

gpt_output = gpt_model(sample_input)
print(f"  Input shape: {sample_input.shape}")
print(f"  Output shape: {gpt_output.shape}")
print(f"  Parameters: {gpt_model.count_params():,}")

# Test Seq2Seq model
print("\n3. Encoder-Decoder Transformer:")
seq2seq_model = TransformerSeq2Seq(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_heads=num_heads,
    ff_dim=ff_dim,
    num_encoder_layers=2,
    num_decoder_layers=2,
    max_length=max_length,
    dropout_rate=0.1
)

encoder_input = tf.random.uniform((2, max_length), maxval=vocab_size, dtype=tf.int32)
decoder_input = tf.random.uniform((2, max_length), maxval=vocab_size, dtype=tf.int32)
seq2seq_output = seq2seq_model([encoder_input, decoder_input])

print(f"  Encoder input: {encoder_input.shape}")
print(f"  Decoder input: {decoder_input.shape}")
print(f"  Output shape: {seq2seq_output.shape}")
print(f"  Parameters: {seq2seq_model.count_params():,}")

## 6. Training and Fine-tuning Transformers

In [None]:
# Create sample text data for transformer training
def create_sample_datasets():
    """Create sample datasets for transformer training"""
    
    # Sample texts for classification
    tech_texts = [
        "Machine learning algorithms can process vast amounts of data to identify patterns.",
        "Artificial intelligence has revolutionized many industries through automation.",
        "Deep learning models use neural networks with multiple layers.",
        "Computer vision applications can recognize objects in images.",
        "Natural language processing enables machines to understand human language."
    ] * 20
    
    science_texts = [
        "Scientists have discovered new properties of quantum particles.",
        "Climate change research shows increasing global temperatures.",
        "Genetic engineering offers promising medical treatments.",
        "Space exploration missions provide insights into the universe.",
        "Renewable energy sources are becoming more efficient."
    ] * 20
    
    business_texts = [
        "Market analysis indicates growth in the technology sector.",
        "Companies are investing heavily in digital transformation.",
        "Supply chain optimization reduces operational costs.",
        "Customer satisfaction surveys show improved service quality.",
        "Financial reports demonstrate strong quarterly performance."
    ] * 20
    
    all_texts = tech_texts + science_texts + business_texts
    all_labels = [0] * len(tech_texts) + [1] * len(science_texts) + [2] * len(business_texts)
    
    return all_texts, all_labels

# Text preprocessing for transformers
def create_tokenizer_and_preprocess(texts, vocab_size=5000, max_length=128):
    """Create tokenizer and preprocess texts"""
    
    # Simple tokenization using TextVectorization
    tokenizer = tf.keras.layers.TextVectorization(
        max_tokens=vocab_size,
        output_sequence_length=max_length,
        pad_to_max_tokens=False
    )
    
    # Add special tokens
    special_tokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]"]
    tokenizer.adapt(special_tokens + texts)
    
    # Get vocabulary
    vocab = tokenizer.get_vocabulary()
    vocab_size_actual = len(vocab)
    
    print(f"Vocabulary size: {vocab_size_actual}")
    print(f"Sample vocabulary: {vocab[:10]}")
    
    return tokenizer, vocab_size_actual

# Advanced training configuration
class TransformerTrainingConfig:
    """Configuration for transformer training"""
    
    def __init__(self):
        # Model architecture
        self.embed_dim = 128
        self.num_heads = 8
        self.ff_dim = 256
        self.num_layers = 4
        self.dropout_rate = 0.1
        self.max_length = 128
        
        # Training hyperparameters
        self.learning_rate = 1e-4
        self.batch_size = 16
        self.epochs = 10
        self.warmup_steps = 1000
        
        # Optimization
        self.weight_decay = 0.01
        self.gradient_clip_norm = 1.0
        self.label_smoothing = 0.1

def create_learning_rate_schedule(embed_dim, warmup_steps=1000):
    """Create warmup + decay learning rate schedule"""
    
    class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
        def __init__(self, embed_dim, warmup_steps=1000):
            super().__init__()
            self.embed_dim = tf.cast(embed_dim, tf.float32)
            self.warmup_steps = warmup_steps
            
        def __call__(self, step):
            step = tf.cast(step, tf.float32)
            arg1 = tf.math.rsqrt(step)
            arg2 = step * (self.warmup_steps ** -1.5)
            
            return tf.math.rsqrt(self.embed_dim) * tf.math.minimum(arg1, arg2)
    
    return CustomSchedule(embed_dim, warmup_steps)

# Custom callbacks for transformer training
class TransformerTrainingCallbacks:
    """Custom callbacks for transformer training"""
    
    @staticmethod
    def get_callbacks(model_path='best_transformer.h5'):
        return [
            tf.keras.callbacks.ModelCheckpoint(
                model_path,
                monitor='val_accuracy',
                save_best_only=True,
                save_weights_only=False,
                verbose=1
            ),
            tf.keras.callbacks.EarlyStopping(
                monitor='val_accuracy',
                patience=5,
                restore_best_weights=True,
                verbose=1
            ),
            tf.keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.5,
                patience=3,
                min_lr=1e-7,
                verbose=1
            ),
            tf.keras.callbacks.CSVLogger(
                'transformer_training.csv',
                append=True
            )
        ]

# Train transformer classifier
print("=== Training Transformer Classifier ===")

# Create sample data
texts, labels = create_sample_datasets()
print(f"Dataset size: {len(texts)} texts, {len(set(labels))} classes")

# Create tokenizer
tokenizer, actual_vocab_size = create_tokenizer_and_preprocess(texts)

# Encode texts
encoded_texts = tokenizer(texts)
labels_array = tf.constant(labels)

# Split data
train_size = int(0.8 * len(texts))
val_size = int(0.1 * len(texts))

X_train = encoded_texts[:train_size]
y_train = labels_array[:train_size]
X_val = encoded_texts[train_size:train_size + val_size]
y_val = labels_array[train_size:train_size + val_size]
X_test = encoded_texts[train_size + val_size:]
y_test = labels_array[train_size + val_size:]

print(f"Training data: {X_train.shape}")
print(f"Validation data: {X_val.shape}")
print(f"Test data: {X_test.shape}")

# Create and configure model
config = TransformerTrainingConfig()
model = TransformerClassifier(
    vocab_size=actual_vocab_size,
    embed_dim=config.embed_dim,
    num_heads=config.num_heads,
    ff_dim=config.ff_dim,
    num_layers=config.num_layers,
    num_classes=3,
    max_length=config.max_length,
    dropout_rate=config.dropout_rate
)

# Create learning rate schedule
lr_schedule = create_learning_rate_schedule(config.embed_dim, config.warmup_steps)

# Compile model with AdamW optimizer
optimizer = tf.keras.optimizers.AdamW(
    learning_rate=lr_schedule,
    weight_decay=config.weight_decay,
    beta_1=0.9,
    beta_2=0.999,
    epsilon=1e-7,
    clipnorm=config.gradient_clip_norm
)

# Label smoothing loss
def label_smoothing_loss(y_true, y_pred, smoothing=0.1):
    """Apply label smoothing to categorical crossentropy"""
    num_classes = tf.shape(y_pred)[-1]
    y_true = tf.one_hot(tf.cast(y_true, tf.int32), num_classes)
    
    # Apply smoothing
    y_true = y_true * (1 - smoothing) + smoothing / tf.cast(num_classes, tf.float32)
    
    return tf.keras.losses.categorical_crossentropy(y_true, y_pred)

model.compile(
    optimizer=optimizer,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy']
)

print(f"Model compiled. Total parameters: {model.count_params():,}")

# Train model
callbacks = TransformerTrainingCallbacks.get_callbacks()

history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    batch_size=config.batch_size,
    epochs=config.epochs,
    callbacks=callbacks,
    verbose=1
)

# Evaluate model
test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)
print(f"\nTest Results:")
print(f"  Test Loss: {test_loss:.4f}")
print(f"  Test Accuracy: {test_accuracy:.4f}")

# Plot training history
def plot_training_history(history):
    """Plot training metrics"""
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Accuracy
    axes[0].plot(history.history['accuracy'], label='Training', marker='o')
    axes[0].plot(history.history['val_accuracy'], label='Validation', marker='s')
    axes[0].set_title('Model Accuracy')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Accuracy')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Loss
    axes[1].plot(history.history['loss'], label='Training', marker='o')
    axes[1].plot(history.history['val_loss'], label='Validation', marker='s')
    axes[1].set_title('Model Loss')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Loss')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

plot_training_history(history)

## 7. Advanced Transformer Techniques

In [None]:
# Memory-efficient attention mechanisms
class EfficientMultiHeadAttention(tf.keras.layers.Layer):
    """Memory-efficient multi-head attention with optional techniques"""
    
    def __init__(self, embed_dim, num_heads, dropout_rate=0.1, 
                 use_relative_position=False, max_relative_position=32, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.use_relative_position = use_relative_position
        self.max_relative_position = max_relative_position
        
        assert embed_dim % num_heads == 0
        self.depth = embed_dim // num_heads
        
        # Use single dense layer for efficiency
        self.qkv_dense = tf.keras.layers.Dense(embed_dim * 3, use_bias=False)
        self.output_dense = tf.keras.layers.Dense(embed_dim)
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        
        # Relative position embeddings if enabled
        if use_relative_position:
            self.relative_position_k = tf.keras.layers.Embedding(
                2 * max_relative_position + 1, self.depth
            )
            self.relative_position_v = tf.keras.layers.Embedding(
                2 * max_relative_position + 1, self.depth
            )
    
    def get_relative_positions(self, seq_len):
        """Get relative position indices"""
        positions = tf.range(seq_len)[:, tf.newaxis] - tf.range(seq_len)[tf.newaxis, :]
        positions = tf.clip_by_value(positions, -self.max_relative_position, self.max_relative_position)
        return positions + self.max_relative_position
    
    def call(self, inputs, training=None, mask=None):
        batch_size = tf.shape(inputs)[0]
        seq_len = tf.shape(inputs)[1]
        
        # Compute Q, K, V in one shot
        qkv = self.qkv_dense(inputs)
        qkv = tf.reshape(qkv, (batch_size, seq_len, 3, self.num_heads, self.depth))
        qkv = tf.transpose(qkv, perm=[2, 0, 3, 1, 4])
        
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        
        # Scaled dot-product attention
        scores = tf.matmul(queries, keys, transpose_b=True)
        
        # Add relative position bias if enabled
        if self.use_relative_position:
            relative_positions = self.get_relative_positions(seq_len)
            relative_position_scores_k = self.relative_position_k(relative_positions)
            relative_position_scores_k = tf.transpose(relative_position_scores_k, [2, 0, 1])
            scores += tf.einsum('bhqd,dhv->bhqv', queries, relative_position_scores_k)
        
        # Scale and apply mask
        scores /= tf.math.sqrt(tf.cast(self.depth, tf.float32))
        
        if mask is not None:
            scores += (mask * -1e9)
        
        attention_weights = tf.nn.softmax(scores, axis=-1)
        attention_weights = self.dropout(attention_weights, training=training)
        
        # Apply attention to values
        attention_output = tf.matmul(attention_weights, values)
        
        # Add relative position bias to values if enabled
        if self.use_relative_position:
            relative_position_scores_v = self.relative_position_v(relative_positions)
            relative_position_scores_v = tf.transpose(relative_position_scores_v, [2, 0, 1])
            attention_output += tf.einsum('bhqv,dhv->bhqd', attention_weights, relative_position_scores_v)
        
        # Concatenate heads
        attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
        attention_output = tf.reshape(attention_output, (batch_size, seq_len, self.embed_dim))
        
        # Output projection
        output = self.output_dense(attention_output)
        return output

# Transformer with advanced techniques
class AdvancedTransformerBlock(tf.keras.layers.Layer):
    """Advanced transformer block with modern improvements"""
    
    def __init__(self, embed_dim, num_heads, ff_dim, dropout_rate=0.1,
                 activation='gelu', pre_norm=True, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.pre_norm = pre_norm
        
        # Multi-head attention
        self.attention = EfficientMultiHeadAttention(
            embed_dim, num_heads, dropout_rate, use_relative_position=True
        )
        
        # Feed-forward network with GELU activation
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(ff_dim, activation=activation),
            tf.keras.layers.Dropout(dropout_rate),
            tf.keras.layers.Dense(embed_dim)
        ])
        
        # Layer normalization
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        
        # Dropout
        self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
    
    def call(self, inputs, training=None, mask=None):
        if self.pre_norm:
            # Pre-normalization (more stable training)
            normed_inputs = self.layernorm1(inputs)
            attention_output = self.attention(normed_inputs, training=training, mask=mask)
            attention_output = self.dropout1(attention_output, training=training)
            out1 = inputs + attention_output
            
            normed_out1 = self.layernorm2(out1)
            ffn_output = self.ffn(normed_out1, training=training)
            ffn_output = self.dropout2(ffn_output, training=training)
            return out1 + ffn_output
        else:
            # Post-normalization (original transformer)
            attention_output = self.attention(inputs, training=training, mask=mask)
            attention_output = self.dropout1(attention_output, training=training)
            out1 = self.layernorm1(inputs + attention_output)
            
            ffn_output = self.ffn(out1, training=training)
            ffn_output = self.dropout2(ffn_output, training=training)
            return self.layernorm2(out1 + ffn_output)

# Sparse attention patterns for long sequences
class SparseAttentionTransformer(tf.keras.Model):
    """Transformer with sparse attention for handling long sequences"""
    
    def __init__(self, vocab_size, embed_dim=128, num_heads=8, ff_dim=256,
                 num_layers=4, max_length=1024, attention_pattern='local',
                 window_size=256, **kwargs):
        super().__init__(**kwargs)
        
        self.embed_dim = embed_dim
        self.attention_pattern = attention_pattern
        self.window_size = window_size
        
        # Embedding
        self.embedding = tf.keras.layers.Embedding(vocab_size, embed_dim)
        self.pos_encoding = PositionalEncoding(max_position=max_length)
        
        # Transformer layers with sparse attention
        self.transformer_layers = [
            AdvancedTransformerBlock(embed_dim, num_heads, ff_dim)
            for _ in range(num_layers)
        ]
        
        self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.output_projection = tf.keras.layers.Dense(vocab_size)
    
    def create_sparse_attention_mask(self, seq_len):
        """Create sparse attention mask based on pattern"""
        
        if self.attention_pattern == 'local':
            # Local attention: attend to nearby tokens
            mask = tf.ones((seq_len, seq_len)) * -1e9
            
            for i in range(seq_len):
                start = max(0, i - self.window_size // 2)
                end = min(seq_len, i + self.window_size // 2 + 1)
                mask = tf.tensor_scatter_nd_update(
                    mask, 
                    [[i, j] for j in range(start, end)], 
                    tf.zeros(end - start)
                )
        
        elif self.attention_pattern == 'strided':
            # Strided attention: attend to every k-th token
            mask = tf.ones((seq_len, seq_len)) * -1e9
            stride = self.window_size
            
            for i in range(seq_len):
                # Local attention
                local_start = max(0, i - 64)
                local_end = min(seq_len, i + 64)
                
                # Global attention (strided)
                global_indices = list(range(0, seq_len, stride))
                
                all_indices = list(range(local_start, local_end)) + global_indices
                all_indices = list(set(all_indices))  # Remove duplicates
                
                for j in all_indices:
                    if j < seq_len:
                        mask = tf.tensor_scatter_nd_update(mask, [[i, j]], [0.0])
        
        else:  # 'full' attention
            mask = tf.zeros((seq_len, seq_len))
        
        return mask[tf.newaxis, tf.newaxis, :, :]  # Add batch and head dimensions
    
    def call(self, inputs, training=None):
        seq_len = tf.shape(inputs)[1]
        
        # Embeddings
        x = self.embedding(inputs)
        x *= tf.math.sqrt(tf.cast(self.embed_dim, tf.float32))
        x = self.pos_encoding(x)
        
        # Create sparse attention mask
        sparse_mask = self.create_sparse_attention_mask(seq_len)
        
        # Apply transformer layers
        for layer in self.transformer_layers:
            x = layer(x, training=training, mask=sparse_mask)
        
        # Output
        x = self.layer_norm(x)
        return self.output_projection(x)

# Test advanced techniques
print("=== Testing Advanced Transformer Techniques ===")

# Test efficient attention
print("1. Efficient Multi-Head Attention:")
efficient_attention = EfficientMultiHeadAttention(
    embed_dim=128, num_heads=8, use_relative_position=True
)

sample_input = tf.random.normal((2, 32, 128))
efficient_output = efficient_attention(sample_input)

print(f"  Input shape: {sample_input.shape}")
print(f"  Output shape: {efficient_output.shape}")
print(f"  Parameters: {efficient_attention.count_params():,}")

# Test advanced transformer block
print("\n2. Advanced Transformer Block:")
advanced_block = AdvancedTransformerBlock(
    embed_dim=128, num_heads=8, ff_dim=256, 
    activation='gelu', pre_norm=True
)

advanced_output = advanced_block(sample_input)
print(f"  Output shape: {advanced_output.shape}")
print(f"  Parameters: {advanced_block.count_params():,}")

# Test sparse attention transformer
print("\n3. Sparse Attention Transformer:")
sparse_transformer = SparseAttentionTransformer(
    vocab_size=5000,
    embed_dim=128,
    num_heads=8,
    ff_dim=256,
    num_layers=3,
    max_length=512,
    attention_pattern='local',
    window_size=64
)

sparse_input = tf.random.uniform((2, 128), maxval=5000, dtype=tf.int32)
sparse_output = sparse_transformer(sparse_input)

print(f"  Input shape: {sparse_input.shape}")
print(f"  Output shape: {sparse_output.shape}")
print(f"  Parameters: {sparse_transformer.count_params():,}")

# Visualize sparse attention patterns
seq_len = 20
local_mask = sparse_transformer.create_sparse_attention_mask(seq_len)
local_mask_viz = tf.where(local_mask[0, 0] == 0, 1.0, 0.0)  # Invert for visualization

plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
sns.heatmap(local_mask_viz.numpy(), cmap='Blues', cbar=True, square=True)
plt.title('Local Sparse Attention Pattern')
plt.xlabel('Key Position')
plt.ylabel('Query Position')

# Create strided pattern for comparison
sparse_transformer.attention_pattern = 'strided'
sparse_transformer.window_size = 8
strided_mask = sparse_transformer.create_sparse_attention_mask(seq_len)
strided_mask_viz = tf.where(strided_mask[0, 0] == 0, 1.0, 0.0)

plt.subplot(1, 2, 2)
sns.heatmap(strided_mask_viz.numpy(), cmap='Greens', cbar=True, square=True)
plt.title('Strided Sparse Attention Pattern')
plt.xlabel('Key Position')
plt.ylabel('Query Position')

plt.tight_layout()
plt.show()

## Summary

This comprehensive notebook has covered the complete implementation of Transformer architectures using tf.keras:

### Key Concepts Mastered

**1. Attention Mechanisms:**
- Scaled dot-product attention from scratch
- Multi-head attention implementation and comparison
- Attention visualization and interpretation

**2. Transformer Components:**
- Positional encoding with sinusoidal patterns
- Encoder and decoder blocks with residual connections
- Layer normalization and feed-forward networks

**3. Complete Architectures:**
- BERT-style encoder-only models for classification
- GPT-style decoder-only models for generation
- Full encoder-decoder models for sequence-to-sequence tasks

**4. Training and Optimization:**
- Custom learning rate schedules with warmup
- Label smoothing and advanced optimization techniques
- Comprehensive evaluation and visualization

**5. Advanced Techniques:**
- Memory-efficient attention mechanisms
- Relative position embeddings
- Sparse attention patterns for long sequences
- Pre-normalization and modern architectural improvements

### Practical Applications

The implementations in this notebook enable you to:
- Build state-of-the-art NLP models for classification, generation, and translation
- Handle sequences of varying lengths efficiently
- Scale to longer sequences using sparse attention
- Apply transfer learning principles to transformer models
- Optimize models for production deployment

### Next Steps

Continue to notebook 12 to explore TensorFlow Hub integration and pre-trained transformer models, where you'll learn to leverage existing models and fine-tune them for specific tasks.

The transformer architecture has revolutionized NLP and continues to be the foundation for the most advanced language models. The implementations here provide a solid foundation for understanding and building upon these powerful architectures.