# File: notebooks/03_neural_networks/10_attention_from_scratch.ipynb

## JAX Neural Networks: Attention from Scratch

This notebook implements attention mechanisms from scratch in JAX, including scaled dot-product attention, multi-head attention, and a basic transformer layer. We'll build these fundamental components that power modern language models and vision transformers.

Attention mechanisms allow models to focus on relevant parts of the input when processing sequences, enabling better long-range dependencies and more interpretable models.

## Setting Up the Environment

```python
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random, lax
from jax.nn import softmax, gelu
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Dict, Any, Optional
import functools
import math

jax.config.update("jax_enable_x64", True)
print(f"JAX version: {jax.__version__}")
```

## Scaled Dot-Product Attention

### Basic Attention Implementation

```python
def scaled_dot_product_attention(query, key, value, mask=None, dropout_rate=0.0, key_dropout=None, training=True):
    """
    Scaled dot-product attention mechanism
    
    Args:
        query: Query tensor of shape (..., seq_len_q, d_k)
        key: Key tensor of shape (..., seq_len_k, d_k)  
        value: Value tensor of shape (..., seq_len_v, d_v)
        mask: Optional attention mask
        dropout_rate: Dropout rate for attention weights
        key_dropout: Random key for dropout
        training: Whether in training mode
    
    Returns:
        output: Attention output of shape (..., seq_len_q, d_v)
        attention_weights: Attention weights of shape (..., seq_len_q, seq_len_k)
    """
    
    # Compute attention scores
    d_k = query.shape[-1]
    scores = jnp.matmul(query, jnp.swapaxes(key, -2, -1)) / jnp.sqrt(d_k)
    
    # Apply mask if provided
    if mask is not None:
        scores = jnp.where(mask, scores, -jnp.inf)
    
    # Apply softmax to get attention weights
    attention_weights = softmax(scores, axis=-1)
    
    # Apply dropout to attention weights
    if training and dropout_rate > 0.0 and key_dropout is not None:
        keep_prob = 1.0 - dropout_rate
        mask_dropout = random.bernoulli(key_dropout, keep_prob, attention_weights.shape)
        attention_weights = jnp.where(mask_dropout, attention_weights / keep_prob, 0.0)
    
    # Apply attention to values
    output = jnp.matmul(attention_weights, value)
    
    return output, attention_weights

def test_basic_attention():
    """Test basic attention mechanism"""
    
    key = random.PRNGKey(42)
    batch_size, seq_len, d_model = 2, 8, 64
    
    # Create random query, key, value tensors
    query = random.normal(key, (batch_size, seq_len, d_model))
    key_tensor = random.normal(random.split(key)[1], (batch_size, seq_len, d_model))
    value = random.normal(random.split(key, 3)[2], (batch_size, seq_len, d_model))
    
    # Apply attention
    output, weights = scaled_dot_product_attention(query, key_tensor, value)
    
    print("Basic Attention Test:")
    print(f"Input shapes: Q{query.shape}, K{key_tensor.shape}, V{value.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Attention weights shape: {weights.shape}")
    print(f"Attention weights sum along last axis: {jnp.sum(weights, axis=-1)[0, 0]:.6f}")
    
    return output, weights

attention_output, attention_weights = test_basic_attention()
```

### Causal (Masked) Attention

