# File: notebooks/03_neural_networks/09_cnn_minimal.ipynb

## JAX Neural Networks: Minimal CNN Implementation

This notebook implements Convolutional Neural Networks (CNNs) from scratch using JAX's low-level convolution operations. We'll cover convolutional layers, pooling operations, and build a complete CNN for image classification using `lax.conv_general_dilated`.

CNNs are fundamental for computer vision tasks, leveraging spatial locality and parameter sharing to efficiently process images and other grid-like data structures.

## Setting Up the Environment

```python
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random, lax
from jax.nn import relu, softmax, log_softmax
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Dict, Any, List
import functools

jax.config.update("jax_enable_x64", True)
print(f"JAX version: {jax.__version__}")
```

## Core Convolution Operations

### Low-Level Convolution Implementation

```python
def conv2d_basic(x, kernel, stride=1, padding='VALID'):
    """Basic 2D convolution using lax.conv_general_dilated"""
    return lax.conv_general_dilated(
        x, kernel,
        window_strides=[stride, stride],
        padding=padding
    )

def test_convolution():
    """Test basic convolution operation"""
    
    # Create test input: batch_size=1, height=5, width=5, channels=1
    x = jnp.ones((1, 5, 5, 1))
    
    # Create 3x3 edge detection kernel
    kernel = jnp.array([
        [[-1, -1, -1],
         [ 0,  0,  0],
         [ 1,  1,  1]]
    ]).reshape(3, 3, 1, 1)  # height, width, in_channels, out_channels
    
    # Apply convolution
    output = conv2d_basic(x, kernel, stride=1, padding='VALID')
    
    print(f"Input shape: {x.shape}")
    print(f"Kernel shape: {kernel.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Output:\n{output[0, :, :, 0]}")
    
    return x, kernel, output

x_test, kernel_test, output_test = test_convolution()
```

### Pooling Operations

```python
def max_pool2d(x, pool_size=2, stride=None):
    """2D max pooling operation"""
    if stride is None:
        stride = pool_size
    
    return lax.reduce_window(
        x, -jnp.inf, lax.max,
        [1, pool_size, pool_size, 1],
        [1, stride, stride, 1],
        'VALID'
    )

def avg_pool2d(x, pool_size=2, stride=None):
    """2D average pooling operation"""
    if stride is None:
        stride = pool_size
    
    pooled = lax.reduce_window(
        x, 0.0, lax.add,
        [1, pool_size, pool_size, 1],
        [1, stride, stride, 1],
        'VALID'
    )
    return pooled / (pool_size * pool_size)

def test_pooling():
    """Test pooling operations"""
    
    # Create test input with pattern
    x = jnp.arange(16).reshape(1, 4, 4, 1).astype(jnp.float32)
    
    max_pooled = max_pool2d(x, pool_size=2, stride=2)
    avg_pooled = avg_pool2d(x, pool_size=2, stride=2)
    
    print("Pooling Operations Test:")
    print(f"Input (4x4):\n{x[0, :, :, 0]}")
    print(f"Max pooled (2x2):\n{max_pooled[0, :, :, 0]}")
    print(f"Avg pooled (2x2):\n{avg_pooled[0, :, :, 0]}")

test_pooling()
```

## CNN Layer Components

### Convolutional Layer Class

