# IF-ELSE

In this notebook we will consider variants of a fairly simple piecewise constant function:

```python
def piecewise_constant(x, a, b, c):
    if x < a:
        return b
    else:
        return c
```

and we will try and understand how we can implement them using JAX primitives.

## JAX imports

## Beginner
### Prerequisites
No prerequisite

### Imports

In [114]:
from jax.lax import cond

### Example

We will first give an example of how to compute the absolute value using JAX:
```python
def my_abs(x):
    return x if x > 0 else -x
```

It can be written in full details using the following code:

In [13]:
def jax_abs(x):
    predicate = x > 0
    
    def true_fun(z):
        return z
    
    def false_fun(z):
        return -z
    
    result = cond(
        predicate,  # predicate for the if
        true_fun, # function to call on operand if predicate is true
        false_fun,  # function to call on operand otherwise
        operand=x  # operand to be passed to either true_fun or false_fun
                 )
    
    return result

print(jax_abs(-5.))

People familiar with C-like languages might have realised that cond acts quite similarly to the trinary operation `a?b:c;`

This can be written more compactly as 

In [14]:
def compact_jax_abs(x):
    return cond(x > 0,
                lambda z: z, 
                lambda z: -z,
                operand=x)

print(compact_jax_abs(-5.))

### Questions:

#### Q1: 
Consider the following implementation:

In [115]:
def other_jax_abs(x):
    return cond(x > 0,
                lambda z: z[0], 
                lambda z: z[1], 
                operand=(x, -x))

print(other_jax_abs(-5.))

5.0


What is the difference with the `compact_jax_abs` implementation? Which one is better? Why?

#### Q2:
Implement the `piecewise_constant` function using cond only.

In [24]:
def piecewise_constant(x, a, b, c):
    pass

assert piecewise_constant(0.1, 0., 1., 2.) == 2.

AssertionError: 

## Intermediate
### Prerequisites
- Beginner if-else  
- Beginner vectorisation
- Numpy  

### Imports

In [119]:
from jax import vmap, make_jaxpr
from jax.lax import cond
import jax.numpy as jax_np

import numpy as np

Now that we know how to compute if-else predicates for scalar inputs, how do we extend this to tensor inputs?

### Example

Fundamentally, JAX already implements the `abs` function so you shouldn't have to care about this, but how would we replicate the result using only their high-level primitives?
```python
def my_abs(x):
    return jnp.abs(x)
```
We will use the following array as an input:

In [120]:
arr = np.array([-1., 0., 1.])

Let's try to use the code we used for the Beginner level and see what happens:

In [121]:
def jax_abs(x):
    predicate = x > 0
    
    def true_fun(z):
        return z
    
    def false_fun(z):
        return -z
    
    result = cond(
        predicate,  # predicate for the if
        true_fun, # function to call on operand if predicate is true
        false_fun,  # function to call on operand otherwise
        operand=x  # operand to be passed to either true_fun or false_fun
                 )
    
    return result

print(jax_abs(arr))

So we simply can't use the same trick. We are left with two choices: we can either use numpy API (just not np.abs for the sake of the exercise though) or try to be smart.


In [123]:
def jax_numpy_abs(x):
    return jax_np.where(x>0, x, -x)

print(jax_numpy_abs(arr))

[ 1. -0.  1.]


### Let's try to be smart
Let's try to be smart and use the primitives we learned about in the vectorisation notebook:

In [53]:
vmap_abs = vmap(jax_abs)
print(vmap_abs(arr))

[ 1. -0.  1.]


In [124]:
vectorized_abs = jax_np.vectorize(jax_abs, signature="()->()")
print(vectorized_abs(arr))

[ 1. -0.  1.]


### Questions:

#### Q1: 
Compare the three different implementations (you can use the utility `make_jaxpr` to see the code generated), what do you think is happening in the background?

#### Q2:
Implement the vectorised `piecewise_constant` function using the three different techniques.

## Advanced
### Prerequisites
- Intermediate if-else  

### Imports

In [110]:
from jax import vmap, make_jaxpr, jit
from jax.lax import cond, switch
import jax.numpy as jax_np

import numpy as np

Now that we know how to compute if-else predicates for vectorized inputs, what happens when we have more than one condition?

The goal here is to implement an extension of the piecewise constant function:

In [128]:
def piecewise_constant(x, xs, ys):
    # len(xs) = len(ys) - 2
    # xs are considered to be sorted
    if x < xs[0]:
        return ys[0]
    for xi, yi in zip(xs, ys[:-1]):
        if x >= xi:
            continue
        return yi
    return ys[-1]

For which we will consider the following inputs

In [86]:
arr_xs = np.array([-1., 0., 1.])
arr_ys = np.array([0.2, 0.4, 0., 0.1])

arr_x = np.array([0.5, -2., 3., -1.5, -0.4])

### Example

First thing first let's look at the switch function:

In [87]:
help(switch)

Help on function switch in module jax._src.lax.control_flow:

switch(index, branches: Sequence[Callable], operand)
    Apply exactly one of ``branches`` given by ``index``.
    
    If ``index`` is out of bounds, it is clamped to within bounds.
    
    Has the semantics of the following Python::
    
      def switch(index, branches, operand):
        index = clamp(0, index, len(branches) - 1)
        return branches[index](operand)
    
    Arguments:
      index: Integer scalar type, indicating which branch function to apply.
      branches: Sequence of functions (A -> B) to be applied based on `index`.
      operand: Operand (A) input to whichever branch is applied.



So we can rewrite our absolute function in terms of the switch function:

In [107]:
def jax_abs_switch(x):
    branches = [lambda z: -z, lambda z: z]
    index = jax_np.asarray(x > 0, jnp.int32)
    return switch(index, branches, x)

print(jax_abs_switch(5.))

5.0


So now let's have a look and see what we really did:

In [108]:
make_jaxpr(jax_abs_switch)(-5.)

{ lambda  ; a.
  let b = gt a 0.0
      c = convert_element_type[ new_dtype=int32
                                old_dtype=bool ] b
      d = clamp 0 c 1
      e = cond[ branches=( { lambda  ; a.
                             let b = neg a
                             in (b,) }
                           { lambda  ; a.
                             let 
                             in (a,) } )
                linear=(False,) ] d a
  in (e,) }

In [109]:
make_jaxpr(compact_jax_abs)(-5.)

{ lambda  ; a.
  let b = gt a 0.0
      c = convert_element_type[ new_dtype=int32
                                old_dtype=bool ] b
      d = cond[ branches=( { lambda  ; a.
                             let b = neg a
                             in (b,) }
                           { lambda  ; a.
                             let 
                             in (a,) } )
                linear=(False,) ] c a
  in (d,) }

Well we did nothing! So `cond` is just a special case of `switch`, and in particular stacking `cond` in the hope of a better performance would be a very bad idea.

### Questions:

#### Q1: 
Implement the generelized `piecewise_constant` using `switch`, assuming the data is sorted.

#### Q2:
If you are already familiar with the loops primitives, implement the generelized `piecewise_constant` using `loops` and `cond` and compare the generated code.

#### Q3:
How would you vectorize this function? Compare the naive vmap with the `jax_np.select` approach.