<center>
    <h1>JaxTon</h1>
    <i>💯 JAX exercises</i>
    <br>
    <br>
    <a href='https://github.com/vopani/jaxton/blob/master/LICENSE'>
        <img src='https://img.shields.io/badge/license-Apache%202.0-blue.svg?logo=apache'>
    </a>
    <a href='https://github.com/vopani/jaxton'>
        <img src='https://img.shields.io/github/stars/vopani/jaxton?color=yellowgreen&logo=github'>
    </a>
    <a href='https://twitter.com/vopani'>
        <img src='https://img.shields.io/twitter/follow/vopani'>
    </a>
</center>

<center>
    All exercises of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find individual sets of the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>

**Prerequisites**

* The configuration of jax should be updated as shown in the code snippet below in order to use TPUs.
* A sample array `sample_data` will be used for some of the exercises.
* Sample functions `square`, `cube` and `areas` will be used for some of the exercises.

In [None]:
## install jax
!python3 -m pip install jax

In [None]:
## import packages
import jax
import jax.numpy as jnp
import math
import os
import requests

## setup JAX to use TPUs if available
try:
    url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
    resp = requests.post(url)
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
except:
    pass

jax.devices()

In [None]:
## sample data
sample_data = jnp.array([10, 1, 24, 20, 15, 14])
sample_data

In [None]:
## sample square function
def square(x):
    return x**2

square(2)

In [None]:
## sample cube function
def cube(x):
    return x**3

cube(2)

In [None]:
## sample areas function
def areas(x):
    return [math.sqrt(3)*x**2/4, x**2, math.pi*x**2]

areas(2)

**Exercise 1: Install the `jax` package**

**Exercise 2: Import the `jax` package**

**Exercise 3: Display the version of `jax`**

**Exercise 4: Display the default backend of `jax`**

**Exercise 5: Display the devices of the backend**

**Exercise 6: Create a JAX DeviceArray with values [10, 1, 24] and assign it to `data`**

**Exercise 7: Display the type of `data`**

**Exercise 8: Display the shape of `data`**

**Exercise 9: Transfer `data` to host and assign it to `data_host`**

**Exercise 10: Transfer `data_host` to device and assign it to `data_device`**

**Exercise 11: Create a matrix with values [[10, 1, 24], [20, 15, 14]] and assign it to `data`**

**Exercise 12: Assign the transpose of `data` to `dataT`**

**Exercise 13: Assign the element of `data` at index [0, 2] to `value`**

**Exercise 14: Update the value of `data` at index [1, 1] to `100`**

**Exercise 15: Add `41` to the value of `data` at index [0, 0]**

**Exercise 16: Calculate the minimum values over axis=1 and assign it to `mins`**

**Exercise 17: Select the first row of values of `data` and assign it to `data_select`**

**Exercise 18: Append the row `data_select` to `data`**

**Exercise 19: Multiply the matrices `data` and `dataT` and assign it to `data_prod`**

**Exercise 20: Convert the dtype of `data_prod` to `float32`**

**Exercise 21: Create a pseudorandom number generator key with seed=100 and assign it to `key`**

**Exercise 22: Create a subkey from `key` and assign it to `subkey`**

**Exercise 23: Split `key` into seven subkeys `key_1`, `key_2`, `key_3`, `key_4`, `key_5`, `key_6` and `key_7`**

**Exercise 24: Create a random permutation of `sample_data` using `key_1` and assign it to `data_permutation`**

**Exercise 25: Choose a random element from `sample_data` using `key_2` and assign it to `random_selection`**

**Exercise 26: Sample an integer between 10 and 24 using `key_3` and assign it to `sample_int`**

**Exercise 27: Sample two values from uniform distribution between 1 and 2 using `key_4` and assign it to `sample_uniform`**

**Exercise 28: Sample three values from bernoulli distribution using `key_5` and assign it to `sample_bernoulli`**

