# Location: notebooks/05_parallelism/15_pjit_and_sharding.ipynb

## Advanced Parallelism with pjit and Sharding

This notebook covers JAX's `pjit` (parallel just-in-time) compilation and advanced sharding strategies for model parallelism, data parallelism, and mixed parallelism patterns.

## Setting up Device Mesh

```python
import jax
import jax.numpy as jnp
from jax.experimental import pjit, PartitionSpec as P
from jax.experimental.maps import mesh
import numpy as np

# Create device mesh
devices = jax.devices()
n_devices = len(devices)
print(f"Available devices: {n_devices}")

# Create 1D mesh for simple sharding
mesh_1d = mesh(devices, ('x',))
print(f"1D Mesh: {mesh_1d}")

# Create 2D mesh if we have enough devices
if n_devices >= 4:
    devices_2d = np.array(devices[:4]).reshape(2, 2)
    mesh_2d = mesh(devices_2d, ('data', 'model'))
    print(f"2D Mesh shape: {devices_2d.shape}")
else:
    # Fallback to 1D mesh
    mesh_2d = mesh(devices, ('data',))
    print("Using 1D mesh as fallback")
```

## Basic pjit with Sharding Specifications

```python
def simple_computation(x, y):
    return jnp.dot(x, y)

# Define sharding specifications
with mesh_1d:
    # Shard first array along first dimension
    # Replicate second array
    # Shard output along first dimension
    pjit_simple = pjit.pjit(
        simple_computation,
        in_axis_resources=(P('x'), P()),
        out_axis_resources=P('x')
    )
    
    # Create test data
    x = jax.random.normal(jax.random.PRNGKey(42), (n_devices * 4, 64))
    y = jax.random.normal(jax.random.PRNGKey(43), (64, 32))
    
    result = pjit_simple(x, y)
    print(f"Input x shape: {x.shape}")
    print(f"Input y shape: {y.shape}")
    print(f"Output shape: {result.shape}")
    print(f"Output sharding: {result.sharding}")
```

## Matrix Operations with Different Sharding Strategies

```python
def matrix_multiply_variants():
    A = jax.random.normal(jax.random.PRNGKey(0), (128, 64))
    B = jax.random.normal(jax.random.PRNGKey(1), (64, 96))
    
    with mesh_1d:
        # Strategy 1: Shard A along rows
        pjit_row_shard = pjit.pjit(
            lambda a, b: jnp.dot(a, b),
            in_axis_resources=(P('x'), P()),
            out_axis_resources=P('x')
        )
        
        result1 = pjit_row_shard(A, B)
        print(f"Row-sharded result shape: {result1.shape}")
        
        # Strategy 2: Shard B along columns
        pjit_col_shard = pjit.pjit(
            lambda a, b: jnp.dot(a, b),
            in_axis_resources=(P(), P(None, 'x')),
            out_axis_resources=P(None, 'x')
        )
        
        result2 = pjit_col_shard(A, B)
        print(f"Column-sharded result shape: {result2.shape}")
        
        # Strategy 3: Replicated computation
        pjit_replicated = pjit.pjit(
            lambda a, b: jnp.dot(a, b),
            in_axis_resources=(P(), P()),
            out_axis_resources=P()
        )
        
        result3 = pjit_replicated(A, B)
        print(f"Replicated result shape: {result3.shape}")
        
        # Verify all results are equivalent
        print(f"Results match: {jnp.allclose(result1, result2) and jnp.allclose(result2, result3)}")

matrix_multiply_variants()
```

## Neural Network Layer with Model Parallelism

```python
def create_parallel_linear_layer(input_size, output_size, mesh_context):
    """Create a parallelized linear layer"""
    
    def init_params(key):
        w_key, b_key = jax.random.split(key)
        w = jax.random.normal(w_key, (input_size, output_size)) * 0.1
        b = jnp.zeros(output_size)
        return {'w': w, 'b': b}
    
    def linear_forward(params, x):
        return jnp.dot(x, params['w']) + params['b']
    
    with mesh_context:
        # Model parallel: split weights along output dimension
        pjit_linear = pjit.pjit(
            linear_forward,
            in_axis_resources=({
                'w': P(None, 'x'),  # Shard weights along output dim
                'b': P('x')         # Shard bias along output dim  
            }, P()),                # Replicate input
            out_axis_resources=P(None, 'x')  # Output sharded along feature dim
        )
    
    return init_params, pjit_linear

# Create parallel linear layer
init_fn, parallel_linear = create_parallel_linear_layer(256, 512, mesh_1d)

# Initialize parameters
params = init_fn(jax.random.PRNGKey(10))
print(f"Weight shape: {params['w'].shape}")
print(f"Bias shape: {params['b'].shape}")

# Test forward pass
x = jax.random.normal(jax.random.PRNGKey(11), (32, 256))
output = parallel_linear(params, x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
```

