# IF-ELSE

In this notebook we will consider variants of a fairly simple piecewise constant function:

```python
def piecewise_constant(x, a, b, c):
    if x < a:
        return b
    else:
        return c
```

and we will try and understand how we can implement them using JAX primitives.

## JAX imports

## Beginner
### Prerequisites
No prerequisite

### Imports

In [13]:
from jax import jit
from jax.lax import cond
import jax.numpy as jnp

### Example

We will first give an example of how to compute the absolute value using JAX:

In [14]:
def my_abs(x):
    return x if x > 0 else -x

Would it work?

In [15]:
my_abs(jnp.asarray(-5.0))

Array(5., dtype=float32, weak_type=True)

Ah yes it would! So why the fuss?

In [16]:
jit(my_abs)(-5.0)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function my_abs at /tmp/ipykernel_206241/326190723.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

Actually no it doesn't. And this is really the fundamental trick of JAX: when you write python code you are not really computing anything, you are building a computation graph to which the data will only later be passed. This means that the predicate `x > 0` is not a boolean in JAX world, but instead an instruction to compute a boolean, so that `if x > 0` simply means nothing for it.  
So why was it working before the JIT compilation? This is because outside the JIT context the value `x > 0` can be cast to a boolean because it carries over its concrete value.

So instead of using the `if ... else` syntax, we need to use JAX cond primitive:

This can be done in full details using the following code:

In [17]:
def jax_abs(x):
    predicate = x > 0

    def true_fun(z):
        return z

    def false_fun(z):
        return -z

    result = cond(
        predicate,  # predicate for the if
        true_fun,  # function to call on operand if predicate is true
        false_fun,  # function to call on operand otherwise
        operand=x,  # operand to be passed to either true_fun or false_fun
    )

    return result


print(jax_abs(-5.0))

5.0


People familiar with C-like languages might have realised that cond acts quite similarly to the trinary operation `a?b:c;`

This can be written more compactly as 

In [18]:
def compact_jax_abs(x):
    return cond(x > 0, lambda z: z, lambda z: -z, operand=x)


print(compact_jax_abs(-5.0))

5.0


### Questions:

#### Q1: 
Consider the following implementation:

In [19]:
def other_jax_abs(x):
    return cond(x > 0, lambda z: z[0], lambda z: z[1], operand=(x, -x))


print(other_jax_abs(-5.0))

5.0


What is the difference with the `compact_jax_abs` implementation? Which one is better? Why?

#### Q2:
Implement the `piecewise_constant` function using cond only.

In [20]:
def piecewise_constant(x, a, b, c):
    pass


assert piecewise_constant(0.1, 0.0, 1.0, 2.0) == 2.0

AssertionError: 

## Intermediate
### Prerequisites
- Beginner if-else  
- Beginner vectorisation
- Numpy  

### Imports

In [22]:
from jax import vmap
from jax.lax import cond

import jax.numpy as jnp

import numpy as np

Now that we know how to compute if-else predicates for scalar inputs, how do we extend this to tensor inputs?

### Example

Fundamentally, JAX already implements the `abs` function so you shouldn't have to care about this, but how would we replicate the result using only their high-level primitives?
```python
def my_abs(x):
    return jnp.abs(x)
```
We will use the following array as an input:

In [23]:
arr = np.array([-1.0, 0.0, 1.0])

Let's try to use the code we used for the Beginner level and see what happens:

In [24]:
def jax_abs(x):
    predicate = x > 0

    def true_fun(z):
        return z

    def false_fun(z):
        return -z

    result = cond(
        predicate,  # predicate for the if
        true_fun,  # function to call on operand if predicate is true
        false_fun,  # function to call on operand otherwise
        operand=x,  # operand to be passed to either true_fun or false_fun
    )

    return result


print(jax_abs(arr))

TypeError: Pred must be a scalar, got [False False  True] of shape (3,).

So we simply can't use the same trick. We are left with two choices: we can either use numpy API (just not np.abs for the sake of the exercise though) or try to be smart.


In [25]:
def jax_numpy_abs(x):
    return jnp.where(x > 0, x, -x)


print(jax_numpy_abs(arr))

[ 1. -0.  1.]


