# Location: notebooks/05_parallelism/16_collectives.ipynb

## Collective Operations in JAX

This notebook explores collective communication operations in JAX, including reductions, all-to-all communication, and synchronization primitives essential for distributed training and parallel algorithms.

## Basic Collective Operations

```python
import jax
import jax.numpy as jnp
from jax import lax, pmap
import numpy as np

# Setup devices
n_devices = jax.device_count()
print(f"Available devices: {n_devices}")

# Create test data
test_data = jax.random.normal(jax.random.PRNGKey(42), (n_devices, 100))
print(f"Test data shape: {test_data.shape}")

def demonstrate_basic_collectives(x):
    """Demonstrate basic collective operations"""
    
    # Local computation
    local_sum = jnp.sum(x)
    local_mean = jnp.mean(x)
    local_max = jnp.max(x)
    local_min = jnp.min(x)
    
    # Cross-device reductions
    global_sum = lax.psum(local_sum, axis_name='devices')
    global_mean = lax.pmean(local_mean, axis_name='devices')  # Average of local means
    global_max = lax.pmax(local_max, axis_name='devices')
    global_min = lax.pmin(local_min, axis_name='devices')
    
    return {
        'local_sum': local_sum,
        'global_sum': global_sum,
        'local_mean': local_mean,
        'global_mean': global_mean,
        'global_max': global_max,
        'global_min': global_min
    }

pmapped_collectives = pmap(demonstrate_basic_collectives, axis_name='devices')
results = pmapped_collectives(test_data)

print("\nCollective Results:")
print(f"Local sums: {results['local_sum']}")
print(f"Global sum (each device): {results['global_sum'][0]}")
print(f"Local means: {results['local_mean']}")
print(f"Global mean: {results['global_mean'][0]}")
print(f"Global max: {results['global_max'][0]}")
print(f"Global min: {results['global_min'][0]}")

# Verify correctness
expected_global_sum = jnp.sum(test_data)
expected_global_max = jnp.max(test_data)
expected_global_min = jnp.min(test_data)

print(f"\nVerification:")
print(f"Sum correct: {jnp.allclose(results['global_sum'][0], expected_global_sum)}")
print(f"Max correct: {jnp.allclose(results['global_max'][0], expected_global_max)}")
print(f"Min correct: {jnp.allclose(results['global_min'][0], expected_global_min)}")
```

## Advanced Reduction Operations

```python
def advanced_reductions(x):
    """More complex reduction patterns"""
    
    # Statistics computation
    local_count = x.size
    local_sum = jnp.sum(x)
    local_sum_squares = jnp.sum(x ** 2)
    
    # Global statistics
    global_count = lax.psum(local_count, axis_name='devices')
    global_sum = lax.psum(local_sum, axis_name='devices')
    global_sum_squares = lax.psum(local_sum_squares, axis_name='devices')
    
    # Compute global mean and variance
    global_mean = global_sum / global_count
    global_variance = (global_sum_squares / global_count) - (global_mean ** 2)
    global_std = jnp.sqrt(global_variance)
    
    # Normalize data using global statistics
    normalized_x = (x - global_mean) / global_std
    
    return {
        'global_mean': global_mean,
        'global_std': global_std,
        'normalized_data': normalized_x,
        'local_stats': {
            'count': local_count,
            'sum': local_sum,
            'sum_squares': local_sum_squares
        }
    }

pmapped_advanced = pmap(advanced_reductions, axis_name='devices')
advanced_results = pmapped_advanced(test_data)

print("Advanced Reduction Results:")
print(f"Global mean: {advanced_results['global_mean'][0]}")
print(f"Global std: {advanced_results['global_std'][0]}")
print(f"Normalized data shape: {advanced_results['normalized_data'].shape}")

# Verify normalized data has zero mean and unit variance
normalized_flat = advanced_results['normalized_data'].reshape(-1)
print(f"Normalized mean ≈ 0: {jnp.abs(jnp.mean(normalized_flat)) < 1e-6}")
print(f"Normalized std ≈ 1: {jnp.abs(jnp.std(normalized_flat) - 1) < 1e-6}")
```