## Multi-Layer Network with Mixed Parallelism

```python
def create_mlp_with_sharding(layer_sizes, mesh_context):
    """Create MLP with different sharding strategies per layer"""
    
    def init_mlp(key):
        keys = jax.random.split(key, len(layer_sizes) - 1)
        params = []
        for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
            w_key, b_key = jax.random.split(keys[i])
            params.append({
                'w': jax.random.normal(w_key, (in_size, out_size)) * 0.1,
                'b': jnp.zeros(out_size)
            })
        return params
    
    def mlp_forward(params, x):
        # First layer: model parallel
        x = jnp.dot(x, params[0]['w']) + params[0]['b']
        x = jax.nn.relu(x)
        
        # Middle layers: data parallel
        for i in range(1, len(params) - 1):
            x = jnp.dot(x, params[i]['w']) + params[i]['b']
            x = jax.nn.relu(x)
        
        # Final layer: model parallel
        if len(params) > 1:
            x = jnp.dot(x, params[-1]['w']) + params[-1]['b']
        
        return x
    
    with mesh_context:
        # Mixed sharding strategy
        if len(layer_sizes) == 3:  # Simple case: input -> hidden -> output
            param_sharding = [
                {'w': P(None, 'x'), 'b': P('x')},      # First layer: model parallel
                {'w': P('x', None), 'b': P()}           # Last layer: input parallel
            ]
        else:
            # More layers
            param_sharding = [
                {'w': P(None, 'x'), 'b': P('x')}       # Model parallel for all
                for _ in range(len(layer_sizes) - 1)
            ]
        
        pjit_mlp = pjit.pjit(
            mlp_forward,
            in_axis_resources=(param_sharding, P()),    # Params sharded, input replicated
            out_axis_resources=P(None, 'x')             # Output model parallel
        )
    
    return init_mlp, pjit_mlp

# Create MLP with mixed parallelism
layer_sizes = [784, 512, 256, 10]
init_mlp, parallel_mlp = create_mlp_with_sharding(layer_sizes, mesh_1d)

# Initialize and test
mlp_params = init_mlp(jax.random.PRNGKey(20))
test_input = jax.random.normal(jax.random.PRNGKey(21), (16, 784))

mlp_output = parallel_mlp(mlp_params, test_input)
print(f"MLP input shape: {test_input.shape}")
print(f"MLP output shape: {mlp_output.shape}")
```

## Advanced Sharding: 2D Parallelism

```python
# Only run if we have enough devices for 2D mesh
if n_devices >= 4:
    with mesh_2d:
        def advanced_matmul(A, B):
            return jnp.dot(A, B)
        
        # 2D sharding: shard A along both dimensions, B along first dimension
        pjit_2d = pjit.pjit(
            advanced_matmul,
            in_axis_resources=(P('data', 'model'), P('model', None)),
            out_axis_resources=P('data', None)
        )
        
        # Create larger matrices for 2D sharding
        A_large = jax.random.normal(jax.random.PRNGKey(30), (256, 256))
        B_large = jax.random.normal(jax.random.PRNGKey(31), (256, 128))
        
        result_2d = pjit_2d(A_large, B_large)
        print(f"2D sharded result shape: {result_2d.shape}")
        
        # Compare with sequential computation
        sequential_result = jnp.dot(A_large, B_large)
        print(f"Results match: {jnp.allclose(result_2d, sequential_result, rtol=1e-5)}")
```

## Training Step with Gradient Sharding

