# When should we NOT jit?

[Why can’t we just JIT everything?](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html#why-can-t-we-just-jit-everything)

There are many situations where we shouldn't `jit` and we go through a few below

## Scenario 1: using python-loops

One such scenario is in the above, where we have the `python` `for-loop`.

```python
def run_simulation(
    W: Mat,
    V: Vec,

    # Neuron Parameters
    tau_m: float,
    v_reset: float,
    v_thresh: float,
    membr_R: float,

    # How long do we run for? 
    t_max: float,
    dt: float, 

):
    ...
```

and the reason is that `jax` will take the function and unroll the loop, a concept we briefly touched upon earlier. Unrolling the loop does two things:

1) generates a large program
2) takes a long time

and we'd like to avoid that. We cover how to handle native loops in the next notebook

## Scenario 2: Boolean indexing

The output and input shape should be consistent between runs. If you modifying the shape in between runs, Jax

In [None]:
to_filter = jnp.asarray([1, 2, 3, jnp.nan, 10, 20])

def is_nan_filter(to_filter_jnp):
    ########################################
    # Your code here
    ########################################
    return to_filter_jnp[to_filter_jnp != jnp.nan]

jitted_func = jax.jit(is_nan_filter)

print(is_nan_filter(to_filter))
try:
    print(jitted_func(to_filter))
except Exception as e:
    print(e)


def is_nan_jit_compatible(to_filter_jnp, replace_with):
    ########################################
    # Your code here
    ########################################
    nan_mask = ~jnp.isnan(to_filter_jnp)
    return jnp.where(nan_mask, to_filter_jnp,replace_with)

jitted_func = jax.jit(is_nan_jit_compatible)
print(jitted_func(to_filter, 0))

## Scenario 3: Conditional looping

Taken directly from the "Why can’t we just JIT everything?" URL above.

In [None]:
def g(x, n):
    i = 0
    while i < n:
        i += 1
    return x + i

g_jit = jax.jit(g)

try:
    g_jit(10, 20)  # Should raise an error. 
except Exception as e:
    print(e)