```python
class ConvLayer:
    """Convolutional layer implementation"""
    
    def __init__(self, out_channels, kernel_size=3, stride=1, padding='SAME', activation='relu'):
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.activation = relu if activation == 'relu' else lambda x: x
    
    def init_params(self, key, input_shape):
        """Initialize convolution parameters using He initialization"""
        batch_size, height, width, in_channels = input_shape
        kernel_shape = (self.kernel_size, self.kernel_size, in_channels, self.out_channels)
        
        # He initialization
        fan_in = self.kernel_size * self.kernel_size * in_channels
        std = jnp.sqrt(2.0 / fan_in)
        
        W = random.normal(key, kernel_shape) * std
        b = jnp.zeros(self.out_channels)
        
        return {'W': W, 'b': b}
    
    def forward(self, params, x):
        """Forward pass through convolution layer"""
        conv_out = lax.conv_general_dilated(
            x, params['W'],
            window_strides=[self.stride, self.stride],
            padding=self.padding
        )
        
        # Add bias (broadcast across spatial dimensions)
        conv_out = conv_out + params['b']
        
        # Apply activation
        return self.activation(conv_out)
    
    def output_shape(self, input_shape):
        """Calculate output shape after convolution"""
        batch_size, height, width, in_channels = input_shape
        
        if self.padding == 'SAME':
            out_height = height // self.stride
            out_width = width // self.stride
        else:  # VALID
            out_height = (height - self.kernel_size) // self.stride + 1
            out_width = (width - self.kernel_size) // self.stride + 1
        
        return (batch_size, out_height, out_width, self.out_channels)

# Test ConvLayer
def test_conv_layer():
    """Test ConvLayer implementation"""
    
    key = random.PRNGKey(42)
    input_shape = (32, 28, 28, 1)  # Batch of MNIST-like images
    
    conv_layer = ConvLayer(out_channels=16, kernel_size=3, stride=1, padding='SAME')
    params = conv_layer.init_params(key, input_shape)
    
    # Create test input
    x = random.normal(random.split(key)[1], input_shape)
    
    # Forward pass
    output = conv_layer.forward(params, x)
    expected_shape = conv_layer.output_shape(input_shape)
    
    print(f"Conv Layer Test:")
    print(f"Input shape: {x.shape}")
    print(f"Expected output shape: {expected_shape}")
    print(f"Actual output shape: {output.shape}")
    print(f"Shapes match: {output.shape == expected_shape}")
    
    return conv_layer, params

conv_layer, conv_params = test_conv_layer()
```

### Pooling Layer Class

```python
class PoolingLayer:
    """Pooling layer implementation"""
    
    def __init__(self, pool_size=2, stride=None, pool_type='max'):
        self.pool_size = pool_size
        self.stride = stride if stride is not None else pool_size
        self.pool_type = pool_type
    
    def forward(self, x):
        """Forward pass through pooling layer"""
        if self.pool_type == 'max':
            return max_pool2d(x, self.pool_size, self.stride)
        elif self.pool_type == 'avg':
            return avg_pool2d(x, self.pool_size, self.stride)
        else:
            raise ValueError(f"Unknown pooling type: {self.pool_type}")
    
    def output_shape(self, input_shape):
        """Calculate output shape after pooling"""
        batch_size, height, width, channels = input_shape
        out_height = (height - self.pool_size) // self.stride + 1
        out_width = (width - self.pool_size) // self.stride + 1
        return (batch_size, out_height, out_width, channels)

# Test PoolingLayer
def test_pooling_layer():
    """Test PoolingLayer implementation"""
    
    input_shape = (32, 28, 28, 16)
    pool_layer = PoolingLayer(pool_size=2, stride=2, pool_type='max')
    
    x = random.normal(random.PRNGKey(0), input_shape)
    output = pool_layer.forward(x)
    expected_shape = pool_layer.output_shape(input_shape)
    
    print(f"Pooling Layer Test:")
    print(f"Input shape: {x.shape}")
    print(f"Expected output shape: {expected_shape}")
    print(f"Actual output shape: {output.shape}")
    print(f"Shapes match: {output.shape == expected_shape}")

test_pooling_layer()
```

## Complete CNN Implementation

### CNN Architecture