```python
def create_causal_mask(seq_len):
    """Create causal mask for autoregressive attention"""
    mask = jnp.tril(jnp.ones((seq_len, seq_len)))
    return mask == 1  # Convert to boolean mask

def create_padding_mask(lengths, max_len):
    """Create padding mask for variable length sequences"""
    positions = jnp.arange(max_len)[None, :]  # (1, max_len)
    lengths = lengths[:, None]  # (batch_size, 1)
    return positions < lengths  # (batch_size, max_len)

def test_masked_attention():
    """Test attention with causal masking"""
    
    key = random.PRNGKey(123)
    batch_size, seq_len, d_model = 1, 6, 32
    
    # Create input tensors
    query = random.normal(key, (batch_size, seq_len, d_model))
    key_tensor = query  # Self-attention
    value = query
    
    # Create causal mask
    causal_mask = create_causal_mask(seq_len)
    causal_mask = jnp.expand_dims(causal_mask, 0)  # Add batch dimension
    
    # Apply masked attention
    output, weights = scaled_dot_product_attention(query, key_tensor, value, mask=causal_mask)
    
    print("Masked Attention Test:")
    print(f"Causal mask shape: {causal_mask.shape}")
    print(f"Causal mask:\n{causal_mask[0].astype(int)}")
    print(f"Attention weights (position 0): {weights[0, 0, :]}")
    print(f"Attention weights (position 3): {weights[0, 3, :]}")
    
    # Verify causality: later positions should have zero attention to future
    future_attention = weights[0, 2, 4:]  # Position 2 attending to positions 4+
    print(f"Future attention (should be ~0): {future_attention}")

test_masked_attention()
```

## Multi-Head Attention

### Multi-Head Attention Implementation

```python
class MultiHeadAttention:
    """Multi-head attention mechanism"""
    
    def __init__(self, d_model, num_heads, dropout_rate=0.1):
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.dropout_rate = dropout_rate
    
    def init_params(self, key):
        """Initialize parameters for multi-head attention"""
        keys = random.split(key, 4)
        
        # Xavier/Glorot initialization
        def init_linear(key, shape):
            fan_in, fan_out = shape[0], shape[1]
            limit = jnp.sqrt(6.0 / (fan_in + fan_out))
            return random.uniform(key, shape, minval=-limit, maxval=limit)
        
        params = {
            'W_q': init_linear(keys[0], (self.d_model, self.d_model)),
            'W_k': init_linear(keys[1], (self.d_model, self.d_model)),
            'W_v': init_linear(keys[2], (self.d_model, self.d_model)),
            'W_o': init_linear(keys[3], (self.d_model, self.d_model))
        }
        
        return params
    
    def split_heads(self, x):
        """Split the last dimension into (num_heads, d_k)"""
        batch_size, seq_len, d_model = x.shape
        x = x.reshape(batch_size, seq_len, self.num_heads, self.d_k)
        return jnp.transpose(x, (0, 2, 1, 3))  # (batch_size, num_heads, seq_len, d_k)
    
    def combine_heads(self, x):
        """Combine heads back into single dimension"""
        batch_size, num_heads, seq_len, d_k = x.shape
        x = jnp.transpose(x, (0, 2, 1, 3))  # (batch_size, seq_len, num_heads, d_k)
        return x.reshape(batch_size, seq_len, self.d_model)
    
    def forward(self, params, query, key, value, mask=None, key_dropout=None, training=True):
        """Forward pass through multi-head attention"""
        
        batch_size, seq_len, _ = query.shape
        
        # Linear transformations
        Q = jnp.dot(query, params['W_q'])
        K = jnp.dot(key, params['W_k'])  
        V = jnp.dot(value, params['W_v'])
        
        # Split into multiple heads
        Q = self.split_heads(Q)  # (batch_size, num_heads, seq_len, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # Apply scaled dot-product attention
        attention_output, attention_weights = scaled_dot_product_attention(
            Q, K, V, mask, self.dropout_rate, key_dropout, training
        )
        
        # Combine heads
        attention_output = self.combine_heads(attention_output)
        
        # Final linear transformation
        output = jnp.dot(attention_output, params['W_o'])
        
        return output, attention_weights
    
    def __call__(self, params, query, key, value, mask=None, key_dropout=None, training=True):
        """Make the class callable"""
        return self.forward(params, query, key, value, mask, key_dropout, training)

def test_multihead_attention():
    """Test multi-head attention implementation"""
    
    key = random.PRNGKey(456)
    batch_size, seq_len, d_model = 2, 10, 128
    num_heads = 8
    
    # Create multi-head attention layer
    mha = MultiHeadAttention(d_model, num_heads, dropout_rate=0.1)
    params = mha.init_params(key)
    
    # Create input tensors
    x = random.normal(random.split(key)[1], (batch_size, seq_len, d_model))
    
    # Self-attention
    output, weights = mha(params, x, x, x, training=False)
    
    print("Multi-Head Attention Test:")
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Attention weights shape: {weights.shape}")
    print(f"Number of heads: {num_heads}")
    print(f"d_k per head: {mha.d_k}")
    
    # Test with causal mask
    causal_mask = create_causal_mask(seq_len)
    causal_mask = jnp.expand_dims(causal_mask, (0, 1))  # Add batch and head dimensions
    
    output_masked, weights_masked = mha(params, x, x, x, mask=causal_mask, training=False)
    print(f"Masked output shape: {output_masked.shape}")
    
    return mha, params

mha, mha_params = test_multihead_attention()
```

