<center>
    <h1>JaxTon</h1>
    <i>💯 JAX exercises</i>
    <br>
    <br>
    <a href='https://github.com/vopani/jaxton/blob/master/LICENSE'>
        <img src='https://img.shields.io/badge/license-Apache%202.0-blue.svg?logo=apache'>
    </a>
    <a href='https://github.com/vopani/jaxton'>
        <img src='https://img.shields.io/github/stars/vopani/jaxton?color=yellowgreen&logo=github'>
    </a>
    <a href='https://twitter.com/vopani'>
        <img src='https://img.shields.io/twitter/follow/vopani'>
    </a>
</center>

<center>
    All solutions of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find individual sets of the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>

**Prerequisites**

* The configuration of jax should be updated as shown in the code snippet below in order to use TPUs.
* A sample array `sample_data` will be used for some of the exercises.
* A sample function `cube` will be used for some of the exercises.

In [1]:
## install jax
!python3 -m pip install jax



In [2]:
## import packages
import jax
import jax.numpy as jnp
import os
import requests

## setup JAX to use TPUs if available
try:
    url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
    resp = requests.post(url)
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
except:
    pass

jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [3]:
## sample data
sample_data = jnp.array([10, 1, 24, 20, 15, 14])
sample_data

DeviceArray([10,  1, 24, 20, 15, 14], dtype=int32)

In [4]:
## sample function
def cube(x):
    return x**3

cube(2.1)

9.261000000000001

**Exercise 1: Install the `jax` package**

In [5]:
!python3 -m pip install jax



**Exercise 2: Import the `jax` package**

In [6]:
import jax

**Exercise 3: Display the version of `jax`**

In [7]:
jax.__version__

'0.2.19'

**Exercise 4: Display the default backend of `jax`**

In [8]:
jax.default_backend()

'tpu'

**Exercise 5: Display the devices of the backend**

In [9]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

**Exercise 6: Create a JAX DeviceArray with values [10, 1, 24] and assign it to `data`**

In [10]:
data = jax.numpy.array([10, 1, 24])
data

DeviceArray([10,  1, 24], dtype=int32)

**Exercise 7: Display the type of `data`**

In [11]:
type(data)

jax.interpreters.xla._DeviceArray

**Exercise 8: Display the shape of `data`**

In [12]:
data.shape

(3,)

**Exercise 9: Transfer `data` to host and assign it to `data_host`**

In [13]:
data_host = jax.device_get(data)
data_host

array([10,  1, 24], dtype=int32)

**Exercise 10: Transfer `data_host` to device and assign it to `data_device`**

In [14]:
data_device = jax.device_put(data_host)
data_device

DeviceArray([10,  1, 24], dtype=int32)

**Exercise 11: Create a matrix with values [[10, 1, 24], [20, 15, 14]] and assign it to `data`**

In [15]:
data = jnp.array([[10, 1, 24], [20, 15, 14]])
data

DeviceArray([[10,  1, 24],
             [20, 15, 14]], dtype=int32)

**Exercise 12: Assign the transpose of `data` to `dataT`**

In [16]:
dataT = data.T
dataT

DeviceArray([[10, 20],
             [ 1, 15],
             [24, 14]], dtype=int32)

**Exercise 13: Assign the element of `data` at index [0, 2] to `value`**

In [17]:
value = data[0, 2]
value

DeviceArray(24, dtype=int32)

**Exercise 14: Update the value of `data` at index [1, 1] to `100`**

In [18]:
data = data.at[1, 1].set(100)
data

DeviceArray([[ 10,   1,  24],
             [ 20, 100,  14]], dtype=int32)

**Exercise 15: Add `41` to the value of `data` at index [0, 0]**

In [19]:
data = data.at[0, 0].add(41)
data

DeviceArray([[ 51,   1,  24],
             [ 20, 100,  14]], dtype=int32)

**Exercise 16: Calculate the minimum values over axis=1 and assign it to `mins`**

In [20]:
mins = data.min(axis=1)
mins

DeviceArray([ 1, 14], dtype=int32)

**Exercise 17: Select the first row of values of `data` and assign it to `data_select`**

In [21]:
data_select = data[0]
data_select

DeviceArray([51,  1, 24], dtype=int32)

**Exercise 18: Append the row `data_select` to `data`**

In [22]:
data = jnp.vstack([data, data_select])
data

DeviceArray([[ 51,   1,  24],
             [ 20, 100,  14],
             [ 51,   1,  24]], dtype=int32)