## All-to-All Communication

```python
def all_to_all_example(x):
    """Demonstrate all-to-all communication patterns"""
    
    # All-gather: collect data from all devices
    all_gathered = lax.all_gather(x, axis_name='devices', axis=0)
    
    # Broadcast from device 0 to all devices
    device_id = lax.axis_index(axis_name='devices')
    is_root = device_id == 0
    root_value = lax.cond(is_root, lambda: jnp.sum(x), lambda: 0.0)
    broadcasted = lax.psum(root_value, axis_name='devices')  # Sum gives broadcast effect
    
    # All-reduce with custom operation
    custom_all_reduce = lax.psum(x ** 2, axis_name='devices')
    
    # Reduce-scatter: reduce then distribute pieces
    reduced = lax.psum(x, axis_name='devices')
    scattered = lax.dynamic_slice(reduced, (device_id * (x.shape[0] // n_devices),), (x.shape[0] // n_devices,))
    
    return {
        'original': x,
        'all_gathered': all_gathered,
        'broadcasted': broadcasted,
        'custom_reduced': custom_all_reduce,
        'scattered': scattered,
        'device_id': device_id
    }

pmapped_all_to_all = pmap(all_to_all_example, axis_name='devices')
all_to_all_results = pmapped_all_to_all(test_data)

print("All-to-All Communication Results:")
print(f"Original shape per device: {all_to_all_results['original'].shape}")
print(f"All-gathered shape per device: {all_to_all_results['all_gathered'].shape}")
print(f"Broadcasted value (should be same): {all_to_all_results['broadcasted'][:3]}")
print(f"Device IDs: {all_to_all_results['device_id']}")
print(f"Scattered shape per device: {all_to_all_results['scattered'].shape}")
```

## Gradient Synchronization Patterns

```python
def create_synchronized_training_step():
    """Create training step with proper gradient synchronization"""
    
    def loss_fn(params, x, y):
        # Simple linear model
        pred = params['w'] @ x + params['b']
        return jnp.mean((pred - y) ** 2)
    
    def sync_training_step(params, x_batch, y_batch, lr):
        # Compute local gradients
        local_loss, local_grads = jax.value_and_grad(loss_fn)(params, x_batch, y_batch)
        
        # Method 1: Average gradients across devices
        sync_grads = jax.tree_map(lambda g: lax.pmean(g, axis_name='devices'), local_grads)
        
        # Method 2: Sum gradients and normalize by device count
        # sum_grads = jax.tree_map(lambda g: lax.psum(g, axis_name='devices') / n_devices, local_grads)
        
        # Update parameters with synchronized gradients
        new_params = jax.tree_map(lambda p, g: p - lr * g, params, sync_grads)
        
        # Synchronize loss for monitoring
        avg_loss = lax.pmean(local_loss, axis_name='devices')
        
        return new_params, avg_loss, {
            'local_loss': local_loss,
            'local_grad_norm': jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree_leaves(local_grads))),
            'sync_grad_norm': jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree_leaves(sync_grads)))
        }
    
    return pmap(sync_training_step, axis_name='devices')

# Create synchronized training step
sync_train_step = create_synchronized_training_step()

# Initialize model parameters
params = {
    'w': jax.random.normal(jax.random.PRNGKey(0), (10, 50)),
    'b': jnp.zeros(10)
}

# Create training data (different on each device to simulate real scenario)
keys = jax.random.split(jax.random.PRNGKey(1), n_devices)
x_data = jnp.stack([jax.random.normal(k, (32, 50)) for k in keys])
y_data = jnp.stack([jax.random.normal(jax.random.split(k)[0], (32, 10)) for k in keys])

print("Synchronized Training:")
print(f"Training data shapes: {x_data.shape}, {y_data.shape}")

# Perform training step
new_params, avg_loss, metrics = sync_train_step(params, x_data, y_data, 0.01)

print(f"Average loss: {avg_loss[0]}")
print(f"Local losses: {metrics['local_loss']}")
print(f"Local gradient norms: {metrics['local_grad_norm']}")
print(f"Synchronized gradient norms: {metrics['sync_grad_norm']}")

# Verify that synchronized parameters are identical across devices
w_diff = jnp.max(jnp.abs(new_params['w'] - new_params['w'][0]))
b_diff = jnp.max(jnp.abs(new_params['b'] - new_params['b'][0]))
print(f"Parameter synchronization - W diff: {w_diff}, B diff: {b_diff}")
```

