[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Ziaeemehr/workshop_hpcpy/blob/main/notebooks/jax/random_generators.ipynb)

# JAX Random Number Generation

This notebook covers JAX's unique approach to random number generation (RNG), which is different from NumPy's stateful RNG. JAX uses a functional, stateless approach based on splitting PRNG keys.

## Why JAX RNG is Different

JAX's RNG system is:
- **Stateless**: No hidden state
- **Reproducible**: Same key = same result
- **Parallelizable**: Safe for vectorization and parallelization
- **Functional**: Pure functions throughout

In [None]:
import jax
import jax.numpy as jnp
from jax import random
import numpy as np
import matplotlib.pyplot as plt
from functools import partial

# Setup for Google Colab or local environment
import os
import sys

# Check if running on Google Colab
try:
    from google.colab import drive
    IN_COLAB = True
    print("Running on Google Colab")
except ImportError:
    IN_COLAB = False
    print("Running locally")

# Clone repository if on Colab and not already cloned
if IN_COLAB:
    if not os.path.exists('/content/workshop_hpcpy'):
        print("Cloning workshop_hpcpy repository...")
        os.system('git clone https://github.com/Ziaeemehr/workshop_hpcpy.git /content/workshop_hpcpy')
    
    # Change to notebook directory
    os.chdir('/content/workshop_hpcpy/notebooks/jax')
    print(f"Working directory: {os.getcwd()}")

## 1. Creating and Using PRNG Keys

In JAX, you must explicitly pass and split keys. This ensures reproducibility and thread-safety.

In [None]:
# Create a PRNG key with a seed
key = random.PRNGKey(0)
print(f"Key: {key}")
print(f"Key shape: {key.shape}")
print(f"Key dtype: {key.dtype}")

In [None]:
# Using the same key produces the same results
key = random.PRNGKey(42)
print("First call            :", random.normal(key, shape=(3,)))
print("Second call (same key):", random.normal(key, shape=(3,)))

## 2. Splitting Keys

To generate different random numbers, you need to split the key. This is the fundamental operation in JAX's RNG system.

In [None]:
# Split a key into two new keys
key = random.PRNGKey(0)
key1, key2 = random.split(key)

print("Original key:", key)
print("Key 1:", key1)
print("Key 2:", key2)
print("\nRandom numbers:")
print("From key1:", random.normal(key1, shape=(3,)))
print("From key2:", random.normal(key2, shape=(3,)))

In [None]:
# Split into multiple keys at once
key = random.PRNGKey(0)
keys = random.split(key, num=5)
print(f"Split into {len(keys)} keys")
print(f"Keys shape: {keys.shape}")

# Generate random numbers from each key
for i, k in enumerate(keys):
    print(f"Key {i}: {random.normal(k, shape=(2,))}")

## 3. Common Pattern: Sequential Key Usage

A typical pattern is to split and update the key for sequential operations.

In [None]:
# Pattern 1: Split and use
key = random.PRNGKey(0)

key, subkey = random.split(key)
x = random.normal(subkey, shape=(1000,))

key, subkey = random.split(key)
y = random.uniform(subkey, shape=(1000,))

key, subkey = random.split(key)
z = random.exponential(subkey, shape=(1000,))

print(f"x mean: {x.mean():.3f}, std: {x.std():.3f}")
print(f"y mean: {y.mean():.3f}, range: [{y.min():.3f}, {y.max():.3f}]")
print(f"z mean: {z.mean():.3f}")

## 4. Random Distributions

JAX provides many common probability distributions.

In [None]:
key = random.PRNGKey(123)

# Uniform distribution [0, 1)
key, subkey = random.split(key)
uniform_samples = random.uniform(subkey, shape=(1000,))

# Normal (Gaussian) distribution
key, subkey = random.split(key)
normal_samples = random.normal(subkey, shape=(1000,))

# Uniform integers
key, subkey = random.split(key)
randint_samples = random.randint(subkey, shape=(1000,), minval=0, maxval=10)

# Bernoulli (coin flip)
key, subkey = random.split(key)
bernoulli_samples = random.bernoulli(subkey, p=0.7, shape=(1000,))

# Exponential
key, subkey = random.split(key)
exponential_samples = random.exponential(subkey, shape=(1000,))

In [None]:
# Visualize distributions
fig, axes = plt.subplots(2, 3, figsize=(10, 4))
fig.suptitle('JAX Random Distributions', fontsize=16)

axes[0, 0].hist(uniform_samples, bins=30, edgecolor='black', color='skyblue', alpha=0.7)
axes[0, 0].set_title('Uniform [0, 1)')
axes[0, 0].set_xlabel('Value')
axes[0, 0].set_ylabel('Frequency')

axes[0, 1].hist(normal_samples, bins=30, edgecolor='black', color='lightcoral', alpha=0.7)
axes[0, 1].set_title('Normal (0, 1)')
axes[0, 1].set_xlabel('Value')