## Position Encoding

### Positional Encoding Implementation

```python
def positional_encoding(seq_len, d_model):
    """Generate sinusoidal positional encoding"""
    
    position = jnp.arange(seq_len)[:, None]  # (seq_len, 1)
    div_term = jnp.exp(jnp.arange(0, d_model, 2) * -(jnp.log(10000.0) / d_model))
    
    pe = jnp.zeros((seq_len, d_model))
    pe = pe.at[:, 0::2].set(jnp.sin(position * div_term))
    pe = pe.at[:, 1::2].set(jnp.cos(position * div_term))
    
    return pe

def learned_positional_encoding(key, seq_len, d_model):
    """Generate learned positional embeddings"""
    return random.normal(key, (seq_len, d_model)) * 0.1

def test_positional_encoding():
    """Test positional encoding implementations"""
    
    seq_len, d_model = 50, 128
    
    # Sinusoidal encoding
    sin_pe = positional_encoding(seq_len, d_model)
    
    # Learned encoding
    key = random.PRNGKey(789)
    learned_pe = learned_positional_encoding(key, seq_len, d_model)
    
    print("Positional Encoding Test:")
    print(f"Sinusoidal PE shape: {sin_pe.shape}")
    print(f"Learned PE shape: {learned_pe.shape}")
    
    # Visualize first few dimensions
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    # Plot sinusoidal encoding
    ax1.imshow(sin_pe[:, :20].T, cmap='RdBu', aspect='auto')
    ax1.set_title('Sinusoidal Positional Encoding')
    ax1.set_xlabel('Position')
    ax1.set_ylabel('Dimension')
    
    # Plot learned encoding
    ax2.imshow(learned_pe[:, :20].T, cmap='RdBu', aspect='auto')
    ax2.set_title('Learned Positional Encoding')
    ax2.set_xlabel('Position')
    ax2.set_ylabel('Dimension')
    
    plt.tight_layout()
    plt.show()
    
    return sin_pe, learned_pe

sin_pe, learned_pe = test_positional_encoding()
```

## Transformer Layer Components

### Feed-Forward Network

