In [1]:
%unload_ext fincantatem
%load_ext fincantatem

The fincantatem extension is not loaded.


## Example: Shape Mismatch in Neural Network Forward Pass

This example demonstrates how the JAX integration helps debug shape mismatches.
The error message alone says "incompatible shapes" - but **which** arrays? **What** shapes?

With the JAX context, you immediately see all array shapes in each frame.


In [2]:
import jax
import jax.numpy as jnp


def mlp_forward(params, x):
    """Forward pass through a 2-layer MLP."""
    w1, b1, w2, b2 = params

    # Hidden layer
    h = jnp.dot(x, w1) + b1
    h = jax.nn.relu(h)

    # Output layer
    out = jnp.dot(h, w2) + b2
    return out


w1 = jnp.ones((128, 784))
b1 = jnp.zeros((128,))
w2 = jnp.ones((128, 10))
b2 = jnp.zeros((10,))

params = (w1, b1, w2, b2)
batch = jnp.ones((32, 784))

mlp_forward(params, batch)

TypeError: dot_general requires contracting dimensions to have the same shape, got (784,) and (128,).

# TL;DR

**Problem**: Matrix dimension mismatch in `jnp.dot(x, w1)`. You're trying to multiply `x` (32, 784) with `w1` (128, 784), but for matrix multiplication the inner dimensions must match.

**Fix**: Transpose `w1` to shape (784, 128):
```python
h = jnp.dot(x, w1.T) + b1  # or h = jnp.dot(x, w1.T) + b1
```

Or initialize `w1` with the correct shape from the start:
```python
w1 = jnp.ones((784, 128))  # Change from (128, 784)
```

---

# Detailed Analysis

## What Went Wrong

The error occurs in the `mlp_forward` function at line 10:
```python
h = jnp.dot(x, w1) + b1
```

### The Issue

You're attempting to perform a matrix multiplication between:
- `x`: shape **(32, 784)** - batch of 32 samples, each with 784 features
- `w1`: shape **(128, 784)** - weight matrix

JAX's `jnp.dot()` function (for 2D arrays) performs standard matrix multiplication, which requires:
- For `A @ B`, the last dimension of `A` must match the first dimension of `B`
- Here: `(32, 784) @ (128, 784)` attempts to contract dimension 784 from `x` with dimension 128 from `w1`

### What the Error Message Tells Us

From the error details:
```
lhs = ShapedArray(float32[32,784])
rhs = ShapedArray(float32[128,784])
dimension_numbers = (((1,), (0,)), ((), ()))
lhs_contracting_shape = (784,)
rhs_contracting_shape = (128,)
```

JAX is trying to contract:
- Dimension 1 of `lhs` (size 784)
- Dimension 0 of `rhs` (size 128)

These don't match: **784 ≠ 128**, hence the error.

## Why Your Code Has This Shape

Looking at your initialization:
```python
w1 = jnp.ones((128, 784))  # Hidden layer weights
```

For a neural network layer transforming 784 input features to 128 hidden units, the weight matrix should be **(784, 128)**, not (128, 784). This way:
- Input: (batch_size, 784)
- Weights: (784, 128)
- Output: (batch_size, 128)

## Solutions

### Option 1: Fix the Initialization (Recommended)
```python
w1 = jnp.ones((784, 128))  # Shape: (input_dim, hidden_dim)
b1 = jnp.zeros(128)
w2 = jnp.ones((128, 10))   # Shape: (hidden_dim, output_dim)
b2 = jnp.zeros(10)
```

### Option 2: Transpose During Forward Pass
```python
h = jnp.dot(x, w1.T) + b1  # Transpose w1 from (128, 784) to (784, 128)
```

### Complete Fixed Code
```python
def mlp_forward(params, x):
    """Forward pass through a 2-layer MLP."""
    w1, b1, w2, b2 = params
    
    # Hidden layer
    h = jnp.dot(x, w1) + b1  # (32, 784) @ (784, 128) = (32, 128)
    h = jax.nn.relu(h)
    
    # Output layer
    y = jnp.dot(h, w2) + b2  # (32, 128) @ (128, 10) = (32, 10)
    return y

# Correct initialization
w1 = jnp.ones((784, 128))
b1 = jnp.zeros(128)
w2 = jnp.ones((128, 10))
b2 = jnp.zeros(10)
params = (w1, b1, w2, b2)
batch = jnp.ones((32, 784))

mlp_forward(params, batch)  # Now works!
```