axes[0, 2].hist(randint_samples, bins=10, edgecolor='black', color='lightgreen', alpha=0.7)
axes[0, 2].set_title('Randint [0, 10)')
axes[0, 2].set_xlabel('Value')

axes[1, 0].hist(bernoulli_samples.astype(int), bins=2, edgecolor='black', color='gold', alpha=0.7)
axes[1, 0].set_title('Bernoulli (p=0.7)')
axes[1, 0].set_xlabel('Value')
axes[1, 0].set_ylabel('Frequency')
axes[1, 0].set_xticks([0, 1])
axes[1, 0].set_xticklabels(['0', '1'])

axes[1, 1].hist(exponential_samples, bins=30, edgecolor='black', color='orchid', alpha=0.7)
axes[1, 1].set_title('Exponential')
axes[1, 1].set_xlabel('Value')

# Additional: Categorical distribution
key, subkey = random.split(key)
categorical_samples = random.categorical(subkey, logits=jnp.array([1.0, 2.0, 3.0]), shape=(1000,))
axes[1, 2].hist(categorical_samples, bins=3, edgecolor='black', color='lightsteelblue', alpha=0.7)
axes[1, 2].set_title('Categorical')
axes[1, 2].set_xlabel('Category')
axes[1, 2].set_xticks([0, 1, 2])
axes[1, 2].set_xticklabels(['0', '1', '2'])

plt.tight_layout()
plt.show()

## 5. Random Sampling Functions

JAX provides functions for common sampling operations.

In [None]:
key = random.PRNGKey(0)

# Permutation: shuffle an array
key, subkey = random.split(key)
arr = jnp.arange(10)
shuffled = random.permutation(subkey, arr)
print(f"Original: {arr}")
print(f"Shuffled: {shuffled}")

# Choice: sample with or without replacement
key, subkey = random.split(key)
samples = random.choice(subkey, arr, shape=(5,), replace=False)
print(f"\nRandom choice (no replacement): {samples}")

key, subkey = random.split(key)
samples = random.choice(subkey, arr, shape=(5,), replace=True)
print(f"Random choice (with replacement): {samples}")

## 6. Working with JAX Transformations

JAX's RNG system works seamlessly with `jit`, `vmap`, and `pmap`.

In [None]:
# Example: JIT-compiled function with randomness
@partial(jax.jit, static_argnums=(2,))
def random_layer(key, x, output_dim):
    """Simple random linear transformation."""
    input_dim = x.shape[-1]
    W_key, b_key = random.split(key)
    W = random.normal(W_key, (input_dim, output_dim))
    b = random.normal(b_key, (output_dim,))
    return jnp.dot(x, W) + b

key = random.PRNGKey(0)
x = jnp.ones((5,))
key, subkey = random.split(key)
output = random_layer(subkey, x, 3)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Output: {output}")

In [None]:
# Example: vmap with random number generation
def generate_sample(key):
    """Generate a single random sample."""
    return random.normal(key, shape=(3,))

# Generate multiple samples in parallel
key = random.PRNGKey(0)
keys = random.split(key, num=5)
samples = jax.vmap(generate_sample)(keys)
print(f"Generated samples shape: {samples.shape}")
print(f"Samples:\n{samples}")

## 7. Practical Example: Dropout

Implementing dropout using JAX's random number generation.

In [None]:
def dropout(key, x, rate=0.5, training=True):
    """
    Apply dropout to input.
    
    Args:
        key: PRNG key
        x: Input array
        rate: Dropout rate (probability of dropping)
        training: Whether in training mode
    """
    if not training or rate == 0.0:
        return x
    
    keep_prob = 1.0 - rate
    mask = random.bernoulli(key, keep_prob, x.shape)
    return jnp.where(mask, x / keep_prob, 0)

# Test dropout
key = random.PRNGKey(0)
x = jnp.ones((10,))

key, subkey = random.split(key)
x_dropped = dropout(subkey, x, rate=0.5, training=True)
print(f"Original: {x}")
print(f"After dropout (rate=0.5): {x_dropped}")
print(f"Fraction kept: {(x_dropped > 0).sum() / len(x)}")

## 8. Practical Example: Mini-batch Sampling

Generating random mini-batches for training.

In [None]:
def get_batches(key, data, batch_size, shuffle=True):
    """
    Generate random mini-batches from data.
    
    Args:
        key: PRNG key
        data: Data array
        batch_size: Size of each batch
        shuffle: Whether to shuffle data
    """
    n = len(data)
    indices = jnp.arange(n)
    
    if shuffle:
        indices = random.permutation(key, indices)
    
    # Generate batches
    batches = []
    for i in range(0, n, batch_size):
        batch_indices = indices[i:i+batch_size]
        batches.append(data[batch_indices])
    
    return batches

