# LOOPS

In this notebook we will learn the different ways to implement loops (with fixed size or dynamic size) using JAX primitives.

## JAX imports

## Beginner
### Prerequisites
No prerequisite - (Beginner if-else is better though)

### Imports

In [17]:
import jax.numpy as jnp
import numpy as np
from jax import make_jaxpr
from jax.lax import fori_loop

### Example

We will first give an example of how to compute the cumulative sum of the values in an array of floats using JAX:
```python
def my_cumsum(xs):
    res = np.zeros_like(xs)
    for i, x in enumerate(xs):
        res[i] = x + res[i-1]
    return res
```

We will use the following array for our tests:

In [2]:
arr = np.arange(0.0, 10.0)
print(arr)

[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]


#### FORI_LOOP

The most natural syntax would seem to be to use the fori_loop: let's first have a look at what simply summing the array would look like:

In [3]:
def naive_sum(xs):
    n = xs.shape[0]
    val = 0.0
    for i in jnp.arange(n):
        val = val + xs[i]
    return val


print(naive_sum(arr))

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


45.0


Can we use this directly on a jax array then?

In [4]:
naive_sum(jnp.asarray(arr))

Array(45., dtype=float32)

Well yes! So why all the fuss? Surely we can just not care right? Let write down a (bad) equivalent in jax and check it out:

In [5]:
def naive_fori_loop_sum(xs):
    xs = jnp.asarray(xs)
    n = xs.shape[0]

    def body(i, val):
        # i is the iteration step
        # val is the running value
        val = val + xs[i]
        return val

    res = fori_loop(
        0,  # starting index
        n,  # total number of iterations: the last index is n-1
        body,  # the function iterated during the loop: res = body(n-1, body(n-2, body(n-3,...)))
        init_val=0.0,  # the initial value for the loop
    )
    return res


print(naive_fori_loop_sum(arr))

45.0


In [6]:
make_jaxpr(naive_sum)(jnp.asarray(arr))

{ lambda ; a:f32[10]. let
    b:i32[10] = iota[dimension=0 dtype=int32 shape=(10,) sharding=None] 
    c:i32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] b
    d:i32[] = squeeze[dimensions=(0,)] c
    e:bool[] = lt d 0
    f:i32[] = add d 10
    g:i32[] = select_n e d f
    h:f32[1] = dynamic_slice[slice_sizes=(1,)] a g
    i:f32[] = squeeze[dimensions=(0,)] h
    j:f32[] = add 0.0 i
    k:i32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] b
    l:i32[] = squeeze[dimensions=(0,)] k
    m:bool[] = lt l 0
    n:i32[] = add l 10
    o:i32[] = select_n m l n
    p:f32[1] = dynamic_slice[slice_sizes=(1,)] a o
    q:f32[] = squeeze[dimensions=(0,)] p
    r:f32[] = add j q
    s:i32[1] = slice[limit_indices=(3,) start_indices=(2,) strides=(1,)] b
    t:i32[] = squeeze[dimensions=(0,)] s
    u:bool[] = lt t 0
    v:i32[] = add t 10
    w:i32[] = select_n u t v
    x:f32[1] = dynamic_slice[slice_sizes=(1,)] a w
    y:f32[] = squeeze[dimensions=(0,)] x
    z:f32

In [7]:
make_jaxpr(naive_fori_loop_sum)(jnp.asarray(arr))

{ lambda ; a:f32[10]. let
    _:i32[] b:f32[] = scan[
      _split_transpose=False
      jaxpr={ lambda ; c:f32[10] d:i32[] e:f32[]. let
          f:i32[] = add d 1
          g:bool[] = lt d 0
          h:i32[] = convert_element_type[new_dtype=int32 weak_type=False] d
          i:i32[] = add h 10
          j:i32[] = select_n g d i
          k:f32[1] = dynamic_slice[slice_sizes=(1,)] c j
          l:f32[] = squeeze[dimensions=(0,)] k
          m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e
          n:f32[] = add m l
        in (f, n) }
      length=10
      linear=(False, False, False)
      num_carry=2
      num_consts=1
      reverse=False
      unroll=1
    ] a 0 0.0
  in (b,) }

So basically, the difference is the number of lines of code generated! It matters for 

Now what if we actually wanted to return the cumulative sum array?

In [8]:
def naive_cumsum(xs):
    n = xs.shape[0]
    res = np.zeros_like(xs)
    val = 0.0
    for i in range(n):
        val = val + xs[i]
        res[i] = val
    return res


print(naive_cumsum(arr))

[ 0.  1.  3.  6. 10. 15. 21. 28. 36. 45.]


This would be equivalent to the following syntax