```python
class FeedForward:
    """Position-wise feed-forward network"""
    
    def __init__(self, d_model, d_ff, dropout_rate=0.1, activation='gelu'):
        self.d_model = d_model
        self.d_ff = d_ff
        self.dropout_rate = dropout_rate
        self.activation = gelu if activation == 'gelu' else jax.nn.relu
    
    def init_params(self, key):
        """Initialize feed-forward parameters"""
        keys = random.split(key, 2)
        
        # Xavier initialization
        def init_linear(key, shape):
            fan_in, fan_out = shape[0], shape[1]
            limit = jnp.sqrt(6.0 / (fan_in + fan_out))
            return random.uniform(key, shape, minval=-limit, maxval=limit)
        
        params = {
            'W1': init_linear(keys[0], (self.d_model, self.d_ff)),
            'b1': jnp.zeros(self.d_ff),
            'W2': init_linear(keys[1], (self.d_ff, self.d_model)),
            'b2': jnp.zeros(self.d_model)
        }
        
        return params
    
    def forward(self, params, x, key_dropout=None, training=True):
        """Forward pass through feed-forward network"""
        
        # First linear layer + activation
        hidden = jnp.dot(x, params['W1']) + params['b1']
        hidden = self.activation(hidden)
        
        # Dropout
        if training and self.dropout_rate > 0.0 and key_dropout is not None:
            keep_prob = 1.0 - self.dropout_rate
            dropout_mask = random.bernoulli(key_dropout, keep_prob, hidden.shape)
            hidden = jnp.where(dropout_mask, hidden / keep_prob, 0.0)
        
        # Second linear layer
        output = jnp.dot(hidden, params['W2']) + params['b2']
        
        return output
    
    def __call__(self, params, x, key_dropout=None, training=True):
        """Make the class callable"""
        return self.forward(params, x, key_dropout, training)

def test_feedforward():
    """Test feed-forward network"""
    
    key = random.PRNGKey(111)
    batch_size, seq_len, d_model, d_ff = 2, 10, 128, 512
    
    ff = FeedForward(d_model, d_ff, dropout_rate=0.1)
    params = ff.init_params(key)
    
    x = random.normal(random.split(key)[1], (batch_size, seq_len, d_model))
    output = ff(params, x, training=False)
    
    print("Feed-Forward Test:")
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"d_ff (hidden size): {d_ff}")
    
    return ff, params

ff, ff_params = test_feedforward()
```

### Layer Normalization

```python
def layer_norm(x, gamma, beta, eps=1e-6):
    """Layer normalization"""
    mean = jnp.mean(x, axis=-1, keepdims=True)
    var = jnp.var(x, axis=-1, keepdims=True)
    normalized = (x - mean) / jnp.sqrt(var + eps)
    return gamma * normalized + beta

def init_layer_norm(d_model):
    """Initialize layer norm parameters"""
    gamma = jnp.ones(d_model)
    beta = jnp.zeros(d_model)
    return {'gamma': gamma, 'beta': beta}

def test_layer_norm():
    """Test layer normalization"""
    
    key = random.PRNGKey(222)
    batch_size, seq_len, d_model = 2, 10, 64
    
    x = random.normal(key, (batch_size, seq_len, d_model))
    ln_params = init_layer_norm(d_model)
    
    normalized = layer_norm(x, ln_params['gamma'], ln_params['beta'])
    
    print("Layer Normalization Test:")
    print(f"Input mean: {jnp.mean(x):.4f}, std: {jnp.std(x):.4f}")
    print(f"Output mean: {jnp.mean(normalized):.4f}, std: {jnp.std(normalized):.4f}")
    print(f"Per-sample output mean: {jnp.mean(normalized, axis=-1)[0, 0]:.6f}")
    print(f"Per-sample output std: {jnp.std(normalized, axis=-1)[0, 0]:.6f}")

test_layer_norm()
```

## Complete Transformer Layer

### Transformer Encoder Layer