# Example usage
key = random.PRNGKey(42)
data = jnp.arange(20)
batches = get_batches(key, data, batch_size=5)

print(f"Original data: {data}")
print(f"\nNumber of batches: {len(batches)}")
for i, batch in enumerate(batches):
    print(f"Batch {i}: {batch}")

## 9. Advanced: Custom Random Functions

Creating custom random number generators.

In [None]:
def truncated_normal(key, shape, lower=-2.0, upper=2.0):
    """
    Generate samples from truncated normal distribution.
    Simple implementation using rejection sampling.
    """
    samples = []
    key_iter = key
    
    while len(samples) < jnp.prod(jnp.array(shape)):
        key_iter, subkey = random.split(key_iter)
        candidate = random.normal(subkey, shape=(1000,))
        valid = candidate[(candidate >= lower) & (candidate <= upper)]
        samples.extend(valid)
    
    return jnp.array(samples[:jnp.prod(jnp.array(shape))]).reshape(shape)

# Test truncated normal
key = random.PRNGKey(0)
samples = truncated_normal(key, shape=(1000,), lower=-1.0, upper=1.0)

plt.figure(figsize=(10, 6))
plt.hist(samples, bins=50, density=True, alpha=0.7, edgecolor='black')
plt.axvline(-1.0, color='r', linestyle='--', label='Truncation bounds')
plt.axvline(1.0, color='r', linestyle='--')
plt.title('Truncated Normal Distribution')
plt.xlabel('Value')
plt.ylabel('Density')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print(f"Mean: {samples.mean():.3f}")
print(f"Std: {samples.std():.3f}")
print(f"Min: {samples.min():.3f}, Max: {samples.max():.3f}")

## 10. Best Practices and Common Pitfalls

### Best Practices
1. Always split keys before generating new random numbers
2. Pass keys explicitly as function arguments
3. Use the pattern: `key, subkey = random.split(key)` for sequential operations
4. Generate multiple keys at once when possible: `random.split(key, num=n)`

### Common Pitfalls

In [None]:
# WRONG: Reusing the same key
key = random.PRNGKey(0)
x1 = random.normal(key, shape=(3,))
x2 = random.normal(key, shape=(3,))  # Same as x1!
print("Reusing key (WRONG):")
print(f"x1: {x1}")
print(f"x2: {x2}")
print(f"Are they equal? {jnp.allclose(x1, x2)}")

# CORRECT: Split the key
key = random.PRNGKey(0)
key, subkey1 = random.split(key)
x1 = random.normal(subkey1, shape=(3,))
key, subkey2 = random.split(key)
x2 = random.normal(subkey2, shape=(3,))
print("\nSplitting key (CORRECT):")
print(f"x1: {x1}")
print(f"x2: {x2}")
print(f"Are they equal? {jnp.allclose(x1, x2)}")

## 11. Exercise: Monte Carlo π Estimation

Estimate π using Monte Carlo sampling with JAX's random number generation.

In [None]:
def estimate_pi(key, n_samples):
    """
    Estimate π using Monte Carlo method.
    Generate random points in [0,1]x[0,1] and count how many fall inside unit circle.
    """
    key_x, key_y = random.split(key)
    x = random.uniform(key_x, shape=(n_samples,))
    y = random.uniform(key_y, shape=(n_samples,))
    
    # Check if points are inside unit circle
    inside_circle = (x**2 + y**2) <= 1.0
    pi_estimate = 4.0 * inside_circle.sum() / n_samples
    
    return pi_estimate, x, y, inside_circle

# Run estimation
key = random.PRNGKey(42)
n_samples = 10000
pi_est, x, y, inside = estimate_pi(key, n_samples)

print(f"Estimated π: {pi_est:.6f}")
print(f"Actual π: {jnp.pi:.6f}")
print(f"Error: {abs(pi_est - jnp.pi):.6f}")

# Visualize
plt.figure(figsize=(5, 5))
plt.scatter(x[inside], y[inside], c='red', s=1, alpha=0.5, label='Inside circle')
plt.scatter(x[~inside], y[~inside], c='blue', s=1, alpha=0.5, label='Outside circle')
circle = plt.Circle((0, 0), 1, fill=False, color='black', linewidth=2)
plt.gca().add_patch(circle)
plt.xlim(0, 1)
plt.ylim(0, 1)
plt.gca().set_aspect('equal')
plt.title(f'Monte Carlo π Estimation: {pi_est:.6f}')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## Summary

Key takeaways:
1. JAX uses **stateless, functional RNG** based on PRNG keys
2. Always **split keys** before generating new random numbers
3. Use pattern: `key, subkey = random.split(key)` for sequential operations
4. JAX RNG is **deterministic** given the same key
5. Works seamlessly with **JIT, vmap, and pmap**
6. Provides many **distributions and sampling functions**

This approach ensures reproducibility, parallelizability, and integration with JAX's transformation system.