```python
class SimpleCNN:
    """Simple CNN implementation for image classification"""
    
    def __init__(self, num_classes=10):
        self.num_classes = num_classes
        
        # Define network architecture
        self.conv1 = ConvLayer(out_channels=32, kernel_size=3, padding='SAME')
        self.pool1 = PoolingLayer(pool_size=2, stride=2, pool_type='max')
        
        self.conv2 = ConvLayer(out_channels=64, kernel_size=3, padding='SAME')
        self.pool2 = PoolingLayer(pool_size=2, stride=2, pool_type='max')
        
        self.conv3 = ConvLayer(out_channels=128, kernel_size=3, padding='SAME')
        self.pool3 = PoolingLayer(pool_size=2, stride=2, pool_type='max')
    
    def init_params(self, key, input_shape):
        """Initialize all network parameters"""
        keys = random.split(key, 4)  # Need keys for conv layers and dense layers
        params = {}
        
        # Track shape through network
        current_shape = input_shape
        
        # Conv1 + Pool1
        params['conv1'] = self.conv1.init_params(keys[0], current_shape)
        current_shape = self.conv1.output_shape(current_shape)
        current_shape = self.pool1.output_shape(current_shape)
        
        # Conv2 + Pool2
        params['conv2'] = self.conv2.init_params(keys[1], current_shape)
        current_shape = self.conv2.output_shape(current_shape)
        current_shape = self.pool2.output_shape(current_shape)
        
        # Conv3 + Pool3
        params['conv3'] = self.conv3.init_params(keys[2], current_shape)
        current_shape = self.conv3.output_shape(current_shape)
        current_shape = self.pool3.output_shape(current_shape)
        
        # Flatten for dense layer
        flattened_size = current_shape[1] * current_shape[2] * current_shape[3]
        
        # Dense layer for classification
        dense_std = jnp.sqrt(2.0 / flattened_size)
        params['dense'] = {
            'W': random.normal(keys[3], (flattened_size, self.num_classes)) * dense_std,
            'b': jnp.zeros(self.num_classes)
        }
        
        return params
    
    def forward(self, params, x):
        """Forward pass through CNN"""
        
        # Conv1 + Pool1
        x = self.conv1.forward(params['conv1'], x)
        x = self.pool1.forward(x)
        
        # Conv2 + Pool2
        x = self.conv2.forward(params['conv2'], x)
        x = self.pool2.forward(x)
        
        # Conv3 + Pool3  
        x = self.conv3.forward(params['conv3'], x)
        x = self.pool3.forward(x)
        
        # Flatten
        batch_size = x.shape[0]
        x = x.reshape(batch_size, -1)
        
        # Dense layer
        x = x @ params['dense']['W'] + params['dense']['b']
        
        return x
    
    def __call__(self, params, x):
        """Make CNN callable"""
        return self.forward(params, x)

# Test CNN
def test_cnn():
    """Test complete CNN implementation"""
    
    key = random.PRNGKey(123)
    input_shape = (8, 32, 32, 3)  # Small batch of RGB images
    
    cnn = SimpleCNN(num_classes=10)
    params = cnn.init_params(key, input_shape)
    
    # Create test input
    x = random.normal(random.split(key)[1], input_shape)
    
    # Forward pass
    logits = cnn.forward(params, x)
    
    print(f"CNN Test:")
    print(f"Input shape: {x.shape}")
    print(f"Output logits shape: {logits.shape}")
    
    # Test with softmax
    probs = softmax(logits)
    print(f"Output probabilities shape: {probs.shape}")
    print(f"Probabilities sum to 1: {jnp.allclose(jnp.sum(probs, axis=1), 1.0)}")
    
    # Count parameters
    total_params = 0
    for layer_name, layer_params in params.items():
        layer_total = sum(p.size for p in layer_params.values())
        total_params += layer_total
        print(f"{layer_name}: {layer_total:,} parameters")
    
    print(f"Total parameters: {total_params:,}")
    
    return cnn, params

cnn, cnn_params = test_cnn()
```

## Training Implementation

### Loss Functions and Training Loop