**Exercise 19: Multiply the matrices `data` and `dataT` and assign it to `data_prod`**

In [23]:
data_prod = jnp.dot(data, dataT)
data_prod

DeviceArray([[1087, 1371],
             [ 636, 2096],
             [1087, 1371]], dtype=int32)

**Exercise 20: Convert the dtype of `data_prod` to `float32`**

In [24]:
data_prod = jnp.array(data_prod, dtype=jnp.float32)
data_prod

DeviceArray([[1087., 1371.],
             [ 636., 2096.],
             [1087., 1371.]], dtype=float32)

**Exercise 21: Create a pseudorandom number generator key with seed=100 and assign it to `key`**

In [25]:
key = jax.random.PRNGKey(100)
key

DeviceArray([  0, 100], dtype=uint32)

**Exercise 22: Create a subkey from `key` and assign it to `subkey`**

In [26]:
key, subkey = jax.random.split(key)
subkey

array([3011861781, 1867493174], dtype=uint32)

**Exercise 23: Split `key` into seven subkeys `key_1`, `key_2`, `key_3`, `key_4`, `key_5`, `key_6` and `key_7`**

In [27]:
key_1, key_2, key_3, key_4, key_5, key_6, key_7 = jax.random.split(key, num=7)
key_1

array([ 402730500, 1595431526], dtype=uint32)

**Exercise 24: Create a random permutation of `sample_data` using `key_1` and assign it to `data_permutation`**

In [28]:
data_permutation = jax.random.permutation(key_1, sample_data)
data_permutation

DeviceArray([20, 14,  1, 10, 24, 15], dtype=int32)

**Exercise 25: Choose a random element from `sample_data` using `key_2` and assign it to `random_selection`**

In [29]:
random_selection = jax.random.choice(key_2, sample_data)
random_selection

DeviceArray(1, dtype=int32)

**Exercise 26: Sample an integer between 10 and 24 using `key_3` and assign it to `sample_int`**

In [30]:
sample_int = jax.random.randint(key_3, shape=(1,), minval=10, maxval=24)
sample_int

DeviceArray([14], dtype=int32)

**Exercise 27: Sample two values from uniform distribution between 1 and 2 using `key_4` and assign it to `sample_uniform`**

In [31]:
sample_uniform = jax.random.uniform(key_4, shape=(2,), minval=1, maxval=2)
sample_uniform

DeviceArray([1.6274643, 1.1133162], dtype=float32)

**Exercise 28: Sample three values from bernoulli distribution using `key_5` and assign it to `sample_bernoulli`**

In [32]:
sample_bernoulli = jax.random.bernoulli(key_5, shape=(3,))
sample_bernoulli

DeviceArray([False,  True,  True], dtype=bool)

**Exercise 29: Sample a 2x3 matrix from poisson distribution with λ=100 using `key_6` and assign it to `sample_poisson`**

In [33]:
sample_poisson = jax.random.poisson(key_6, shape=(2, 3), lam=100)
sample_poisson

DeviceArray([[ 88,  82, 110],
             [ 89,  85,  98]], dtype=int32)

**Exercise 30: Sample a 2x3x4 array from normal distribution using `key_7` and assign it to `sample_normal`**

In [34]:
sample_normal = jax.random.normal(key_7, shape=(2, 3, 4))
sample_normal

DeviceArray([[[ 0.25418088,  1.1962731 ,  1.3234351 ,  0.79711384],
              [-1.8524722 , -0.28634202,  0.2251514 , -0.6195333 ],
              [ 2.4013765 ,  0.07618266,  1.2277839 , -0.7562425 ]],

             [[-0.45340484,  1.1029627 , -0.39860612, -1.1235143 ],
              [-1.5689532 ,  0.4617323 , -0.5607138 , -1.7508575 ],
              [ 0.50200105, -1.4972546 , -1.6995528 ,  0.5555226 ]]],            dtype=float32)

**Exercise 31: JIT-compile the `cube` function and assign it to `cube_jit`**

In [35]:
cube_jit = jax.jit(cube)
cube_jit

<CompiledFunction at 0x7ff0b03d6500>

**Exercise 32: Display the execution time of `cube_jit` for first run (with overhead) with input=10.24**

In [36]:
%%time
cube_jit(10.24)

CPU times: user 15.1 ms, sys: 8.41 ms, total: 23.5 ms
Wall time: 19.8 ms