## Custom Collective Operations

```python
def custom_collectives(x):
    """Implement custom collective operations"""
    
    # Custom operation: weighted average based on device ID
    device_id = lax.axis_index(axis_name='devices')
    weight = (device_id + 1) / n_devices  # Weight increases with device ID
    
    weighted_value = x * weight
    weighted_sum = lax.psum(weighted_value, axis_name='devices')
    weight_sum = lax.psum(weight, axis_name='devices')
    weighted_average = weighted_sum / weight_sum
    
    # Custom reduction: geometric mean
    log_values = jnp.log(jnp.abs(x) + 1e-8)  # Avoid log(0)
    sum_log_values = lax.psum(log_values, axis_name='devices')
    geometric_mean = jnp.exp(sum_log_values / n_devices)
    
    # Custom all-reduce: median approximation via sorting
    all_values = lax.all_gather(x, axis_name='devices', axis=0)
    sorted_values = jnp.sort(all_values.reshape(-1))
    approx_median = sorted_values[sorted_values.shape[0] // 2]
    
    # Consensus operation: agree on maximum value's location
    local_max = jnp.max(x)
    global_max = lax.pmax(local_max, axis_name='devices')
    has_global_max = jnp.isclose(local_max, global_max)
    max_device = lax.psum(device_id * has_global_max, axis_name='devices')
    
    return {
        'device_id': device_id,
        'weight': weight,
        'weighted_average': weighted_average,
        'geometric_mean': geometric_mean,
        'approx_median': approx_median,
        'global_max': global_max,
        'max_device': max_device,
        'has_max': has_global_max
    }

pmapped_custom = pmap(custom_collectives, axis_name='devices')
custom_results = pmapped_custom(test_data)

print("Custom Collective Operations:")
print(f"Device weights: {custom_results['weight']}")
print(f"Weighted average: {custom_results['weighted_average'][0]}")
print(f"Geometric mean: {custom_results['geometric_mean'][0]}")
print(f"Approximate median: {custom_results['approx_median'][0]}")
print(f"Global maximum: {custom_results['global_max'][0]}")
print(f"Device with max: {custom_results['max_device'][0]}")
print(f"Devices with max: {custom_results['has_max']}")
```

## Barrier and Synchronization

```python
def synchronization_example(x):
    """Demonstrate synchronization patterns"""
    
    device_id = lax.axis_index(axis_name='devices')
    
    # Phase 1: Local computation with varying time
    # Simulate different computation times per device
    computation_factor = device_id + 1
    local_result = x
    for _ in range(computation_factor):
        local_result = jnp.sin(local_result) + jnp.cos(local_result)
    
    # Implicit barrier: all devices must reach this point
    barrier_value = lax.psum(1, axis_name='devices')  # Count of devices that reached here
    
    # Phase 2: Synchronized computation
    synchronized_input = lax.pmean(local_result, axis_name='devices')
    
    # Phase 3: Leader election for coordination
    # Device 0 performs special computation
    is_leader = device_id == 0
    leader_computation = lax.cond(
        is_leader,
        lambda: jnp.sum(synchronized_input ** 2),
        lambda: 0.0
    )
    
    # Broadcast leader's result to all devices
    shared_result = lax.psum(leader_computation, axis_name='devices')
    
    return {
        'device_id': device_id,
        'computation_factor': computation_factor,
        'local_result_norm': jnp.linalg.norm(local_result),
        'barrier_count': barrier_value,
        'synchronized_norm': jnp.linalg.norm(synchronized_input),
        'shared_result': shared_result,
        'is_leader': is_leader
    }

pmapped_sync = pmap(synchronization_example, axis_name='devices')
sync_results = pmapped_sync(test_data)

print("Synchronization Example:")
print(f"Computation factors: {sync_results['computation_factor']}")
print(f"Local result norms: {sync_results['local_result_norm']}")
print(f"Barrier count: {sync_results['barrier_count'][0]} (should equal {n_devices})")
print(f"Synchronized norms: {sync_results['synchronized_norm']}")
print(f"Shared result: {sync_results['shared_result'][0]}")
print(f"Leaders: {sync_results['is_leader']}")
```