### Let's try to be smart
Let's try to be smart and use the primitives we learned about in the vectorisation notebook:

In [26]:
vmap_abs = vmap(jax_abs)
print(vmap_abs(arr))

[ 1. -0.  1.]


In [27]:
vectorized_abs = jnp.vectorize(jax_abs, signature="()->()")
print(vectorized_abs(arr))

[ 1. -0.  1.]


### Questions:

#### Q1: 
Compare the three different implementations (you can use the utility `make_jaxpr` to see the code generated), what do you think is happening in the background?

#### Q2:
Implement the vectorised `piecewise_constant` function using the three different techniques.

## Advanced
### Prerequisites
- Intermediate if-else  

### Imports

In [28]:
from jax import make_jaxpr
from jax.lax import cond, switch
import jax.numpy as jnp

import numpy as np

Now that we know how to compute if-else predicates for vectorized inputs, what happens when we have more than one condition?

The goal here is to implement an extension of the piecewise constant function:

In [29]:
def piecewise_constant(x, xs, ys):
    # len(xs) = len(ys) - 2
    # xs are considered to be sorted
    if x < xs[0]:
        return ys[0]
    for xi, yi in zip(xs, ys[:-1]):
        if x >= xi:
            continue
        return yi
    return ys[-1]

For which we will consider the following inputs

In [30]:
arr_xs = np.array([-1.0, 0.0, 1.0])
arr_ys = np.array([0.2, 0.4, 0.0, 0.1])

arr_x = np.array([0.5, -2.0, 3.0, -1.5, -0.4])

### Example

First thing first let's look at the switch function:

In [31]:
help(switch)

Help on function switch in module jax._src.lax.control_flow.conditionals:

switch(index, branches: 'Sequence[Callable]', *operands, operand=<object object at 0x7006827c5800>)
    Apply exactly one of the ``branches`` given by ``index``.

    If ``index`` is out of bounds, it is clamped to within bounds.

    Has the semantics of the following Python::

      def switch(index, branches, *operands):
        index = clamp(0, index, len(branches) - 1)
        return branches[index](*operands)

    Internally this wraps XLA's `Conditional
    <https://www.tensorflow.org/xla/operation_semantics#conditional>`_
    operator. However, when transformed with :func:`~jax.vmap` to operate over a
    batch of predicates, ``cond`` is converted to :func:`~jax.lax.select`.

    Args:
      index: Integer scalar type, indicating which branch function to apply.
      branches: Sequence of functions (A -> B) to be applied based on ``index``.
      operands: Operands (A) input to whichever branch is applie

So we can rewrite our absolute function in terms of the switch function:

In [32]:
def jax_abs_switch(x):
    branches = [lambda z: -z, lambda z: z]
    index = jnp.asarray(x > 0, jnp.int32)
    return switch(index, branches, x)


print(jax_abs_switch(5.0))

5.0


So now let's have a look and see what we really did:

In [33]:
make_jaxpr(jax_abs_switch)(-5.0)

{ lambda ; a:f32[]. let
    b:bool[] = gt a 0.0
    c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    d:i32[] = clamp 0 c 1
    e:f32[] = cond[
      branches=(
        { lambda ; f:f32[]. let g:f32[] = neg f in (g,) }
        { lambda ; h:f32[]. let  in (h,) }
      )
    ] d a
  in (e,) }

In [34]:
make_jaxpr(compact_jax_abs)(-5.0)

{ lambda ; a:f32[]. let
    b:bool[] = gt a 0.0
    c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    d:f32[] = cond[
      branches=(
        { lambda ; e:f32[]. let f:f32[] = neg e in (f,) }
        { lambda ; g:f32[]. let  in (g,) }
      )
    ] c a
  in (d,) }

Well we did nothing! So `cond` is just a special case of `switch`, and in particular stacking `cond` in the hope of a better performance would be a very bad idea.

### Questions:

#### Q1: 
Implement the generelized `piecewise_constant` using `switch`, assuming the data is sorted.

#### Q2:
If you are already familiar with the loops primitives, implement the generelized `piecewise_constant` using `loops` and `cond` and compare the generated code.

#### Q3:
How would you vectorize this function? Compare the naive vmap with the `jnp.select` approach.