```python
class TransformerLayer:
    """Complete transformer encoder layer"""
    
    def __init__(self, d_model, num_heads, d_ff, dropout_rate=0.1):
        self.d_model = d_model
        self.mha = MultiHeadAttention(d_model, num_heads, dropout_rate)
        self.ff = FeedForward(d_model, d_ff, dropout_rate)
        self.dropout_rate = dropout_rate
    
    def init_params(self, key):
        """Initialize all layer parameters"""
        keys = random.split(key, 4)
        
        params = {
            'mha': self.mha.init_params(keys[0]),
            'ff': self.ff.init_params(keys[1]),
            'ln1': init_layer_norm(self.d_model),
            'ln2': init_layer_norm(self.d_model)
        }
        
        return params
    
    def forward(self, params, x, mask=None, key_dropout=None, training=True):
        """Forward pass through transformer layer"""
        
        if key_dropout is not None:
            key_dropout, key_dropout_ff = random.split(key_dropout)
        else:
            key_dropout_ff = None
        
        # Multi-head attention with residual connection and layer norm
        attn_output, attn_weights = self.mha(
            params['mha'], x, x, x, mask, key_dropout, training
        )
        
        # Residual connection + layer norm (Pre-LN variant)
        x = layer_norm(x + attn_output, params['ln1']['gamma'], params['ln1']['beta'])
        
        # Feed-forward with residual connection and layer norm
        ff_output = self.ff(params['ff'], x, key_dropout_ff, training)
        x = layer_norm(x + ff_output, params['ln2']['gamma'], params['ln2']['beta'])
        
        return x, attn_weights
    
    def __call__(self, params, x, mask=None, key_dropout=None, training=True):
        """Make the class callable"""
        return self.forward(params, x, mask, key_dropout, training)

def test_transformer_layer():
    """Test complete transformer layer"""
    
    key = random.PRNGKey(333)
    batch_size, seq_len, d_model = 2, 12, 256
    num_heads, d_ff = 8, 1024
    
    transformer = TransformerLayer(d_model, num_heads, d_ff, dropout_rate=0.1)
    params = transformer.init_params(key)
    
    # Create input with positional encoding
    x = random.normal(random.split(key)[1], (batch_size, seq_len, d_model))
    pos_encoding = positional_encoding(seq_len, d_model)
    x_with_pos = x + pos_encoding[None, :, :]  # Add batch dimension to pos encoding
    
    # Forward pass
    output, attn_weights = transformer(params, x_with_pos, training=False)
    
    print("Transformer Layer Test:")
    print(f"Input shape: {x_with_pos.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Attention weights shape: {attn_weights.shape}")
    print(f"Input/output same shape: {x_with_pos.shape == output.shape}")
    
    # Test with causal mask
    causal_mask = create_causal_mask(seq_len)
    causal_mask = jnp.expand_dims(causal_mask, (0, 1))
    
    output_masked, attn_weights_masked = transformer(
        params, x_with_pos, mask=causal_mask, training=False
    )
    
    print(f"Masked output shape: {output_masked.shape}")
    
    # Count parameters
    total_params = 0
    for component, component_params in params.items():
        component_total = sum(p.size for p in jax.tree_leaves(component_params))
        total_params += component_total
        print(f"{component}: {component_total:,} parameters")
    
    print(f"Total parameters: {total_params:,}")
    
    return transformer, params

transformer, transformer_params = test_transformer_layer()
```

## Simple Sequence-to-Sequence Task

### Training a Transformer on Copy Task

