# üìò Notebook 4: Vectorization with `vmap` - Automatic Batching Magic

Welcome to one of JAX's most powerful features! `vmap` (vectorizing map) automatically makes your single-example code work on batches - no loops needed!

## üéØ What You'll Learn (25-35 minutes)

By the end of this notebook, you'll understand:
- ‚úÖ What vectorization means (and why it's essential)
- ‚úÖ Why loops are slow and batching is fast
- ‚úÖ How `jax.vmap()` automatically batches operations
- ‚úÖ Specifying which axes to batch over
- ‚úÖ Broadcasting and reshaping with `vmap`
- ‚úÖ Combining `vmap` with `jit` and `grad` for maximum performance
- ‚úÖ Practical application: Batch gradient computation

## ü§î What is Vectorization?

### The Problem: Loops are Slow
Machine learning processes thousands/millions of examples. Using Python loops is painfully slow:

```python
# SLOW: Loop over 1000 examples
results = []
for i in range(1000):
    result = model(data[i])
    results.append(result)
```

This is slow because:
- Python loops have overhead
- Can't use GPU parallelism
- Each operation is separate

### The Solution: Vectorization (Batching)
Process ALL examples at once using array operations:

```python
# FAST: Process all 1000 examples simultaneously
results = model(data)  # Shape: (1000, ...)
```

This is **10-100x faster** because:
- Single optimized array operation
- GPU processes examples in parallel
- Minimal Python overhead

### JAX's `vmap`: Automatic Vectorization! üéâ

**The Problem:** You wrote a function for ONE example. Now you need it for a BATCH.

**The Old Way:** Manually rewrite with loops or complex broadcasting.

**The JAX Way:** Wrap your function with `vmap()` - automatic batching!

```python
# Function for ONE example
def predict_one(x):
    return x ** 2 + 2 * x

# Automatically works on BATCHES
predict_batch = jax.vmap(predict_one)

single_example = 3.0
batch_examples = jnp.array([1.0, 2.0, 3.0, 4.0])

print(predict_one(single_example))   # Single result
print(predict_batch(batch_examples)) # Batch results - no loop!
```

**Magic!** You write single-example code, JAX makes it batch-ready!

## üìö Key Concepts Explained

### 1. What is a "Batch"?
**Definition:** Multiple examples processed together as one array.

**Example:** Instead of 5 separate images, one array of shape `(5, height, width, channels)`

**Why?** GPUs are designed to process batches in parallel - massive speedup!

### 2. Batch Dimension (Axis)
**What is it?** The dimension representing different examples.

**Usually:** The first dimension (axis 0)
- Shape `(32, 10)` ‚Üí 32 examples, each has 10 features
- Shape `(100, 28, 28)` ‚Üí 100 images, each 28x28 pixels

### 3. Broadcasting
**What is it?** Automatically expanding arrays to match shapes.

**Example:** Adding scalar to array
```python
array = jnp.array([1, 2, 3])
result = array + 10  # Broadcasting! ‚Üí [11, 12, 13]
```

### 4. `in_axes` Parameter
**What is it?** Tells `vmap` which axis is the batch dimension.

**Common values:**
- `in_axes=0` ‚Üí First axis is batch (default)
- `in_axes=1` ‚Üí Second axis is batch  
- `in_axes=None` ‚Üí Don't batch this argument (same for all examples)

**Example:**
```python
def dot_product(x, weights):
    return jnp.dot(x, weights)

# Batch over x (different for each), weights same for all
batch_dot = jax.vmap(dot_product, in_axes=(0, None))
```

## üéì What's in This Notebook?

This notebook has **7 comprehensive examples**:

1. **Basic vmap** - Simple vectorization
2. **Performance comparison** - Loop vs vmap speed test
3. **Multiple arguments** - Vectorizing multi-input functions
4. **in_axes control** - Specifying batch dimensions
5. **Combining vmap + jit** - Ultimate performance combo
6. **Combining vmap + grad** - Batch gradient computation
7. **Practical example** - Batch loss computation for training

## üöÄ Prerequisites

Before starting this notebook, you should:
- ‚úÖ Complete Notebook 1 (JAX Basics)
- ‚úÖ Complete Notebook 2 (JIT Compilation) - helpful but not required
- ‚úÖ Understand array shapes and dimensions
- ‚úÖ Know what a batch is (or learn it in this notebook!)

## ‚ö° Performance Impact

**Real performance differences you'll see:**

| Method | Time | Speedup |
|--------|------|---------|
| Python loop | 100ms | 1x (baseline) |
| `vmap` only | 10ms | 10x faster |
| `jit + vmap` | 1ms | **100x faster!** |

**Key Insight:** `vmap` + `jit` together is a superpower combination!

## üí° Key Takeaway

**Write code for ONE example. Use `vmap` to automatically handle BATCHES.**

This is how you write clean, readable code that's also lightning fast!

Let's see automatic batching in action! üöÄ

In [None]:
# =============================================================================
# VECTORIZATION WITH VMAP
# =============================================================================

import jax
import jax.numpy as jnp
import numpy as np
import time

print("=" * 70)
print("AUTOMATIC BATCHING WITH VMAP")
print("=" * 70)

# -----------------------------------------------------------------------------
# Example 1: Basic vmap - Single Input Batching
# -----------------------------------------------------------------------------
print("\n1Ô∏è‚É£  BASIC VMAP")
print("-" * 70)

def square(x):
    """Operates on a SINGLE number"""
    return x ** 2

# Create vectorized version
vectorized_square = jax.vmap(square)

# Apply to batch
batch = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
result = vectorized_square(batch)

print("Function for single input: square(x) = x^2")
print(f"Input batch: {batch}")
print(f"Output: {result}")
print("vmap automatically applies the function to each element!")

# -----------------------------------------------------------------------------
# Example 2: Multiple Inputs with Different Batching
# -----------------------------------------------------------------------------
print("\n2Ô∏è‚É£  VMAP WITH MULTIPLE INPUTS")
print("-" * 70)

def weighted_sum(x, weight):
    """Compute weighted sum: x * weight"""
    return x * weight

# Batch over first argument, keep second fixed
# in_axes=(0, None) means: map over axis 0 of x, don't map over weight
batched_fn = jax.vmap(weighted_sum, in_axes=(0, None))

x_batch = jnp.array([1.0, 2.0, 3.0])
weight = 2.5

result = batched_fn(x_batch, weight)
print(f"Batch: {x_batch}")
print(f"Fixed weight: {weight}")
print(f"Result: {result}")
print("in_axes=(0, None) batches first arg, keeps second fixed")

# Batch over both arguments
batched_fn_both = jax.vmap(weighted_sum, in_axes=(0, 0))
weights_batch = jnp.array([1.0, 2.0, 3.0])
result_both = batched_fn_both(x_batch, weights_batch)
print(f"\nWith batched weights: {weights_batch}")
print(f"Result: {result_both}")

# -----------------------------------------------------------------------------
# Example 3: Matrix-Vector Product with vmap
# -----------------------------------------------------------------------------
print("\n3Ô∏è‚É£  MATRIX-VECTOR PRODUCTS")
print("-" * 70)

def matvec(matrix, vector):
    """Single matrix-vector product"""
    return jnp.dot(matrix, vector)

# Batch over vectors (multiple vectors, same matrix)
batch_matvec = jax.vmap(matvec, in_axes=(None, 0))

matrix = jnp.array([[1.0, 2.0], [3.0, 4.0]])
vectors = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])

