<center>
    <h1>JaxTon</h1>
    <i>💯 JAX exercises</i>
    <br>
    <br>
    <a href='https://github.com/vopani/jaxton/blob/master/LICENSE'>
        <img src='https://img.shields.io/badge/license-Apache%202.0-blue.svg?logo=apache'>
    </a>
    <a href='https://github.com/vopani/jaxton'>
        <img src='https://img.shields.io/github/stars/vopani/jaxton?color=yellowgreen&logo=github'>
    </a>
    <a href='https://twitter.com/vopani'>
        <img src='https://img.shields.io/twitter/follow/vopani'>
    </a>
</center>

<center>
    All solutions of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find individual sets of the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>

**Prerequisites**

* The configuration of jax should be updated as shown in the code snippet below in order to use TPUs.

In [1]:
## install jax
!python3 -m pip install jax



In [2]:
## import packages
import jax
import jax.numpy as jnp
import os
import requests

## setup JAX to use TPUs if available
try:
    url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
    resp = requests.post(url)
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
except:
    pass

jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

**Exercise 1: Install the `jax` package**

In [3]:
!python3 -m pip install jax



**Exercise 2: Import the `jax` package**

In [4]:
import jax

**Exercise 3: Display the version of `jax`**

In [5]:
jax.__version__

'0.2.19'

**Exercise 4: Display the default backend of `jax`**

In [6]:
jax.default_backend()

'tpu'

**Exercise 5: Display the devices of the backend**

In [7]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

**Exercise 6: Create a JAX DeviceArray with values [10, 1, 24] and assign it to `data`**

In [8]:
data = jax.numpy.array([10, 1, 24])
data

DeviceArray([10,  1, 24], dtype=int32)

**Exercise 7: Display the type of `data`**

In [9]:
type(data)

jax.interpreters.xla._DeviceArray

**Exercise 8: Display the shape of `data`**

In [10]:
data.shape

(3,)

**Exercise 9: Transfer `data` to host and assign it to `data_host`**

In [11]:
data_host = jax.device_get(data)
data_host

array([10,  1, 24], dtype=int32)

**Exercise 10: Transfer `data_host` to device and assign it to `data_device`**

In [12]:
data_device = jax.device_put(data_host)
data_device

DeviceArray([10,  1, 24], dtype=int32)

**Exercise 11: Create a matrix with values [[10, 1, 24], [20, 15, 14]] and assign it to `data`**

In [13]:
data = jnp.array([[10, 1, 24], [20, 15, 14]])
data

DeviceArray([[10,  1, 24],
             [20, 15, 14]], dtype=int32)

**Exercise 12: Assign the transpose of `data` to `dataT`**

In [14]:
dataT = data.T
dataT

DeviceArray([[10, 20],
             [ 1, 15],
             [24, 14]], dtype=int32)

**Exercise 13: Assign the element of `data` at index [0, 2] to `value`**

In [15]:
value = data[0, 2]
value

DeviceArray(24, dtype=int32)

**Exercise 14: Update the value of `data` at index [1, 1] to `100`**

In [16]:
data = data.at[1, 1].set(100)
data

DeviceArray([[ 10,   1,  24],
             [ 20, 100,  14]], dtype=int32)

**Exercise 15: Add `41` to the value of `data` at index [0, 0]**

In [17]:
data = data.at[0, 0].add(41)
data

DeviceArray([[ 51,   1,  24],
             [ 20, 100,  14]], dtype=int32)

**Exercise 16: Calculate the minimum values over axis=1 and assign it to `mins`**

In [18]:
mins = data.min(axis=1)
mins

DeviceArray([ 1, 14], dtype=int32)

**Exercise 17: Select the first row of values of `data` and assign it to `data_select`**

In [19]:
data_select = data[0]
data_select

DeviceArray([51,  1, 24], dtype=int32)

**Exercise 18: Append the row `data_select` to `data`**

In [20]:
data = jnp.vstack([data, data_select])
data

DeviceArray([[ 51,   1,  24],
             [ 20, 100,  14],
             [ 51,   1,  24]], dtype=int32)

**Exercise 19: Multiply the matrices `data` and `dataT` and assign it to `data_prod`**

In [21]:
data_prod = jnp.dot(data, dataT)
data_prod

DeviceArray([[1087, 1371],
             [ 636, 2096],
             [1087, 1371]], dtype=int32)

**Exercise 20: Convert the dtype of `data_prod` to `float32`**

In [22]:
data_prod = jnp.array(data_prod, dtype=jnp.float32)
data_prod

DeviceArray([[1087., 1371.],
             [ 636., 2096.],
             [1087., 1371.]], dtype=float32)

<center>
    All solutions of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find individual sets of the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>