```python
def create_sharded_training_step(mesh_context):
    """Create training step with gradient accumulation and sharding"""
    
    def loss_fn(params, x, y):
        pred = jnp.dot(x, params['w']) + params['b']
        return jnp.mean((pred - y) ** 2)
    
    def training_step(params, x, y, lr):
        loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
        
        # Update parameters
        new_params = {
            'w': params['w'] - lr * grads['w'],
            'b': params['b'] - lr * grads['b']
        }
        
        return new_params, loss
    
    with mesh_context:
        # Shard parameters and gradients
        param_spec = {'w': P('x', None), 'b': P()}
        
        pjit_train_step = pjit.pjit(
            training_step,
            in_axis_resources=(param_spec, P(), P(), P()),  # params, x, y, lr
            out_axis_resources=(param_spec, P())            # new_params, loss
        )
    
    return pjit_train_step

# Create sharded training step
sharded_train_step = create_sharded_training_step(mesh_1d)

# Initialize training data
train_params = {
    'w': jax.random.normal(jax.random.PRNGKey(40), (256, 64)),
    'b': jnp.zeros(64)
}
train_x = jax.random.normal(jax.random.PRNGKey(41), (32, 256))
train_y = jax.random.normal(jax.random.PRNGKey(42), (32, 64))
learning_rate = 0.01

# Perform training step
new_params, loss = sharded_train_step(train_params, train_x, train_y, learning_rate)
print(f"Training loss: {loss}")
print(f"Parameter update successful: {not jnp.allclose(train_params['w'], new_params['w'])}")
```

## Dynamic Sharding and Resharding

```python
def demonstrate_resharding():
    """Show how to change sharding patterns dynamically"""
    
    with mesh_1d:
        # Create data with one sharding pattern
        x = jax.random.normal(jax.random.PRNGKey(50), (n_devices * 8, 64))
        
        # Function that expects row-sharded input
        pjit_row_op = pjit.pjit(
            lambda x: jnp.sum(x, axis=1),
            in_axis_resources=P('x'),
            out_axis_resources=P('x')
        )
        
        # Function that expects column-sharded input (requires resharding)
        pjit_col_op = pjit.pjit(
            lambda x: jnp.sum(x, axis=0),
            in_axis_resources=P(None, 'x'),
            out_axis_resources=P('x')
        )
        
        # Apply row operation (no resharding needed)
        result1 = pjit_row_op(x)
        print(f"Row operation result shape: {result1.shape}")
        
        # Apply column operation (will trigger automatic resharding)
        result2 = pjit_col_op(x)
        print(f"Column operation result shape: {result2.shape}")
        
        # Explicit resharding using pjit
        reshard_fn = pjit.pjit(
            lambda x: x,  # Identity function
            in_axis_resources=P('x'),      # Input: row-sharded
            out_axis_resources=P(None, 'x')  # Output: column-sharded
        )
        
        x_resharded = reshard_fn(x)
        print(f"Original sharding: row-wise")
        print(f"After resharding: column-wise")

demonstrate_resharding()
```

## Performance Analysis and Optimization

```python
import time

def benchmark_sharding_strategies():
    """Compare performance of different sharding strategies"""
    
    # Test data
    A = jax.random.normal(jax.random.PRNGKey(60), (512, 512))
    B = jax.random.normal(jax.random.PRNGKey(61), (512, 512))
    
    strategies = {
        'replicated': (P(), P(), P()),
        'row_sharded': (P('x'), P(), P('x')),
        'col_sharded': (P(), P(None, 'x'), P(None, 'x')),
    }
    
    def matmul(a, b):
        return jnp.dot(a, b)
    
    results = {}
    
    with mesh_1d:
        for name, (in1_spec, in2_spec, out_spec) in strategies.items():
            pjit_fn = pjit.pjit(
                matmul,
                in_axis_resources=(in1_spec, in2_spec),
                out_axis_resources=out_spec
            )
            
            # Warmup
            _ = pjit_fn(A, B)
            
            # Timing
            n_trials = 5
            start_time = time.time()
            for _ in range(n_trials):
                result = pjit_fn(A, B)
                result.block_until_ready()  # Ensure completion
            
            avg_time = (time.time() - start_time) / n_trials
            results[name] = avg_time
            
            print(f"{name:15}: {avg_time:.4f}s")
    
    return results

performance_results = benchmark_sharding_strategies()

# Find best strategy
best_strategy = min(performance_results.items(), key=lambda x: x[1])
print(f"\nBest strategy: {best_strategy[0]} ({best_strategy[1]:.4f}s)")
```

## Memory-Efficient Large Model Patterns

