# 🔪 Sharp Edges ⚠️

## Lesson Goals:

By the end of this lesson, you will be aware of certain limitations and gotchas that you'll need to keep in mind as you use Jax. This lesson integrates many of the previously covered concepts such as **mutability**, the **jit**. This tutorial tries to not only highlight roadblocks, but also provide solutions.

For a thorough rundown please see [Jax The Sharp Bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).

## Core Concepts:

- functional programming and purity
- `jax.debug.print`

In [None]:
%xmode Minimal

In [None]:
import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp


# State + Impurity

- What happens when we have code that has side-effects?

## Sharp Edge: Global State Modification via `print`

WHen we print, we alter global state, which means that the operation isn't "pure"

In [None]:
def impure_print_side_effect(x):
    print("Executing function")  # This is a side-effect
    return x

# The side-effects appear during the first run
print ("Jitted First call: ", impure_print_side_effect(4.))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Jitted Second call: ", impure_print_side_effect(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
print ("Jitted Third call, different type: ", impure_print_side_effect([4.]))

print("*" * 10)

## Solution: Use `jax.debug.print`

In [None]:
#################################################################################
@jax.jit
def pure_print_side_effect(x):
    # TODO: your function here
    raise NotImplementedError
    return x

# The side-effects appear during the first run
print ("Jitted First call: ", pure_print_side_effect(4.))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Jitted Second call: ", pure_print_side_effect(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
print ("Jitted Third call, different type: ", pure_print_side_effect([4.]))


## Sharp Edge: In-place updating arrays

As before, this involves a mutation, which makes it unallowed

In [None]:
numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)

# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)

## Solution: Use `at[...].set(X)`

In [None]:
jax_array = jnp.zeros((3,3), dtype=jnp.float32)

# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0


In [None]:
jax_array.at[1, :].set(1.0)

# Control Flow

## Sharp Edge: if-statement

In [None]:
def f(x):
  if x < 3:
    return 3. * x ** 3
  else:
    return -4 * x

print(grad(f)(2.))  # ok!
print(grad(f)(4.))  # ok!


In [None]:
jitted_f = jit(f)

print(grad(jitted_f)(2.)) 
print(grad(jitted_f)(4.))

## Solution: use `jax.lax.cond`

In [None]:
@jit
def jit_compat_f(x):
    # TODO: your function here that reimplements `f` but with `jax.lax.cond`
    raise NotImplementedError

print(grad(jit_compat_f)(2.))  # ok!
print(grad(jit_compat_f)(4.))  # ok!


## Sharp Edge: Tracing specific values

Jax wants doesn't like functions where execution depends on the inputs. For example...

In [None]:
def run_for_n(x, n):
    accum = 0
    for i in range(n):
        accum = accum + x
    return accum

print(run_for_n(5, 3))

In [None]:
jitted_run_n = jit(run_for_n)

print(jitted_run_n(5, n=3))

## Solution: static arguments

By marking an argument as static, Jax uses a less abstract tracer for that arg/kwarg. **However**, for every new value of `n`, the function has to be recompiled

In [None]:
# TODO: a small change to `jitted_run_n = jit(run_for_n)` will make this run :) 

print(jitted_run_n(5, n=3))