# Jax's Loops

## Lesson Goals:

By the end of this lesson, you'll be able to articulate why and when you want to use `jax`'s native `while_loop`, `fori_loop`, and `scan` over python's native loops. In the process, you'll learn how to read haskell-like type signatures, which will be useful as you explore the `jax` library further.

## Core Concepts:

- functional programming 
- reading Haskell function signatures
- `while_loop`
- `fori_loop`
- `scan`

## Concepts In action:

- Easy: [lotka-volterra](../case_studies/lotka-volterra/README.md)

- Intermediate: [leaky_integrate_and_fire](../case_studies/leaky_integrate_and_fire/README.md)

In [None]:
import numpy as np
from typing import TypeAlias
import time
import jax.numpy as jnp

np.random.seed(42)

# Haskell-like signatures

![](../assets/haskell.png)

type signatures are a great way to abstractly understand functions and what they do. Let's walk through a few examples:

## Signature 1: Mapping functions over elements of an iterable

```haskell
map :: (a -> b) -> [a] -> [b]
```

```python
def map(func: Callable, arr):
    return ...
```

## Signature 2: Joining Structures

```haskell
(++) :: [a] -> [a] -> [a]
```

What sort of pre-condition

```python
def abstract_join_of_iterable(s1, s2):
    """
    TODO: answer the following
    1) what sort of pre-condition do we need? What function must type(s1) implement?
    2) assuming the pre-condition, what might this function look like?
    """
    return ...
```


## Signature 3: Filtering elements

```haskell
filter :: (a -> Bool) -> [a] -> [a]
```

```python
def filter(filter_func, iterable):
    """
    TODO: answer the following
    1) must it be a list? 
    2) implement the function
    """

```

# Looping in Jax

As mentioned before, you probably don't want to `jit` a function that has a native python `for-loop` in it as this increases your compilation time. Thankfully, `jax` provides:

- [jax.lax.while_loop](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html#jax.lax.while_loop)
- [jax.lax.fori_loop](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html)
- [jax.lax.scan](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)

to circumvent this issue. 

Note: we don't necessarily see a speedup in runtime; the primary advantage of using these jax functions is that the compilation time is reduced.


## The Jax Functions we will be covering:

```haskell
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
```

```haskell
fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a
```

```haskell
scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
```

## Jax's Loops

```python
def while_loop(cond_fun, body_fun, init_val):
    val = init_val
    while cond_fun(val):
        val = body_fun(val)
    return val
```

```haskell
while_loop: TODO
```

```python
def fori_loop(lower, upper, body_fun, init_val):
    val = init_val
    for i in range(lower, upper):
        val = body_fun(i, val)
    return val
```

```haskell
fori_loop: TODO
```

```python
def scan(f, init, xs):
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
  return carry, np.stack(ys)
```


```haskell
scan: TODO
```

# Further Exercises: 

0) Read through [case_studies/leaky_integrate_and_fire/jax_leaky_integrate_and_fire_3_scan.ipynb](../case_studies/leaky_integrate_and_fire/jax_leaky_integrate_and_fire_3_scan.ipynb) for an example of using `scan` in a "real-world" scenario

1) Read through [case_studies/lotka-volterra/jax_lotka-volterra.ipynb](../case_studies/lotka-volterra/jax_lotka-volterra.ipynb) which shows you the various ways you can use a `fori_loop` and `scan` 

