<center>
    <h1>JaxTon</h1>
    <i>💯 JAX exercises</i>
    <br>
    <br>
    <a href='https://github.com/vopani/jaxton/blob/master/LICENSE'>
        <img src='https://img.shields.io/badge/license-Apache%202.0-blue.svg?logo=apache'>
    </a>
    <a href='https://github.com/vopani/jaxton'>
        <img src='https://img.shields.io/github/stars/vopani/jaxton?color=yellowgreen&logo=github'>
    </a>
    <a href='https://twitter.com/vopani'>
        <img src='https://img.shields.io/twitter/follow/vopani'>
    </a>
</center>

<center>
    This is Set 5: Control Flows (Solutions 41-50) of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find all the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>

**Prerequisites**

* The configuration of jax should be set as shown in the code snippet below in order to use TPUs.
* A sample array `sample_data` will be used for some of the exercises.
* Sample functions `square` and `cube` will be used for the exercises.

In [1]:
!python3 -m pip install jax



In [2]:
import jax
import jax.numpy as jnp
import os
import requests

try:
    url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
    resp = requests.post(url)
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
except:
    pass

jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [3]:
## sample data
sample_data = jnp.array([10, 1, 24, 20, 15, 14])
sample_data

DeviceArray([10,  1, 24, 20, 15, 14], dtype=int32)

In [4]:
## sample square function
def square(x):
    return x**2

square(2)

4

In [5]:
## sample cube function
def cube(x):
    return x**3

cube(2)

8

**Exercise 41: Calculate the cumulative sum of `sample_data` using the associative scan operator and assign it to `data_cumsum`**

In [6]:
data_cumsum = jax.lax.associative_scan(jnp.add, sample_data)
data_cumsum

DeviceArray([10, 11, 35, 55, 70, 84], dtype=int32)

**Exercise 42: Calculate the cumulative sum of `sample_data` in reverse order using the associative scan operator and assign it to `data_cumsum_reverse`**

In [7]:
data_cumsum_reverse = jax.lax.associative_scan(jnp.add, sample_data, reverse=True)
data_cumsum_reverse

DeviceArray([84, 74, 73, 49, 29, 14], dtype=int32)

**Exercise 43: Create a JIT-compiled lambda function that outputs `square` of input if it is even and `cube` of input if it is odd using the cond operator and assign it to `parity_ifelse`**

In [8]:
parity_ifelse = jax.jit(lambda x: jax.lax.cond(jnp.remainder(x, 2) == 0, square, cube, x))
parity_ifelse

<CompiledFunction at 0x7f8700fba500>

**Exercise 44: Run `parity_ifelse` with the first element of `data_cumsum` and assign it to `parity_1`**

In [9]:
parity_1 = parity_ifelse(data_cumsum[0])
parity_1

DeviceArray(100, dtype=int32)

**Exercise 45: Run `parity_ifelse` with the second element of `data_cumsum` and assign it to `parity_2`**

In [10]:
parity_2 = parity_ifelse(data_cumsum[1])
parity_2

DeviceArray(1331, dtype=int32)

**Exercise 46: Create a JIT-compiled lambda function that outputs `square` of input if it is even and `cube` of input if it is odd using the switch operator and assign it to `parity_switch`**

In [11]:
parity_switch = jax.jit(lambda x: jax.lax.switch(jnp.remainder(x, 2), [square, cube], x))
parity_switch

<CompiledFunction at 0x7f86f05e3230>

**Exercise 47: Run `parity_switch` with the fourth element of `data_cumsum` and assign it to `parity_4`**

In [12]:
parity_4 = parity_switch(data_cumsum[3])
parity_4

DeviceArray(166375, dtype=int32)

**Exercise 48: Run `parity_switch` with the fifth element of `data_cumsum` and assign it to `parity_5`**

In [13]:
parity_5 = parity_switch(data_cumsum[4])
parity_5

DeviceArray(4900, dtype=int32)

**Exercise 49: Calculate the sum of the first four elements of `data_cumsum` using the for operator and assign it to `sum_four`**

In [14]:
sum_four = jax.lax.fori_loop(0, 4, lambda i, x: x+data_cumsum[i], 0)
sum_four

DeviceArray(111, dtype=int32)

**Exercise 50: Keep subtracting 25 from `sum_four` until the result is negative using the while operator and assign it to `subtract_until_negative`**

In [15]:
subtract_until_negative = jax.lax.while_loop(lambda x: x>0, lambda x: x-25, sum_four)
subtract_until_negative

DeviceArray(-14, dtype=int32)

<center>
    This completes Set 5: Control Flows (Solutions 41-50) of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find all the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>