# LOOPS

In this notebook we will learn the different ways to implement loops (with fixed size or dynamic size) using JAX primitives.

## JAX imports

## Beginner
### Prerequisites
No prerequisite - (Beginner if-else is better though)

### Imports

In [26]:
from jax import make_jaxpr
from jax.lax import while_loop, fori_loop, scan
import jax.numpy as jnp
from jax.ops import index_update

import numpy as np

### Example

We will first give an example of how to compute the cumulative sum of the values in an array of floats using JAX:
```python
def my_cumsum(xs):
    res = np.zeros_like(xs)
    for i, x in enumerate(xs):
        res[i] = x + res[i-1]
    return res
```

We will use the following array for our tests:

In [38]:
arr = np.arange(0., 10.)
print(arr)

[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]


#### FORI_LOOP

The most natural syntax would seem to be to use the fori_loop: let's first have a look at what simply summing the array would look like:

In [39]:
def naive_sum(xs):
    n = xs.shape[0]
    res = np.zeros_like(xs)
    val = 0.
    for i in range(n):
        val = val + xs[i]
    return val

print(naive_numpy_cumsum(arr))

45.0


In [40]:
def naive_fori_loop_sum(xs):
#     xs = jnp.asarray(xs) 
    n = xs.shape[0]
    def body(i, val):
        # i is the iteration step
        # val is the running value
        val = val + xs[i]
        return val
    
    res = fori_loop(0,  # starting index
                  n,  # total number of iterations: the last index is n-1
                  body,  # the function iterated during the loop: res = body(n-1, body(n-2, body(n-3,...)))
                  init_val=0.  # the initial value for the loop
                 )
    return res

print(naive_fori_loop_sum(arr))

45.0


Now what if we actually wanted to return the cumulative sum array?

In [41]:
def naive_cumsum(xs):
    n = xs.shape[0]
    res = np.zeros_like(xs)
    val = 0.
    for i in range(n):
        val = val + xs[i]
        res[i] = val
    return res

print(naive_numpy_cumsum(arr))

45.0


This would be equivalent to the following syntax

In [42]:
def naive_fori_loop_cumsum(xs):
    n = xs.shape[0]
    res = jnp.zeros_like(xs)
    def body(i, val):
        # i is the iteration step
        # val is the running value
        val = val + xs[i]
        res[i] = val
        return val
    
    _ = fori_loop(0,  # starting index
                  n,  # total number of iterations: the last index is n-1
                  body,  # the function iterated during the loop: res = body(n-1, body(n-2, body(n-3,...)))
                  init_val=0.  # the initial value for the loop
                 )
    return res

print(naive_fori_loop_cumsum(arr))

Exception: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>.

This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using `jnp` together with `import jax.numpy as jnp` rather than using `np` via `import numpy as np`. If this error arises on a line that involves array indexing, like `x[idx]`, it may be that the array being indexed `x` is a raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that case, you can instead write `jax.device_put(x)[idx]`.

Yeah can't do that actually. The main problem is the line
```python
res[i] = val
```

To be able to trace gradients, JAX needs to use pure functions, that is non-inline functions. So to be able to use the loop syntax we would need to use a JAX specific operation: `index_update`

In [43]:
def not_so_naive_fori_loop_cumsum(xs):
    xs = jnp.asarray(xs)
    n = xs.shape[0]
    res = jnp.zeros_like(xs)
    def body(i, carry):
        # i is the iteration step
        # carry is the running value
        cumsum, val = carry
        val = val + xs[i]
        cumsum = index_update(cumsum, i, val)
        return cumsum, val
    
    res, val = fori_loop(0, n, body, init_val=(jnp.zeros_like(xs), 0.))
    return res

print(not_so_naive_fori_loop_cumsum(arr))

[ 0.  1.  3.  6. 10. 15. 21. 28. 36. 45.]


This is starting to become awfully complicated just for a simple cumulative sum, thankfully, JAX has a `scan` operator that solves just this problem:

In [44]:
def naive_scan_cumsum(xs):
    def body(val, x):
        val = val + x
        return val, val  # the first output is the carry, the second one is the recorded value at the current index
    
    _final_val, cumsum = scan(body,  # function with signature (carry, input)-> (new_carry, res[i])
                              0.,  # the initial value
                              xs  # the list of inputs x to the body function
                             )
    return cumsum  # in this example we just don't need the final value

print(naive_scan_cumsum(arr))

[ 0.  1.  3.  6. 10. 15. 21. 28. 36. 45.]


Finally imagine we want to compute the sum up to a threshold:
```python
def sum_to_threshold(xs, thresh):
    res = 0.
    i = 0
    while True:
        new_res = xs[i] + res
        if new_res > thresh:
            break
        res = new_res
        i += 1
    return res
        
```

then we would use a while loop:

In [66]:
def while_loop_sum_to_threshold(xs, thresh):
    xs = jnp.asarray(xs)
    n = xs.shape[0]
    
    def cond(carry):
        i, val = carry
        return jnp.logical_and((i < n), 
                               (val + xs[i] < thresh))  # if true we continue
    
    
    def body(carry):        
        i, val = carry
        return i + 1, val + xs[i]
    
    _, val = while_loop(cond, 
                        body,
                        init_val=(0, 0.)
                        )
    return val

print(while_loop_sum_to_threshold(arr, 11.))

10.0


### Questions:

#### Q1: 
Using `scan` with a fixed number of iterations implement the Newton square root algorithm which in pure python is given by:
```python
def sqrt(x, x0, N):
    y = x0
    for _ in range(N):
        y = 0.5 * (y + x / y)
    return y
```