## Key Takeaway

In neural networks, weight matrices should have shape `(input_features, output_features)` to allow standard matrix multiplication: `output = input @ weights + bias`.

## Example: vmap Batch Dimension Error

When using `vmap`, errors can be cryptic. The transformation context and array shapes help identify:

- Which transformation is active
- Whether arrays are tracers
- The actual vs expected shapes


In [5]:
import jax.numpy as jnp
import pickle
import jax


def pairwise_distance(x, y):
    """Compute distance between two points."""
    return jnp.sqrt(jnp.sum((x - y) ** 2))


with open("queries.pkl", "rb") as f:
    queries = pickle.load(f)

with open("references.pkl", "rb") as f:
    references = pickle.load(f)

batched_distance = jax.vmap(pairwise_distance, in_axes=(0, 0))

batched_distance(queries, references)

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * one axis had size 100: axis 0 of argument x of type float32[100,3];
  * one axis had size 50: axis 0 of argument y of type float32[50,3]

# TL;DR

**Problem:** You're calling `jax.vmap` with `in_axes=(0, 0)` on two arrays with different sizes along axis 0: `queries` has shape `(100, 3)` and `references` has shape `(50, 3)`. JAX's `vmap` requires all mapped axes to have the same size when using the same axis specification.

**Fix:** Change your `vmap` usage based on your intent:
- For **broadcasting** (compute all pairwise distances): Use `in_axes=(0, None)` and `in_axes=(None, 0)` with nested `vmap`
- For **element-wise** distances: Ensure both arrays have the same first dimension (both 100 or both 50)

---

# Detailed Analysis

## What Went Wrong

The error occurs at line 19 of your code:
```python
batched_distance(queries, references)
```

Where:
- `queries` has shape `(100, 3)` 
- `references` has shape `(50, 3)`
- `batched_distance = jax.vmap(pairwise_distance, in_axes=(0, 0))`

### Understanding `vmap` with `in_axes=(0, 0)`

When you specify `in_axes=(0, 0)`, you're telling JAX to:
1. Map over axis 0 of the first argument (`x`)
2. Map over axis 0 of the second argument (`y`)
3. Apply the function element-wise to corresponding pairs

This is similar to a zip operation - JAX expects to pair up `queries[0]` with `references[0]`, `queries[1]` with `references[1]`, etc. For this to work, **both axes must have the same size**.

Since you have 100 queries but only 50 references, JAX cannot pair them up and raises a `ValueError`.

## Solution Options

### Option 1: Compute All Pairwise Distances (Most Likely Intent)

If you want to compute the distance between **every** query and **every** reference (resulting in a 100×50 matrix):

```python
# Double vmap: outer over queries, inner over references
batched_distance = jax.vmap(
    lambda q: jax.vmap(lambda r: pairwise_distance(q, r))(references)
)
result = batched_distance(queries)  # Shape: (100, 50)

# Or more concisely:
batched_distance = jax.vmap(
    jax.vmap(pairwise_distance, in_axes=(None, 0)),
    in_axes=(0, None)
)
result = batched_distance(queries, references)  # Shape: (100, 50)
```

### Option 2: Element-wise Distances

If you want element-wise distances (query[i] with reference[i]), make sure both arrays have the same size:

```python
queries = jnp.ones((50, 3))  # Match the size
references = jnp.ones((50, 3))

batched_distance = jax.vmap(pairwise_distance, in_axes=(0, 0))
result = batched_distance(queries, references)  # Shape: (50,)
```

### Option 3: Broadcast One Array

If you want to compute distances from all queries to a **single** reference (or vice versa):

```python
# All queries to first reference
batched_distance = jax.vmap(pairwise_distance, in_axes=(0, None))
result = batched_distance(queries, references[0])  # Shape: (100,)
```

## Key Takeaway