```python
def create_copy_task_data(key, seq_len=10, vocab_size=20, n_samples=1000):
    """Create simple copy task dataset"""
    
    # Generate random sequences
    sequences = random.randint(key, (n_samples, seq_len), 0, vocab_size)
    
    # Input: [SOS, seq], Target: [seq, EOS]
    sos_token = vocab_size  # Special start token
    eos_token = vocab_size + 1  # Special end token
    
    # Input sequences with SOS
    inputs = jnp.concatenate([
        jnp.full((n_samples, 1), sos_token),
        sequences
    ], axis=1)
    
    # Target sequences with EOS
    targets = jnp.concatenate([
        sequences,
        jnp.full((n_samples, 1), eos_token)
    ], axis=1)
    
    return inputs, targets

def train_copy_task():
    """Train transformer on simple copy task"""
    
    key = random.PRNGKey(42)
    seq_len, vocab_size = 8, 16
    d_model, num_heads, d_ff = 64, 4, 256
    
    # Create transformer
    transformer = TransformerLayer(d_model, num_heads, d_ff)
    transformer_params = transformer.init_params(key)
    
    # Create embedding layer
    embedding_key = random.split(key)[1]
    embedding_matrix = random.normal(embedding_key, (vocab_size + 2, d_model)) * 0.1
    
    # Create output projection
    output_key = random.split(key, 3)[2]
    output_matrix = random.normal(output_key, (d_model, vocab_size + 2)) * 0.1
    
    # Generate data
    data_key = random.split(key, 4)[3]
    inputs, targets = create_copy_task_data(data_key, seq_len, vocab_size, n_samples=500)
    
    # Simple training function
    def model_fn(params, inputs):
        # Embed inputs
        embedded = embedding_matrix[inputs]  # (batch, seq_len+1, d_model)
        
        # Add positional encoding
        pos_enc = positional_encoding(embedded.shape[1], d_model)
        embedded = embedded + pos_enc[None, :, :]
        
        # Forward through transformer
        output, _ = transformer(params, embedded, training=True)
        
        # Project to vocabulary
        logits = jnp.dot(output, output_matrix)
        
        return logits
    
    # Loss function
    def loss_fn(params, inputs, targets):
        logits = model_fn(params, inputs)
        targets_one_hot = jax.nn.one_hot(targets, vocab_size + 2)
        
        log_probs = jax.nn.log_softmax(logits)
        loss = -jnp.mean(jnp.sum(targets_one_hot * log_probs, axis=-1))
        
        return loss
    
    # Simple training loop
    learning_rate = 0.001
    grad_fn = grad(loss_fn)
    
    # Split data
    n_train = 400
    train_inputs, test_inputs = inputs[:n_train], inputs[n_train:]
    train_targets, test_targets = targets[:n_train], targets[n_train:]
    
    print("Training Copy Task:")
    print(f"Vocab size: {vocab_size}, Seq length: {seq_len}")
    print(f"Training samples: {n_train}, Test samples: {len(test_inputs)}")
    
    # Training epochs
    for epoch in range(20):
        # Compute gradients
        grads = grad_fn(transformer_params, train_inputs, train_targets)
        
        # Simple SGD update
        transformer_params = jax.tree_map(
            lambda p, g: p - learning_rate * g,
            transformer_params, grads
        )
        
        # Evaluate
        train_loss = loss_fn(transformer_params, train_inputs, train_targets)
        test_loss = loss_fn(transformer_params, test_inputs, test_targets)
        
        if epoch % 5 == 0:
            print(f"Epoch {epoch:2d}: train_loss={train_loss:.4f}, test_loss={test_loss:.4f}")
    
    # Test model predictions
    test_logits = model_fn(transformer_params, test_inputs[:5])
    predictions = jnp.argmax(test_logits, axis=-1)
    
    print("\nSample Predictions:")
    for i in range(3):
        print(f"Input:  {test_inputs[i, 1:].tolist()}")  # Remove SOS
        print(f"Target: {test_targets[i, :-1].tolist()}")  # Remove EOS
        print(f"Pred:   {predictions[i, :-1].tolist()}")  # Remove EOS position
        print()

train_copy_task()
```

## Summary

In this notebook, we've implemented attention mechanisms from scratch in JAX:

**Core Components:**

1. **Scaled Dot-Product Attention**: Foundation of all attention mechanisms
2. **Multi-Head Attention**: Parallel attention heads for different representation subspaces
3. **Positional Encoding**: Sinusoidal and learned position representations
4. **Feed-Forward Networks**: Position-wise fully connected layers
5. **Layer Normalization**: Stabilization technique for deep networks

**Transformer Architecture:**
- Complete transformer encoder layer with residual connections
- Pre-layer normalization variant
- Causal masking for autoregressive generation
- Proper parameter initialization

**Key Insights:**
- Attention allows dynamic focus on relevant input positions
- Multi-head attention captures different types of relationships
- Positional encoding is crucial for sequence understanding
- Layer normalization and residuals enable deep architectures

**JAX Implementation Benefits:**
- Efficient matrix operations for attention computation
- Automatic differentiation through complex attention mechanics
- JIT compilation for performance optimization
- Functional programming with clean parameter management

**Training Observations:**
- Simple copy task tests basic sequence modeling capability
- Proper initialization crucial for attention stability
- Positional encoding essential for sequence order understanding
- Masking enables causal language modeling

**Next Steps:**
- The next notebook will cover optimizers in detail
- We'll explore advanced optimization techniques for training
- Understanding attention enables modern transformer architectures

This attention implementation provides the foundation for transformer models, which have revolutionized natural language processing and are increasingly used in computer vision and other domains.