- This notebook demonstrates advanced JAX concepts:
  - Automatic differentiation (grad, value_and_grad, higher-order derivatives)
  - PyTrees for handling nested data structures
  - tree_map for applying functions over PyTrees
  - Efficiency with vectorized operations and batching
- Includes practical code examples and comments


#### Automatic differentiation

In [None]:
import jax
import jax.numpy as jnp

def f(x):
    return x**2 + 3*x

# jax.grad takes a function and returns a NEW function
dfdx = jax.grad(f)
ddfddx = jax.grad(dfdx) # Second derivative!

print(f"f(2.0):   {f(2.0)}")      # 10.0
print(f"f'(2.0):  {dfdx(2.0)}")   # 2*2 + 3 = 7.0
print(f"f''(2.0): {ddfddx(2.0)}") # 2.0

#### Value and gradient

In [None]:
# Returns tuple: (value, gradient)
val, grad = jax.value_and_grad(f)(2.0)
print(f"Loss: {val}, Grad: {grad}")

#### PyTrees

In [None]:
def mechanics(x):
    return x**2 + 1.0

# If we have a batch of particles, standard Python loops are slow.
batch = jnp.arange(5.0)

# Imagine this is a Neural Network's parameters
params = {
    'layer1': {'w': jnp.array([1, 2]), 'b': jnp.array([0.1])},
    'layer2': {'w': jnp.array([3, 4]), 'b': jnp.array([0.2])}
}

# We want to double every weight. We can't do params * 2.
# We must use tree_map.
doubled_params = jax.tree_util.tree_map(lambda x: x * 2, params)

print(doubled_params['layer1']['w']) # [2, 4]

# ⚡ Crucial Concept: The PyTree & `tree_map`

## 1. What is a PyTree?
In JAX and Flax, a **PyTree** is simply a container of data. It is a fancy name for nested structures we use every day in Python:
* A `list` is a PyTree.
* A `tuple` is a PyTree.
* A `dict` is a PyTree.
* A `dict` of `lists` of `tuples`... is a PyTree.

**Why do we care?**
Flax stores your neural network parameters as a nested dictionary (a PyTree).
```python
params = {
    'layer1': {'w': jnp.array([...]), 'b': jnp.array([...])},
    'layer2': {'w': jnp.array([...]), 'b': jnp.array([...])}
}
```

## 2. The Problem
Standard Python math operators (`+`, `-`, `*`) do not work on dictionaries. You cannot simply multiply your model parameters by 2.

```python
# ❌ This crashes
# params * 2 
# TypeError: unsupported operand type(s) for *: 'dict' and 'int'
```

## 3. The Solution: `jax.tree_util.tree_map`
JAX provides a utility that allows you to apply a function to every "leaf" (array) in the tree, while preserving the structure.

`tree_map(function, tree)` says:
> *"Travel down to every leaf. Apply `function` to that leaf. Rebuild the tree exactly as it was."*

```python
# ✅ The JAX way
# Apply "x * 2" to every array inside the dictionary
doubled_params = jax.tree_util.tree_map(lambda x: x * 2, params)
```

## 4. The "Killer App": Gradient Descent
When you train a model, you get a gradient tree (`grads`) that has the exact same structure as your parameter tree (`params`). To update your weights, you need to subtract the gradients from the parameters leaf-by-leaf.

`tree_map` can take **multiple trees** as input, as long as they have the same structure.

```python
# SGD Update Rule: params = params - lr * grads
updated_params = jax.tree_util.tree_map(
    lambda p, g: p - 0.1 * g,  # The function takes (param, grad)
    params,                    # Tree 1 (p)
    grads                      # Tree 2 (g)
)
```