# üìò Notebook 5: JAX vs PyTorch - A Side-by-Side Comparison

Welcome to the "Rosetta Stone" of deep learning frameworks! This notebook compares JAX and PyTorch directly, showing you how to translate between them.

## üéØ What You'll Learn (25-35 minutes)

By the end of this notebook, you'll understand:
- ‚úÖ Core philosophy differences between JAX and PyTorch
- ‚úÖ How to do the same operations in both frameworks
- ‚úÖ Array creation and manipulation in both
- ‚úÖ Automatic differentiation differences
- ‚úÖ Model definition approaches
- ‚úÖ Training loops in JAX vs PyTorch
- ‚úÖ When to choose JAX vs PyTorch

## ü§î JAX vs PyTorch: What's the Difference?

### PyTorch: The Deep Learning Framework
**Philosophy:** Object-oriented, imperative, mutable state

**Best for:**
- Traditional deep learning workflows
- Researchers who want flexibility
- Projects with complex dynamic architectures
- Quick prototyping with `nn.Module`

**Key features:**
- Built-in neural network layers (`torch.nn`)
- Automatic differentiation with `.backward()`
- Training utilities (optimizers, data loaders)
- Mutable tensors (can modify in-place)

### JAX: The NumPy++ for ML
**Philosophy:** Functional, immutable, composable transformations

**Best for:**
- Research requiring custom operations
- High-performance scientific computing
- Projects needing advanced transformations (vmap, pmap)
- When you want fine-grained control

**Key features:**
- NumPy-compatible API
- Functional transformations (jit, grad, vmap, pmap)
- Immutable arrays (safer, easier to optimize)
- Composability (combine transformations freely)

## üìä Quick Comparison Table

| Feature | JAX | PyTorch |
|---------|-----|---------|
| **Arrays** | `jax.numpy` (immutable) | `torch.Tensor` (mutable) |
| **Gradients** | `jax.grad(fn)` (functional) | `tensor.backward()` (OOP) |
| **JIT Compilation** | `@jax.jit` decorator | `@torch.compile` (PyTorch 2.0+) |
| **Batching** | `jax.vmap` (automatic) | Manual or `torch.vmap` (newer) |
| **Neural Networks** | DIY or Flax/Haiku | `torch.nn.Module` (built-in) |
| **Optimizers** | DIY or Optax | `torch.optim` (built-in) |
| **GPU Support** | Automatic | `.to('cuda')` or `.cuda()` |
| **Parallelism** | `jax.pmap` (powerful) | `torch.distributed` |
| **Learning Curve** | Steeper (functional style) | Gentler (OOP familiar) |
| **Ecosystem** | Growing | Massive |

## üß© Key Philosophical Differences

### 1. Mutability
**PyTorch:** Tensors are mutable (can be modified)
```python
x = torch.tensor([1.0, 2.0])
x[0] = 10.0  # ‚úÖ Works! Modifies in-place
```

**JAX:** Arrays are immutable (cannot be modified)
```python
x = jnp.array([1.0, 2.0])
x[0] = 10.0  # ‚ùå Error! Cannot modify
x = x.at[0].set(10.0)  # ‚úÖ Creates new array
```

### 2. Function vs Method Style
**PyTorch:** Object-oriented (methods on tensors)
```python
loss = ((y_pred - y_true) ** 2).mean()
loss.backward()  # Method call
```

**JAX:** Functional (functions that transform functions)
```python
loss_fn = lambda params: ((predict(params, x) - y) ** 2).mean()
grad_fn = jax.grad(loss_fn)  # Function transformation
grads = grad_fn(params)
```

### 3. Neural Network Definition
**PyTorch:** Class-based with `nn.Module`
```python
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 20)
        
    def forward(self, x):
        return self.fc1(x)
```

**JAX:** Function-based (or use Flax/Haiku)
```python
def mlp(params, x):
    x = jnp.dot(x, params['w1']) + params['b1']
    return x
```

## üéì What's in This Notebook?

This notebook has **10 side-by-side comparisons**:

