# File: notebooks/01_fundamentals/01_arrays_and_prng.ipynb

## JAX Fundamentals: Arrays and PRNG

Welcome to the first notebook in our JAX Neural Science Library (JAX-NSL) series! In this notebook, we'll explore the fundamental building blocks of JAX: arrays and pseudo-random number generation (PRNG). These concepts form the foundation for all scientific computing and machine learning tasks in JAX.

JAX arrays are the core data structure, similar to NumPy arrays but with additional capabilities for automatic differentiation, just-in-time compilation, and parallel execution. JAX's PRNG system provides reproducible randomness that's essential for scientific computing and machine learning.

## Setting Up the Environment

```python
import jax
import jax.numpy as jnp
from jax import random
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Any

# Enable 64-bit precision for numerical stability
jax.config.update("jax_enable_x64", True)

print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")
```

## JAX Arrays Fundamentals

### Array Creation and Basic Operations

```python
# Creating arrays - similar to NumPy but immutable
arr1 = jnp.array([1, 2, 3, 4, 5])
arr2 = jnp.arange(0, 10, 2)  # [0, 2, 4, 6, 8]
arr3 = jnp.linspace(0, 1, 5)  # [0, 0.25, 0.5, 0.75, 1]

print(f"arr1: {arr1}")
print(f"arr2: {arr2}")
print(f"arr3: {arr3}")

# Basic array properties
print(f"\nArray shape: {arr1.shape}")
print(f"Array dtype: {arr1.dtype}")
print(f"Array size: {arr1.size}")
print(f"Array ndim: {arr1.ndim}")
```

### Multi-dimensional Arrays

```python
# Creating 2D arrays
matrix = jnp.array([[1, 2, 3], 
                    [4, 5, 6], 
                    [7, 8, 9]])

# Array creation functions
zeros_mat = jnp.zeros((3, 3))
ones_mat = jnp.ones((2, 4))
identity_mat = jnp.eye(3)
full_mat = jnp.full((2, 3), 7)

print(f"Matrix:\n{matrix}")
print(f"Zeros matrix:\n{zeros_mat}")
print(f"Identity matrix:\n{identity_mat}")
```

### Array Operations and Broadcasting

```python
# Element-wise operations
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])

add_result = a + b
mult_result = a * b
power_result = a ** 2

print(f"Addition: {add_result}")
print(f"Multiplication: {mult_result}")
print(f"Power: {power_result}")

# Broadcasting examples
scalar = 5
broadcast_add = a + scalar
print(f"Broadcasting with scalar: {broadcast_add}")

# Matrix operations
mat_a = jnp.array([[1, 2], [3, 4]])
mat_b = jnp.array([[5, 6], [7, 8]])
mat_mult = jnp.dot(mat_a, mat_b)
print(f"Matrix multiplication:\n{mat_mult}")
```

## JAX PRNG System

### Understanding JAX PRNG Keys

```python
# JAX uses explicit PRNG keys for reproducible randomness
key = random.PRNGKey(42)  # Seed with 42
print(f"PRNG Key: {key}")
print(f"Key shape: {key.shape}")
print(f"Key dtype: {key.dtype}")

# Generate random numbers
random_normal = random.normal(key, shape=(5,))
print(f"Random normal samples: {random_normal}")
```

### Key Splitting and Management

```python
# Key splitting - fundamental concept in JAX
key = random.PRNGKey(0)

# Split key into subkeys
key, subkey1 = random.split(key)
key, subkey2 = random.split(key)

# Generate different random numbers with each subkey
sample1 = random.normal(subkey1, shape=(3,))
sample2 = random.normal(subkey2, shape=(3,))

print(f"Sample 1: {sample1}")
print(f"Sample 2: {sample2}")

# Multiple splits at once
key = random.PRNGKey(123)
subkeys = random.split(key, num=4)
print(f"Multiple subkeys shape: {subkeys.shape}")
```

### Random Sampling Patterns