```python
def create_memory_efficient_layer(input_size, output_size, mesh_context):
    """Create layer optimized for memory efficiency"""
    
    def efficient_linear(w, x):
        # Use checkpoint to save memory during backprop
        @jax.checkpoint
        def chunked_computation(w_chunk, x):
            return jnp.dot(x, w_chunk)
        
        return chunked_computation(w, x)
    
    with mesh_context:
        # Shard weights to distribute memory load
        pjit_efficient = pjit.pjit(
            efficient_linear,
            in_axis_resources=(P('x', None), P()),
            out_axis_resources=P(None, 'x')
        )
    
    return pjit_efficient

# Create memory-efficient layer
efficient_layer = create_memory_efficient_layer(1024, 2048, mesh_1d)

# Test with large tensors
large_w = jax.random.normal(jax.random.PRNGKey(70), (1024, 2048))
large_x = jax.random.normal(jax.random.PRNGKey(71), (64, 1024))

efficient_output = efficient_layer(large_w, large_x)
print(f"Efficient layer output shape: {efficient_output.shape}")

# Gradient computation test
grad_fn = jax.grad(lambda w, x: jnp.sum(efficient_layer(w, x)))
grads = grad_fn(large_w, large_x)
print(f"Gradients computed successfully: {grads.shape}")
```

## Best Practices and Common Patterns

```python
# Pattern 1: Conditional sharding based on array size
def adaptive_sharding(x, threshold=1000):
    """Use different sharding based on array size"""
    if x.size > threshold:
        # Large arrays: use sharding
        with mesh_1d:
            pjit_fn = pjit.pjit(
                lambda arr: jnp.sum(arr ** 2),
                in_axis_resources=P('x'),
                out_axis_resources=P()
            )
    else:
        # Small arrays: replicate
        with mesh_1d:
            pjit_fn = pjit.pjit(
                lambda arr: jnp.sum(arr ** 2),
                in_axis_resources=P(),
                out_axis_resources=P()
            )
    
    return pjit_fn(x)

# Test adaptive sharding
small_array = jax.random.normal(jax.random.PRNGKey(80), (100,))
large_array = jax.random.normal(jax.random.PRNGKey(81), (10000,))

result_small = adaptive_sharding(small_array)
result_large = adaptive_sharding(large_array)
print(f"Small array result: {result_small}")
print(f"Large array result: {result_large}")

# Pattern 2: Pipeline parallelism simulation
def pipeline_stages(x):
    """Simulate pipeline parallelism with different sharding per stage"""
    
    with mesh_1d:
        # Stage 1: Data parallel processing
        stage1 = pjit.pjit(
            lambda x: jax.nn.relu(x),
            in_axis_resources=P('x'),
            out_axis_resources=P('x')
        )
        
        # Stage 2: Model parallel processing (reshape first)
        stage2 = pjit.pjit(
            lambda x: jnp.mean(x, axis=0),
            in_axis_resources=P('x'),
            out_axis_resources=P()
        )
        
        # Execute pipeline
        x = stage1(x)
        x = stage2(x)
        
        return x

# Test pipeline
pipeline_input = jax.random.normal(jax.random.PRNGKey(90), (n_devices * 16, 64))
pipeline_output = pipeline_stages(pipeline_input)
print(f"Pipeline output shape: {pipeline_output.shape}")
```

## Summary

In this notebook, we explored advanced parallelism with `pjit` and sharding:

**Key Concepts:**
- `pjit` enables fine-grained control over data and computation placement
- `PartitionSpec` defines how arrays are sharded across device dimensions
- Device meshes organize devices for different parallelism strategies
- Automatic resharding handles changes in sharding patterns

**Sharding Strategies:**
- **Data Parallelism**: Shard data, replicate parameters
- **Model Parallelism**: Shard parameters, replicate data  
- **Mixed Parallelism**: Combine data and model parallelism
- **2D Parallelism**: Use multiple mesh dimensions for complex patterns

**Best Practices:**
- Choose sharding strategy based on memory constraints and computation patterns
- Use explicit resharding for better control over data movement
- Consider memory efficiency with gradient checkpointing
- Profile different strategies to find optimal performance

**Advanced Patterns:**
- Adaptive sharding based on array sizes
- Pipeline parallelism with different sharding per stage
- Memory-efficient patterns for large models
- Dynamic resharding for complex workflows

`pjit` provides the most flexible and powerful parallelism capabilities in JAX, enabling efficient scaling from single devices to large clusters while maintaining automatic differentiation and compilation benefits.