# Location: notebooks/05_parallelism/14_pmap_basics.ipynb

## Introduction to pmap: Data Parallelism in JAX

This notebook introduces JAX's `pmap` (parallel map) transformation for data parallelism. We'll learn how to distribute computations across multiple devices and handle device placement efficiently.

## Basic pmap Usage

```python
import jax
import jax.numpy as jnp
from jax import pmap
import numpy as np

# Check available devices
print(f"Available devices: {jax.devices()}")
print(f"Device count: {jax.device_count()}")

# Simple pmap example
def simple_add(x, y):
    return x + y

# Create pmapped version
pmapped_add = pmap(simple_add)

# Create data for multiple devices
n_devices = jax.device_count()
x = jnp.arange(n_devices * 4).reshape(n_devices, 4)
y = jnp.ones((n_devices, 4))

print(f"Input x shape: {x.shape}")
print(f"Input y shape: {y.shape}")

result = pmapped_add(x, y)
print(f"Result shape: {result.shape}")
print(f"Result: {result}")
```

## Matrix Operations with pmap

```python
def matrix_multiply(A, B):
    return jnp.dot(A, B)

pmapped_matmul = pmap(matrix_multiply)

# Create batch of matrices
batch_size = n_devices
A = jax.random.normal(jax.random.PRNGKey(42), (batch_size, 64, 32))
B = jax.random.normal(jax.random.PRNGKey(43), (batch_size, 32, 16))

# Parallel matrix multiplication
result = pmapped_matmul(A, B)
print(f"A shape: {A.shape}")
print(f"B shape: {B.shape}")
print(f"Result shape: {result.shape}")

# Verify correctness
sequential_result = jnp.stack([jnp.dot(A[i], B[i]) for i in range(batch_size)])
print(f"Results match: {jnp.allclose(result, sequential_result)}")
```

## Neural Network Forward Pass with pmap

```python
def init_mlp_params(key, layer_sizes):
    keys = jax.random.split(key, len(layer_sizes))
    params = []
    for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
        w_key, b_key = jax.random.split(keys[i])
        params.append({
            'w': jax.random.normal(w_key, (in_size, out_size)) * 0.1,
            'b': jnp.zeros(out_size)
        })
    return params

def mlp_forward(params, x):
    for layer in params[:-1]:
        x = jnp.tanh(jnp.dot(x, layer['w']) + layer['b'])
    # Final layer (no activation)
    final_layer = params[-1]
    return jnp.dot(x, final_layer['w']) + final_layer['b']

# Create model
layer_sizes = [784, 256, 128, 10]
params = init_mlp_params(jax.random.PRNGKey(0), layer_sizes)

# Create pmapped forward pass
pmapped_forward = pmap(mlp_forward, in_axes=(None, 0))

# Create batch of inputs
batch_size = n_devices
inputs = jax.random.normal(jax.random.PRNGKey(1), (batch_size, 784))

# Parallel forward pass
outputs = pmapped_forward(params, inputs)
print(f"Input shape: {inputs.shape}")
print(f"Output shape: {outputs.shape}")
```

## Handling Device Placement and Replication

```python
# Device placement utilities
def replicate_across_devices(x):
    """Replicate data across all devices"""
    return jnp.broadcast_to(x, (n_devices,) + x.shape)

def shard_across_devices(x, axis=0):
    """Shard data across devices along specified axis"""
    return x.reshape(n_devices, -1, *x.shape[axis+1:])

# Example: replicate parameters, shard data
single_param = jax.random.normal(jax.random.PRNGKey(10), (64, 32))
replicated_params = replicate_across_devices(single_param)

large_data = jax.random.normal(jax.random.PRNGKey(11), (n_devices * 8, 64))
sharded_data = shard_across_devices(large_data)

print(f"Original param shape: {single_param.shape}")
print(f"Replicated param shape: {replicated_params.shape}")
print(f"Original data shape: {large_data.shape}")
print(f"Sharded data shape: {sharded_data.shape}")

# Simple computation with replicated params and sharded data
def compute_with_params(params, data):
    return jnp.sum(data @ params, axis=1)

pmapped_compute = pmap(compute_with_params, in_axes=(0, 0))
result = pmapped_compute(replicated_params, sharded_data)
print(f"Computation result shape: {result.shape}")
```

## Reduction Operations in pmap

```python
from jax import lax

def parallel_sum_and_mean(x):
    # Local computation
    local_sum = jnp.sum(x)
    local_count = x.size
    
    # Cross-device reductions
    global_sum = lax.psum(local_sum, axis_name='devices')
    global_count = lax.psum(local_count, axis_name='devices')
    
    global_mean = global_sum / global_count
    
    return {
        'local_sum': local_sum,
        'global_sum': global_sum,
        'global_mean': global_mean
    }

pmapped_reduction = pmap(parallel_sum_and_mean, axis_name='devices')

# Test data
test_data = jax.random.normal(jax.random.PRNGKey(20), (n_devices, 1000))

results = pmapped_reduction(test_data)
print(f"Local sums: {results['local_sum']}")
print(f"Global sum (all devices): {results['global_sum'][0]}")
print(f"Global mean (all devices): {results['global_mean'][0]}")

# Verify
expected_sum = jnp.sum(test_data)
expected_mean = jnp.mean(test_data)
print(f"Expected sum: {expected_sum}")
print(f"Expected mean: {expected_mean}")
print(f"Sum matches: {jnp.allclose(results['global_sum'][0], expected_sum)}")
print(f"Mean matches: {jnp.allclose(results['global_mean'][0], expected_mean)}")
```

## Training Step with pmap