```python
# Various random distributions
key = random.PRNGKey(42)

# Normal distribution
normal_samples = random.normal(key, shape=(1000,))

# Uniform distribution
key, subkey = random.split(key)
uniform_samples = random.uniform(subkey, shape=(1000,), minval=-1, maxval=1)

# Categorical sampling
key, subkey = random.split(key)
categorical_samples = random.categorical(subkey, jnp.log(jnp.array([0.3, 0.4, 0.3])), shape=(100,))

print(f"Normal samples stats: mean={jnp.mean(normal_samples):.3f}, std={jnp.std(normal_samples):.3f}")
print(f"Uniform samples range: [{jnp.min(uniform_samples):.3f}, {jnp.max(uniform_samples):.3f}]")
print(f"Categorical samples: {jnp.bincount(categorical_samples)}")
```

## Advanced Array Manipulations

### Indexing and Slicing

```python
# Advanced indexing
arr = jnp.arange(20).reshape(4, 5)
print(f"Original array:\n{arr}")

# Basic slicing
print(f"First row: {arr[0, :]}")
print(f"Last column: {arr[:, -1]}")
print(f"Submatrix: {arr[1:3, 2:4]}")

# Boolean indexing
mask = arr > 10
filtered = arr[mask]
print(f"Elements > 10: {filtered}")

# Advanced indexing with arrays
indices = jnp.array([0, 2, 3])
selected_rows = arr[indices]
print(f"Selected rows:\n{selected_rows}")
```

### Array Transformations

```python
# Reshaping and transposing
original = jnp.arange(12)
reshaped = original.reshape(3, 4)
transposed = reshaped.T

print(f"Original: {original}")
print(f"Reshaped:\n{reshaped}")
print(f"Transposed:\n{transposed}")

# Concatenation and stacking
a = jnp.array([[1, 2], [3, 4]])
b = jnp.array([[5, 6], [7, 8]])

concatenated = jnp.concatenate([a, b], axis=0)
stacked = jnp.stack([a, b], axis=0)

print(f"Concatenated:\n{concatenated}")
print(f"Stacked shape: {stacked.shape}")
```

## Practical Examples

### Monte Carlo Integration

```python
def monte_carlo_pi_estimation(key: jax.random.PRNGKey, n_samples: int) -> float:
    """Estimate π using Monte Carlo method"""
    # Generate random points in unit square
    points = random.uniform(key, shape=(n_samples, 2), minval=-1, maxval=1)
    
    # Check if points are inside unit circle
    distances_squared = jnp.sum(points**2, axis=1)
    inside_circle = distances_squared <= 1.0
    
    # Estimate π
    pi_estimate = 4.0 * jnp.mean(inside_circle)
    return pi_estimate

# Run estimation with different sample sizes
key = random.PRNGKey(42)
sample_sizes = [100, 1000, 10000, 100000]

for n in sample_sizes:
    key, subkey = random.split(key)
    pi_est = monte_carlo_pi_estimation(subkey, n)
    error = abs(pi_est - jnp.pi)
    print(f"Samples: {n:6d}, π estimate: {pi_est:.6f}, Error: {error:.6f}")
```

### Random Matrix Generation

```python
def generate_random_matrices(key: jax.random.PRNGKey, 
                           shape: Tuple[int, int], 
                           matrix_type: str = "normal") -> jnp.ndarray:
    """Generate different types of random matrices"""
    
    if matrix_type == "normal":
        return random.normal(key, shape)
    elif matrix_type == "orthogonal":
        # Generate orthogonal matrix via QR decomposition
        A = random.normal(key, shape)
        Q, R = jnp.linalg.qr(A)
        return Q
    elif matrix_type == "symmetric":
        A = random.normal(key, shape)
        return (A + A.T) / 2
    elif matrix_type == "positive_definite":
        A = random.normal(key, shape)
        return jnp.dot(A, A.T)
    else:
        raise ValueError(f"Unknown matrix type: {matrix_type}")

# Generate different matrix types
key = random.PRNGKey(123)
shape = (4, 4)

for mat_type in ["normal", "orthogonal", "symmetric", "positive_definite"]:
    key, subkey = random.split(key)
    mat = generate_random_matrices(subkey, shape, mat_type)
    
    print(f"\n{mat_type.capitalize()} Matrix:")
    print(f"Condition number: {jnp.linalg.cond(mat):.2f}")
    
    if mat_type == "orthogonal":
        # Check orthogonality
        should_be_identity = jnp.dot(mat, mat.T)
        orthogonality_error = jnp.max(jnp.abs(should_be_identity - jnp.eye(shape[0])))
        print(f"Orthogonality error: {orthogonality_error:.2e}")
```