1. **Array creation** - Making tensors/arrays
2. **Basic operations** - Math, indexing, reshaping
3. **Random numbers** - Random generation differences
4. **Gradients** - Computing derivatives
5. **Simple functions** - Forward + backward pass
6. **Loss functions** - MSE, cross-entropy
7. **Model definition** - Linear layer, MLP
8. **Optimization step** - SGD update
9. **Full training loop** - Complete training example
10. **Performance** - Speed comparison

Each example shows **BOTH frameworks side-by-side** so you can see the translation!

## üöÄ Prerequisites

Before starting this notebook, you should:
- ‚úÖ Complete Notebooks 1-3 (JAX Basics, JIT, Autodiff)
- ‚úÖ Basic Python knowledge
- ‚ùå **Don't need**: Prior PyTorch experience (we explain everything!)

## üí° Which Framework Should You Choose?

**Choose JAX if you:**
- Want fine-grained control over everything
- Need advanced transformations (vmap, pmap)
- Prefer functional programming style
- Are doing research with custom operations
- Want NumPy-like code that's GPU-ready

**Choose PyTorch if you:**
- Want a complete ecosystem (models, datasets, utilities)
- Prefer object-oriented style
- Need extensive community support
- Want to use pre-trained models (torchvision, transformers)
- Are building production applications quickly

**Good news:** After this notebook, you'll understand both! üéâ

## üîÑ Key Takeaway

**JAX and PyTorch solve the same problems with different philosophies:**
- PyTorch: "Here's a complete framework"
- JAX: "Here's NumPy + composable transformations"

Both are excellent! This notebook helps you choose and translate between them.

Let's compare them side by side! üî¨

In [None]:
# =============================================================================
# JAX VS PYTORCH - SIDE-BY-SIDE FEATURE COMPARISON
# =============================================================================

import jax
import jax.numpy as jnp
import torch
import torch.nn as nn
import numpy as np
import time

print("=" * 70)
print("FEATURE COMPARISON: JAX vs PYTORCH")
print("=" * 70)

# -----------------------------------------------------------------------------
# 1. ARRAY/TENSOR CREATION
# -----------------------------------------------------------------------------
print("\n1Ô∏è‚É£  ARRAY/TENSOR CREATION")
print("-" * 70)

# JAX
jax_array = jnp.array([1.0, 2.0, 3.0])
jax_zeros = jnp.zeros((3, 3))
jax_random = jax.random.normal(jax.random.PRNGKey(0), (3, 3))

# PyTorch
torch_tensor = torch.tensor([1.0, 2.0, 3.0])
torch_zeros = torch.zeros((3, 3))
torch_random = torch.randn(3, 3)

print("JAX:")
print(f"  Array: {jax_array}, type: {type(jax_array)}")
print(f"  Random array shape: {jax_random.shape}")
print("\nPyTorch:")
print(f"  Tensor: {torch_tensor}, type: {type(torch_tensor)}")
print(f"  Random tensor shape: {torch_random.shape}")

print("\nüîë Key Difference:")
print("  JAX: Explicit random keys (no global state)")
print("  PyTorch: Global random state (torch.manual_seed)")

# -----------------------------------------------------------------------------
# 2. MUTABILITY
# -----------------------------------------------------------------------------
print("\n2Ô∏è‚É£  MUTABILITY")
print("-" * 70)

# PyTorch - Mutable
torch_x = torch.tensor([1.0, 2.0, 3.0])
torch_x[0] = 10.0
print(f"PyTorch (mutable): {torch_x}")

# JAX - Immutable
jax_x = jnp.array([1.0, 2.0, 3.0])
try:
    jax_x[0] = 10.0  # This will fail
except TypeError as e:
    print(f"JAX (immutable): Error - {e}")
    jax_x_updated = jax_x.at[0].set(10.0)
    print(f"JAX update method: {jax_x_updated}")

print("\nüîë Key Difference:")
print("  JAX: Arrays are immutable (functional style)")
print("  PyTorch: Tensors are mutable (in-place operations)")

# -----------------------------------------------------------------------------
# 3. AUTOMATIC DIFFERENTIATION
# -----------------------------------------------------------------------------
print("\n3Ô∏è‚É£  AUTOMATIC DIFFERENTIATION")
print("-" * 70)