```python
def cross_entropy_loss(logits, labels):
    """Cross-entropy loss for classification"""
    log_probs = log_softmax(logits)
    return -jnp.mean(jnp.sum(labels * log_probs, axis=1))

def accuracy(logits, labels):
    """Compute classification accuracy"""
    predicted_class = jnp.argmax(logits, axis=1)
    true_class = jnp.argmax(labels, axis=1)
    return jnp.mean(predicted_class == true_class)

class Adam:
    """Adam optimizer for CNN training"""
    
    def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, eps=1e-8):
        self.learning_rate = learning_rate
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
    
    def init_state(self, params):
        """Initialize optimizer state"""
        return {
            'm': jax.tree_map(jnp.zeros_like, params),
            'v': jax.tree_map(jnp.zeros_like, params),
            'step': 0
        }
    
    def update(self, grads, state, params):
        """Update parameters using Adam"""
        step = state['step'] + 1
        
        m = jax.tree_map(
            lambda m_prev, g: self.beta1 * m_prev + (1 - self.beta1) * g,
            state['m'], grads
        )
        v = jax.tree_map(
            lambda v_prev, g: self.beta2 * v_prev + (1 - self.beta2) * g**2,
            state['v'], grads
        )
        
        # Bias correction
        m_hat = jax.tree_map(lambda m_val: m_val / (1 - self.beta1**step), m)
        v_hat = jax.tree_map(lambda v_val: v_val / (1 - self.beta2**step), v)
        
        # Parameter update
        new_params = jax.tree_map(
            lambda p, m_val, v_val: p - self.learning_rate * m_val / (jnp.sqrt(v_val) + self.eps),
            params, m_hat, v_hat
        )
        
        new_state = {'m': m, 'v': v, 'step': step}
        return new_params, new_state

def train_cnn(cnn, train_data, test_data, num_epochs=10, batch_size=32, learning_rate=0.001):
    """Train CNN on dataset"""
    
    X_train, y_train = train_data
    X_test, y_test = test_data
    
    # Initialize parameters and optimizer
    key = random.PRNGKey(42)
    params = cnn.init_params(key, (batch_size,) + X_train.shape[1:])
    
    optimizer = Adam(learning_rate=learning_rate)
    opt_state = optimizer.init_state(params)
    
    # JIT compile training step
    @jit
    def train_step(params, opt_state, batch_x, batch_y):
        def loss_fn(params):
            logits = cnn(params, batch_x)
            return cross_entropy_loss(logits, batch_y)
        
        loss, grads = jax.value_and_grad(loss_fn)(params)
        new_params, new_opt_state = optimizer.update(grads, opt_state, params)
        return new_params, new_opt_state, loss
    
    # JIT compile evaluation
    @jit 
    def eval_step(params, x, y):
        logits = cnn(params, x)
        loss = cross_entropy_loss(logits, y)
        acc = accuracy(logits, y)
        return loss, acc
    
    # Training loop
    n_train = len(X_train)
    n_batches = n_train // batch_size
    
    history = {'train_losses': [], 'test_losses': [], 'test_accuracies': []}
    
    for epoch in range(num_epochs):
        # Shuffle training data
        perm = random.permutation(key, n_train)
        key = random.split(key)[0]
        
        X_shuffled = X_train[perm]
        y_shuffled = y_train[perm]
        
        # Training batches
        epoch_losses = []
        for i in range(n_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            
            batch_x = X_shuffled[start_idx:end_idx]
            batch_y = y_shuffled[start_idx:end_idx]
            
            params, opt_state, batch_loss = train_step(params, opt_state, batch_x, batch_y)
            epoch_losses.append(batch_loss)
        
        # Record metrics
        avg_train_loss = jnp.mean(jnp.array(epoch_losses))
        test_loss, test_acc = eval_step(params, X_test, y_test)
        
        history['train_losses'].append(avg_train_loss)
        history['test_losses'].append(test_loss)
        history['test_accuracies'].append(test_acc)
        
        print(f"Epoch {epoch+1:2d}: train_loss={avg_train_loss:.4f}, "
              f"test_loss={test_loss:.4f}, test_acc={test_acc:.4f}")
    
    return params, history
```

## Synthetic Dataset Example

### Create and Train on Synthetic Data