### Array Statistics and Analysis

```python
# Generate sample data
key = random.PRNGKey(0)
data = random.normal(key, shape=(1000, 5))

# Basic statistics
print("Data Statistics:")
print(f"Shape: {data.shape}")
print(f"Mean: {jnp.mean(data, axis=0)}")
print(f"Std: {jnp.std(data, axis=0)}")
print(f"Min: {jnp.min(data, axis=0)}")
print(f"Max: {jnp.max(data, axis=0)}")

# Correlation matrix
correlation_matrix = jnp.corrcoef(data.T)
print(f"\nCorrelation matrix shape: {correlation_matrix.shape}")
print(f"Off-diagonal correlations range: [{jnp.min(correlation_matrix[~jnp.eye(5, dtype=bool)]):.3f}, {jnp.max(correlation_matrix[~jnp.eye(5, dtype=bool)]):.3f}]")
```

## Performance Considerations

### Memory Layout and Efficiency

```python
import time

# Compare different array creation methods
def benchmark_array_ops():
    key = random.PRNGKey(42)
    n = 10000
    
    # Method 1: Direct creation
    start = time.time()
    arr1 = jnp.arange(n)
    end = time.time()
    time1 = end - start
    
    # Method 2: From list
    start = time.time()
    arr2 = jnp.array(list(range(n)))
    end = time.time()
    time2 = end - start
    
    # Method 3: Random generation
    start = time.time()
    arr3 = random.normal(key, shape=(n,))
    end = time.time()
    time3 = end - start
    
    print(f"Direct arange: {time1:.4f}s")
    print(f"From list: {time2:.4f}s")
    print(f"Random generation: {time3:.4f}s")

benchmark_array_ops()
```

### PRNG Key Management Best Practices

```python
class PRNGManager:
    """Helper class for managing PRNG keys"""
    
    def __init__(self, seed: int = 42):
        self.key = random.PRNGKey(seed)
    
    def get_key(self) -> jax.random.PRNGKey:
        """Get a new subkey and update internal state"""
        self.key, subkey = random.split(self.key)
        return subkey
    
    def get_keys(self, n: int) -> jnp.ndarray:
        """Get multiple subkeys"""
        self.key, *subkeys = random.split(self.key, n + 1)
        return jnp.array(subkeys)

# Usage example
prng = PRNGManager(seed=123)

# Generate multiple random arrays with proper key management
arrays = []
for i in range(3):
    key = prng.get_key()
    arr = random.normal(key, shape=(5,))
    arrays.append(arr)
    print(f"Array {i}: {arr}")
```

## Summary

In this notebook, we've covered the fundamental concepts of JAX arrays and PRNG:

**Key Takeaways:**

1. **JAX Arrays**: Immutable, GPU-compatible arrays with NumPy-like interface
2. **PRNG System**: Explicit key management for reproducible randomness
3. **Key Splitting**: Essential pattern for managing random state in functional programming
4. **Array Operations**: Broadcasting, indexing, and transformations work similarly to NumPy
5. **Performance**: JAX arrays are optimized for scientific computing and ML workloads

**Best Practices:**
- Always split PRNG keys before use
- Use appropriate data types (float32 vs float64)
- Leverage vectorization for better performance
- Understand immutability - operations create new arrays

**Next Steps:**
- The next notebook will cover automatic differentiation basics
- We'll build upon these array fundamentals to implement gradient computation
- PRNG patterns will be used throughout for initializing neural networks

These fundamentals form the foundation for all advanced JAX operations including automatic differentiation, just-in-time compilation, and parallel execution that we'll explore in subsequent notebooks.