The `in_axes` parameter controls how arrays are mapped:
- `in_axes=(0, 0)`: Element-wise mapping (requires same size)
- `in_axes=(0, None)`: Map over first arg, broadcast second
- `in_axes=(None, 0)`: Broadcast first, map over second
- Nested `vmap`: Create Cartesian product of operations

## Example: Python Control Flow Under `jit`

When Python control flow depends on array values, the failure often appears only after applying transformations like `jit`. The resulting error reflects JAX’s tracing model rather than the original intent of the code.

In these situations, additional context helps identify:

- That the function is being traced rather than executed
- Which values are tracers rather than concrete arrays
- Where Python control flow depends on traced array values
- Why the conditional cannot be resolved at trace time


In [None]:
import jax, jax.numpy as jnp


def loss(x):
    if x.sum() > 0:
        return x.sum()
    else:
        return 0.0


jax.grad(jax.jit(loss))(jnp.array([1.0, -2.0, 1.0]))

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function loss at /var/folders/jz/dlk_sncn2lvfszqx5d49t0z40000gn/T/ipykernel_49729/1709465187.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

# TL;DR

**Problem**: You're using a Python `if` statement with a condition that depends on traced JAX array values (`x.sum() > 0`) inside a JIT-compiled function. JAX cannot convert traced values to concrete boolean values during compilation.

**Fix**: Replace the Python `if` statement with JAX's functional conditional `jnp.where()`:

```python
def loss(x):
    return jnp.where(x.sum() > 0, x.sum(), 0.0)
```

---

# Detailed Analysis

## What Went Wrong

The error occurs in this code:

```python
def loss(x):
    if x.sum() > 0:  # ← Problem here
        return x.sum()
    else:
        return 0.0

jax.grad(jax.jit(loss))(jnp.array([1.0, -2.0, 1.0]))
```

### Root Cause

JAX uses **tracing** to compile functions with `jit`. During tracing:

1. JAX replaces concrete values with abstract "tracer" objects that track operations
2. These tracers represent the **shape and dtype** of values, not their actual content
3. Python's `if` statement requires a concrete `True` or `False` value
4. When you write `if x.sum() > 0:`, JAX tries to convert the traced boolean to a Python bool
5. This fails because the actual value isn't known yet—only the computation graph is being built

From the stack trace, you can see:
- **Frame 1** shows the error at line 4: `if x.sum() > 0:`
- **Immediate failure** shows: `TracerBoolConversionError` with `JitTracer<bool[]>` (a traced boolean, not a concrete value)

### Why This Happens with JIT

The error specifically mentions "while tracing the function loss... for jit". When you use `jax.grad(jax.jit(loss))`:

1. `jax.jit` tries to compile the function to XLA
2. During compilation, it traces through your code with abstract values
3. The `if` statement forces Python to evaluate the boolean, which isn't possible with tracers

## The Solution

JAX provides functional alternatives to Python control flow:

### Option 1: `jnp.where()` (recommended for simple cases)

```python
def loss(x):
    return jnp.where(x.sum() > 0, x.sum(), 0.0)
```

This evaluates both branches and selects the result based on the condition.

### Option 2: `jax.lax.cond()` (for complex branches)

```python
import jax.lax as lax

def loss(x):
    return lax.cond(
        x.sum() > 0,
        lambda x: x.sum(),
        lambda x: 0.0,
        x
    )
```

This is more efficient when branches have expensive computations, as only one branch executes.

### Option 3: Remove `jit` (not recommended)

```python
jax.grad(loss)(jnp.array([1.0, -2.0, 1.0]))  # Works but slow
```

This works but defeats the purpose of using JAX for performance.

## Key Takeaways

1. **Never use Python `if/elif/else` with traced JAX arrays** inside jitted functions
2. Use JAX's functional equivalents: `jnp.where()`, `lax.cond()`, `lax.switch()`
3. Similarly, avoid Python `for`/`while` loops—use `jax.lax.fori_loop()` or `jax.lax.scan()`
4. This is a fundamental constraint of JAX's compilation model, not a bug

## Example: Accidental Rank Collapse Under `vmap`