```python
def create_synthetic_image_data(key, n_samples=1000, image_size=32, n_classes=5):
    """Create synthetic image classification dataset"""
    
    # Generate random images with class-dependent patterns
    images = []
    labels = []
    
    for class_idx in range(n_classes):
        class_key = random.split(key, n_classes)[class_idx]
        n_class_samples = n_samples // n_classes
        
        for i in range(n_class_samples):
            sample_key = random.split(class_key, n_class_samples)[i]
            
            # Base noise
            img = 0.1 * random.normal(sample_key, (image_size, image_size, 3))
            
            # Add class-specific pattern
            if class_idx == 0:  # Horizontal stripes
                img = img.at[::4, :, 0].set(1.0)
            elif class_idx == 1:  # Vertical stripes  
                img = img.at[:, ::4, 1].set(1.0)
            elif class_idx == 2:  # Checkerboard
                img = img.at[::8, ::8, 2].set(1.0)
                img = img.at[4::8, 4::8, 2].set(1.0)
            elif class_idx == 3:  # Circular pattern
                center = image_size // 2
                y, x = jnp.mgrid[:image_size, :image_size]
                mask = (x - center)**2 + (y - center)**2 < (image_size//4)**2
                img = img.at[mask, :].set([1.0, 1.0, 0.0])
            else:  # Random bright spots
                bright_key = random.split(sample_key)[1]
                bright_locs = random.randint(bright_key, (10, 2), 0, image_size)
                for loc in bright_locs:
                    img = img.at[loc[0], loc[1], :].set(1.0)
            
            images.append(img)
            labels.append(class_idx)
    
    images = jnp.array(images)
    labels = jax.nn.one_hot(jnp.array(labels), n_classes)
    
    return images, labels

# Create synthetic dataset
key = random.PRNGKey(0)
X_synth, y_synth = create_synthetic_image_data(key, n_samples=500, image_size=32, n_classes=5)

# Split train/test
split_idx = int(0.8 * len(X_synth))
X_train = X_synth[:split_idx]
y_train = y_synth[:split_idx]
X_test = X_synth[split_idx:]
y_test = y_synth[split_idx:]

print(f"Synthetic Dataset:")
print(f"Training: {X_train.shape[0]} samples")
print(f"Testing: {X_test.shape[0]} samples")
print(f"Image shape: {X_train.shape[1:]}")

# Train CNN on synthetic data
print(f"\nTraining CNN...")
cnn_synthetic = SimpleCNN(num_classes=5)
trained_params, training_history = train_cnn(
    cnn_synthetic, 
    (X_train, y_train),
    (X_test, y_test),
    num_epochs=15,
    batch_size=16,
    learning_rate=0.001
)

print(f"\nFinal test accuracy: {training_history['test_accuracies'][-1]:.4f}")
```

## Feature Visualization

### Visualize Learned Filters

```python
def visualize_conv_filters(params, layer_name='conv1', max_filters=8):
    """Visualize learned convolutional filters"""
    
    filters = params[layer_name]['W']  # Shape: (H, W, in_channels, out_channels)
    n_filters = min(filters.shape[-1], max_filters)
    
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    axes = axes.ravel()
    
    for i in range(n_filters):
        filter_weights = filters[:, :, 0, i]  # Take first input channel
        
        axes[i].imshow(filter_weights, cmap='RdBu', vmin=-0.5, vmax=0.5)
        axes[i].set_title(f'Filter {i+1}')
        axes[i].axis('off')
    
    plt.suptitle(f'{layer_name.upper()} Learned Filters')
    plt.tight_layout()
    plt.show()

# Visualize learned filters
visualize_conv_filters(trained_params, 'conv1', max_filters=8)
```

## Summary

In this notebook, we've implemented a complete CNN from scratch in JAX:

**Core Components:**

1. **Convolution Operations**: Using `lax.conv_general_dilated` for efficient convolutions
2. **Pooling Operations**: Max and average pooling with `lax.reduce_window`  
3. **Layer Classes**: Modular ConvLayer and PoolingLayer implementations
4. **CNN Architecture**: Complete SimpleCNN with multiple conv-pool blocks

**Key Features:**
- Proper weight initialization (He initialization for ReLU)
- Flexible layer configuration (kernel size, stride, padding)
- Efficient JAX operations with JIT compilation
- Complete training loop with Adam optimizer

**JAX Advantages:**
- Automatic differentiation through convolutions
- JIT compilation for performance
- Functional programming with immutable parameters
- Easy vectorization for batch processing

**Training Insights:**
- CNNs learn hierarchical features (edges → patterns → objects)
- Pooling reduces spatial dimensions while preserving important features
- Proper initialization crucial for training stability
- Batch processing enables efficient GPU utilization

**Next Steps:**
- The next notebook will implement attention mechanisms
- We'll explore self-attention and transformer architectures
- Understanding CNNs provides foundation for modern computer vision

This minimal CNN implementation demonstrates JAX's capability for implementing complex neural architectures while maintaining clarity and performance.