# 🔪 Sharp Edges ⚠️

By the end of this lesson, you will be aware of certain limitations and gotchas that you'll need to keep in mind as you use Jax. This lesson integrates many of the previously covered concepts such as **mutability**, the **jit**

For a thorough rundown please see [Jax The Sharp Bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).

In [None]:
%xmode Minimal

In [None]:
import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp


# State + Impurity

- What happens when we have code that has side-effects?

## Global State Modification

In [None]:
def impure_print_side_effect(x):
    print("Executing function")  # This is a side-effect
    return x

# The side-effects appear during the first run
print ("Jitted First call: ", impure_print_side_effect(4.))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Jitted Second call: ", impure_print_side_effect(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
print ("Jitted Third call, different type: ", impure_print_side_effect([4.]))

print("*" * 10)

In [None]:
#################################################################################
@jax.jit
def pure_print_side_effect(x):
    jax.debug.print("Executing function")
    return x

# The side-effects appear during the first run
print ("Jitted First call: ", pure_print_side_effect(4.))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Jitted Second call: ", pure_print_side_effect(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
print ("Jitted Third call, different type: ", pure_print_side_effect([4.]))


## In-place updating arrays

In [None]:
numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)

# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)

In [None]:
jax_array = jnp.zeros((3,3), dtype=jnp.float32)

# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0


In [None]:
jax_array.at[1, :].set(1.0)

# Control Flow

In [None]:
def f(x):
  if x < 3:
    return 3. * x ** 3
  else:
    return -4 * x

print(grad(f)(2.))  # ok!
print(grad(f)(4.))  # ok!


In [None]:
jitted_f = jit(f)

print(grad(jitted_f)(2.)) 
print(grad(jitted_f)(4.))

In [None]:
@jit
def jit_compat_f(x):
    return jax.lax.cond(x < 3, lambda y: 3. * y ** 3, lambda y: -4 * y, x)

print(grad(jit_compat_f)(2.))  # ok!
print(grad(jit_compat_f)(4.))  # ok!