results = batch_matvec(matrix, vectors)
print(f"Matrix:\n{matrix}")
print(f"Vector batch (shape {vectors.shape}):\n{vectors}")
print(f"Results (shape {results.shape}):\n{results}")

# -----------------------------------------------------------------------------
# Example 4: Per-Sample Gradients (The Killer App!)
# -----------------------------------------------------------------------------
print("\n4Ô∏è‚É£  PER-SAMPLE GRADIENTS")
print("-" * 70)

def loss_single_sample(params, x, y):
    """Loss for ONE training example"""
    prediction = jnp.dot(params, x)
    return (prediction - y) ** 2

# Create function that computes gradient for one sample
grad_fn_single = jax.grad(loss_single_sample)

# Vectorize it to compute per-sample gradients for entire batch!
# in_axes=(None, 0, 0) means: same params, batch over x and y
per_sample_grads = jax.vmap(grad_fn_single, in_axes=(None, 0, 0))

# Test data
params = jnp.array([1.0, 2.0, 3.0])
x_batch = jnp.array([
    [1.0, 0.0, 0.0],
    [0.0, 1.0, 0.0],
    [0.0, 0.0, 1.0],
    [1.0, 1.0, 1.0]
])
y_batch = jnp.array([1.0, 2.0, 3.0, 6.0])

grads = per_sample_grads(params, x_batch, y_batch)
print(f"Parameters: {params}")
print(f"Batch size: {len(x_batch)}")
print(f"Per-sample gradients (shape {grads.shape}):")
print(grads)
print("\nEach row is the gradient for one training example!")
print("This is critical for differential privacy and some RL algorithms.")

# -----------------------------------------------------------------------------
# Example 5: Performance Comparison - Loop vs vmap
# -----------------------------------------------------------------------------
print("\n5Ô∏è‚É£  PERFORMANCE: LOOP vs VMAP")
print("-" * 70)