## Hierarchical Communication Patterns

```python
def hierarchical_communication(x):
    """Demonstrate hierarchical reduction patterns"""
    
    device_id = lax.axis_index(axis_name='devices')
    
    # Group devices into pairs/hierarchies
    group_size = 2
    group_id = device_id // group_size
    local_id_in_group = device_id % group_size
    
    # Step 1: Reduce within groups
    # For simplicity, just sum within conceptual groups
    group_sum = lax.psum(x, axis_name='devices')  # This sums across all, but we simulate groups
    
    # Step 2: Simulate hierarchy with conditional operations
    is_group_leader = local_id_in_group == 0
    
    # Group leaders do inter-group communication
    inter_group_value = lax.cond(
        is_group_leader,
        lambda: jnp.mean(x),
        lambda: 0.0
    )
    
    # Sum across group leaders (simulated)
    global_leader_sum = lax.psum(inter_group_value, axis_name='devices')
    
    # Step 3: Broadcast back down hierarchy
    final_result = global_leader_sum  # All devices get the same result
    
    return {
        'device_id': device_id,
        'group_id': group_id,
        'local_id_in_group': local_id_in_group,
        'is_group_leader': is_group_leader,
        'inter_group_value': inter_group_value,
        'final_result': final_result
    }

pmapped_hierarchical = pmap(hierarchical_communication, axis_name='devices')
hier_results = pmapped_hierarchical(test_data)

print("Hierarchical Communication:")
print(f"Device IDs: {hier_results['device_id']}")
print(f"Group IDs: {hier_results['group_id']}")
print(f"Local IDs in group: {hier_results['local_id_in_group']}")
print(f"Group leaders: {hier_results['is_group_leader']}")
print(f"Inter-group values: {hier_results['inter_group_value']}")
print(f"Final result: {hier_results['final_result'][0]}")
```

## Communication-Efficient Training Patterns

```python
def create_communication_efficient_training():
    """Create training with reduced communication overhead"""
    
    def efficient_training_step(params, x_batch, y_batch, lr, step_count, sync_frequency=4):
        # Compute local gradients
        def loss_fn(p, x, y):
            pred = p['w'] @ x + p['b']
            return jnp.mean((pred - y) ** 2)
        
        loss, grads = jax.value_and_grad(loss_fn)(params, x_batch, y_batch)
        
        # Conditional synchronization based on step count
        should_sync = (step_count % sync_frequency) == 0
        
        def sync_grads():
            return jax.tree_map(lambda g: lax.pmean(g, axis_name='devices'), grads)
        
        def keep_local_grads():
            return grads
        
        final_grads = lax.cond(should_sync, sync_grads, keep_local_grads)
        
        # Update parameters
        new_params = jax.tree_map(lambda p, g: p - lr * g, params, final_grads)
        
        # Optionally sync parameters periodically for stability
        def sync_params():
            return jax.tree_map(lambda p: lax.pmean(p, axis_name='devices'), new_params)
        
        def keep_params():
            return new_params
        
        final_params = lax.cond(should_sync, sync_params, keep_params)
        
        return final_params, loss, {
            'should_sync': should_sync,
            'grad_norm': jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree_leaves(grads))),
            'step_count': step_count
        }
    
    return pmap(efficient_training_step, axis_name='devices')

# Create efficient training step
efficient_train_step = create_communication_efficient_training()

# Simulate multiple training steps
params = {
    'w': jax.random.normal(jax.random.PRNGKey(100), (5, 20)),
    'b': jnp.zeros(5)
}

print("Communication-Efficient Training:")
for step in range(8):
    # Generate new data for each step
    keys = jax.random.split(jax.random.PRNGKey(step + 200), n_devices)
    x_batch = jnp.stack([jax.random.normal(k, (16, 20)) for k in keys])
    y_batch = jnp.stack([jax.random.normal(jax.random.split(k)[0], (16, 5)) for k in keys])
    
    params, loss, metrics = efficient_train_step(
        params, x_batch, y_batch, 0.01, step, sync_frequency=3
    )
    
    print(f"Step {step}: Loss={loss[0]:.4f}, Sync={metrics['should_sync'][0]}, "
          f"GradNorm={metrics['grad_norm'][0]:.4f}")

# Check final parameter synchronization
param_sync_diff = jnp.max(jnp.abs(params['w'] - params['w'][0]))
print(f"Final parameter difference across devices: {param_sync_diff:.6f}")
```