```python
def loss_fn(params, x, y):
    """Simple MSE loss"""
    pred = mlp_forward(params, x)
    return jnp.mean((pred - y) ** 2)

def training_step(params, x, y, lr=0.01):
    """Single training step with gradient computation"""
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    
    # Update parameters
    new_params = []
    for param, grad in zip(params, grads):
        new_param = {
            'w': param['w'] - lr * grad['w'],
            'b': param['b'] - lr * grad['b']
        }
        new_params.append(new_param)
    
    return new_params, loss

# Create pmapped training step
pmapped_train_step = pmap(training_step, in_axes=(None, 0, 0), axis_name='devices')

# Create training data
train_x = jax.random.normal(jax.random.PRNGKey(30), (n_devices, 784))
train_y = jax.random.normal(jax.random.PRNGKey(31), (n_devices, 10))

# Perform parallel training step
new_params, losses = pmapped_train_step(params, train_x, train_y)

print(f"Losses across devices: {losses}")
print(f"New params type: {type(new_params)}")
print(f"First layer weight shape: {new_params[0]['w'].shape}")

# Average gradients across devices (for synchronous training)
def average_params_across_devices(params):
    """Average parameters across all devices"""
    averaged_params = []
    for layer in params:
        averaged_layer = {
            'w': jnp.mean(layer['w'], axis=0, keepdims=True),
            'b': jnp.mean(layer['b'], axis=0, keepdims=True)
        }
        averaged_params.append(averaged_layer)
    return averaged_params

# This would typically be done after psum in the training step
# averaged_params = average_params_across_devices(new_params)
```

## Performance Comparison

```python
import time

def benchmark_computation(data, n_trials=10):
    """Benchmark sequential vs parallel computation"""
    
    def sequential_computation(x):
        return jnp.sum(x ** 2 + jnp.sin(x) * jnp.cos(x))
    
    def parallel_computation(x):
        return jnp.sum(x ** 2 + jnp.sin(x) * jnp.cos(x))
    
    pmapped_computation = pmap(parallel_computation)
    
    # Sequential timing
    sequential_data = data.reshape(-1, data.shape[-1])
    start_time = time.time()
    for _ in range(n_trials):
        results = [sequential_computation(x) for x in sequential_data]
    sequential_time = (time.time() - start_time) / n_trials
    
    # Parallel timing
    start_time = time.time()
    for _ in range(n_trials):
        result = pmapped_computation(data)
    parallel_time = (time.time() - start_time) / n_trials
    
    return sequential_time, parallel_time

# Benchmark with different data sizes
sizes = [1000, 10000, 100000]
for size in sizes:
    test_data = jax.random.normal(jax.random.PRNGKey(40), (n_devices, size))
    seq_time, par_time = benchmark_computation(test_data)
    speedup = seq_time / par_time
    
    print(f"Data size: {size}")
    print(f"  Sequential time: {seq_time:.4f}s")
    print(f"  Parallel time: {par_time:.4f}s")
    print(f"  Speedup: {speedup:.2f}x")
    print()
```

## Common pmap Patterns and Best Practices

```python
# Pattern 1: Device-specific computations
def device_specific_computation(x, device_id):
    # Use device_id for different behavior per device
    return x * (device_id + 1)

# Get device IDs for pmap
device_ids = jnp.arange(n_devices)
test_data = jnp.ones((n_devices, 10))

pmapped_device_specific = pmap(device_specific_computation)
result = pmapped_device_specific(test_data, device_ids)
print(f"Device-specific results:\n{result}")

# Pattern 2: Conditional execution based on device
def conditional_computation(x, is_lead_device):
    def lead_computation():
        return jnp.sum(x) * 2
    
    def other_computation():
        return jnp.mean(x)
    
    return lax.cond(is_lead_device, lead_computation, other_computation)

lead_flags = jnp.array([i == 0 for i in range(n_devices)])
pmapped_conditional = pmap(conditional_computation)
result = pmapped_conditional(test_data, lead_flags)
print(f"Conditional results: {result}")

# Pattern 3: Efficient data movement
def efficient_data_pattern(large_array):
    # Minimize data movement by keeping computations local
    local_mean = jnp.mean(large_array, axis=0)
    local_std = jnp.std(large_array, axis=0)
    
    # Only communicate necessary statistics
    global_mean = lax.pmean(local_mean, axis_name='devices')
    
    return {
        'local_mean': local_mean,
        'global_mean': global_mean,
        'local_std': local_std
    }

pmapped_efficient = pmap(efficient_data_pattern, axis_name='devices')
large_test_data = jax.random.normal(jax.random.PRNGKey(50), (n_devices, 1000, 64))
results = pmapped_efficient(large_test_data)
print(f"Efficient pattern results shapes:")
print(f"  Local mean: {results['local_mean'].shape}")
print(f"  Global mean: {results['global_mean'].shape}")
```

## Summary

In this notebook, we explored JAX's `pmap` for data parallelism:

**Key Concepts:**
- `pmap` distributes computation across multiple devices
- Data must be sharded along the leading dimension to match device count
- Parameters can be replicated across devices using appropriate `in_axes`
- Collective operations like `psum` and `pmean` enable cross-device communication

**Best Practices:**
- Use `pmap` for embarrassingly parallel problems
- Minimize cross-device communication for better performance
- Replicate parameters and shard data appropriately
- Consider device-specific computations when needed

**Common Patterns:**
- Training steps with gradient averaging
- Batch processing with parameter sharing
- Reduction operations across devices
- Efficient data movement strategies

`pmap` provides a simple but powerful way to scale computations across multiple devices while maintaining JAX's functional programming model and automatic differentiation capabilities.