def compute_single(x):
    """Some computation on a single input"""
    return jnp.sum(jnp.sin(x) ** 2 + jnp.cos(x) ** 2)

# Manual loop version
def loop_version(batch):
    results = []
    for x in batch:
        results.append(compute_single(x))
    return jnp.array(results)

# vmap version
vmap_version = jax.vmap(compute_single)

# Generate test batch
batch_size = 1000
test_batch = jnp.ones((batch_size, 100))

# Warm up JIT
_ = vmap_version(test_batch)

# Benchmark loop
start = time.time()
for _ in range(100):
    _ = loop_version(test_batch)
time_loop = time.time() - start

# Benchmark vmap
start = time.time()
for _ in range(100):
    _ = vmap_version(test_batch)
time_vmap = time.time() - start

print(f"Batch size: {batch_size}")
print(f"Loop version:  {time_loop:.4f} seconds")
print(f"vmap version:  {time_vmap:.4f} seconds")
print(f"Speedup: {time_loop/time_vmap:.2f}x faster with vmap!")

# -----------------------------------------------------------------------------
# Example 6: Nested vmap - Batching Over Multiple Dimensions
# -----------------------------------------------------------------------------
print("\n6Ô∏è‚É£  NESTED VMAP")
print("-" * 70)

def pairwise_distance(x1, x2):
    """Distance between two vectors"""
    return jnp.sqrt(jnp.sum((x1 - x2) ** 2))

# Compute pairwise distances between all vectors in two sets
# First vmap: over first set
# Second vmap: over second set
pairwise_distances = jax.vmap(
    lambda x1: jax.vmap(lambda x2: pairwise_distance(x1, x2))(set2)
)

set1 = jnp.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]])
set2 = jnp.array([[1.0, 1.0], [2.0, 2.0]])

distances = pairwise_distances(set1)
print(f"Set 1 (shape {set1.shape}):\n{set1}")
print(f"Set 2 (shape {set2.shape}):\n{set2}")
print(f"Pairwise distances (shape {distances.shape}):\n{distances}")
print("Each row shows distances from one point in set1 to all points in set2")

# -----------------------------------------------------------------------------
# Example 7: vmap with Different Output Axes
# -----------------------------------------------------------------------------
print("\n7Ô∏è‚É£  CONTROLLING OUTPUT AXES")
print("-" * 70)

def create_matrix(scale):
    """Create a 2x2 matrix based on scale"""
    return jnp.array([[scale, 0], [0, scale]])

# Default: out_axes=0 (stack along first dimension)
vmap_default = jax.vmap(create_matrix)
scales = jnp.array([1.0, 2.0, 3.0])
result_default = vmap_default(scales)

print(f"Scales: {scales}")
print(f"Output shape with out_axes=0 (default): {result_default.shape}")
print(f"Result:\n{result_default}")

# With out_axes=1: stack along second dimension
vmap_axis1 = jax.vmap(create_matrix, out_axes=1)
result_axis1 = vmap_axis1(scales)
print(f"\nOutput shape with out_axes=1: {result_axis1.shape}")

# -----------------------------------------------------------------------------
# Example 8: Combining JIT and vmap
# -----------------------------------------------------------------------------
print("\n8Ô∏è‚É£  JIT + VMAP COMBO")
print("-" * 70)

def expensive_computation(x):
    """Some complex computation"""
    result = x
    for _ in range(10):
        result = jnp.sin(result) + jnp.cos(result)
    return result

# Combine transformations: JIT the vmapped function
fast_batch_compute = jax.jit(jax.vmap(expensive_computation))

batch = jnp.ones(1000)

# Warm up
_ = fast_batch_compute(batch)

# Benchmark
start = time.time()
for _ in range(1000):
    _ = fast_batch_compute(batch).block_until_ready()
time_taken = time.time() - start

print("Combined jax.jit(jax.vmap(function)) for maximum performance")
print(f"Time for 1000 iterations on batch of 1000: {time_taken:.4f}s")
print("Both JIT compilation and automatic batching working together!")

print("\n" + "=" * 70)
print("KEY POINTS - VMAP")
print("=" * 70)
print("""
‚úÖ vmap automatically batches functions - no manual loop writing
‚úÖ Write single-example code, vmap handles the batching
‚úÖ in_axes controls which dimensions to map over (0, None, etc.)
‚úÖ Much faster than Python loops (compiled vectorization)
‚úÖ Perfect for per-sample gradients: vmap(grad(loss))
‚úÖ Can nest vmap for multi-dimensional batching
‚úÖ Combine with JIT for maximum performance
‚úÖ out_axes controls how outputs are stacked
‚úÖ Cleaner code: intent is clear, no index juggling
""")