#### Q2:
Using `while_loop` implement the searchsorted function:
```python
def searchsorted(x, arr):
    i = 0
    while i < len(arr):
        if arr[i] >= x:
            return i
        i += 1
    return i
            
```

## Intermediate
### Prerequisites
- Beginner if-else  
- Beginner vectorisation
- Numpy  

### Imports

In [29]:
from jax import vmap
from jax.lax import cond
import jax.numpy as jax_np

import numpy as np

Now that we know how to compute if-else predicates for scalar inputs, how do we extend this to tensor inputs?

### Example

Fundamentally, JAX already implements the `abs` function so you shouldn't have to care about this, but how would we replicate the result using only their high-level primitives?
```python
def my_abs(x):
    return jnp.abs(x)
```
We will use the following array as an input:

In [39]:
arr = np.array([-1., 0., 1.])

Let's try to use the code we used for the Beginner level and see what happens:

In [40]:
def jax_abs(x):
    predicate = x > 0
    
    def true_fun(z):
        return z
    
    def false_fun(z):
        return -z
    
    result = cond(
        predicate,  # predicate for the if
        true_fun, # function to call on operand if predicate is true
        false_fun,  # function to call on operand otherwise
        operand=x  # operand to be passed to either true_fun or false_fun
                 )
    
    return result

In [41]:
print(jax_abs(arr))

TypeError: Pred must be a scalar, got [False False  True] of shape (3,).

So we simply can't use the same trick. We are left with two choices: we can either use numpy API (just not np.abs for the sake of the exercise though) or try to be smart.


In [42]:
def jax_numpy_abs(x):
    return jax_np.where(x>0, x, -x)

print(jax_numpy_abs(arr))

[ 1. -0.  1.]


### Let's try to be smart
Let's try to be smart and use the primitives we learned about in the vectorisation notebook:

In [53]:
vmap_abs = vmap(jax_abs)
print(vmap_abs(arr))

[ 1. -0.  1.]


In [54]:
vectorized_abs = jax_np.vectorize(jax_abs, signature=("()->()"))
print(vectorized_abs(arr))

[ 1. -0.  1.]


### Questions:

#### Q1: 
Compare the three different implementations (you can use the utility make_jaxpr to see the code generated), what do you think is happening in the background?

#### Q2:
Implement the vectorised `piecewise_constant` function using the three different techniques.

## Advanced
### Prerequisites
- Intermediate if-else  

### Imports

In [110]:
from jax import vmap, make_jaxpr, jit
from jax.lax import cond, switch
import jax.numpy as jax_np

import numpy as np

Now that we know how to compute if-else predicates for vectorized inputs, what happens when we have more than one condition?

The goal here is to implement an extension of the piecewise polynomial function:

In [85]:
def piecewise_constant(x, xs, ys):
    # len(xs) = len(ys) - 2
    # xs are considered to be sorted
    if x < xs[0]:
        return ys[0]
    for xi, yi in zip(xs, ys[:-1]):
        if x >= xi:
            continue
        return yi
    return ys[-1]

For which we will consider the following inputs

In [86]:
arr_xs = np.array([-1., 0., 1.])
arr_ys = np.array([0.2, 0.4, 0., 0.1])

arr_x = np.array([0.5, -2., 3., -1.5, -0.4])

### Example

First thing first let's look at the switch function:

In [87]:
help(switch)

Help on function switch in module jax._src.lax.control_flow:

switch(index, branches: Sequence[Callable], operand)
    Apply exactly one of ``branches`` given by ``index``.
    
    If ``index`` is out of bounds, it is clamped to within bounds.
    
    Has the semantics of the following Python::
    
      def switch(index, branches, operand):
        index = clamp(0, index, len(branches) - 1)
        return branches[index](operand)
    
    Arguments:
      index: Integer scalar type, indicating which branch function to apply.
      branches: Sequence of functions (A -> B) to be applied based on `index`.
      operand: Operand (A) input to whichever branch is applied.



So we can rewrite our absolute function in terms of the switch function:

In [107]:
def jax_abs_switch(x):
    branches = [lambda z: -z, lambda z: z]
    index = jax_np.asarray(x > 0, jnp.int32)
    return switch(index, branches, x)

print(jax_abs_switch(5.))

5.0


So now let's have a look and see what we really did:

In [108]:
make_jaxpr(jax_abs_switch)(-5.)

{ lambda  ; a.
  let b = gt a 0.0
      c = convert_element_type[ new_dtype=int32
                                old_dtype=bool ] b
      d = clamp 0 c 1
      e = cond[ branches=( { lambda  ; a.
                             let b = neg a
                             in (b,) }
                           { lambda  ; a.
                             let 
                             in (a,) } )
                linear=(False,) ] d a
  in (e,) }

In [109]:
make_jaxpr(compact_jax_abs)(-5.)

{ lambda  ; a.
  let b = gt a 0.0
      c = convert_element_type[ new_dtype=int32
                                old_dtype=bool ] b
      d = cond[ branches=( { lambda  ; a.
                             let b = neg a
                             in (b,) }
                           { lambda  ; a.
                             let 
                             in (a,) } )
                linear=(False,) ] c a
  in (d,) }

Well we did nothing! So `cond` is just a special case of `switch`, and in particular stacking `cond` in the hope of a better performance would be a very bad idea.

### Questions:

#### Q1: 
Implement the generelized `piecewise_constant` using `switch`, assuming the data is sorted.

#### Q2:
If you are already familiar with the loops primitives, implement the generelized `piecewise_constant` using `loops` and `cond` and compare the generated code.

#### Q3:
How would you vectorize this function? Compare the naive vmap with the `jax_np.select` approach.