When using `vmap`, it is easy to accidentally change array rank upstream (for example, via a reduction) and still pass type checking. The resulting error often appears far from the real cause.

In these cases, inspecting array metadata and transformation context helps identify:

- How `vmap` is mapping over the input (and what each mapped element actually is)
- Whether values are scalars or vectors at the point of failure
- Where an unintended rank reduction occurred
- The actual vs intended shapes flowing through the transformation


In [9]:
import jax
import jax.numpy as jnp


def normalize_row(x):
    return x / jnp.linalg.norm(x, axis=0)


def batched_normalize(X):
    # Intended X: (batch, d)
    return jax.vmap(normalize_row)(X)


X_good = jnp.ones((4, 8))
X_bad = jnp.ones((4, 8)).sum(axis=1)

jax.jit(batched_normalize)(X_bad)


ValueError: axis 0 is out of bounds for array of dimension 0

# TL;DR

**Problem**: You're using `jnp.linalg.norm(x, axis=0)` on a scalar (0-dimensional array), but `axis=0` is invalid for scalars since they have no axes.

**Root Cause**: `X_bad` has shape `(4,)`, so when `vmap` maps over it, each element `x` becomes a scalar with shape `()`. You can't compute a norm along `axis=0` of a scalar.

**Fix**: Remove the `axis` parameter from `jnp.linalg.norm()` in `normalize_row`:
```python
def normalize_row(x):
    return x / jnp.linalg.norm(x)  # Remove axis=0
```

---

# Detailed Analysis

## What Happened

The error occurs when JAX tries to trace your function for JIT compilation. Let's trace through the execution:

1. **Input**: `X_bad = jnp.ones((4, 8)).sum(axis=1)` creates an array with shape `(4,)` (a 1D array of 4 elements)

2. **Function call**: `jax.jit(batched_normalize)(X_bad)` is called

3. **Inside `batched_normalize`**: 
   - `jax.vmap(normalize_row)(X)` maps `normalize_row` over the first axis of `X_bad`
   - Since `X_bad` has shape `(4,)`, vmap iterates 4 times, passing each scalar element to `normalize_row`

4. **Inside `normalize_row`**: 
   - Each `x` is now a **scalar** (0-dimensional array with shape `()`)
   - You call `jnp.linalg.norm(x, axis=0)`
   - This tries to compute the norm along axis 0 of a scalar

5. **Error**: A scalar has 0 dimensions, so `axis=0` is out of bounds

## The Core Issue

Looking at the local variables in Frame 0:
```python
x = JitTracer<float32[]>  # Note the empty brackets [] indicating shape ()
axis = 0
ndim = 0  # Zero dimensions!
```

The `canonicalize_axis` function checks: `if not -num_dims <= axis < num_dims:` which evaluates to `if not 0 <= 0 < 0:` → `if not False:` → raises the error.

## Why This Happens

Your intended use case in the comment says `# Intended X: (batch, d)`, meaning a 2D array where:
- First dimension is the batch
- Second dimension is the feature vector to normalize

When you pass `X_good = jnp.ones((4, 8))` (shape `(4, 8)`):
- `vmap` iterates over axis 0 (the batch dimension)
- Each `x` in `normalize_row` has shape `(8,)` 
- `jnp.linalg.norm(x, axis=0)` computes the norm of this 8-element vector ✓

When you pass `X_bad = jnp.ones((4,))` (shape `(4,)`):
- `vmap` iterates over axis 0
- Each `x` in `normalize_row` is a **scalar** with shape `()`
- `jnp.linalg.norm(x, axis=0)` fails because scalars have no axis 0 ✗

## The Fix

For your use case, you should normalize each row as a vector, not compute the norm along a specific axis:

```python
def normalize_row(x):
    return x / jnp.linalg.norm(x)  # Computes norm of entire vector
```

This works because:
- When `x` has shape `(8,)`, `jnp.linalg.norm(x)` returns the L2 norm of the vector
- When vmapped over a `(4, 8)` array, it normalizes each of the 4 rows independently

The `axis` parameter in `jnp.linalg.norm` is only needed when you want to compute norms along specific dimensions of a multi-dimensional array, which isn't your case here.