def simple_func(x):
    return x ** 2 + 2 * x + 1

# JAX approach - functional
jax_grad_fn = jax.grad(lambda x: simple_func(x))
x = 3.0
jax_gradient = jax_grad_fn(x)

# PyTorch approach - tensor-based
torch_x = torch.tensor(3.0, requires_grad=True)
torch_output = simple_func(torch_x)
torch_output.backward()
torch_gradient = torch_x.grad

print(f"Function: f(x) = x¬≤ + 2x + 1")
print(f"At x = {x}:")
print(f"  JAX gradient: {jax_gradient}")
print(f"  PyTorch gradient: {torch_gradient.item()}")

print("\nüîë Key Difference:")
print("  JAX: Transform functions with grad()")
print("  PyTorch: Call .backward() on tensors")

# -----------------------------------------------------------------------------
# 4. JIT COMPILATION
# -----------------------------------------------------------------------------
print("\n4Ô∏è‚É£  JIT COMPILATION")
print("-" * 70)

def compute(x):
    for _ in range(10):
        x = jnp.sin(x) + jnp.cos(x)
    return x

# JAX JIT
jax_jit_fn = jax.jit(compute)

# PyTorch JIT (TorchScript)
@torch.jit.script
def torch_compute(x):
    for _ in range(10):
        x = torch.sin(x) + torch.cos(x)
    return x

# Benchmark
test_size = 10000
jax_input = jnp.ones(test_size)
torch_input = torch.ones(test_size)

# Warm up
_ = jax_jit_fn(jax_input).block_until_ready()
_ = torch_compute(torch_input)

# Time JAX
start = time.time()
for _ in range(100):
    _ = jax_jit_fn(jax_input).block_until_ready()
jax_time = time.time() - start

# Time PyTorch
start = time.time()
for _ in range(100):
    _ = torch_compute(torch_input)
torch_time = time.time() - start

print(f"Array size: {test_size}, 100 iterations")
print(f"  JAX JIT:     {jax_time:.4f}s")
print(f"  PyTorch JIT: {torch_time:.4f}s")
print(f"  Ratio: {torch_time/jax_time:.2f}x")

print("\nüîë Key Difference:")
print("  JAX: Built-in with @jax.jit decorator")
print("  PyTorch: TorchScript with @torch.jit.script")

# -----------------------------------------------------------------------------
# 5. BATCHING/VECTORIZATION
# -----------------------------------------------------------------------------
print("\n5Ô∏è‚É£  BATCHING/VECTORIZATION")
print("-" * 70)

def single_example_fn(x):
    """Function designed for single input."""
    return x ** 2 + 2 * x

# JAX - Automatic with vmap
jax_batched = jax.vmap(single_example_fn)
jax_batch = jnp.array([1.0, 2.0, 3.0, 4.0])
jax_result = jax_batched(jax_batch)

# PyTorch - Manual batching or broadcasting
torch_batch = torch.tensor([1.0, 2.0, 3.0, 4.0])
torch_result = single_example_fn(torch_batch)  # Works due to broadcasting

print(f"Batch: {jax_batch}")
print(f"  JAX vmap result: {jax_result}")
print(f"  PyTorch result:  {torch_result}")

print("\nüîë Key Difference:")
print("  JAX: vmap for automatic batching")
print("  PyTorch: Broadcasting handles most cases")

# -----------------------------------------------------------------------------
# 6. GRADIENT OF BATCHED OPERATIONS (Per-Sample Gradients)
# -----------------------------------------------------------------------------
print("\n6Ô∏è‚É£  PER-SAMPLE GRADIENTS")
print("-" * 70)

def loss_fn(x):
    return jnp.sum(x ** 2)

# JAX - Easy with vmap(grad())
jax_per_sample_grad = jax.vmap(jax.grad(loss_fn))
jax_batch_input = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
jax_grads = jax_per_sample_grad(jax_batch_input)