## Performance Analysis of Collectives

```python
import time

def benchmark_collective_operations():
    """Benchmark different collective operations"""
    
    # Create test data of different sizes
    sizes = [100, 1000, 10000]
    operations = {}
    
    def create_benchmark_fn(op_name, collective_fn):
        def benchmark_op(x):
            return collective_fn(x)
        return pmap(benchmark_op, axis_name='devices')
    
    # Define collective operations to benchmark
    collective_ops = {
        'psum': lambda x: lax.psum(x, axis_name='devices'),
        'pmean': lambda x: lax.pmean(x, axis_name='devices'), 
        'pmax': lambda x: lax.pmax(x, axis_name='devices'),
        'all_gather': lambda x: lax.all_gather(x, axis_name='devices', axis=0),
    }
    
    results = {}
    
    for size in sizes:
        print(f"\nBenchmarking with array size: {size}")
        test_array = jax.random.normal(jax.random.PRNGKey(300), (n_devices, size))
        
        size_results = {}
        for op_name, op_fn in collective_ops.items():
            benchmark_fn = create_benchmark_fn(op_name, op_fn)
            
            # Warmup
            _ = benchmark_fn(test_array)
            
            # Timing
            n_trials = 10
            start_time = time.time()
            for _ in range(n_trials):
                result = benchmark_fn(test_array)
                if isinstance(result, jax.Array):
                    result.block_until_ready()
            
            avg_time = (time.time() - start_time) / n_trials
            size_results[op_name] = avg_time
            
            print(f"  {op_name:12}: {avg_time:.4f}s")
        
        results[size] = size_results
    
    return results

# Run benchmarks
benchmark_results = benchmark_collective_operations()

# Analyze scaling
print("\nScaling Analysis:")
for op_name in ['psum', 'pmean', 'pmax', 'all_gather']:
    times = [benchmark_results[size][op_name] for size in [100, 1000, 10000]]
    scaling_100_to_1k = times[1] / times[0]
    scaling_1k_to_10k = times[2] / times[1]
    
    print(f"{op_name:12}: 100->1K: {scaling_100_to_1k:.2f}x, 1K->10K: {scaling_1k_to_10k:.2f}x")
```

## Summary

In this notebook, we explored collective communication operations in JAX:

**Basic Collectives:**
- `psum`: Sum across devices
- `pmean`: Average across devices  
- `pmax/pmin`: Maximum/minimum across devices
- `all_gather`: Gather data from all devices

**Advanced Patterns:**
- Gradient synchronization for distributed training
- Custom collective operations for specific algorithms
- Hierarchical communication patterns
- Communication-efficient training strategies

**Key Applications:**
- **Distributed Training**: Gradient averaging and parameter synchronization
- **Statistics Computation**: Global mean, variance, and other statistics
- **Consensus Algorithms**: Agreement on values across devices
- **Load Balancing**: Distributing work based on global information

**Performance Considerations:**
- Communication overhead scales with data size and device count
- Reduce communication frequency when possible
- Use hierarchical patterns for large device counts
- Consider bandwidth and latency trade-offs

**Best Practices:**
- Synchronize gradients, not parameters when possible
- Use `pmean` instead of `psum` + manual division
- Batch collective operations to reduce overhead
- Profile communication patterns to identify bottlenecks

Collective operations are essential for scaling JAX computations across multiple devices while maintaining correctness and efficiency in distributed algorithms.