DeviceArray(1073.7418, dtype=float32)

**Exercise 33: Display execution time of `cube_jit` for second run (without overhead) with input=10.24**

In [37]:
%%time
cube_jit(10.24)

CPU times: user 4.39 ms, sys: 747 µs, total: 5.14 ms
Wall time: 2.48 ms


DeviceArray(1073.7418, dtype=float32)

**Exercise 34: Run `cube_jit` with input=10.24 and assign it to `cube_value`**

In [38]:
cube_value = cube_jit(10.24)
cube_value

DeviceArray(1073.7418, dtype=float32)

**Exercise 35: Run `cube_jit` with jit disabled and input=10.24 and assign it to `cube_value_nojit`**

In [39]:
with jax.disable_jit():
    cube_value_nojit = cube_jit(10.24)

cube_value_nojit

1073.7418240000002

**Exercise 36: Evaluate the shape of `cube_jit` with input=10.24 and assign it to `cube_shape`**

In [40]:
cube_shape = jax.eval_shape(cube_jit, 10.24)
cube_shape

ShapeDtypeStruct(shape=(), dtype=float32)

**Exercise 37: Create the jaxpr of `cube_jit` with input=10.24 and assign it to `cube_jaxpr`**

In [41]:
cube_jaxpr = jax.make_jaxpr(cube_jit)(10.24)
cube_jaxpr

{ lambda  ; a.
  let b = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a.
                                 let b = integer_pow[ y=3 ] a
                                 in (b,) }
                    device=None
                    donated_invars=(False,)
                    inline=False
                    name=cube ] a
  in (b,) }

**Exercise 38: Assign the XLA computation of `cube_jit` with input=10.24 to `cube_xla` and print it's XLA HLO text**

In [42]:
cube_xla = jax.xla_computation(cube_jit)(10.24)
print(cube_xla.as_hlo_text())

HloModule xla_computation_cube.12

jit_cube__1.3 {
  constant.5 = pred[] constant(false)
  parameter.4 = f32[] parameter(0)
  multiply.6 = f32[] multiply(parameter.4, parameter.4)
  multiply.7 = f32[] multiply(parameter.4, multiply.6)
  ROOT tuple.8 = (f32[]) tuple(multiply.7)
}

ENTRY xla_computation_cube.12 {
  constant.2 = pred[] constant(false)
  parameter.1 = f32[] parameter(0)
  call.9 = (f32[]) call(parameter.1), to_apply=jit_cube__1.3
  get-tuple-element.10 = f32[] get-tuple-element(call.9), index=0
  ROOT tuple.11 = (f32[]) tuple(get-tuple-element.10)
}




**Exercise 39: Use the name `jaxton_cube_fn` internally for the `cube_jit` function and assign the named function to `cube_named_jit`**

In [43]:
cube_named_jit = jax.named_call(cube_jit, name='jaxton_cube_fn')
cube_named_jit

<function __main__.cube(x)>

**Exercise 40: Assign the XLA computation of `cube_named_jit` with input=10.24 to `cube_named_xla` and print it's XLA HLO text**

In [44]:
cube_named_xla = jax.xla_computation(cube_named_jit)(10.24)
print(cube_named_xla.as_hlo_text())

HloModule xla_computation_cube__1.18

jit_cube__2.3 {
  constant.5 = pred[] constant(false)
  parameter.4 = f32[] parameter(0)
  multiply.6 = f32[] multiply(parameter.4, parameter.4)
  multiply.7 = f32[] multiply(parameter.4, multiply.6)
  ROOT tuple.8 = (f32[]) tuple(multiply.7)
}

jaxton_cube_fn.9 {
  constant.11 = pred[] constant(false)
  parameter.10 = f32[] parameter(0)
  call.12 = (f32[]) call(parameter.10), to_apply=jit_cube__2.3
  get-tuple-element.13 = f32[] get-tuple-element(call.12), index=0
  ROOT tuple.14 = (f32[]) tuple(get-tuple-element.13)
}

ENTRY xla_computation_cube__1.18 {
  constant.2 = pred[] constant(false)
  parameter.1 = f32[] parameter(0)
  call.15 = (f32[]) call(parameter.1), to_apply=jaxton_cube_fn.9
  get-tuple-element.16 = f32[] get-tuple-element(call.15), index=0
  ROOT tuple.17 = (f32[]) tuple(get-tuple-element.16)
}




<center>
    All solutions of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find individual sets of the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>