In [9]:
def naive_fori_loop_cumsum(xs):
    xs = jnp.asarray(xs)
    n = xs.shape[0]
    res = jnp.zeros_like(xs)

    def body(i, val):
        # i is the iteration step
        # val is the running value
        val = val + xs[i]
        res[i] = val
        return val

    _ = fori_loop(
        0,  # starting index
        n,  # total number of iterations: the last index is n-1
        body,  # the function iterated during the loop: res = body(n-1, body(n-2, body(n-3,...)))
        init_val=0.0,  # the initial value for the loop
    )
    return res


print(naive_fori_loop_cumsum(arr))

TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

Yeah can't do that actually. The main problem is the line
```python
res[i] = val
```

To be able to trace gradients, JAX needs to use pure functions, that is non-inplace updating functions. So to be able to use a JAX specific function that tells it we wish to update the arr `res` at index `i` with the value `val`. This is done using the syntax:
```python
res = res.at[i].set(val)
```
While it may seem like a copy of the array is made at each iteration, JAX is smart enough to only update the necessary values. Let's see how this looks like in code:  

In [15]:
def not_so_naive_fori_loop_cumsum(xs):
    xs = jnp.asarray(xs)
    n = xs.shape[0]

    def body(i, carry):
        # i is the iteration step
        # carry is the running value
        cumsum, val = carry
        val = val + xs[i]
        cumsum = cumsum.at[i].set(val)
        return cumsum, val

    res, val = fori_loop(0, n, body, init_val=(jnp.zeros_like(xs), 0.0))
    return res


print(not_so_naive_fori_loop_cumsum(arr))

[ 0.  1.  3.  6. 10. 15. 21. 28. 36. 45.]


This is starting to become awfully complicated just for a simple cumulative sum, thankfully, JAX has a `scan` operator that solves just this problem and accumulates the values for you:

In [13]:
def naive_scan_cumsum(xs):
    def body(val, x):
        val = val + x
        return (
            val,
            val,
        )  # the first output is the carry, the second one is the recorded value at the current index

    _final_val, cumsum = scan(
        body,  # function with signature (carry, input)-> (new_carry, res[i])
        0.0,  # the initial value
        xs,  # the list of inputs x to the body function
    )
    return cumsum  # in this example we just don't need the final value


print(naive_scan_cumsum(arr))

[ 0.  1.  3.  6. 10. 15. 21. 28. 36. 45.]


Finally imagine we want to compute the sum up to a threshold:
```python
def sum_to_threshold(xs, thresh):
    res = 0.
    i = 0
    while True:
        new_res = xs[i] + res
        if new_res > thresh:
            break
        res = new_res
        i += 1
    return res
        
```

then we would use a while loop:

In [18]:
def while_loop_sum_to_threshold(xs, thresh):
    xs = jnp.asarray(xs)
    n = xs.shape[0]

    def cond(carry):
        i, val = carry
        return jnp.logical_and((i < n), (val + xs[i] < thresh))  # if true we continue

    def body(carry):
        i, val = carry
        return i + 1, val + xs[i]

    _, val = while_loop(cond, body, init_val=(0, 0.0))
    return val


print(while_loop_sum_to_threshold(arr, 11.0))

10.0


### Questions:

#### Q1: 
Using `scan` with a fixed number of iterations implement the Newton square root algorithm which in pure python is given by:
```python
def sqrt(x, x0, N):
    y = x0
    for _ in range(N):
        y = 0.5 * (y + x / y)
    return y
```

#### Q2:
Using `while_loop` implement the searchsorted function:
```python
def searchsorted(x, arr):
    i = 0
    while i < len(arr):
        if arr[i] >= x:
            return i
        i += 1
    return i
            
```

## Intermediate
### Prerequisites
- Beginner loops
- Beginner automatic differentiation

### Imports

In [18]:
from jax import make_jaxpr, grad, jvp
from jax.lax import while_loop

Now that we know how to compute loops, how does it bode in terms of gradients?

### Example

To understand the respective behaviour state of `while_loop` versus `scan` operations (the `fori_loop` is in a stage of limbo at the time of this workshop writing but should eventually be implemented as a `scan` operation) we will consider the newton square root toy example from the Beginners question Q1:

In [19]:
def while_loop_sqrt(x, x0=1.0, n_iter=10):
    def cond(carry):
        i, _val = carry
        return i < n_iter

    def body(carry):
        i, val = carry
        return i + 1, 0.5 * (val + x / val)

    _, res = while_loop(cond, body, (0, x0))
    return res


print(while_loop_sqrt(0.5))

0.70710677


In [20]:
def scan_loop_sqrt(x, x0=1.0, n_iter=10):
    def body(val, _):
        return 0.5 * (val + x / val), None

    res, _ = scan(body, x0, jnp.arange(n_iter))
    return res


print(scan_loop_sqrt(0.5))

0.70710677