# PyTorch - More complex, need functorch or manual loop
torch_batch_input = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], requires_grad=True)
torch_grads = []
for i in range(len(torch_batch_input)):
    loss = torch.sum(torch_batch_input[i] ** 2)
    grad = torch.autograd.grad(loss, torch_batch_input, retain_graph=True)[0][i]
    torch_grads.append(grad)
torch_grads = torch.stack(torch_grads)

print(f"Batch of 3 samples, 2 features each")
print(f"JAX per-sample gradients:\n{jax_grads}")
print(f"PyTorch per-sample gradients:\n{torch_grads}")

print("\nüîë Key Difference:")
print("  JAX: vmap(grad()) is natural and fast")
print("  PyTorch: Requires functorch or manual loops")

# -----------------------------------------------------------------------------
# 7. RANDOM NUMBER GENERATION
# -----------------------------------------------------------------------------
print("\n7Ô∏è‚É£  RANDOM NUMBER GENERATION")
print("-" * 70)

# JAX - Explicit keys (functional)
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
jax_rand1 = jax.random.normal(subkey, (3,))
key, subkey = jax.random.split(key)
jax_rand2 = jax.random.normal(subkey, (3,))

# PyTorch - Global state
torch.manual_seed(42)
torch_rand1 = torch.randn(3)
torch_rand2 = torch.randn(3)

print("JAX (explicit keys):")
print(f"  Random 1: {jax_rand1}")
print(f"  Random 2: {jax_rand2}")
print("\nPyTorch (global state):")
print(f"  Random 1: {torch_rand1}")
print(f"  Random 2: {torch_rand2}")

print("\nüîë Key Difference:")
print("  JAX: Functional RNG with explicit key splitting")
print("  PyTorch: Global RNG state (easier but less reproducible)")

# -----------------------------------------------------------------------------
# 8. HARDWARE ACCELERATION
# -----------------------------------------------------------------------------
print("\n8Ô∏è‚É£  HARDWARE ACCELERATION")
print("-" * 70)

print("JAX:")
print(f"  Default backend: {jax.default_backend()}")
print(f"  Available devices: {jax.devices()}")

print("\nPyTorch:")
print(f"  CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  Device: {torch.cuda.get_device_name(0)}")
else:
    print(f"  Device: CPU")

print("\nüîë Key Difference:")
print("  JAX: Automatic device placement, explicit device control")
print("  PyTorch: Manual .to(device) for GPU placement")

# -----------------------------------------------------------------------------
# 9. COMPOSABILITY
# -----------------------------------------------------------------------------
print("\n9Ô∏è‚É£  COMPOSABILITY OF TRANSFORMATIONS")
print("-" * 70)

def f(x):
    return jnp.sum(x ** 2)

# JAX - Compose transformations freely
composed = jax.jit(jax.vmap(jax.grad(f)))
batch = jnp.array([[1.0, 2.0], [3.0, 4.0]])
result = composed(batch)

print("JAX: jax.jit(jax.vmap(jax.grad(f)))")
print(f"  Result shape: {result.shape}")
print(f"  Result:\n{result}")

print("\nüîë Key Difference:")
print("  JAX: Free composition of jit, grad, vmap")
print("  PyTorch: Less flexible transformation composition")

# -----------------------------------------------------------------------------
# SUMMARY TABLE
# -----------------------------------------------------------------------------
print("\n" + "=" * 70)
print("SUMMARY: JAX vs PYTORCH")
print("=" * 70)

summary = """
Feature                 | JAX                      | PyTorch
------------------------|--------------------------|---------------------------
Paradigm                | Functional               | Object-oriented
Mutability              | Immutable arrays         | Mutable tensors
Autodiff API            | grad(function)           | tensor.backward()
JIT Compilation         | @jax.jit (XLA)           | @torch.jit.script
Batching                | vmap (automatic)         | Broadcasting (manual)
Per-sample gradients    | vmap(grad()) - easy      | Loops or functorch
Random numbers          | Explicit keys            | Global state
Composability           | High (jit+grad+vmap)     | Moderate
Learning curve          | Steeper (functional)     | Gentler (imperative)
Ecosystem               | Growing (Flax, Optax)    | Mature (torchvision, etc)
Best for                | Research, custom algos   | Production, quick start
"""
print(summary)