**Exercise 29: Sample a 2x3 matrix from poisson distribution with λ=100 using `key_6` and assign it to `sample_poisson`**

**Exercise 30: Sample a 2x3x4 array from normal distribution using `key_7` and assign it to `sample_normal`**

**Exercise 31: JIT-compile the `cube` function and assign it to `cube_jit`**

**Exercise 32: Display the execution time of `cube_jit` for first run (with overhead) with input=10.24**

**Exercise 33: Display execution time of `cube_jit` for second run (without overhead) with input=10.24**

**Exercise 34: Run `cube_jit` with input=10.24 and assign it to `cube_value`**

**Exercise 35: Run `cube_jit` with jit disabled and input=10.24 and assign it to `cube_value_nojit`**

**Exercise 36: Evaluate the shape of `cube_jit` with input=10.24 and assign it to `cube_shape`**

**Exercise 37: Create the jaxpr of `cube_jit` with input=10.24 and assign it to `cube_jaxpr`**

**Exercise 38: Assign the XLA computation of `cube_jit` with input=10.24 to `cube_xla` and print it's XLA HLO text**

**Exercise 39: Use the name `jaxton_cube_fn` internally for the `cube_jit` function and assign the named function to `cube_named_jit`**

**Exercise 40: Assign the XLA computation of `cube_named_jit` with input=10.24 to `cube_named_xla` and print it's XLA HLO text**

**Exercise 41: Calculate the cumulative sum of `sample_data` using the associative scan operator and assign it to `data_cumsum`**

**Exercise 42: Calculate the cumulative sum of `sample_data` in reverse order using the associative scan operator and assign it to `data_cumsum_reverse`**

**Exercise 43: Create a JIT-compiled lambda function that outputs `square` of input if it is even and `cube` of input if it is odd using the cond operator and assign it to `parity_ifelse`**

**Exercise 44: Run `parity_ifelse` with the first element of `data_cumsum` and assign it to `parity_1`**

**Exercise 45: Run `parity_ifelse` with the second element of `data_cumsum` and assign it to `parity_2`**

**Exercise 46: Create a JIT-compiled lambda function that outputs `square` of input if it is even and `cube` of input if it is odd using the switch operator and assign it to `parity_switch`**

**Exercise 47: Run `parity_switch` with the fourth element of `data_cumsum` and assign it to `parity_4`**

**Exercise 48: Run `parity_switch` with the fifth element of `data_cumsum` and assign it to `parity_5`**

**Exercise 49: Calculate the sum of the first four elements of `data_cumsum` using the for operator and assign it to `sum_four`**

**Exercise 50: Keep subtracting 25 from `sum_four` until the result is negative using the while operator and assign it to `subtract_until_negative`**

**Exercise 51: JIT-compile the derivative of `cube` and assign it to `derivative_cube`**

**Exercise 52: Run `derivative_cube` with value=7**

**Exercise 53: JIT-compile the value and derivative of `cube` together, assign it to `value_and_derivative_cube` and run it with value=7**

**Exercise 54: JIT-compile the second order derivative of `cube`, assign it to `derivative_cube_2` and run it with value=7**

**Exercise 55: JIT-compile the hessian of `cube`, assign it to `hessian_cube` and run it with value=7**

**Exercise 56: JIT-compile `areas`, assign it to `jit_areas` and run it with value=9**

**Exercise 57: Compute the Jacobian of `areas` using forward-mode automatic differentiation, assign it to `jacfwd_areas` and run it with value=9**

**Exercise 58: Compute the Jacobian of `areas` using reverse-mode automatic differentiation, assign it to `jacrev_areas` and run it with value=9**

**Exercise 59: Compute the Jacobian-vector product of `cube` at value=7 with vector=9 and assign it to `jvp_cube`**

**Exercise 60: Compute the linear approximation of `areas` with value=5, assign it to `areas_linear` and run it with value=9**

<center>
    All exercises of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find individual sets of the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>