Clearly the main difference is that the scan version number of iterations is known in advance, whereas the while_loop is a bit more flexible so as to what the stopping condition is (here I put the same one to compare, but one may imagine using a precision threshold instead). This has a very important implication in terms of gradients (theoretically in terms of performance too on GPU, but as we speak the implementation is not full GPU supported).

Let's first look at the code generated by the `while_loop` and the `scan` respective implementations:

In [21]:
make_jaxpr(while_loop_sqrt)(0.5)

{ lambda  ; a.
  let _ b = while[ body_jaxpr={ lambda  ; a b c.
                                let d = add b 1
                                    e = div a c
                                    f = add c e
                                    g = mul f 0.5
                                in (d, g) }
                   body_nconsts=1
                   cond_jaxpr={ lambda  ; a b.
                                let c = lt a 10
                                in (c,) }
                   cond_nconsts=0 ] a 0 1.0
  in (b,) }

In [22]:
make_jaxpr(scan_loop_sqrt)(0.5)

{ lambda  ; a.
  let b = iota[ dimension=0
                dtype=int32
                shape=(10,) ] 
      c = scan[ jaxpr={ lambda  ; a b c.
                        let d = div a b
                            e = add b d
                            f = mul e 0.5
                        in (f,) }
                length=10
                linear=(False, False, False)
                num_carry=1
                num_consts=1
                reverse=False
                unroll=1 ] a 1.0 b
  in (c,) }

We can now see how this translates in terms of gradients:

In [23]:
make_jaxpr(grad(scan_loop_sqrt))(0.5)

{ lambda  ; a.
  let b = iota[ dimension=0
                dtype=int32
                shape=(10,) ] 
      _ _ c _ d =
        scan[ jaxpr={ lambda  ; f a b c d e.
                      let g = div f c
                          h = add c g
                          i = mul h 0.5
                          j = integer_pow[ y=-2 ] c
                      in (i, *, c, *, j) }
              length=10
              linear=(False, True, True, False, True, False)
              num_carry=2
              num_consts=3
              reverse=False
              unroll=1 ] a * * 1.0 * b
      _ e _ _ _ =
        scan[ jaxpr={ lambda  ; a b c d e f g.
                      let h = mul e 0.5
                          i = mul h g
                          j = mul i a
                          k = neg j
                          l = add_any h k
                          m = div h f
                          n = add_any c m
                      in (b, n, *, l, *) }
              length=10
             

In [24]:
jvp(scan_loop_sqrt, (0.25,), (1.0,))

(DeviceArray(0.5, dtype=float32), DeviceArray(1., dtype=float32))

In [25]:
jvp(while_loop_sqrt, (0.25,), (1.0,))

(DeviceArray(0.5, dtype=float32), DeviceArray(1., dtype=float32))

In [26]:
grad(scan_loop_sqrt)(0.25)

DeviceArray(1., dtype=float32)

In [27]:
grad(while_loop_sqrt)(0.25)

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop. Try using lax.scan instead.

As explained in the automatic differentiation, reverse-mode needs to allocate memory, which can't be done dynamically in XLA, so that JAX doesn't allow for reverse mode differentiation through while_loops which could grow indefinitely. Instead some work has been planned in JAX to allow for bounded size while loops to be implemented using the scan syntax.

### Questions:

#### Q1: 
Implement the following bubble sort algorithm:
```python
def bubble_sort(arr): 
    n = len(arr) 
    res = np.copy(arr)
    for i in range(n-1): 
        for j in range(0, n-i-1): 
            if res[j] > res[j+1]: 
                res[j], res[j+1] = res[j+1], res[j]
    return res   
```
What is its Jacobian?

#### Q2:
Implement the following discrete [Hidden Markov Model](https://en.wikipedia.org/wiki/Hidden_Markov_model):

In [28]:
def hmm_filter(A, B, pi, ys):
    pxs = np.empty((ys.shape[0], pi.shape[0]))
    for i, y in enumerate(ys):
        B_y = B[:, y]  # likelihood
        px = B_y * px  # unormalised bayes rule
        px = px / px.sum()  # normalisation
        pxs[i] = px  # registration
        px = A @ px  # prediction
    return pxs

## Advanced
### Prerequisites
- Intermediate loops 

### Imports

In [57]:
from jax import make_jaxpr, jvp, vjp
from jax.lax import associative_scan, scan

import jax.numpy as jnp
import numpy as np

We now present an additional primitive which can be useful when dealing with associative binary operations (such as summation): `associative_scan`, also known as [prefix sum](https://en.wikipedia.org/wiki/Prefix_sum). It consists in applying recursive operations to subsets of the inputs (divide and conquer strategy) instead of applying it sequentially. This has the benefit of being easily parallelisable and is natively implemented in JAX. 

### Example

In practice, it implements a parallelised version (see for example the [Wikipedia](https://en.wikipedia.org/wiki/Prefix_sum)) of the following algorithm:
```python
def my_sequential_associative_scan(binary_op, xs):
    res = np.copy(xs)
    val = xs[0]
    for i, x in enumerate(xs[1:]):
        val = binary_op(val, x)
        res[i+1] = val
    return res
```

so that the cumulative sum would be for example written as 
```python
my_sequential_associative_scan(lambda x, y: x + y, np.arange(10))
```
In pure python, this would typically look like:
```python
def home_made_blelloch(arr):
    # This is for illustration purposes only, and for instance doesn't take into
    # account the case when the array size is not a pure power of 2
    res = np.copy(arr)
    n = res.shape[0]
    log_n = np.log2(n).astype(int)

    # Up pass
    for d in range(log_n):
        # this loop can't be done in parallel so it defines the span complexity under
        # parallelization
        for i in range(0, n, 2 ** (d + 1)):
            # this should be done in parallel, therefore would not be taken
            # into account in the span complexity provided we have at least
            # n / 2^{d+1} cores on our GPU
            i1 = i + 2 ** d - 1
            i2 = i + 2 ** (d + 1) - 1
            res[i2] += res[i1]

    res[-1] = 0

    # Down pass
    for d in range(log_n - 1, -1, -1):
        # this loop can't be done in parallel so it defines the span complexity under
        # parallelization
        for i in range(0, n, 2 ** (d + 1)):
            # this should be done in parallel, therefore would not be taken
            # into account in the span complexity provided we have at least
            # n / 2^{d+1} cores on our GPU
            i1 = i + 2 ** d - 1
            i2 = i + 2 ** (d + 1) - 1

            res[i1], res[i2] = res[i2], res[i1] + res[i2]
            # Extra pass
    res += arr

    return res
```


Using JAX this would actually be written in the following wa:

In [52]:
def associative_cumulative_sum(xs):
    return associative_scan(lambda x, y: x + y, xs)


print(associative_cumulative_sum(np.arange(10.0)))

[ 0.  1.  3.  6. 10. 15. 21. 28. 36. 45.]


Let's look at the code generated:

In [53]:
make_jaxpr(associative_cumulative_sum)(np.arange(10.0))

{ lambda  ; a.
  let b = slice[ limit_indices=(9,)
                 start_indices=(0,)
                 strides=(2,) ] a
      c = slice[ limit_indices=(10,)
                 start_indices=(1,)
                 strides=(2,) ] a
      d = add b c
      e = slice[ limit_indices=(4,)
                 start_indices=(0,)
                 strides=(2,) ] d
      f = slice[ limit_indices=(5,)
                 start_indices=(1,)
                 strides=(2,) ] d
      g = add e f
      h = slice[ limit_indices=(1,)
                 start_indices=(0,)
                 strides=(2,) ] g
      i = slice[ limit_indices=(2,)
                 start_indices=(1,)
                 strides=(2,) ] g
      j = add h i
      k = slice[ limit_indices=(0,)
                 start_indices=(0,)
                 strides=(1,) ] j
      l = slice[ limit_indices=(2,)
                 start_indices=(2,)
                 strides=(2,) ] g
      m = add k l
      n = slice[ limit_indices=(1,)
                 start_indic

It looks just as when we were doing python loops! Is it bad? No, because the depth of the generated graph will only grow in $\log_2(n)$.

And as for `scan`, it is closed under differentiation:

In [62]:
print(jvp(associative_cumulative_sum, (np.arange(5.0),), (np.ones(5),)))

(DeviceArray([ 0.,  1.,  3.,  6., 10.], dtype=float32), DeviceArray([1., 2., 3., 4., 5.], dtype=float32))


In [65]:
val, cumsum_bwd = vjp(associative_cumulative_sum, np.arange(5.0))
print(val)
print(cumsum_bwd(np.ones(5)))

[ 0.  1.  3.  6. 10.]
(DeviceArray([5., 4., 3., 2., 1.], dtype=float32),)


### Questions:

#### Q1: 
Compare the speed of the associative_scan and scan implementation of cumulative sum on GPU and CPU (use the device flag in the jit function)

#### Q2:
A simpler version of parallel scans is given by Hillis and Steele's algorithm, which is given in a pure python version by:
```python
def home_made_hillis_and_steel(arr):
    # This is for illustration purposes only
    res = np.copy(arr)
    n = res.shape[0]
    log_n = np.log2(n).astype(int)
    n_operations = 0
    for d in range(log_n):
        # this loop can't be done in parallel so it defines the span complexity under
        # parallelization
        for i in reversed(range(n)):
            # For each i, in parallel
            if i - 2 ** d >= 0:
                n_operations += 1
                res[i] += res[i - 2 ** d]
    return res, n_operations
```
Implements your own parallel version of `associative_scan` using jax primitives, using the Hillis and Steele's algorithm, or Blelloch's algorithm if you feel adventurous.